1#![recursion_limit = "1024"]
2
3mod async_rt;
4mod backend;
5mod codec;
6mod dealer;
7mod endpoint;
8mod error;
9mod fair_queue;
10mod message;
11mod r#pub;
12mod pull;
13mod push;
14mod rep;
15mod req;
16mod router;
17mod sub;
18mod task_handle;
19mod transport;
20pub mod util;
21
22#[doc(hidden)]
23pub mod __async_rt {
24    pub use super::async_rt::*;
26}
27
28pub use crate::dealer::*;
29pub use crate::endpoint::{Endpoint, Host, Transport, TryIntoEndpoint};
30pub use crate::error::{ZmqError, ZmqResult};
31pub use crate::message::*;
32pub use crate::pull::*;
33pub use crate::push::*;
34pub use crate::r#pub::*;
35pub use crate::rep::*;
36pub use crate::req::*;
37pub use crate::router::*;
38pub use crate::sub::*;
39
40use crate::codec::*;
41use crate::transport::AcceptStopHandle;
42use util::PeerIdentity;
43
44use async_trait::async_trait;
45use asynchronous_codec::FramedWrite;
46use futures_channel::mpsc;
47use futures_util::{select, FutureExt};
48use parking_lot::Mutex;
49
50use std::collections::HashMap;
51use std::convert::TryFrom;
52use std::fmt::{Debug, Display};
53use std::str::FromStr;
54use std::sync::Arc;
55
56const COMPATIBILITY_MATRIX: [u8; 121] = [
57    1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, ];
70
71#[allow(clippy::upper_case_acronyms)]
72#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
73#[repr(usize)]
74pub enum SocketType {
75    PAIR = 0,
76    PUB = 1,
77    SUB = 2,
78    REQ = 3,
79    REP = 4,
80    DEALER = 5,
81    ROUTER = 6,
82    PULL = 7,
83    PUSH = 8,
84    XPUB = 9,
85    XSUB = 10,
86    STREAM = 11,
87}
88
89impl SocketType {
90    pub const fn as_str(&self) -> &'static str {
91        match self {
92            SocketType::PAIR => "PAIR",
93            SocketType::PUB => "PUB",
94            SocketType::SUB => "SUB",
95            SocketType::REQ => "REQ",
96            SocketType::REP => "REP",
97            SocketType::DEALER => "DEALER",
98            SocketType::ROUTER => "ROUTER",
99            SocketType::PULL => "PULL",
100            SocketType::PUSH => "PUSH",
101            SocketType::XPUB => "XPUB",
102            SocketType::XSUB => "XSUB",
103            SocketType::STREAM => "STREAM",
104        }
105    }
106
107    pub fn compatible(&self, other: SocketType) -> bool {
116        let row_index = *self as usize;
117        let col_index = other as usize;
118        COMPATIBILITY_MATRIX[row_index * 11 + col_index] != 0
119    }
120}
121
122impl FromStr for SocketType {
123    type Err = ZmqError;
124
125    #[inline]
126    fn from_str(s: &str) -> Result<Self, ZmqError> {
127        Self::try_from(s.as_bytes())
128    }
129}
130
131impl TryFrom<&[u8]> for SocketType {
132    type Error = ZmqError;
133
134    fn try_from(s: &[u8]) -> Result<Self, ZmqError> {
135        Ok(match s {
136            b"PAIR" => SocketType::PAIR,
137            b"PUB" => SocketType::PUB,
138            b"SUB" => SocketType::SUB,
139            b"REQ" => SocketType::REQ,
140            b"REP" => SocketType::REP,
141            b"DEALER" => SocketType::DEALER,
142            b"ROUTER" => SocketType::ROUTER,
143            b"PULL" => SocketType::PULL,
144            b"PUSH" => SocketType::PUSH,
145            b"XPUB" => SocketType::XPUB,
146            b"XSUB" => SocketType::XSUB,
147            b"STREAM" => SocketType::STREAM,
148            _ => return Err(ZmqError::Other("Unknown socket type")),
149        })
150    }
151}
152
153impl Display for SocketType {
154    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155        f.write_str(self.as_str())
156    }
157}
158
159#[derive(Debug)]
160pub enum SocketEvent {
161    Connected(Endpoint, PeerIdentity),
162    ConnectDelayed,
163    ConnectRetried,
164    Listening(Endpoint),
165    Accepted(Endpoint, PeerIdentity),
166    AcceptFailed(ZmqError),
167    Closed,
168    CloseFailed,
169    Disconnected(PeerIdentity),
170}
171
172#[derive(Default)]
173pub struct SocketOptions {
174    pub(crate) peer_id: Option<PeerIdentity>,
175}
176
177impl SocketOptions {
178    pub fn peer_identity(&mut self, peer_id: PeerIdentity) -> &mut Self {
179        self.peer_id = Some(peer_id);
180        self
181    }
182}
183
184#[async_trait]
185pub trait MultiPeerBackend: SocketBackend {
186    async fn peer_connected(self: Arc<Self>, peer_id: &PeerIdentity, io: FramedIo);
190    fn peer_disconnected(&self, peer_id: &PeerIdentity);
191}
192
193pub trait SocketBackend: Send + Sync {
194    fn socket_type(&self) -> SocketType;
195    fn socket_options(&self) -> &SocketOptions;
196    fn shutdown(&self);
197    fn monitor(&self) -> &Mutex<Option<mpsc::Sender<SocketEvent>>>;
198}
199
200#[async_trait]
201pub trait SocketRecv {
202    async fn recv(&mut self) -> ZmqResult<ZmqMessage>;
203}
204
205#[async_trait]
206pub trait SocketSend {
207    async fn send(&mut self, message: ZmqMessage) -> ZmqResult<()>;
208}
209
210pub trait CaptureSocket: SocketSend {}
213
214#[async_trait]
215pub trait Socket: Sized + Send {
216    fn new() -> Self {
217        Self::with_options(SocketOptions::default())
218    }
219
220    fn with_options(options: SocketOptions) -> Self;
221
222    fn backend(&self) -> Arc<dyn MultiPeerBackend>;
223
224    async fn bind(&mut self, endpoint: &str) -> ZmqResult<Endpoint> {
230        let endpoint = TryIntoEndpoint::try_into(endpoint)?;
231
232        let cloned_backend = self.backend();
233        let cback = move |result| {
234            let cloned_backend = cloned_backend.clone();
235            async move {
236                let result = match result {
237                    Ok((socket, endpoint)) => {
238                        match util::peer_connected(socket, cloned_backend.clone()).await {
239                            Ok(peer_id) => Ok((endpoint, peer_id)),
240                            Err(e) => Err(e),
241                        }
242                    }
243                    Err(e) => Err(e),
244                };
245                match result {
246                    Ok((endpoint, peer_id)) => {
247                        if let Some(monitor) = cloned_backend.monitor().lock().as_mut() {
248                            let _ = monitor.try_send(SocketEvent::Accepted(endpoint, peer_id));
249                        }
250                    }
251                    Err(e) => {
252                        if let Some(monitor) = cloned_backend.monitor().lock().as_mut() {
253                            let _ = monitor.try_send(SocketEvent::AcceptFailed(e));
254                        }
255                    }
256                }
257            }
258        };
259
260        let (endpoint, stop_handle) = transport::begin_accept(endpoint, cback).await?;
261
262        if let Some(monitor) = self.backend().monitor().lock().as_mut() {
263            let _ = monitor.try_send(SocketEvent::Listening(endpoint.clone()));
264        }
265
266        self.binds().insert(endpoint.clone(), stop_handle);
267        Ok(endpoint)
268    }
269
270    fn binds(&mut self) -> &mut HashMap<Endpoint, AcceptStopHandle>;
271
272    async fn unbind(&mut self, endpoint: Endpoint) -> ZmqResult<()> {
279        let stop_handle = self.binds().remove(&endpoint);
280        let stop_handle = stop_handle.ok_or(ZmqError::NoSuchBind(endpoint))?;
281        stop_handle.0.shutdown().await
282    }
283
284    async fn unbind_all(&mut self) -> Vec<ZmqError> {
286        let mut errs = Vec::new();
287        let endpoints: Vec<_> = self
288            .binds()
289            .iter()
290            .map(|(endpoint, _)| endpoint.clone())
291            .collect();
292        for endpoint in endpoints {
293            if let Err(err) = self.unbind(endpoint).await {
294                errs.push(err);
295            }
296        }
297        errs
298    }
299
300    async fn connect(&mut self, endpoint: &str) -> ZmqResult<()> {
302        let backend = self.backend();
303        let endpoint = TryIntoEndpoint::try_into(endpoint)?;
304
305        let result = match util::connect_forever(endpoint).await {
306            Ok((socket, endpoint)) => match util::peer_connected(socket, backend).await {
307                Ok(peer_id) => Ok((endpoint, peer_id)),
308                Err(e) => Err(e),
309            },
310            Err(e) => Err(e),
311        };
312        match result {
313            Ok((endpoint, peer_id)) => {
314                if let Some(monitor) = self.backend().monitor().lock().as_mut() {
315                    let _ = monitor.try_send(SocketEvent::Connected(endpoint, peer_id));
316                }
317                Ok(())
318            }
319            Err(e) => Err(e),
320        }
321    }
322
323    fn monitor(&mut self) -> mpsc::Receiver<SocketEvent>;
328
329    async fn close(mut self) -> Vec<ZmqError> {
350        self.unbind_all().await
352    }
353}
354
355pub async fn proxy<Frontend: SocketSend + SocketRecv, Backend: SocketSend + SocketRecv>(
356    mut frontend: Frontend,
357    mut backend: Backend,
358    mut capture: Option<Box<dyn CaptureSocket>>,
359) -> ZmqResult<()> {
360    loop {
361        select! {
362            frontend_mess = frontend.recv().fuse() => {
363                match frontend_mess {
364                    Ok(message) => {
365                        if let Some(capture) = &mut capture {
366                            capture.send(message.clone()).await?;
367                        }
368                        backend.send(message).await?;
369                    }
370                    Err(_) => {
371                        todo!()
372                    }
373                }
374            },
375            backend_mess = backend.recv().fuse() => {
376                match backend_mess {
377                    Ok(message) => {
378                        if let Some(capture) = &mut capture {
379                            capture.send(message.clone()).await?;
380                        }
381                        frontend.send(message).await?;
382                    }
383                    Err(_) => {
384                        todo!()
385                    }
386                }
387            }
388        };
389    }
390}
391
392pub mod prelude {
393    pub use crate::{Socket, SocketRecv, SocketSend, TryIntoEndpoint};
396}