Skip to main content

ws_reconnect_client/
stream.rs

1use futures_util::{Stream, StreamExt, SinkExt, Future};
2use serde::de::DeserializeOwned;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6use tokio::sync::Mutex;
7use tokio_tungstenite::tungstenite::Message;
8
9use crate::{PingManager, Result, WebSocketError, WsReader, WsWriter};
10
11/// A WebSocket message stream that automatically handles PING/PONG and yields parsed messages
12///
13/// This stream:
14/// - Responds to server PINGs automatically with PONGs
15/// - Sends periodic client PINGs to keep the connection alive
16/// - Parses incoming text messages as JSON (supports both single messages and arrays)
17/// - Filters out control messages (PING, PONG, empty strings)
18/// - Only yields successfully parsed data messages of type T
19pub struct MessageStream<T>
20where
21    T: DeserializeOwned,
22{
23    reader: WsReader,
24    writer: Arc<Mutex<Option<WsWriter>>>,
25    ping_manager: Option<PingManager>,
26    // Buffer for parsed messages (one WS message can contain multiple T messages)
27    message_buffer: Vec<T>,
28    _phantom: std::marker::PhantomData<T>,
29}
30
31impl<T> MessageStream<T>
32where
33    T: DeserializeOwned,
34{
35    /// Create a new MessageStream
36    ///
37    /// # Arguments
38    /// * `reader` - The WebSocket reader
39    /// * `writer` - Shared WebSocket writer (Arc<Mutex<Option<WsWriter>>>)
40    /// * `ping_interval_secs` - Optional interval for sending client pings (0 = disabled)
41    pub fn new(reader: WsReader, writer: Arc<Mutex<Option<WsWriter>>>, ping_interval_secs: u64) -> Self {
42        let ping_manager = if ping_interval_secs > 0 {
43            Some(PingManager::new(ping_interval_secs))
44        } else {
45            None
46        };
47
48        Self {
49            reader,
50            writer,
51            ping_manager,
52            message_buffer: Vec::new(),
53            _phantom: std::marker::PhantomData,
54        }
55    }
56
57    /// Parse a text message as either a single message or array of messages
58    fn parse_text_message(&self, text: &str) -> Result<Vec<T>> {
59        // Skip control messages
60        let trimmed = text.trim();
61        if trimmed.is_empty() || trimmed == "PING" || trimmed == "PONG" {
62            return Ok(Vec::new());
63        }
64
65        // Log raw message for debugging
66        if trimmed.len() < 500 {
67        } else {
68        }
69
70        // Try array first if starts with '[', then single message, then empty vec
71        if trimmed.starts_with('[') {
72            serde_json::from_str::<Vec<T>>(text).ok()
73        } else {
74            None
75        }
76        .or_else(|| serde_json::from_str::<T>(text).ok().map(|msg| vec![msg]))
77        .ok_or(())
78        .or(Ok(Vec::new()))
79    }
80
81    /// Handle an incoming message, possibly sending PONG in response to PING
82    async fn handle_incoming_message(&mut self, msg: Message) -> Result<Vec<T>> {
83        match msg {
84            Message::Text(text) => self.parse_text_message(&text),
85
86            Message::Binary(data) => {
87                let text = String::from_utf8_lossy(&data);
88                self.parse_text_message(&text)
89            }
90
91            Message::Ping(ping) => {
92                // Respond to server ping with pong
93                let mut writer_guard = self.writer.lock().await;
94                if let Some(writer) = writer_guard.as_mut() {
95                    writer
96                        .send(Message::Pong(ping))
97                        .await
98                        .map_err(|_| WebSocketError::SendError)?;
99                }
100                Ok(Vec::new())
101            }
102
103            Message::Pong(_) => {
104                // Pong received, connection is alive
105                Ok(Vec::new())
106            }
107
108            Message::Close(_) => Err(WebSocketError::ConnectionClosed),
109
110            Message::Frame(_) => Ok(Vec::new()),
111        }
112    }
113
114    /// Send a periodic ping to keep the connection alive
115    async fn send_ping(&mut self) -> Result<()> {
116        let mut writer_guard = self.writer.lock().await;
117        if let Some(writer) = writer_guard.as_mut() {
118            writer
119                .send(Message::Ping(vec![].into()))
120                .await
121                .map_err(|_| WebSocketError::SendError)?;
122        }
123        Ok(())
124    }
125
126    /// Get the next parsed message from the stream
127    ///
128    /// This method handles:
129    /// - Incoming messages (responding to PINGs, parsing text)
130    /// - Outgoing periodic PINGs (if enabled)
131    /// - Returns the next successfully parsed message
132    async fn next_message(&mut self) -> Option<Result<T>> {
133        // Buffer for parsed messages (since one websocket message can contain multiple T messages)
134        let mut buffer = Vec::new();
135
136        loop {
137            // If we have buffered messages, return the first one
138            if !buffer.is_empty() {
139                return Some(Ok(buffer.remove(0)));
140            }
141
142            tokio::select! {
143                // Handle incoming messages
144                msg = self.reader.next() => {
145                    match msg {
146                        Some(Ok(msg)) => {
147                            match self.handle_incoming_message(msg).await {
148                                Ok(messages) => {
149                                    buffer = messages;
150                                    // Continue loop to check buffer and return message if available
151                                }
152                                Err(e) => return Some(Err(e)),
153                            }
154                        }
155                        Some(Err(e)) => {
156                            return Some(Err(WebSocketError::Tungstenite(e)));
157                        }
158                        None => {
159                            return None; // Stream ended
160                        }
161                    }
162                }
163
164                // Send periodic pings if enabled
165                _ = async {
166                    if let Some(ref mut pm) = self.ping_manager {
167                        pm.wait_for_next_ping().await;
168                    } else {
169                        // Never completes if pings are disabled
170                        std::future::pending::<()>().await;
171                    }
172                } => {
173                    if let Err(e) = self.send_ping().await {
174                        return Some(Err(e));
175                    }
176                }
177            }
178        }
179    }
180}
181
182impl<T> Stream for MessageStream<T>
183where
184    T: DeserializeOwned + Unpin + 'static,
185{
186    type Item = Result<T>;
187
188    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
189        loop {
190            // First check if we have buffered messages
191            if !self.message_buffer.is_empty() {
192                return Poll::Ready(Some(Ok(self.message_buffer.remove(0))));
193            }
194
195            // No buffered messages, poll next_message to fill the buffer
196            let fut = self.next_message();
197            tokio::pin!(fut);
198
199            match fut.poll(cx) {
200                Poll::Ready(Some(Ok(msg))) => {
201                    // Got a message, return it immediately
202                    return Poll::Ready(Some(Ok(msg)));
203                }
204                Poll::Ready(Some(Err(e))) => {
205                    return Poll::Ready(Some(Err(e)));
206                }
207                Poll::Ready(None) => {
208                    return Poll::Ready(None);
209                }
210                Poll::Pending => {
211                    return Poll::Pending;
212                }
213            }
214        }
215    }
216}