ws_reconnect_client/
stream.rs1use futures_util::{Stream, StreamExt, SinkExt, Future};
2use serde::de::DeserializeOwned;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6use tokio::sync::Mutex;
7use tokio_tungstenite::tungstenite::Message;
8
9use crate::{PingManager, Result, WebSocketError, WsReader, WsWriter};
10
11pub struct MessageStream<T>
20where
21 T: DeserializeOwned,
22{
23 reader: WsReader,
24 writer: Arc<Mutex<Option<WsWriter>>>,
25 ping_manager: Option<PingManager>,
26 message_buffer: Vec<T>,
28 _phantom: std::marker::PhantomData<T>,
29}
30
31impl<T> MessageStream<T>
32where
33 T: DeserializeOwned,
34{
35 pub fn new(reader: WsReader, writer: Arc<Mutex<Option<WsWriter>>>, ping_interval_secs: u64) -> Self {
42 let ping_manager = if ping_interval_secs > 0 {
43 Some(PingManager::new(ping_interval_secs))
44 } else {
45 None
46 };
47
48 Self {
49 reader,
50 writer,
51 ping_manager,
52 message_buffer: Vec::new(),
53 _phantom: std::marker::PhantomData,
54 }
55 }
56
57 fn parse_text_message(&self, text: &str) -> Result<Vec<T>> {
59 let trimmed = text.trim();
61 if trimmed.is_empty() || trimmed == "PING" || trimmed == "PONG" {
62 return Ok(Vec::new());
63 }
64
65 if trimmed.len() < 500 {
67 } else {
68 }
69
70 if trimmed.starts_with('[') {
72 serde_json::from_str::<Vec<T>>(text).ok()
73 } else {
74 None
75 }
76 .or_else(|| serde_json::from_str::<T>(text).ok().map(|msg| vec![msg]))
77 .ok_or(())
78 .or(Ok(Vec::new()))
79 }
80
81 async fn handle_incoming_message(&mut self, msg: Message) -> Result<Vec<T>> {
83 match msg {
84 Message::Text(text) => self.parse_text_message(&text),
85
86 Message::Binary(data) => {
87 let text = String::from_utf8_lossy(&data);
88 self.parse_text_message(&text)
89 }
90
91 Message::Ping(ping) => {
92 let mut writer_guard = self.writer.lock().await;
94 if let Some(writer) = writer_guard.as_mut() {
95 writer
96 .send(Message::Pong(ping))
97 .await
98 .map_err(|_| WebSocketError::SendError)?;
99 }
100 Ok(Vec::new())
101 }
102
103 Message::Pong(_) => {
104 Ok(Vec::new())
106 }
107
108 Message::Close(_) => Err(WebSocketError::ConnectionClosed),
109
110 Message::Frame(_) => Ok(Vec::new()),
111 }
112 }
113
114 async fn send_ping(&mut self) -> Result<()> {
116 let mut writer_guard = self.writer.lock().await;
117 if let Some(writer) = writer_guard.as_mut() {
118 writer
119 .send(Message::Ping(vec![].into()))
120 .await
121 .map_err(|_| WebSocketError::SendError)?;
122 }
123 Ok(())
124 }
125
126 async fn next_message(&mut self) -> Option<Result<T>> {
133 let mut buffer = Vec::new();
135
136 loop {
137 if !buffer.is_empty() {
139 return Some(Ok(buffer.remove(0)));
140 }
141
142 tokio::select! {
143 msg = self.reader.next() => {
145 match msg {
146 Some(Ok(msg)) => {
147 match self.handle_incoming_message(msg).await {
148 Ok(messages) => {
149 buffer = messages;
150 }
152 Err(e) => return Some(Err(e)),
153 }
154 }
155 Some(Err(e)) => {
156 return Some(Err(WebSocketError::Tungstenite(e)));
157 }
158 None => {
159 return None; }
161 }
162 }
163
164 _ = async {
166 if let Some(ref mut pm) = self.ping_manager {
167 pm.wait_for_next_ping().await;
168 } else {
169 std::future::pending::<()>().await;
171 }
172 } => {
173 if let Err(e) = self.send_ping().await {
174 return Some(Err(e));
175 }
176 }
177 }
178 }
179 }
180}
181
182impl<T> Stream for MessageStream<T>
183where
184 T: DeserializeOwned + Unpin + 'static,
185{
186 type Item = Result<T>;
187
188 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
189 loop {
190 if !self.message_buffer.is_empty() {
192 return Poll::Ready(Some(Ok(self.message_buffer.remove(0))));
193 }
194
195 let fut = self.next_message();
197 tokio::pin!(fut);
198
199 match fut.poll(cx) {
200 Poll::Ready(Some(Ok(msg))) => {
201 return Poll::Ready(Some(Ok(msg)));
203 }
204 Poll::Ready(Some(Err(e))) => {
205 return Poll::Ready(Some(Err(e)));
206 }
207 Poll::Ready(None) => {
208 return Poll::Ready(None);
209 }
210 Poll::Pending => {
211 return Poll::Pending;
212 }
213 }
214 }
215 }
216}