Skip to main content

wae_websocket/
lib.rs

1//! WAE WebSocket - 实时通信抽象层
2//!
3//! 提供统一的 WebSocket 通信能力,支持服务端和客户端模式。
4//!
5//! 深度融合 tokio 运行时,所有 API 都是异步优先设计。
6//! 微服务架构友好,支持房间管理、广播、自动重连、心跳检测等特性。
7
8#![warn(missing_docs)]
9
10use async_trait::async_trait;
11use futures_util::{SinkExt, StreamExt};
12use serde::{Serialize, de::DeserializeOwned};
13use std::{collections::HashMap, fmt, net::SocketAddr, sync::Arc, time::Duration};
14use tokio::sync::{RwLock, broadcast, mpsc};
15use tokio_tungstenite::tungstenite::protocol::Message as WsMessage;
16
17/// WebSocket 错误类型
18#[derive(Debug)]
19pub enum WebSocketError {
20    /// 连接失败
21    ConnectionFailed(String),
22
23    /// 连接已关闭
24    ConnectionClosed,
25
26    /// 发送消息失败
27    SendFailed(String),
28
29    /// 接收消息失败
30    ReceiveFailed(String),
31
32    /// 序列化失败
33    SerializationFailed(String),
34
35    /// 反序列化失败
36    DeserializationFailed(String),
37
38    /// 房间不存在
39    RoomNotFound(String),
40
41    /// 连接不存在
42    ConnectionNotFound(String),
43
44    /// 连接数超限
45    MaxConnectionsExceeded(u32),
46
47    /// 操作超时
48    Timeout(String),
49
50    /// 服务内部错误
51    Internal(String),
52}
53
54impl fmt::Display for WebSocketError {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        match self {
57            WebSocketError::ConnectionFailed(msg) => write!(f, "WebSocket connection failed: {}", msg),
58            WebSocketError::ConnectionClosed => write!(f, "WebSocket connection closed"),
59            WebSocketError::SendFailed(msg) => write!(f, "Failed to send message: {}", msg),
60            WebSocketError::ReceiveFailed(msg) => write!(f, "Failed to receive message: {}", msg),
61            WebSocketError::SerializationFailed(msg) => write!(f, "Serialization failed: {}", msg),
62            WebSocketError::DeserializationFailed(msg) => write!(f, "Deserialization failed: {}", msg),
63            WebSocketError::RoomNotFound(msg) => write!(f, "Room not found: {}", msg),
64            WebSocketError::ConnectionNotFound(msg) => write!(f, "Connection not found: {}", msg),
65            WebSocketError::MaxConnectionsExceeded(max) => write!(f, "Maximum connections exceeded: {}", max),
66            WebSocketError::Timeout(msg) => write!(f, "Operation timeout: {}", msg),
67            WebSocketError::Internal(msg) => write!(f, "WebSocket internal error: {}", msg),
68        }
69    }
70}
71
72impl std::error::Error for WebSocketError {}
73
74/// WebSocket 操作结果类型
75pub type WebSocketResult<T> = Result<T, WebSocketError>;
76
77/// 连接 ID 类型
78pub type ConnectionId = String;
79
80/// 房间 ID 类型
81pub type RoomId = String;
82
83/// WebSocket 消息类型
84#[derive(Debug, Clone, PartialEq, Eq)]
85pub enum Message {
86    /// 文本消息
87    Text(String),
88    /// 二进制消息
89    Binary(Vec<u8>),
90    /// Ping 消息
91    Ping,
92    /// Pong 消息
93    Pong,
94    /// 关闭消息
95    Close,
96}
97
98impl Message {
99    /// 创建文本消息
100    pub fn text(content: impl Into<String>) -> Self {
101        Message::Text(content.into())
102    }
103
104    /// 创建二进制消息
105    pub fn binary(data: impl Into<Vec<u8>>) -> Self {
106        Message::Binary(data.into())
107    }
108
109    /// 检查是否为文本消息
110    pub fn is_text(&self) -> bool {
111        matches!(self, Message::Text(_))
112    }
113
114    /// 检查是否为二进制消息
115    pub fn is_binary(&self) -> bool {
116        matches!(self, Message::Binary(_))
117    }
118
119    /// 获取文本内容
120    pub fn as_text(&self) -> Option<&str> {
121        match self {
122            Message::Text(s) => Some(s),
123            _ => None,
124        }
125    }
126
127    /// 获取二进制内容
128    pub fn as_binary(&self) -> Option<&[u8]> {
129        match self {
130            Message::Binary(data) => Some(data),
131            _ => None,
132        }
133    }
134}
135
136impl From<WsMessage> for Message {
137    fn from(msg: WsMessage) -> Self {
138        match msg {
139            WsMessage::Text(s) => Message::Text(s.to_string()),
140            WsMessage::Binary(data) => Message::Binary(data.to_vec()),
141            WsMessage::Ping(_) => Message::Ping,
142            WsMessage::Pong(_) => Message::Pong,
143            WsMessage::Close(_) => Message::Close,
144            _ => Message::Close,
145        }
146    }
147}
148
149impl From<Message> for WsMessage {
150    fn from(msg: Message) -> Self {
151        match msg {
152            Message::Text(s) => WsMessage::Text(s.into()),
153            Message::Binary(data) => WsMessage::Binary(data.into()),
154            Message::Ping => WsMessage::Ping(Vec::new().into()),
155            Message::Pong => WsMessage::Pong(Vec::new().into()),
156            Message::Close => WsMessage::Close(None),
157        }
158    }
159}
160
161/// 连接信息
162#[derive(Debug, Clone)]
163pub struct Connection {
164    /// 连接 ID
165    pub id: ConnectionId,
166    /// 客户端地址
167    pub addr: SocketAddr,
168    /// 连接时间
169    pub connected_at: std::time::Instant,
170    /// 用户自定义数据
171    pub metadata: HashMap<String, String>,
172    /// 所属房间列表
173    pub rooms: Vec<RoomId>,
174}
175
176impl Connection {
177    /// 创建新连接
178    pub fn new(id: ConnectionId, addr: SocketAddr) -> Self {
179        Self { id, addr, connected_at: std::time::Instant::now(), metadata: HashMap::new(), rooms: Vec::new() }
180    }
181
182    /// 设置元数据
183    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
184        self.metadata.insert(key.into(), value.into());
185        self
186    }
187
188    /// 获取连接持续时间
189    pub fn duration(&self) -> Duration {
190        self.connected_at.elapsed()
191    }
192}
193
194/// 连接管理器
195pub struct ConnectionManager {
196    connections: Arc<RwLock<HashMap<ConnectionId, Connection>>>,
197    max_connections: u32,
198}
199
200impl ConnectionManager {
201    /// 创建新的连接管理器
202    pub fn new(max_connections: u32) -> Self {
203        Self { connections: Arc::new(RwLock::new(HashMap::new())), max_connections }
204    }
205
206    /// 添加连接
207    pub async fn add(&self, connection: Connection) -> WebSocketResult<()> {
208        let mut connections = self.connections.write().await;
209        if connections.len() >= self.max_connections as usize {
210            return Err(WebSocketError::MaxConnectionsExceeded(self.max_connections));
211        }
212        connections.insert(connection.id.clone(), connection);
213        Ok(())
214    }
215
216    /// 移除连接
217    pub async fn remove(&self, id: &str) -> Option<Connection> {
218        let mut connections = self.connections.write().await;
219        connections.remove(id)
220    }
221
222    /// 获取连接
223    pub async fn get(&self, id: &str) -> Option<Connection> {
224        let connections = self.connections.read().await;
225        connections.get(id).cloned()
226    }
227
228    /// 检查连接是否存在
229    pub async fn exists(&self, id: &str) -> bool {
230        let connections = self.connections.read().await;
231        connections.contains_key(id)
232    }
233
234    /// 获取连接数量
235    pub async fn count(&self) -> usize {
236        let connections = self.connections.read().await;
237        connections.len()
238    }
239
240    /// 获取所有连接 ID
241    pub async fn all_ids(&self) -> Vec<ConnectionId> {
242        let connections = self.connections.read().await;
243        connections.keys().cloned().collect()
244    }
245
246    /// 更新连接的房间列表
247    pub async fn join_room(&self, id: &str, room: &str) -> WebSocketResult<()> {
248        let mut connections = self.connections.write().await;
249        if let Some(conn) = connections.get_mut(id) {
250            if !conn.rooms.contains(&room.to_string()) {
251                conn.rooms.push(room.to_string());
252            }
253            return Ok(());
254        }
255        Err(WebSocketError::ConnectionNotFound(id.to_string()))
256    }
257
258    /// 离开房间
259    pub async fn leave_room(&self, id: &str, room: &str) -> WebSocketResult<()> {
260        let mut connections = self.connections.write().await;
261        if let Some(conn) = connections.get_mut(id) {
262            conn.rooms.retain(|r| r != room);
263            return Ok(());
264        }
265        Err(WebSocketError::ConnectionNotFound(id.to_string()))
266    }
267}
268
269/// 房间管理器
270pub struct RoomManager {
271    rooms: Arc<RwLock<HashMap<RoomId, Vec<ConnectionId>>>>,
272}
273
274impl RoomManager {
275    /// 创建新的房间管理器
276    pub fn new() -> Self {
277        Self { rooms: Arc::new(RwLock::new(HashMap::new())) }
278    }
279
280    /// 创建房间
281    pub async fn create_room(&self, room_id: &str) {
282        let mut rooms = self.rooms.write().await;
283        rooms.entry(room_id.to_string()).or_insert_with(Vec::new);
284    }
285
286    /// 删除房间
287    pub async fn delete_room(&self, room_id: &str) -> Option<Vec<ConnectionId>> {
288        let mut rooms = self.rooms.write().await;
289        rooms.remove(room_id)
290    }
291
292    /// 加入房间
293    pub async fn join(&self, room_id: &str, connection_id: &str) {
294        let mut rooms = self.rooms.write().await;
295        let room = rooms.entry(room_id.to_string()).or_insert_with(Vec::new);
296        if !room.contains(&connection_id.to_string()) {
297            room.push(connection_id.to_string());
298        }
299    }
300
301    /// 离开房间
302    pub async fn leave(&self, room_id: &str, connection_id: &str) {
303        let mut rooms = self.rooms.write().await;
304        if let Some(room) = rooms.get_mut(room_id) {
305            room.retain(|id| id != connection_id);
306            if room.is_empty() {
307                rooms.remove(room_id);
308            }
309        }
310    }
311
312    /// 获取房间内的所有连接
313    pub async fn get_members(&self, room_id: &str) -> Vec<ConnectionId> {
314        let rooms = self.rooms.read().await;
315        rooms.get(room_id).cloned().unwrap_or_default()
316    }
317
318    /// 检查房间是否存在
319    pub async fn room_exists(&self, room_id: &str) -> bool {
320        let rooms = self.rooms.read().await;
321        rooms.contains_key(room_id)
322    }
323
324    /// 获取房间数量
325    pub async fn room_count(&self) -> usize {
326        let rooms = self.rooms.read().await;
327        rooms.len()
328    }
329
330    /// 获取房间成员数量
331    pub async fn member_count(&self, room_id: &str) -> usize {
332        let rooms = self.rooms.read().await;
333        rooms.get(room_id).map(|r| r.len()).unwrap_or(0)
334    }
335
336    /// 广播消息到房间
337    pub async fn broadcast(&self, room_id: &str, sender: &Sender, message: &Message) -> WebSocketResult<Vec<ConnectionId>> {
338        let members = self.get_members(room_id).await;
339        let mut sent_to = Vec::new();
340        for conn_id in &members {
341            if sender.send_to(conn_id, message.clone()).await.is_ok() {
342                sent_to.push(conn_id.clone());
343            }
344        }
345        Ok(sent_to)
346    }
347}
348
349impl Default for RoomManager {
350    fn default() -> Self {
351        Self::new()
352    }
353}
354
355/// 消息发送器
356#[derive(Clone)]
357pub struct Sender {
358    senders: Arc<RwLock<HashMap<ConnectionId, mpsc::UnboundedSender<Message>>>>,
359}
360
361impl Sender {
362    /// 创建新的发送器
363    pub fn new() -> Self {
364        Self { senders: Arc::new(RwLock::new(HashMap::new())) }
365    }
366
367    /// 注册连接的发送通道
368    pub async fn register(&self, connection_id: ConnectionId, sender: mpsc::UnboundedSender<Message>) {
369        let mut senders = self.senders.write().await;
370        senders.insert(connection_id, sender);
371    }
372
373    /// 注销连接的发送通道
374    pub async fn unregister(&self, connection_id: &str) {
375        let mut senders = self.senders.write().await;
376        senders.remove(connection_id);
377    }
378
379    /// 发送消息到指定连接
380    pub async fn send_to(&self, connection_id: &str, message: Message) -> WebSocketResult<()> {
381        let senders = self.senders.read().await;
382        if let Some(sender) = senders.get(connection_id) {
383            sender.send(message).map_err(|e| WebSocketError::SendFailed(e.to_string()))?;
384            return Ok(());
385        }
386        Err(WebSocketError::ConnectionNotFound(connection_id.to_string()))
387    }
388
389    /// 广播消息到所有连接
390    pub async fn broadcast(&self, message: Message) -> WebSocketResult<usize> {
391        let senders = self.senders.read().await;
392        let mut count = 0;
393        for sender in senders.values() {
394            if sender.send(message.clone()).is_ok() {
395                count += 1;
396            }
397        }
398        Ok(count)
399    }
400
401    /// 获取连接数量
402    pub async fn count(&self) -> usize {
403        let senders = self.senders.read().await;
404        senders.len()
405    }
406}
407
408impl Default for Sender {
409    fn default() -> Self {
410        Self::new()
411    }
412}
413
414/// 服务端配置
415#[derive(Debug, Clone)]
416pub struct ServerConfig {
417    /// 监听地址
418    pub host: String,
419    /// 监听端口
420    pub port: u16,
421    /// 最大连接数
422    pub max_connections: u32,
423    /// 心跳间隔
424    pub heartbeat_interval: Duration,
425    /// 连接超时
426    pub connection_timeout: Duration,
427}
428
429impl Default for ServerConfig {
430    fn default() -> Self {
431        Self {
432            host: "0.0.0.0".to_string(),
433            port: 8080,
434            max_connections: 1000,
435            heartbeat_interval: Duration::from_secs(30),
436            connection_timeout: Duration::from_secs(60),
437        }
438    }
439}
440
441impl ServerConfig {
442    /// 创建新的服务端配置
443    pub fn new() -> Self {
444        Self::default()
445    }
446
447    /// 设置监听地址
448    pub fn host(mut self, host: impl Into<String>) -> Self {
449        self.host = host.into();
450        self
451    }
452
453    /// 设置监听端口
454    pub fn port(mut self, port: u16) -> Self {
455        self.port = port;
456        self
457    }
458
459    /// 设置最大连接数
460    pub fn max_connections(mut self, max: u32) -> Self {
461        self.max_connections = max;
462        self
463    }
464
465    /// 设置心跳间隔
466    pub fn heartbeat_interval(mut self, interval: Duration) -> Self {
467        self.heartbeat_interval = interval;
468        self
469    }
470}
471
472/// 客户端处理器 trait
473#[async_trait]
474pub trait ClientHandler: Send + Sync {
475    /// 连接建立时调用
476    async fn on_connect(&self, connection: &Connection) -> WebSocketResult<()>;
477
478    /// 收到消息时调用
479    async fn on_message(&self, connection: &Connection, message: Message) -> WebSocketResult<()>;
480
481    /// 连接关闭时调用
482    async fn on_disconnect(&self, connection: &Connection);
483}
484
485/// 默认客户端处理器
486pub struct DefaultClientHandler;
487
488#[async_trait]
489impl ClientHandler for DefaultClientHandler {
490    async fn on_connect(&self, _connection: &Connection) -> WebSocketResult<()> {
491        Ok(())
492    }
493
494    async fn on_message(&self, _connection: &Connection, _message: Message) -> WebSocketResult<()> {
495        Ok(())
496    }
497
498    async fn on_disconnect(&self, _connection: &Connection) {}
499}
500
501/// WebSocket 服务端
502pub struct WebSocketServer {
503    config: ServerConfig,
504    connection_manager: Arc<ConnectionManager>,
505    room_manager: Arc<RoomManager>,
506    sender: Sender,
507    shutdown_tx: broadcast::Sender<()>,
508}
509
510impl WebSocketServer {
511    /// 创建新的 WebSocket 服务端
512    pub fn new(config: ServerConfig) -> Self {
513        let (shutdown_tx, _) = broadcast::channel(1);
514        Self {
515            config,
516            connection_manager: Arc::new(ConnectionManager::new(1000)),
517            room_manager: Arc::new(RoomManager::new()),
518            sender: Sender::new(),
519            shutdown_tx,
520        }
521    }
522
523    /// 获取连接管理器
524    pub fn connection_manager(&self) -> &Arc<ConnectionManager> {
525        &self.connection_manager
526    }
527
528    /// 获取房间管理器
529    pub fn room_manager(&self) -> &Arc<RoomManager> {
530        &self.room_manager
531    }
532
533    /// 获取消息发送器
534    pub fn sender(&self) -> &Sender {
535        &self.sender
536    }
537
538    /// 获取配置
539    pub fn config(&self) -> &ServerConfig {
540        &self.config
541    }
542
543    /// 启动服务端
544    pub async fn start<H: ClientHandler + 'static>(&self, handler: H) -> WebSocketResult<()> {
545        let addr = format!("{}:{}", self.config.host, self.config.port);
546        let listener =
547            tokio::net::TcpListener::bind(&addr).await.map_err(|e| WebSocketError::ConnectionFailed(e.to_string()))?;
548
549        tracing::info!("WebSocket server listening on {}", addr);
550
551        let mut shutdown_rx = self.shutdown_tx.subscribe();
552        let handler = Arc::new(handler);
553
554        loop {
555            tokio::select! {
556                accept_result = listener.accept() => {
557                    match accept_result {
558                        Ok((stream, addr)) => {
559                            let connection_manager = self.connection_manager.clone();
560                            let room_manager = self.room_manager.clone();
561                            let sender = self.sender.clone();
562                            let handler = handler.clone();
563                            let config = self.config.clone();
564
565                            tokio::spawn(async move {
566                                if let Err(e) = Self::handle_connection(
567                                    stream,
568                                    addr,
569                                    connection_manager,
570                                    room_manager,
571                                    sender,
572                                    handler,
573                                    config,
574                                ).await {
575                                    tracing::error!("Connection error: {}", e);
576                                }
577                            });
578                        }
579                        Err(e) => {
580                            tracing::error!("Accept error: {}", e);
581                        }
582                    }
583                }
584                _ = shutdown_rx.recv() => {
585                    tracing::info!("WebSocket server shutting down");
586                    break;
587                }
588            }
589        }
590
591        Ok(())
592    }
593
594    async fn handle_connection<H: ClientHandler>(
595        stream: tokio::net::TcpStream,
596        addr: SocketAddr,
597        connection_manager: Arc<ConnectionManager>,
598        room_manager: Arc<RoomManager>,
599        sender: Sender,
600        handler: Arc<H>,
601        config: ServerConfig,
602    ) -> WebSocketResult<()> {
603        let ws_stream =
604            tokio_tungstenite::accept_async(stream).await.map_err(|e| WebSocketError::ConnectionFailed(e.to_string()))?;
605
606        let connection_id = uuid::Uuid::new_v4().to_string();
607        let connection = Connection::new(connection_id.clone(), addr);
608
609        if connection_manager.add(connection.clone()).await.is_err() {
610            return Err(WebSocketError::MaxConnectionsExceeded(config.max_connections));
611        }
612
613        handler.on_connect(&connection).await?;
614        tracing::info!("Client connected: {} from {}", connection_id, addr);
615
616        let (ws_sender, mut ws_receiver) = ws_stream.split();
617        let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
618
619        sender.register(connection_id.clone(), tx).await;
620
621        let send_task = async move {
622            let mut ws_sender = ws_sender;
623            while let Some(msg) = rx.recv().await {
624                if ws_sender.send(msg.into()).await.is_err() {
625                    break;
626                }
627            }
628            let _ = ws_sender.close().await;
629        };
630
631        let connection_manager_clone = connection_manager.clone();
632        let room_manager_clone = room_manager.clone();
633        let sender_clone = sender.clone();
634        let connection_id_clone = connection_id.clone();
635        let connection_clone = connection.clone();
636        let handler_clone = handler.clone();
637        let recv_task = async move {
638            while let Some(msg_result) = ws_receiver.next().await {
639                match msg_result {
640                    Ok(ws_msg) => {
641                        let msg: Message = ws_msg.into();
642                        if matches!(msg, Message::Close) {
643                            break;
644                        }
645                        if handler_clone.on_message(&connection_clone, msg).await.is_err() {
646                            break;
647                        }
648                    }
649                    Err(_) => break,
650                }
651            }
652        };
653
654        tokio::select! {
655            _ = send_task => {},
656            _ = recv_task => {},
657        }
658
659        for room_id in &connection.rooms {
660            room_manager_clone.leave(room_id, &connection_id_clone).await;
661        }
662
663        connection_manager_clone.remove(&connection_id_clone).await;
664        sender_clone.unregister(&connection_id_clone).await;
665        handler.on_disconnect(&connection).await;
666
667        tracing::info!("Client disconnected: {}", connection_id);
668
669        Ok(())
670    }
671
672    /// 停止服务端
673    pub fn shutdown(&self) {
674        let _ = self.shutdown_tx.send(());
675    }
676
677    /// 广播消息到所有连接
678    pub async fn broadcast(&self, message: Message) -> WebSocketResult<usize> {
679        self.sender.broadcast(message).await
680    }
681
682    /// 广播消息到房间
683    pub async fn broadcast_to_room(&self, room_id: &str, message: Message) -> WebSocketResult<Vec<ConnectionId>> {
684        self.room_manager.broadcast(room_id, &self.sender, &message).await
685    }
686}
687
688/// 客户端配置
689#[derive(Debug, Clone)]
690pub struct ClientConfig {
691    /// 服务端 URL
692    pub url: String,
693    /// 重连间隔
694    pub reconnect_interval: Duration,
695    /// 心跳间隔
696    pub heartbeat_interval: Duration,
697    /// 连接超时
698    pub connection_timeout: Duration,
699    /// 最大重连次数 (0 表示无限重连)
700    pub max_reconnect_attempts: u32,
701}
702
703impl Default for ClientConfig {
704    fn default() -> Self {
705        Self {
706            url: "ws://127.0.0.1:8080".to_string(),
707            reconnect_interval: Duration::from_secs(5),
708            heartbeat_interval: Duration::from_secs(30),
709            connection_timeout: Duration::from_secs(10),
710            max_reconnect_attempts: 0,
711        }
712    }
713}
714
715impl ClientConfig {
716    /// 创建新的客户端配置
717    pub fn new(url: impl Into<String>) -> Self {
718        Self { url: url.into(), ..Self::default() }
719    }
720
721    /// 设置重连间隔
722    pub fn reconnect_interval(mut self, interval: Duration) -> Self {
723        self.reconnect_interval = interval;
724        self
725    }
726
727    /// 设置心跳间隔
728    pub fn heartbeat_interval(mut self, interval: Duration) -> Self {
729        self.heartbeat_interval = interval;
730        self
731    }
732
733    /// 设置最大重连次数
734    pub fn max_reconnect_attempts(mut self, attempts: u32) -> Self {
735        self.max_reconnect_attempts = attempts;
736        self
737    }
738}
739
740/// WebSocket 客户端
741pub struct WebSocketClient {
742    config: ClientConfig,
743    sender: mpsc::UnboundedSender<Message>,
744    receiver: mpsc::UnboundedReceiver<Message>,
745}
746
747impl WebSocketClient {
748    /// 创建新的 WebSocket 客户端
749    pub fn new(config: ClientConfig) -> Self {
750        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded_channel::<Message>();
751        let (incoming_tx, incoming_rx) = mpsc::unbounded_channel::<Message>();
752
753        let config_clone = config.clone();
754
755        tokio::spawn(async move {
756            let mut attempt = 0u32;
757            loop {
758                match tokio_tungstenite::connect_async(&config_clone.url).await {
759                    Ok((ws_stream, _)) => {
760                        tracing::info!("WebSocket client connected to {}", config_clone.url);
761                        attempt = 0;
762
763                        let (mut ws_sender, mut ws_receiver) = ws_stream.split();
764
765                        let send_task = async {
766                            while let Some(msg) = outgoing_rx.recv().await {
767                                if ws_sender.send(msg.into()).await.is_err() {
768                                    break;
769                                }
770                            }
771                        };
772
773                        let recv_task = async {
774                            while let Some(msg_result) = ws_receiver.next().await {
775                                match msg_result {
776                                    Ok(ws_msg) => {
777                                        let msg: Message = ws_msg.into();
778                                        if matches!(msg, Message::Close) {
779                                            break;
780                                        }
781                                        if incoming_tx.send(msg).is_err() {
782                                            break;
783                                        }
784                                    }
785                                    Err(_) => break,
786                                }
787                            }
788                        };
789
790                        tokio::select! {
791                            _ = send_task => {},
792                            _ = recv_task => {},
793                        }
794
795                        tracing::warn!("WebSocket client disconnected, attempting to reconnect...");
796                    }
797                    Err(e) => {
798                        tracing::error!("WebSocket connection failed: {}", e);
799                    }
800                }
801
802                attempt += 1;
803                if config_clone.max_reconnect_attempts > 0 && attempt >= config_clone.max_reconnect_attempts {
804                    tracing::error!("Max reconnect attempts reached, giving up");
805                    break;
806                }
807
808                tokio::time::sleep(config_clone.reconnect_interval).await;
809            }
810        });
811
812        Self { config, sender: outgoing_tx, receiver: incoming_rx }
813    }
814
815    /// 发送消息
816    pub async fn send(&self, message: Message) -> WebSocketResult<()> {
817        self.sender.send(message).map_err(|e| WebSocketError::SendFailed(e.to_string()))
818    }
819
820    /// 发送文本消息
821    pub async fn send_text(&self, text: impl Into<String>) -> WebSocketResult<()> {
822        self.send(Message::text(text)).await
823    }
824
825    /// 发送二进制消息
826    pub async fn send_binary(&self, data: impl Into<Vec<u8>>) -> WebSocketResult<()> {
827        self.send(Message::binary(data)).await
828    }
829
830    /// 发送 JSON 消息
831    pub async fn send_json<T: Serialize + ?Sized>(&self, value: &T) -> WebSocketResult<()> {
832        let json = serde_json::to_string(value).map_err(|e| WebSocketError::SerializationFailed(e.to_string()))?;
833        self.send_text(json).await
834    }
835
836    /// 接收消息
837    pub async fn receive(&mut self) -> Option<Message> {
838        self.receiver.recv().await
839    }
840
841    /// 接收并解析 JSON 消息
842    pub async fn receive_json<T: DeserializeOwned>(&mut self) -> WebSocketResult<Option<T>> {
843        match self.receive().await {
844            Some(msg) => {
845                let text =
846                    msg.as_text().ok_or_else(|| WebSocketError::DeserializationFailed("Expected text message".into()))?;
847                let value: T = serde_json::from_str(text).map_err(|e| WebSocketError::DeserializationFailed(e.to_string()))?;
848                Ok(Some(value))
849            }
850            None => Ok(None),
851        }
852    }
853
854    /// 获取配置
855    pub fn config(&self) -> &ClientConfig {
856        &self.config
857    }
858
859    /// 关闭连接
860    pub async fn close(&self) -> WebSocketResult<()> {
861        self.send(Message::Close).await
862    }
863}
864
865/// 便捷函数:创建 WebSocket 服务端
866pub fn websocket_server(config: ServerConfig) -> WebSocketServer {
867    WebSocketServer::new(config)
868}
869
870/// 便捷函数:创建 WebSocket 客户端
871pub fn websocket_client(config: ClientConfig) -> WebSocketClient {
872    WebSocketClient::new(config)
873}