Skip to main content

zeromq/
xpub.rs

1use crate::codec::*;
2use crate::endpoint::Endpoint;
3use crate::error::ZmqResult;
4use crate::fair_queue::{FairQueue, QueueInner};
5use crate::message::*;
6use crate::transport::AcceptStopHandle;
7use crate::util::PeerIdentity;
8use crate::{CaptureSocket, SocketOptions};
9use crate::{
10    MultiPeerBackend, Socket, SocketBackend, SocketEvent, SocketRecv, SocketSend, SocketType,
11    ZmqError,
12};
13
14use async_trait::async_trait;
15use futures::channel::mpsc;
16use futures::lock::Mutex as AsyncMutex;
17use futures::{SinkExt, StreamExt};
18use parking_lot::Mutex;
19
20use std::collections::HashMap;
21use std::io::ErrorKind;
22use std::pin::Pin;
23use std::sync::Arc;
24
25pub(crate) struct XPubSubscriber {
26    pub(crate) subscriptions: Vec<Vec<u8>>,
27    pub(crate) send_queue: Arc<AsyncMutex<Pin<Box<ZmqFramedWrite>>>>,
28}
29
30pub(crate) struct XPubSocketBackend {
31    subscribers: scc::HashMap<PeerIdentity, XPubSubscriber>,
32    fair_queue_inner: Arc<Mutex<QueueInner<ZmqFramedRead, PeerIdentity>>>,
33    socket_monitor: Mutex<Option<mpsc::Sender<SocketEvent>>>,
34    socket_options: SocketOptions,
35}
36
37impl XPubSocketBackend {
38    fn message_received(&self, peer_id: &PeerIdentity, message: Message) {
39        let data = match message {
40            Message::Message(m) => {
41                if m.len() != 1 {
42                    return;
43                }
44                m.into_vec().pop().unwrap_or_default()
45            }
46            _ => return,
47        };
48
49        if data.is_empty() {
50            return;
51        }
52
53        match data.first() {
54            Some(1) => {
55                // Subscribe
56                if let Some(mut entry) = self.subscribers.get_sync(peer_id) {
57                    entry.subscriptions.push(Vec::from(&data[1..]));
58                }
59            }
60            Some(0) => {
61                // Unsubscribe
62                let sub = Vec::from(&data[1..]);
63                if let Some(mut entry) = self.subscribers.get_sync(peer_id) {
64                    if let Some(index) = entry.subscriptions.iter().position(|s| s == &sub) {
65                        entry.subscriptions.remove(index);
66                    }
67                }
68            }
69            _ => {}
70        }
71    }
72}
73
74impl SocketBackend for XPubSocketBackend {
75    fn socket_type(&self) -> SocketType {
76        SocketType::XPUB
77    }
78
79    fn socket_options(&self) -> &SocketOptions {
80        &self.socket_options
81    }
82
83    fn shutdown(&self) {
84        self.subscribers.clear_sync();
85    }
86
87    fn monitor(&self) -> &Mutex<Option<mpsc::Sender<SocketEvent>>> {
88        &self.socket_monitor
89    }
90}
91
92#[async_trait]
93impl MultiPeerBackend for XPubSocketBackend {
94    async fn peer_connected(self: Arc<Self>, peer_id: &PeerIdentity, io: FramedIo) {
95        let (recv_queue, send_queue) = io.into_parts();
96
97        self.subscribers
98            .upsert_async(
99                peer_id.clone(),
100                XPubSubscriber {
101                    subscriptions: vec![],
102                    send_queue: Arc::new(AsyncMutex::new(Box::pin(send_queue))),
103                },
104            )
105            .await;
106
107        self.fair_queue_inner
108            .lock()
109            .insert(peer_id.clone(), recv_queue);
110    }
111
112    fn peer_disconnected(&self, peer_id: &PeerIdentity) {
113        log::info!("Client disconnected {:?}", peer_id);
114        self.subscribers.remove_sync(peer_id);
115        self.fair_queue_inner.lock().remove(peer_id);
116    }
117}
118
119pub struct XPubSocket {
120    pub(crate) backend: Arc<XPubSocketBackend>,
121    fair_queue: FairQueue<ZmqFramedRead, PeerIdentity>,
122    binds: HashMap<Endpoint, AcceptStopHandle>,
123}
124
125impl Drop for XPubSocket {
126    fn drop(&mut self) {
127        self.backend.shutdown();
128    }
129}
130
131#[async_trait]
132impl SocketSend for XPubSocket {
133    async fn send(&mut self, message: ZmqMessage) -> ZmqResult<()> {
134        let first_frame = match message.get(0) {
135            Some(frame) => frame,
136            None => return Ok(()), // Empty message, nothing to publish
137        };
138        let mut targets = Vec::new();
139        let mut iter = self.backend.subscribers.begin_async().await;
140        while let Some(subscriber) = iter {
141            if subscriber.subscriptions.iter().any(|sub_filter| {
142                sub_filter.len() <= first_frame.len()
143                    && sub_filter.as_slice() == &first_frame[0..sub_filter.len()]
144            }) {
145                targets.push((subscriber.key().clone(), subscriber.send_queue.clone()));
146            }
147            iter = subscriber.next_async().await;
148        }
149
150        let mut dead_peers = Vec::new();
151        for (peer_id, send_queue) in targets {
152            let res = send_queue
153                .lock()
154                .await
155                .as_mut()
156                .send(Message::Message(message.clone()))
157                .await;
158            match res {
159                Ok(()) => {}
160                Err(CodecError::Io(e)) => {
161                    if e.kind() == ErrorKind::BrokenPipe {
162                        dead_peers.push(peer_id);
163                    } else {
164                        log::error!("Error sending message: {:?}", e);
165                    }
166                }
167                Err(e) => {
168                    log::error!("Error sending message: {:?}", e);
169                    return Err(e.into());
170                }
171            }
172        }
173        for peer in dead_peers {
174            self.backend.peer_disconnected(&peer);
175        }
176        Ok(())
177    }
178}
179
180#[async_trait]
181impl SocketRecv for XPubSocket {
182    async fn recv(&mut self) -> ZmqResult<ZmqMessage> {
183        loop {
184            match self.fair_queue.next().await {
185                Some((peer_id, Ok(Message::Message(message)))) => {
186                    // Process the subscription message internally to update tracking
187                    self.backend
188                        .message_received(&peer_id, Message::Message(message.clone()));
189                    // Also expose it to the application
190                    return Ok(message);
191                }
192                Some((_peer_id, Ok(_msg))) => {
193                    // Ignore non-message frames
194                }
195                Some((peer_id, Err(e))) => {
196                    self.backend.peer_disconnected(&peer_id);
197                    return Err(e.into());
198                }
199                None => {
200                    return Err(ZmqError::NoMessage);
201                }
202            }
203        }
204    }
205}
206
207impl CaptureSocket for XPubSocket {}
208
209#[async_trait]
210impl Socket for XPubSocket {
211    fn with_options(options: SocketOptions) -> Self {
212        let mut fair_queue = FairQueue::new(true);
213        let backend = Arc::new(XPubSocketBackend {
214            subscribers: scc::HashMap::new(),
215            fair_queue_inner: fair_queue.inner(),
216            socket_monitor: Mutex::new(None),
217            socket_options: options,
218        });
219
220        let backend_weak = Arc::downgrade(&backend);
221        fair_queue.set_on_disconnect(move |peer_id: PeerIdentity| {
222            if let Some(backend) = backend_weak.upgrade() {
223                backend.peer_disconnected(&peer_id);
224            }
225        });
226
227        Self {
228            backend,
229            fair_queue,
230            binds: HashMap::new(),
231        }
232    }
233
234    fn backend(&self) -> Arc<dyn MultiPeerBackend> {
235        self.backend.clone()
236    }
237
238    fn binds(&mut self) -> &mut HashMap<Endpoint, AcceptStopHandle> {
239        &mut self.binds
240    }
241
242    fn monitor(&mut self) -> mpsc::Receiver<SocketEvent> {
243        let (sender, receiver) = mpsc::channel(1024);
244        self.backend.socket_monitor.lock().replace(sender);
245        receiver
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use crate::async_rt;
253    use crate::util::tests::{
254        test_bind_to_any_port_helper, test_bind_to_unspecified_interface_helper,
255    };
256    use crate::ZmqResult;
257    use std::net::IpAddr;
258
259    #[async_rt::test]
260    async fn test_bind_to_any_port() -> ZmqResult<()> {
261        let s = XPubSocket::new();
262        test_bind_to_any_port_helper(s).await
263    }
264
265    #[async_rt::test]
266    async fn test_bind_to_any_ipv4_interface() -> ZmqResult<()> {
267        let any_ipv4: IpAddr = "0.0.0.0".parse().unwrap();
268        let s = XPubSocket::new();
269        test_bind_to_unspecified_interface_helper(any_ipv4, s, 4020).await
270    }
271
272    #[async_rt::test]
273    async fn test_bind_to_any_ipv6_interface() -> ZmqResult<()> {
274        let any_ipv6: IpAddr = "::".parse().unwrap();
275        let s = XPubSocket::new();
276        test_bind_to_unspecified_interface_helper(any_ipv6, s, 4030).await
277    }
278}