Skip to main content

rustapi_ws/
socket.rs

1//! WebSocket stream implementation
2
3use crate::{Message, WebSocketError, WsHeartbeatConfig};
4use futures_util::{
5    stream::{SplitSink, SplitStream},
6    SinkExt, Stream, StreamExt,
7};
8use hyper::upgrade::Upgraded;
9use hyper_util::rt::TokioIo;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use tokio::sync::mpsc;
13use tokio_tungstenite::WebSocketStream as TungsteniteStream;
14
15/// Type alias for the upgraded connection
16type UpgradedConnection = TungsteniteStream<TokioIo<Upgraded>>;
17
18/// Internal implementation of the WebSocket stream
19#[allow(clippy::large_enum_variant)]
20enum StreamImpl {
21    /// Direct connection (no heartbeat/management)
22    Direct(UpgradedConnection),
23    /// Managed connection (heartbeat/cleanup running in background task)
24    Managed {
25        tx: mpsc::Sender<Message>,
26        rx: mpsc::Receiver<Result<Message, WebSocketError>>,
27    },
28}
29
30/// A WebSocket stream
31pub struct WebSocketStream {
32    inner: StreamImpl,
33}
34
35impl WebSocketStream {
36    /// Create a new direct WebSocket stream
37    pub(crate) fn new(inner: UpgradedConnection) -> Self {
38        Self {
39            inner: StreamImpl::Direct(inner),
40        }
41    }
42
43    /// Create a new managed WebSocket stream with heartbeat
44    pub(crate) fn new_managed(inner: UpgradedConnection, config: WsHeartbeatConfig) -> Self {
45        let (mut sender, mut receiver) = inner.split();
46        let (user_tx, mut internal_rx) = mpsc::channel::<Message>(32);
47        let (internal_tx, user_rx) = mpsc::channel::<Result<Message, WebSocketError>>(32);
48
49        // Spawn management task
50        tokio::spawn(async move {
51            let mut heartbeat_interval = tokio::time::interval(config.interval);
52            // First tick finishes immediately
53            heartbeat_interval.tick().await;
54
55            // For pong tracking, we can just track last activity or strictly check pongs.
56            // Simplified: we rely on TCP checks and ping writing success.
57            // If we want to enforce timeout, we need to track "last pong".
58
59            // Tungstenite handles Pong responses to our Pings automatically IF we poll the stream.
60            // But we are polling the stream in the select loop below.
61
62            // Note: Tungstenite returns Pongs as messages. We should filter them out mostly,
63            // or pass them if the user wants them?
64            // Usually heartbeat Pongs are an implementation detail.
65
66            let mut last_heartbeat = tokio::time::Instant::now();
67            let mut timeout_check = tokio::time::interval(config.timeout);
68
69            loop {
70                tokio::select! {
71                    // 1. Receive message from socket
72                    msg = receiver.next() => {
73                        match msg {
74                            Some(Ok(msg)) => {
75                                last_heartbeat = tokio::time::Instant::now();
76                                if msg.is_pong() {
77                                    // Received a pong (response to our ping)
78                                    continue;
79                                }
80                                if msg.is_ping() {
81                                    // Received a ping (from client)
82                                    // Tungstenite might have auto-replied if we used the right callback,
83                                    // but default poll_next does reply to pings by queueing a pong.
84                                    // We need to ensure that queued pong is sent.
85                                    // However, we are in a split stream.
86                                    // `receiver` is Stream. `sender` is Sink.
87                                    // Tungstenite's split separates them.
88                                    // The `receiver` will NOT automatically write to `sender`.
89                                    // WE must handle Ping replies if split?
90                                    // `tokio-tungstenite` docs: "You must handle Pings manually when split?"
91                                    // No, the `tungstenite` protocol handler is shared? No.
92
93                                    // If we receive a Ping, we should send a Pong.
94                                    let _ = sender.send(Message::Pong(msg.into_data()).into()).await;
95                                    continue;
96                                }
97
98                                // Forward other messages to user
99                                if internal_tx.send(Ok(Message::from(msg))).await.is_err() {
100                                    break; // User dropped receiver
101                                }
102                            }
103                            Some(Err(e)) => {
104                                let _ = internal_tx.send(Err(WebSocketError::from(e))).await;
105                                break;
106                            }
107                            None => break, // Connection closed
108                        }
109                    }
110
111                    // 2. Receive message from user to send
112                    msg = internal_rx.recv() => {
113                        match msg {
114                            Some(msg) => {
115                                if sender.send(msg.into()).await.is_err() {
116                                    break; // Connection closed
117                                }
118                            }
119                            None => break, // User dropped sender
120                        }
121                    }
122
123                    // 3. Send Ping
124                    _ = heartbeat_interval.tick() => {
125                         if sender.send(Message::Ping(vec![]).into()).await.is_err() {
126                             break;
127                         }
128                    }
129
130                    // 4. Check timeout
131                    _ = timeout_check.tick() => {
132                        if last_heartbeat.elapsed() > config.interval + config.timeout {
133                            // Timeout
134                            break;
135                            // This drops 'sender', closing the connection
136                        }
137                    }
138                }
139            }
140            // Loop break drops sender/receiver, closing connection
141        });
142
143        Self {
144            inner: StreamImpl::Managed {
145                tx: user_tx,
146                rx: user_rx,
147            },
148        }
149    }
150
151    /// Split the stream into sender and receiver halves
152    pub fn split(self) -> (WebSocketSender, WebSocketReceiver) {
153        match self.inner {
154            StreamImpl::Direct(inner) => {
155                let (sink, stream) = inner.split();
156                (
157                    WebSocketSender {
158                        inner: SenderImpl::Direct(sink),
159                    },
160                    WebSocketReceiver {
161                        inner: ReceiverImpl::Direct(stream),
162                    },
163                )
164            }
165            StreamImpl::Managed { tx, rx } => (
166                WebSocketSender {
167                    inner: SenderImpl::Managed(tx),
168                },
169                WebSocketReceiver {
170                    inner: ReceiverImpl::Managed(rx),
171                },
172            ),
173        }
174    }
175}
176
177// Implement helper methods directly on WebSocketStream for convenience
178impl WebSocketStream {
179    /// Send a message
180    pub async fn send(&mut self, msg: Message) -> Result<(), WebSocketError> {
181        match &mut self.inner {
182            StreamImpl::Direct(s) => s.send(msg.into()).await.map_err(WebSocketError::from),
183            StreamImpl::Managed { tx, .. } => tx
184                .send(msg)
185                .await
186                .map_err(|_| WebSocketError::ConnectionClosed),
187        }
188    }
189
190    /// Receive the next message
191    pub async fn recv(&mut self) -> Option<Result<Message, WebSocketError>> {
192        match &mut self.inner {
193            StreamImpl::Direct(s) => s
194                .next()
195                .await
196                .map(|r| r.map(Message::from).map_err(WebSocketError::from)),
197            StreamImpl::Managed { rx, .. } => rx.recv().await,
198        }
199    }
200
201    /// Send a text message
202    pub async fn send_text(&mut self, text: impl Into<String>) -> Result<(), WebSocketError> {
203        self.send(Message::text(text)).await
204    }
205
206    /// Send a binary message
207    pub async fn send_binary(&mut self, data: impl Into<Vec<u8>>) -> Result<(), WebSocketError> {
208        self.send(Message::binary(data)).await
209    }
210
211    /// Send a JSON message
212    pub async fn send_json<T: serde::Serialize>(
213        &mut self,
214        value: &T,
215    ) -> Result<(), WebSocketError> {
216        self.send(Message::json(value)?).await
217    }
218}
219
220// Inner implementations for Sender/Receiver
221
222enum SenderImpl {
223    Direct(SplitSink<UpgradedConnection, tungstenite::Message>),
224    Managed(mpsc::Sender<Message>),
225}
226
227/// Sender half of a WebSocket stream
228pub struct WebSocketSender {
229    inner: SenderImpl,
230}
231
232impl WebSocketSender {
233    /// Send a message
234    pub async fn send(&mut self, msg: Message) -> Result<(), WebSocketError> {
235        match &mut self.inner {
236            SenderImpl::Direct(s) => s.send(msg.into()).await.map_err(WebSocketError::from),
237            SenderImpl::Managed(s) => s
238                .send(msg)
239                .await
240                .map_err(|_| WebSocketError::ConnectionClosed),
241        }
242    }
243
244    /// Send a text message
245    pub async fn send_text(&mut self, text: impl Into<String>) -> Result<(), WebSocketError> {
246        self.send(Message::text(text)).await
247    }
248
249    /// Send a binary message
250    pub async fn send_binary(&mut self, data: impl Into<Vec<u8>>) -> Result<(), WebSocketError> {
251        self.send(Message::binary(data)).await
252    }
253
254    /// Send a JSON message
255    pub async fn send_json<T: serde::Serialize>(
256        &mut self,
257        value: &T,
258    ) -> Result<(), WebSocketError> {
259        self.send(Message::json(value)?).await
260    }
261
262    /// Close the sender
263    pub async fn close(mut self) -> Result<(), WebSocketError> {
264        match &mut self.inner {
265            SenderImpl::Direct(s) => s.close().await.map_err(WebSocketError::from),
266            SenderImpl::Managed(_) => {
267                // Drop sender to close channel, explicitly nothing else to do
268                Ok(())
269            }
270        }
271    }
272}
273
274enum ReceiverImpl {
275    Direct(SplitStream<UpgradedConnection>),
276    Managed(mpsc::Receiver<Result<Message, WebSocketError>>),
277}
278
279/// Receiver half of a WebSocket stream
280pub struct WebSocketReceiver {
281    inner: ReceiverImpl,
282}
283
284impl WebSocketReceiver {
285    /// Receive the next message
286    pub async fn recv(&mut self) -> Option<Result<Message, WebSocketError>> {
287        match &mut self.inner {
288            ReceiverImpl::Direct(s) => s
289                .next()
290                .await
291                .map(|r| r.map(Message::from).map_err(WebSocketError::from)),
292            ReceiverImpl::Managed(s) => s.recv().await,
293        }
294    }
295}
296
297impl Stream for WebSocketReceiver {
298    type Item = Result<Message, WebSocketError>;
299
300    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
301        match &mut self.inner {
302            ReceiverImpl::Direct(s) => match Pin::new(s).poll_next(cx) {
303                Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(Ok(Message::from(msg)))),
304                Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(WebSocketError::from(e)))),
305                Poll::Ready(None) => Poll::Ready(None),
306                Poll::Pending => Poll::Pending,
307            },
308            ReceiverImpl::Managed(s) => s.poll_recv(cx),
309        }
310    }
311}