stream_tungstenite/message/
dispatcher.rs1use futures_util::stream::{SplitSink, SplitStream};
4use futures_util::{SinkExt, StreamExt};
5use std::future::Future;
6use std::marker::PhantomData;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::{broadcast, mpsc, RwLock};
11use tokio::task::JoinHandle;
12use tungstenite::Message;
13
14pub type SharedMessage = Arc<Message>;
16
17use crate::connection::WsStream;
18use crate::error::{ExtensionError, ReceiveError, SendError};
19
20#[derive(Debug, Clone)]
22pub struct DispatcherConfig {
23 pub receive_timeout: Duration,
25 pub broadcast_capacity: usize,
27 pub send_buffer_capacity: usize,
29 pub processor_error_policy: ProcessorErrorPolicy,
31}
32
33#[derive(Debug, Clone, Copy)]
35pub enum ProcessorErrorPolicy {
36 Ignore,
38 Disconnect,
40}
41
42impl Default for DispatcherConfig {
43 fn default() -> Self {
44 Self {
45 receive_timeout: Duration::from_secs(30),
46 broadcast_capacity: 1024,
47 send_buffer_capacity: 256,
48 processor_error_policy: ProcessorErrorPolicy::Ignore,
49 }
50 }
51}
52
53impl DispatcherConfig {
54 #[must_use]
56 pub fn new() -> Self {
57 Self::default()
58 }
59
60 #[must_use]
62 pub const fn with_receive_timeout(mut self, timeout: Duration) -> Self {
63 self.receive_timeout = timeout;
64 self
65 }
66
67 #[must_use]
69 pub const fn with_broadcast_capacity(mut self, capacity: usize) -> Self {
70 self.broadcast_capacity = capacity;
71 self
72 }
73
74 #[must_use]
76 pub const fn with_send_buffer_capacity(mut self, capacity: usize) -> Self {
77 self.send_buffer_capacity = capacity;
78 self
79 }
80
81 #[must_use]
83 pub const fn with_processor_error_policy(mut self, policy: ProcessorErrorPolicy) -> Self {
84 self.processor_error_policy = policy;
85 self
86 }
87}
88
89struct SenderState<S: WsStream> {
91 send_task: Option<JoinHandle<()>>,
93 send_tx: Option<mpsc::Sender<Message>>,
95 _marker: PhantomData<S>,
97}
98
99pub struct MessageDispatcher<S: WsStream = crate::connection::DefaultWsStream> {
101 config: DispatcherConfig,
103 sender_state: Arc<RwLock<SenderState<S>>>,
105 is_connected: Arc<AtomicBool>,
107 message_tx: broadcast::Sender<SharedMessage>,
109}
110
111#[allow(clippy::future_not_send)]
116impl<S: WsStream> MessageDispatcher<S> {
117 #[must_use]
119 pub fn new(config: DispatcherConfig) -> Self {
120 let (message_tx, _) = broadcast::channel(config.broadcast_capacity);
121
122 Self {
123 config,
124 sender_state: Arc::new(RwLock::new(SenderState::<S> {
125 send_task: None,
126 send_tx: None,
127 _marker: PhantomData,
128 })),
129 is_connected: Arc::new(AtomicBool::new(false)),
130 message_tx,
131 }
132 }
133
134 pub async fn attach(&self, sender: SplitSink<S, Message>) {
136 let (tx, mut rx) = mpsc::channel::<Message>(self.config.send_buffer_capacity);
138
139 let connected = self.is_connected.clone();
141 let send_task = tokio::spawn(async move {
142 let mut sink = sender;
143 while let Some(msg) = rx.recv().await {
144 if let Err(e) = sink.send(msg).await {
146 tracing::debug!(error = ?e, "Dispatcher send task encountered error");
147 connected.store(false, Ordering::Release);
148 break;
149 }
150 }
151 });
152
153 {
155 let mut state = self.sender_state.write().await;
156 if let Some(handle) = state.send_task.take() {
158 handle.abort();
159 }
160 state.send_tx = Some(tx);
161 state.send_task = Some(send_task);
162 }
163 self.is_connected.store(true, Ordering::Release);
165 tracing::debug!("Message dispatcher attached");
166 }
167
168 pub async fn detach(&self) {
170 self.is_connected.store(false, Ordering::Release);
171 {
172 let mut state = self.sender_state.write().await;
173 state.send_tx = None;
175 if let Some(handle) = state.send_task.take() {
177 handle.abort();
178 }
179 }
180 tracing::debug!("Message dispatcher detached");
181 }
182
183 #[must_use]
185 pub fn is_connected(&self) -> bool {
186 self.is_connected.load(Ordering::Acquire)
187 }
188
189 pub async fn send(&self, msg: Message) -> Result<(), SendError> {
196 if !self.is_connected() {
198 return Err(SendError::NotConnected);
199 }
200 let tx = {
202 let state = self.sender_state.read().await;
203 state.send_tx.clone()
204 };
205 match tx {
206 Some(tx) => tx.send(msg).await.map_err(|_| SendError::ChannelClosed),
207 None => Err(SendError::NotConnected),
208 }
209 }
210
211 #[must_use]
218 pub fn subscribe(&self) -> broadcast::Receiver<SharedMessage> {
219 self.message_tx.subscribe()
220 }
221
222 #[must_use]
224 pub fn subscriber_count(&self) -> usize {
225 self.message_tx.receiver_count()
226 }
227
228 pub async fn receive_loop(&self, mut receiver: SplitStream<S>) -> Result<(), ReceiveError> {
239 let timeout = self.config.receive_timeout;
240
241 loop {
242 let result = tokio::time::timeout(timeout, receiver.next()).await;
243
244 match result {
245 Ok(Some(Ok(msg))) => {
246 let _ = self.message_tx.send(Arc::new(msg));
249 }
250 Ok(Some(Err(e))) => {
251 tracing::debug!(error = ?e, "WebSocket receive error");
252 return Err(ReceiveError::WebSocket(e.to_string()));
253 }
254 Ok(None) => {
255 tracing::debug!("WebSocket stream closed");
256 return Err(ReceiveError::StreamClosed);
257 }
258 Err(_) => {
259 tracing::debug!(timeout = ?timeout, "Receive timeout");
260 return Err(ReceiveError::Timeout(timeout));
261 }
262 }
263 }
264 }
265
266 pub async fn receive_loop_with_activity<F>(
276 &self,
277 mut receiver: SplitStream<S>,
278 on_activity: F,
279 ) -> Result<(), ReceiveError>
280 where
281 F: Fn() + Send + Sync,
282 {
283 let timeout = self.config.receive_timeout;
284
285 loop {
286 let result = tokio::time::timeout(timeout, receiver.next()).await;
287
288 match result {
289 Ok(Some(Ok(msg))) => {
290 on_activity();
292
293 let _ = self.message_tx.send(Arc::new(msg));
295 }
296 Ok(Some(Err(e))) => {
297 return Err(ReceiveError::WebSocket(e.to_string()));
298 }
299 Ok(None) => {
300 return Err(ReceiveError::StreamClosed);
301 }
302 Err(_) => {
303 return Err(ReceiveError::Timeout(timeout));
304 }
305 }
306 }
307 }
308
309 pub async fn receive_loop_with_processor<FAct, FActFut, FProc, FProcFut>(
320 &self,
321 mut receiver: SplitStream<S>,
322 on_activity: FAct,
323 processor: FProc,
324 ) -> Result<(), ReceiveError>
325 where
326 FAct: Fn() -> FActFut + Send + Sync,
327 FActFut: Future<Output = ()> + Send,
328 FProc: Fn(Message) -> FProcFut + Send + Sync,
329 FProcFut: Future<Output = Result<Option<Message>, ExtensionError>> + Send,
330 {
331 let timeout = self.config.receive_timeout;
332
333 loop {
334 let result = tokio::time::timeout(timeout, receiver.next()).await;
335
336 match result {
337 Ok(Some(Ok(msg))) => {
338 on_activity().await;
340
341 match processor(msg).await {
343 Ok(Some(broadcast_msg)) => {
344 let _ = self.message_tx.send(Arc::new(broadcast_msg));
345 }
346 Ok(None) => {
347 }
349 Err(e) => match self.config.processor_error_policy {
350 ProcessorErrorPolicy::Ignore => {
351 tracing::warn!(error = ?e, "Message processor failed");
352 }
353 ProcessorErrorPolicy::Disconnect => {
354 return Err(ReceiveError::WebSocket(e.to_string()));
355 }
356 },
357 }
358 }
359 Ok(Some(Err(e))) => {
360 return Err(ReceiveError::WebSocket(e.to_string()));
361 }
362 Ok(None) => {
363 return Err(ReceiveError::StreamClosed);
364 }
365 Err(_) => {
366 return Err(ReceiveError::Timeout(timeout));
367 }
368 }
369 }
370 }
371}
372
373impl<S: WsStream> Default for MessageDispatcher<S> {
374 fn default() -> Self {
375 Self::new(DispatcherConfig::default())
376 }
377}
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn test_dispatcher_config() {
384 let config = DispatcherConfig::new()
385 .with_receive_timeout(Duration::from_secs(60))
386 .with_broadcast_capacity(2048);
387
388 assert_eq!(config.receive_timeout, Duration::from_secs(60));
389 assert_eq!(config.broadcast_capacity, 2048);
390 }
391
392 #[tokio::test]
393 async fn test_dispatcher_not_connected() {
394 let dispatcher = MessageDispatcher::<crate::connection::DefaultWsStream>::default();
395
396 let result = dispatcher.send(Message::Text("test".into())).await;
398 assert!(matches!(result, Err(SendError::NotConnected)));
399 }
400}