1use crate::codec::*;
2use crate::endpoint::Endpoint;
3use crate::error::ZmqResult;
4use crate::fair_queue::{FairQueue, QueueInner};
5use crate::message::*;
6use crate::transport::AcceptStopHandle;
7use crate::util::PeerIdentity;
8use crate::{CaptureSocket, SocketOptions};
9use crate::{
10 MultiPeerBackend, Socket, SocketBackend, SocketEvent, SocketRecv, SocketSend, SocketType,
11 ZmqError,
12};
13
14use async_trait::async_trait;
15use futures::channel::mpsc;
16use futures::lock::Mutex as AsyncMutex;
17use futures::{SinkExt, StreamExt};
18use parking_lot::Mutex;
19
20use std::collections::HashMap;
21use std::io::ErrorKind;
22use std::pin::Pin;
23use std::sync::Arc;
24
25pub(crate) struct XPubSubscriber {
26 pub(crate) subscriptions: Vec<Vec<u8>>,
27 pub(crate) send_queue: Arc<AsyncMutex<Pin<Box<ZmqFramedWrite>>>>,
28}
29
30pub(crate) struct XPubSocketBackend {
31 subscribers: scc::HashMap<PeerIdentity, XPubSubscriber>,
32 fair_queue_inner: Arc<Mutex<QueueInner<ZmqFramedRead, PeerIdentity>>>,
33 socket_monitor: Mutex<Option<mpsc::Sender<SocketEvent>>>,
34 socket_options: SocketOptions,
35}
36
37impl XPubSocketBackend {
38 fn message_received(&self, peer_id: &PeerIdentity, message: Message) {
39 let data = match message {
40 Message::Message(m) => {
41 if m.len() != 1 {
42 return;
43 }
44 m.into_vec().pop().unwrap_or_default()
45 }
46 _ => return,
47 };
48
49 if data.is_empty() {
50 return;
51 }
52
53 match data.first() {
54 Some(1) => {
55 if let Some(mut entry) = self.subscribers.get_sync(peer_id) {
57 entry.subscriptions.push(Vec::from(&data[1..]));
58 }
59 }
60 Some(0) => {
61 let sub = Vec::from(&data[1..]);
63 if let Some(mut entry) = self.subscribers.get_sync(peer_id) {
64 if let Some(index) = entry.subscriptions.iter().position(|s| s == &sub) {
65 entry.subscriptions.remove(index);
66 }
67 }
68 }
69 _ => {}
70 }
71 }
72}
73
74impl SocketBackend for XPubSocketBackend {
75 fn socket_type(&self) -> SocketType {
76 SocketType::XPUB
77 }
78
79 fn socket_options(&self) -> &SocketOptions {
80 &self.socket_options
81 }
82
83 fn shutdown(&self) {
84 self.subscribers.clear_sync();
85 }
86
87 fn monitor(&self) -> &Mutex<Option<mpsc::Sender<SocketEvent>>> {
88 &self.socket_monitor
89 }
90}
91
92#[async_trait]
93impl MultiPeerBackend for XPubSocketBackend {
94 async fn peer_connected(self: Arc<Self>, peer_id: &PeerIdentity, io: FramedIo) {
95 let (recv_queue, send_queue) = io.into_parts();
96
97 self.subscribers
98 .upsert_async(
99 peer_id.clone(),
100 XPubSubscriber {
101 subscriptions: vec![],
102 send_queue: Arc::new(AsyncMutex::new(Box::pin(send_queue))),
103 },
104 )
105 .await;
106
107 self.fair_queue_inner
108 .lock()
109 .insert(peer_id.clone(), recv_queue);
110 }
111
112 fn peer_disconnected(&self, peer_id: &PeerIdentity) {
113 log::info!("Client disconnected {:?}", peer_id);
114 self.subscribers.remove_sync(peer_id);
115 self.fair_queue_inner.lock().remove(peer_id);
116 }
117}
118
119pub struct XPubSocket {
120 pub(crate) backend: Arc<XPubSocketBackend>,
121 fair_queue: FairQueue<ZmqFramedRead, PeerIdentity>,
122 binds: HashMap<Endpoint, AcceptStopHandle>,
123}
124
125impl Drop for XPubSocket {
126 fn drop(&mut self) {
127 self.backend.shutdown();
128 }
129}
130
131#[async_trait]
132impl SocketSend for XPubSocket {
133 async fn send(&mut self, message: ZmqMessage) -> ZmqResult<()> {
134 let first_frame = match message.get(0) {
135 Some(frame) => frame,
136 None => return Ok(()), };
138 let mut targets = Vec::new();
139 let mut iter = self.backend.subscribers.begin_async().await;
140 while let Some(subscriber) = iter {
141 if subscriber.subscriptions.iter().any(|sub_filter| {
142 sub_filter.len() <= first_frame.len()
143 && sub_filter.as_slice() == &first_frame[0..sub_filter.len()]
144 }) {
145 targets.push((subscriber.key().clone(), subscriber.send_queue.clone()));
146 }
147 iter = subscriber.next_async().await;
148 }
149
150 let mut dead_peers = Vec::new();
151 for (peer_id, send_queue) in targets {
152 let res = send_queue
153 .lock()
154 .await
155 .as_mut()
156 .send(Message::Message(message.clone()))
157 .await;
158 match res {
159 Ok(()) => {}
160 Err(CodecError::Io(e)) => {
161 if e.kind() == ErrorKind::BrokenPipe {
162 dead_peers.push(peer_id);
163 } else {
164 log::error!("Error sending message: {:?}", e);
165 }
166 }
167 Err(e) => {
168 log::error!("Error sending message: {:?}", e);
169 return Err(e.into());
170 }
171 }
172 }
173 for peer in dead_peers {
174 self.backend.peer_disconnected(&peer);
175 }
176 Ok(())
177 }
178}
179
180#[async_trait]
181impl SocketRecv for XPubSocket {
182 async fn recv(&mut self) -> ZmqResult<ZmqMessage> {
183 loop {
184 match self.fair_queue.next().await {
185 Some((peer_id, Ok(Message::Message(message)))) => {
186 self.backend
188 .message_received(&peer_id, Message::Message(message.clone()));
189 return Ok(message);
191 }
192 Some((_peer_id, Ok(_msg))) => {
193 }
195 Some((peer_id, Err(e))) => {
196 self.backend.peer_disconnected(&peer_id);
197 return Err(e.into());
198 }
199 None => {
200 return Err(ZmqError::NoMessage);
201 }
202 }
203 }
204 }
205}
206
207impl CaptureSocket for XPubSocket {}
208
209#[async_trait]
210impl Socket for XPubSocket {
211 fn with_options(options: SocketOptions) -> Self {
212 let mut fair_queue = FairQueue::new(true);
213 let backend = Arc::new(XPubSocketBackend {
214 subscribers: scc::HashMap::new(),
215 fair_queue_inner: fair_queue.inner(),
216 socket_monitor: Mutex::new(None),
217 socket_options: options,
218 });
219
220 let backend_weak = Arc::downgrade(&backend);
221 fair_queue.set_on_disconnect(move |peer_id: PeerIdentity| {
222 if let Some(backend) = backend_weak.upgrade() {
223 backend.peer_disconnected(&peer_id);
224 }
225 });
226
227 Self {
228 backend,
229 fair_queue,
230 binds: HashMap::new(),
231 }
232 }
233
234 fn backend(&self) -> Arc<dyn MultiPeerBackend> {
235 self.backend.clone()
236 }
237
238 fn binds(&mut self) -> &mut HashMap<Endpoint, AcceptStopHandle> {
239 &mut self.binds
240 }
241
242 fn monitor(&mut self) -> mpsc::Receiver<SocketEvent> {
243 let (sender, receiver) = mpsc::channel(1024);
244 self.backend.socket_monitor.lock().replace(sender);
245 receiver
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use crate::async_rt;
253 use crate::util::tests::{
254 test_bind_to_any_port_helper, test_bind_to_unspecified_interface_helper,
255 };
256 use crate::ZmqResult;
257 use std::net::IpAddr;
258
259 #[async_rt::test]
260 async fn test_bind_to_any_port() -> ZmqResult<()> {
261 let s = XPubSocket::new();
262 test_bind_to_any_port_helper(s).await
263 }
264
265 #[async_rt::test]
266 async fn test_bind_to_any_ipv4_interface() -> ZmqResult<()> {
267 let any_ipv4: IpAddr = "0.0.0.0".parse().unwrap();
268 let s = XPubSocket::new();
269 test_bind_to_unspecified_interface_helper(any_ipv4, s, 4020).await
270 }
271
272 #[async_rt::test]
273 async fn test_bind_to_any_ipv6_interface() -> ZmqResult<()> {
274 let any_ipv6: IpAddr = "::".parse().unwrap();
275 let s = XPubSocket::new();
276 test_bind_to_unspecified_interface_helper(any_ipv6, s, 4030).await
277 }
278}