wrpc_transport/frame/conn/
mod.rs

1use core::future::Future;
2use core::mem;
3use core::pin::Pin;
4use core::task::{ready, Context, Poll};
5
6use std::sync::Arc;
7
8use anyhow::ensure;
9use bytes::{Buf as _, BufMut as _, Bytes, BytesMut};
10use futures::Sink as _;
11use pin_project_lite::pin_project;
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt as _};
13use tokio::sync::mpsc;
14use tokio::task::JoinSet;
15use tokio_stream::wrappers::ReceiverStream;
16use tokio_util::codec::Encoder;
17use tokio_util::io::StreamReader;
18use tokio_util::sync::PollSender;
19use tracing::{debug, error, instrument, trace, Instrument as _, Span};
20use wasm_tokio::{AsyncReadLeb128 as _, Leb128Encoder};
21
22use crate::Index;
23
24mod accept;
25mod client;
26mod server;
27
28pub use accept::*;
29pub use client::*;
30pub use server::*;
31
32/// Index trie containing async stream subscriptions
33#[derive(Debug, Default)]
34enum IndexTrie {
35    #[default]
36    Empty,
37    Leaf {
38        tx: Option<mpsc::Sender<std::io::Result<Bytes>>>,
39        rx: Option<mpsc::Receiver<std::io::Result<Bytes>>>,
40    },
41    IndexNode {
42        tx: Option<mpsc::Sender<std::io::Result<Bytes>>>,
43        rx: Option<mpsc::Receiver<std::io::Result<Bytes>>>,
44        nested: Vec<Option<IndexTrie>>,
45    },
46    // TODO: Add partially-indexed `WildcardIndexNode`
47    WildcardNode {
48        tx: Option<mpsc::Sender<std::io::Result<Bytes>>>,
49        rx: Option<mpsc::Receiver<std::io::Result<Bytes>>>,
50        nested: Option<Box<IndexTrie>>,
51    },
52}
53
54impl<'a>
55    From<(
56        &'a [Option<usize>],
57        mpsc::Sender<std::io::Result<Bytes>>,
58        Option<mpsc::Receiver<std::io::Result<Bytes>>>,
59    )> for IndexTrie
60{
61    fn from(
62        (path, tx, rx): (
63            &'a [Option<usize>],
64            mpsc::Sender<std::io::Result<Bytes>>,
65            Option<mpsc::Receiver<std::io::Result<Bytes>>>,
66        ),
67    ) -> Self {
68        match path {
69            [] => Self::Leaf { tx: Some(tx), rx },
70            [None, path @ ..] => Self::WildcardNode {
71                tx: None,
72                rx: None,
73                nested: Some(Box::new(Self::from((path, tx, rx)))),
74            },
75            [Some(i), path @ ..] => Self::IndexNode {
76                tx: None,
77                rx: None,
78                nested: {
79                    let n = i.saturating_add(1);
80                    let mut nested = Vec::with_capacity(n);
81                    nested.resize_with(n, Option::default);
82                    nested[*i] = Some(Self::from((path, tx, rx)));
83                    nested
84                },
85            },
86        }
87    }
88}
89
90impl<'a>
91    From<(
92        &'a [Option<usize>],
93        mpsc::Sender<std::io::Result<Bytes>>,
94        mpsc::Receiver<std::io::Result<Bytes>>,
95    )> for IndexTrie
96{
97    fn from(
98        (path, tx, rx): (
99            &'a [Option<usize>],
100            mpsc::Sender<std::io::Result<Bytes>>,
101            mpsc::Receiver<std::io::Result<Bytes>>,
102        ),
103    ) -> Self {
104        Self::from((path, tx, Some(rx)))
105    }
106}
107
108impl<'a> From<(&'a [Option<usize>], mpsc::Sender<std::io::Result<Bytes>>)> for IndexTrie {
109    fn from((path, tx): (&'a [Option<usize>], mpsc::Sender<std::io::Result<Bytes>>)) -> Self {
110        Self::from((path, tx, None))
111    }
112}
113
114impl<P: AsRef<[Option<usize>]>> FromIterator<P> for IndexTrie {
115    fn from_iter<T: IntoIterator<Item = P>>(iter: T) -> Self {
116        let mut root = Self::Empty;
117        for path in iter {
118            let (tx, rx) = mpsc::channel(16);
119            if !root.insert(path.as_ref(), tx, Some(rx)) {
120                return Self::Empty;
121            }
122        }
123        root
124    }
125}
126
127impl IndexTrie {
128    /// Takes the receiver
129    #[instrument(level = "trace", skip(self), ret(level = "trace"))]
130    fn take_rx(&mut self, path: &[usize]) -> Option<mpsc::Receiver<std::io::Result<Bytes>>> {
131        let Some((i, path)) = path.split_first() else {
132            return match self {
133                Self::Empty => None,
134                Self::Leaf { rx, .. } => rx.take(),
135                Self::IndexNode { tx, rx, nested } => {
136                    let rx = rx.take();
137                    if nested.is_empty() && tx.is_none() {
138                        *self = Self::Empty;
139                    }
140                    rx
141                }
142                Self::WildcardNode { tx, rx, nested } => {
143                    let rx = rx.take();
144                    if nested.is_none() && tx.is_none() {
145                        *self = Self::Empty;
146                    }
147                    rx
148                }
149            };
150        };
151        match self {
152            Self::Empty | Self::Leaf { .. } | Self::WildcardNode { .. } => None,
153            Self::IndexNode { ref mut nested, .. } => nested
154                .get_mut(*i)
155                .and_then(|nested| nested.as_mut().and_then(|nested| nested.take_rx(path))),
156            // TODO: Demux the subscription
157            //Self::WildcardNode { ref mut nested, .. } => {
158            //    nested.as_mut().and_then(|nested| nested.take(path))
159            //}
160        }
161    }
162
163    /// Gets a sender
164    #[instrument(level = "trace", skip(self), ret(level = "trace"))]
165    fn get_tx(&mut self, path: &[usize]) -> Option<mpsc::Sender<std::io::Result<Bytes>>> {
166        let Some((i, path)) = path.split_first() else {
167            return match self {
168                Self::Empty => None,
169                Self::Leaf { tx, .. } => tx.clone(),
170                Self::IndexNode { tx, .. } | Self::WildcardNode { tx, .. } => tx.clone(),
171            };
172        };
173        match self {
174            Self::Empty | Self::Leaf { .. } | Self::WildcardNode { .. } => None,
175            Self::IndexNode { ref mut nested, .. } => {
176                let nested = nested.get_mut(*i)?;
177                let nested = nested.as_mut()?;
178                nested.get_tx(path)
179            } // TODO: Demux the subscription
180              //Self::WildcardNode { ref mut nested, .. } => {
181              //    nested.as_mut().and_then(|nested| nested.take(path))
182              //}
183        }
184    }
185
186    /// Closes all senders in the trie
187    #[instrument(level = "trace", skip(self), ret(level = "trace"))]
188    fn close_tx(&mut self) {
189        match self {
190            Self::Empty => {}
191            Self::Leaf { tx, .. } => {
192                mem::take(tx);
193            }
194            Self::IndexNode {
195                tx, ref mut nested, ..
196            } => {
197                mem::take(tx);
198                for nested in nested.iter_mut().flatten() {
199                    nested.close_tx();
200                }
201            }
202            Self::WildcardNode {
203                tx, ref mut nested, ..
204            } => {
205                mem::take(tx);
206                if let Some(nested) = nested {
207                    nested.close_tx();
208                }
209            }
210        }
211    }
212
213    /// Inserts `sender` and `receiver` under a `path` - returns `false` if it failed and `true` if it succeeded.
214    /// Tree state after `false` is returned is undefined
215    #[instrument(level = "trace", skip(self, sender, receiver), ret(level = "trace"))]
216    fn insert(
217        &mut self,
218        path: &[Option<usize>],
219        sender: mpsc::Sender<std::io::Result<Bytes>>,
220        receiver: Option<mpsc::Receiver<std::io::Result<Bytes>>>,
221    ) -> bool {
222        match self {
223            Self::Empty => {
224                *self = Self::from((path, sender, receiver));
225                true
226            }
227            Self::Leaf { .. } => {
228                let Some((i, path)) = path.split_first() else {
229                    return false;
230                };
231                let Self::Leaf { tx, rx } = mem::take(self) else {
232                    return false;
233                };
234                if let Some(i) = i {
235                    let n = i.saturating_add(1);
236                    let mut nested = Vec::with_capacity(n);
237                    nested.resize_with(n, Option::default);
238                    nested[*i] = Some(Self::from((path, sender, receiver)));
239                    *self = Self::IndexNode { tx, rx, nested };
240                } else {
241                    *self = Self::WildcardNode {
242                        tx,
243                        rx,
244                        nested: Some(Box::new(Self::from((path, sender, receiver)))),
245                    };
246                }
247                true
248            }
249            Self::IndexNode {
250                ref mut tx,
251                ref mut rx,
252                ref mut nested,
253            } => match (&tx, &rx, path) {
254                (None, None, []) => {
255                    *tx = Some(sender);
256                    *rx = receiver;
257                    true
258                }
259                (_, _, [Some(i), path @ ..]) => {
260                    let cap = i.saturating_add(1);
261                    if nested.len() < cap {
262                        nested.resize_with(cap, Option::default);
263                    }
264                    let nested = &mut nested[*i];
265                    if let Some(nested) = nested {
266                        nested.insert(path, sender, receiver)
267                    } else {
268                        *nested = Some(Self::from((path, sender, receiver)));
269                        true
270                    }
271                }
272                _ => false,
273            },
274            Self::WildcardNode {
275                ref mut tx,
276                ref mut rx,
277                ref mut nested,
278            } => match (&tx, &rx, path) {
279                (None, None, []) => {
280                    *tx = Some(sender);
281                    *rx = receiver;
282                    true
283                }
284                (_, _, [None, path @ ..]) => {
285                    if let Some(nested) = nested {
286                        nested.insert(path, sender, receiver)
287                    } else {
288                        *nested = Some(Box::new(Self::from((path, sender, receiver))));
289                        true
290                    }
291                }
292                _ => false,
293            },
294        }
295    }
296}
297
298pin_project! {
299    /// Incoming framed stream
300    #[project = IncomingProj]
301    pub struct Incoming {
302        #[pin]
303        rx: Option<StreamReader<ReceiverStream<std::io::Result<Bytes>>, Bytes>>,
304        path: Arc<[usize]>,
305        index: Arc<std::sync::Mutex<IndexTrie>>,
306        io: Arc<JoinSet<()>>,
307    }
308}
309
310impl Index<Self> for Incoming {
311    #[instrument(level = "trace", skip(self), fields(path = ?self.path))]
312    fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
313        ensure!(!path.is_empty());
314        let path = if self.path.is_empty() {
315            Arc::from(path)
316        } else {
317            Arc::from([self.path.as_ref(), path].concat())
318        };
319        trace!("locking index trie");
320        let mut index = self
321            .index
322            .lock()
323            .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err.to_string()))?;
324        trace!(?path, "taking index subscription");
325        let rx = index
326            .take_rx(&path)
327            .map(|rx| StreamReader::new(ReceiverStream::new(rx)));
328        Ok(Self {
329            rx,
330            path,
331            index: Arc::clone(&self.index),
332            io: Arc::clone(&self.io),
333        })
334    }
335}
336
337impl AsyncRead for Incoming {
338    #[instrument(level = "trace", skip_all, fields(path = ?self.path), ret(level = "trace"))]
339    fn poll_read(
340        mut self: Pin<&mut Self>,
341        cx: &mut Context<'_>,
342        buf: &mut tokio::io::ReadBuf<'_>,
343    ) -> Poll<std::io::Result<()>> {
344        if buf.remaining() == 0 {
345            return Poll::Ready(Ok(()));
346        }
347        trace!("reading");
348        let this = self.as_mut().project();
349        let Some(rx) = this.rx.as_pin_mut() else {
350            trace!("reader is closed");
351            return Poll::Ready(Ok(()));
352        };
353        ready!(rx.poll_read(cx, buf))?;
354        trace!(buf = ?buf.filled(), "read buffer");
355        if buf.filled().is_empty() {
356            self.rx.take();
357        }
358        Poll::Ready(Ok(()))
359    }
360}
361
362pin_project! {
363    /// Outgoing framed stream
364    #[project = OutgoingProj]
365    pub struct Outgoing {
366        #[pin]
367        tx: PollSender<(Bytes, Bytes)>,
368        path: Arc<[usize]>,
369        path_buf: Bytes,
370    }
371}
372
373impl Index<Self> for Outgoing {
374    #[instrument(level = "trace", skip(self), fields(path = ?self.path))]
375    fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
376        ensure!(!path.is_empty());
377        let path: Arc<[usize]> = if self.path.is_empty() {
378            Arc::from(path)
379        } else {
380            Arc::from([self.path.as_ref(), path].concat())
381        };
382        let mut buf = BytesMut::with_capacity(path.len().saturating_add(5));
383        let n = u32::try_from(path.len())
384            .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
385        trace!(n, "encoding path length");
386        Leb128Encoder.encode(n, &mut buf)?;
387        for p in path.as_ref() {
388            let p = u32::try_from(*p)
389                .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
390            trace!(p, "encoding path element");
391            Leb128Encoder.encode(p, &mut buf)?;
392        }
393        Ok(Self {
394            tx: self.tx.clone(),
395            path,
396            path_buf: buf.freeze(),
397        })
398    }
399}
400
401impl AsyncWrite for Outgoing {
402    #[instrument(level = "trace", skip_all, fields(path = ?self.path, buf = format!("{buf:02x?}")), ret(level = "trace"))]
403    fn poll_write(
404        self: Pin<&mut Self>,
405        cx: &mut Context<'_>,
406        buf: &[u8],
407    ) -> Poll<std::io::Result<usize>> {
408        trace!("writing outgoing chunk");
409        let mut this = self.project();
410        ready!(this.tx.as_mut().poll_ready(cx))
411            .map_err(|err| std::io::Error::new(std::io::ErrorKind::BrokenPipe, err))?;
412        this.tx
413            .start_send((this.path_buf.clone(), Bytes::copy_from_slice(buf)))
414            .map_err(|err| std::io::Error::new(std::io::ErrorKind::BrokenPipe, err))?;
415        Poll::Ready(Ok(buf.len()))
416    }
417
418    #[instrument(level = "trace", skip_all, fields(path = ?self.path), ret(level = "trace"))]
419    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
420        Poll::Ready(Ok(()))
421    }
422
423    #[instrument(level = "trace", skip_all, fields(path = ?self.path), ret(level = "trace"))]
424    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
425        Poll::Ready(Ok(()))
426    }
427}
428
429#[instrument(level = "trace", skip_all, ret(level = "trace"))]
430async fn ingress(
431    mut rx: impl AsyncRead + Unpin,
432    index: &std::sync::Mutex<IndexTrie>,
433    param_tx: mpsc::Sender<std::io::Result<Bytes>>,
434) -> std::io::Result<()> {
435    loop {
436        trace!("reading path length");
437        let b = match rx.read_u8().await {
438            Ok(b) => b,
439            Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
440            Err(err) => return Err(err),
441        };
442        let n = AsyncReadExt::chain([b].as_slice(), &mut rx)
443            .read_u32_leb128()
444            .await?;
445        let n = n
446            .try_into()
447            .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
448        trace!(n, "read path length");
449        let tx = if n == 0 {
450            &param_tx
451        } else {
452            let mut path = Vec::with_capacity(n);
453            for i in 0..n {
454                trace!(i, "reading path element");
455                let p = rx.read_u32_leb128().await?;
456                let p = usize::try_from(p)
457                    .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
458                path.push(p);
459            }
460            trace!(?path, "read path");
461
462            trace!("locking index trie");
463            let mut index = index
464                .lock()
465                .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err.to_string()))?;
466            &index.get_tx(&path).ok_or_else(|| {
467                std::io::Error::new(
468                    std::io::ErrorKind::NotFound,
469                    format!("`{path:?}` subscription not found"),
470                )
471            })?
472        };
473        trace!("reading data length");
474        let n = rx.read_u32_leb128().await?;
475        let n = n
476            .try_into()
477            .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
478        trace!(n, "read data length");
479        let mut buf = BytesMut::with_capacity(n);
480        buf.put_bytes(0, n);
481        trace!("reading data");
482        rx.read_exact(&mut buf).await?;
483        trace!(?buf, "read data");
484        tx.send(Ok(buf.freeze())).await.map_err(|_| {
485            std::io::Error::new(std::io::ErrorKind::BrokenPipe, "stream receiver closed")
486        })?;
487    }
488}
489
490#[instrument(level = "trace", skip_all)]
491async fn egress(
492    mut tx: impl AsyncWrite + Unpin,
493    mut rx: mpsc::Receiver<(Bytes, Bytes)>,
494) -> std::io::Result<()> {
495    let mut buf = BytesMut::with_capacity(5);
496    trace!("waiting for next frame");
497    while let Some((path, data)) = rx.recv().await {
498        let data_len = u32::try_from(data.len())
499            .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
500        buf.clear();
501        Leb128Encoder.encode(data_len, &mut buf)?;
502        let mut frame = path.chain(&mut buf).chain(data);
503        trace!(?frame, "writing egress frame");
504        tx.write_all_buf(&mut frame).await?;
505    }
506    trace!("shutting down outgoing stream");
507    tx.shutdown().await
508}
509
510/// Connection handler defines the connection I/O behavior.
511/// It is mostly useful for transports that may require additional clean up not already covered
512/// by [AsyncWrite::shutdown], for example.
513/// This API is experimental and may change in backwards-incompatible ways in the future.
514pub trait ConnHandler<Rx, Tx> {
515    /// Handle ingress completion
516    fn on_ingress(rx: Rx, res: std::io::Result<()>) -> impl Future<Output = ()> + Send {
517        _ = rx;
518        if let Err(err) = res {
519            error!(?err, "ingress failed");
520        } else {
521            debug!("ingress successfully complete");
522        }
523        async {}
524    }
525
526    /// Handle egress completion
527    fn on_egress(tx: Tx, res: std::io::Result<()>) -> impl Future<Output = ()> + Send {
528        _ = tx;
529        if let Err(err) = res {
530            error!(?err, "egress failed");
531        } else {
532            debug!("egress successfully complete");
533        }
534        async {}
535    }
536}
537
538impl<Rx, Tx> ConnHandler<Rx, Tx> for () {}
539
540/// Peer connection
541pub(crate) struct Conn {
542    rx: Incoming,
543    tx: Outgoing,
544}
545
546impl Conn {
547    /// Creates a new [Conn] given an [AsyncRead], [ConnHandler] and a set of async paths
548    fn new<H, Rx, Tx, P>(mut rx: Rx, mut tx: Tx, paths: impl IntoIterator<Item = P>) -> Self
549    where
550        Rx: AsyncRead + Unpin + Send + 'static,
551        Tx: AsyncWrite + Unpin + Send + 'static,
552        H: ConnHandler<Rx, Tx>,
553        P: AsRef<[Option<usize>]>,
554    {
555        let index = Arc::new(std::sync::Mutex::new(paths.into_iter().collect()));
556        let (rx_tx, rx_rx) = mpsc::channel(128);
557        let mut rx_io = JoinSet::new();
558        let span = Span::current();
559        rx_io.spawn({
560            let index = Arc::clone(&index);
561            async move {
562                let res = ingress(&mut rx, &index, rx_tx).await;
563                H::on_ingress(rx, res).await;
564                let Ok(mut index) = index.lock() else {
565                    error!("failed to lock index trie");
566                    return;
567                };
568                trace!("shutting down index trie");
569                index.close_tx();
570            }
571            .instrument(span.clone())
572        });
573        let (tx_tx, tx_rx) = mpsc::channel(128);
574        tokio::spawn(
575            async {
576                let res = egress(&mut tx, tx_rx).await;
577                H::on_egress(tx, res).await;
578            }
579            .instrument(span.clone()),
580        );
581        Conn {
582            tx: Outgoing {
583                tx: PollSender::new(tx_tx),
584                path: Arc::from([]),
585                path_buf: Bytes::from_static(&[0]),
586            },
587            rx: Incoming {
588                rx: Some(StreamReader::new(ReceiverStream::new(rx_rx))),
589                path: Arc::from([]),
590                index: Arc::clone(&index),
591                io: Arc::new(rx_io),
592            },
593        }
594    }
595}