Skip to main content

rustapi_ws/
broadcast.rs

1//! Broadcast channel for WebSocket messages
2
3use crate::Message;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::Arc;
6use tokio::sync::broadcast;
7
8/// A broadcast channel for sending messages to multiple WebSocket clients
9///
10/// This is useful for implementing pub/sub patterns, chat rooms, or any
11/// scenario where you need to send the same message to multiple clients.
12///
13/// # Example
14///
15/// ```rust,ignore
16/// use rustapi_ws::{Broadcast, Message};
17/// use std::sync::Arc;
18///
19/// let broadcast = Arc::new(Broadcast::new());
20///
21/// // Subscribe to receive messages
22/// let mut rx = broadcast.subscribe();
23///
24/// // Send a message to all subscribers
25/// broadcast.send(Message::text("Hello everyone!"));
26///
27/// // Receive the message
28/// let msg = rx.recv().await.unwrap();
29/// ```
30#[derive(Clone)]
31pub struct Broadcast {
32    sender: broadcast::Sender<Message>,
33    subscriber_count: Arc<AtomicUsize>,
34}
35
36impl Broadcast {
37    /// Create a new broadcast channel with default capacity (100 messages)
38    pub fn new() -> Self {
39        Self::with_capacity(100)
40    }
41
42    /// Create a new broadcast channel with specified capacity
43    pub fn with_capacity(capacity: usize) -> Self {
44        let (sender, _) = broadcast::channel(capacity);
45        Self {
46            sender,
47            subscriber_count: Arc::new(AtomicUsize::new(0)),
48        }
49    }
50
51    /// Subscribe to receive broadcast messages
52    pub fn subscribe(&self) -> BroadcastReceiver {
53        self.subscriber_count.fetch_add(1, Ordering::SeqCst);
54        BroadcastReceiver {
55            inner: self.sender.subscribe(),
56            subscriber_count: self.subscriber_count.clone(),
57        }
58    }
59
60    /// Send a message to all subscribers
61    ///
62    /// Returns the number of receivers that received the message.
63    /// Returns 0 if there are no active subscribers.
64    pub fn send(&self, msg: Message) -> usize {
65        self.sender.send(msg).unwrap_or(0)
66    }
67
68    /// Send a text message to all subscribers
69    pub fn send_text(&self, text: impl Into<String>) -> usize {
70        self.send(Message::text(text))
71    }
72
73    /// Send a JSON message to all subscribers
74    pub fn send_json<T: serde::Serialize>(
75        &self,
76        value: &T,
77    ) -> Result<usize, crate::WebSocketError> {
78        let msg = Message::json(value)?;
79        Ok(self.send(msg))
80    }
81
82    /// Get the current number of subscribers
83    pub fn subscriber_count(&self) -> usize {
84        self.subscriber_count.load(Ordering::SeqCst)
85    }
86
87    /// Check if there are any active subscribers
88    pub fn has_subscribers(&self) -> bool {
89        self.subscriber_count() > 0
90    }
91}
92
93impl Default for Broadcast {
94    fn default() -> Self {
95        Self::new()
96    }
97}
98
99/// Receiver for broadcast messages
100pub struct BroadcastReceiver {
101    inner: broadcast::Receiver<Message>,
102    subscriber_count: Arc<AtomicUsize>,
103}
104
105impl BroadcastReceiver {
106    /// Receive the next broadcast message
107    ///
108    /// Returns `None` if the broadcast channel is closed.
109    /// Returns `Err` if messages were missed due to slow consumption.
110    pub async fn recv(&mut self) -> Option<Result<Message, BroadcastRecvError>> {
111        match self.inner.recv().await {
112            Ok(msg) => Some(Ok(msg)),
113            Err(broadcast::error::RecvError::Closed) => None,
114            Err(broadcast::error::RecvError::Lagged(count)) => {
115                Some(Err(BroadcastRecvError::Lagged(count)))
116            }
117        }
118    }
119
120    /// Try to receive a message without waiting
121    pub fn try_recv(&mut self) -> Option<Result<Message, BroadcastRecvError>> {
122        match self.inner.try_recv() {
123            Ok(msg) => Some(Ok(msg)),
124            Err(broadcast::error::TryRecvError::Empty) => None,
125            Err(broadcast::error::TryRecvError::Closed) => None,
126            Err(broadcast::error::TryRecvError::Lagged(count)) => {
127                Some(Err(BroadcastRecvError::Lagged(count)))
128            }
129        }
130    }
131}
132
133impl Drop for BroadcastReceiver {
134    fn drop(&mut self) {
135        self.subscriber_count.fetch_sub(1, Ordering::SeqCst);
136    }
137}
138
139/// Error when receiving broadcast messages
140#[derive(Debug, Clone, Copy)]
141pub enum BroadcastRecvError {
142    /// Some messages were missed because the receiver is too slow
143    Lagged(u64),
144}
145
146impl std::fmt::Display for BroadcastRecvError {
147    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148        match self {
149            Self::Lagged(count) => write!(f, "Lagged behind by {} messages", count),
150        }
151    }
152}
153
154impl std::error::Error for BroadcastRecvError {}