1use core::future::Future;
2use core::mem;
3use core::pin::Pin;
4use core::task::{ready, Context, Poll};
5
6use std::sync::Arc;
7
8use anyhow::ensure;
9use bytes::{Buf as _, BufMut as _, Bytes, BytesMut};
10use futures::Sink as _;
11use pin_project_lite::pin_project;
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt as _};
13use tokio::sync::mpsc;
14use tokio::task::JoinSet;
15use tokio_stream::wrappers::ReceiverStream;
16use tokio_util::codec::Encoder;
17use tokio_util::io::StreamReader;
18use tokio_util::sync::PollSender;
19use tracing::{debug, error, instrument, trace, Instrument as _, Span};
20use wasm_tokio::{AsyncReadLeb128 as _, Leb128Encoder};
21
22use crate::Index;
23
24mod accept;
25mod client;
26mod server;
27
28pub use accept::*;
29pub use client::*;
30pub use server::*;
31
32#[derive(Debug, Default)]
34enum IndexTrie {
35 #[default]
36 Empty,
37 Leaf {
38 tx: Option<mpsc::Sender<std::io::Result<Bytes>>>,
39 rx: Option<mpsc::Receiver<std::io::Result<Bytes>>>,
40 },
41 IndexNode {
42 tx: Option<mpsc::Sender<std::io::Result<Bytes>>>,
43 rx: Option<mpsc::Receiver<std::io::Result<Bytes>>>,
44 nested: Vec<Option<IndexTrie>>,
45 },
46 WildcardNode {
48 tx: Option<mpsc::Sender<std::io::Result<Bytes>>>,
49 rx: Option<mpsc::Receiver<std::io::Result<Bytes>>>,
50 nested: Option<Box<IndexTrie>>,
51 },
52}
53
54impl<'a>
55 From<(
56 &'a [Option<usize>],
57 mpsc::Sender<std::io::Result<Bytes>>,
58 Option<mpsc::Receiver<std::io::Result<Bytes>>>,
59 )> for IndexTrie
60{
61 fn from(
62 (path, tx, rx): (
63 &'a [Option<usize>],
64 mpsc::Sender<std::io::Result<Bytes>>,
65 Option<mpsc::Receiver<std::io::Result<Bytes>>>,
66 ),
67 ) -> Self {
68 match path {
69 [] => Self::Leaf { tx: Some(tx), rx },
70 [None, path @ ..] => Self::WildcardNode {
71 tx: None,
72 rx: None,
73 nested: Some(Box::new(Self::from((path, tx, rx)))),
74 },
75 [Some(i), path @ ..] => Self::IndexNode {
76 tx: None,
77 rx: None,
78 nested: {
79 let n = i.saturating_add(1);
80 let mut nested = Vec::with_capacity(n);
81 nested.resize_with(n, Option::default);
82 nested[*i] = Some(Self::from((path, tx, rx)));
83 nested
84 },
85 },
86 }
87 }
88}
89
90impl<'a>
91 From<(
92 &'a [Option<usize>],
93 mpsc::Sender<std::io::Result<Bytes>>,
94 mpsc::Receiver<std::io::Result<Bytes>>,
95 )> for IndexTrie
96{
97 fn from(
98 (path, tx, rx): (
99 &'a [Option<usize>],
100 mpsc::Sender<std::io::Result<Bytes>>,
101 mpsc::Receiver<std::io::Result<Bytes>>,
102 ),
103 ) -> Self {
104 Self::from((path, tx, Some(rx)))
105 }
106}
107
108impl<'a> From<(&'a [Option<usize>], mpsc::Sender<std::io::Result<Bytes>>)> for IndexTrie {
109 fn from((path, tx): (&'a [Option<usize>], mpsc::Sender<std::io::Result<Bytes>>)) -> Self {
110 Self::from((path, tx, None))
111 }
112}
113
114impl<P: AsRef<[Option<usize>]>> FromIterator<P> for IndexTrie {
115 fn from_iter<T: IntoIterator<Item = P>>(iter: T) -> Self {
116 let mut root = Self::Empty;
117 for path in iter {
118 let (tx, rx) = mpsc::channel(16);
119 if !root.insert(path.as_ref(), tx, Some(rx)) {
120 return Self::Empty;
121 }
122 }
123 root
124 }
125}
126
127impl IndexTrie {
128 #[instrument(level = "trace", skip(self), ret(level = "trace"))]
130 fn take_rx(&mut self, path: &[usize]) -> Option<mpsc::Receiver<std::io::Result<Bytes>>> {
131 let Some((i, path)) = path.split_first() else {
132 return match self {
133 Self::Empty => None,
134 Self::Leaf { rx, .. } => rx.take(),
135 Self::IndexNode { tx, rx, nested } => {
136 let rx = rx.take();
137 if nested.is_empty() && tx.is_none() {
138 *self = Self::Empty;
139 }
140 rx
141 }
142 Self::WildcardNode { tx, rx, nested } => {
143 let rx = rx.take();
144 if nested.is_none() && tx.is_none() {
145 *self = Self::Empty;
146 }
147 rx
148 }
149 };
150 };
151 match self {
152 Self::Empty | Self::Leaf { .. } | Self::WildcardNode { .. } => None,
153 Self::IndexNode { ref mut nested, .. } => nested
154 .get_mut(*i)
155 .and_then(|nested| nested.as_mut().and_then(|nested| nested.take_rx(path))),
156 }
161 }
162
163 #[instrument(level = "trace", skip(self), ret(level = "trace"))]
165 fn get_tx(&mut self, path: &[usize]) -> Option<mpsc::Sender<std::io::Result<Bytes>>> {
166 let Some((i, path)) = path.split_first() else {
167 return match self {
168 Self::Empty => None,
169 Self::Leaf { tx, .. } => tx.clone(),
170 Self::IndexNode { tx, .. } | Self::WildcardNode { tx, .. } => tx.clone(),
171 };
172 };
173 match self {
174 Self::Empty | Self::Leaf { .. } | Self::WildcardNode { .. } => None,
175 Self::IndexNode { ref mut nested, .. } => {
176 let nested = nested.get_mut(*i)?;
177 let nested = nested.as_mut()?;
178 nested.get_tx(path)
179 } }
184 }
185
186 #[instrument(level = "trace", skip(self), ret(level = "trace"))]
188 fn close_tx(&mut self) {
189 match self {
190 Self::Empty => {}
191 Self::Leaf { tx, .. } => {
192 mem::take(tx);
193 }
194 Self::IndexNode {
195 tx, ref mut nested, ..
196 } => {
197 mem::take(tx);
198 for nested in nested.iter_mut().flatten() {
199 nested.close_tx();
200 }
201 }
202 Self::WildcardNode {
203 tx, ref mut nested, ..
204 } => {
205 mem::take(tx);
206 if let Some(nested) = nested {
207 nested.close_tx();
208 }
209 }
210 }
211 }
212
213 #[instrument(level = "trace", skip(self, sender, receiver), ret(level = "trace"))]
216 fn insert(
217 &mut self,
218 path: &[Option<usize>],
219 sender: mpsc::Sender<std::io::Result<Bytes>>,
220 receiver: Option<mpsc::Receiver<std::io::Result<Bytes>>>,
221 ) -> bool {
222 match self {
223 Self::Empty => {
224 *self = Self::from((path, sender, receiver));
225 true
226 }
227 Self::Leaf { .. } => {
228 let Some((i, path)) = path.split_first() else {
229 return false;
230 };
231 let Self::Leaf { tx, rx } = mem::take(self) else {
232 return false;
233 };
234 if let Some(i) = i {
235 let n = i.saturating_add(1);
236 let mut nested = Vec::with_capacity(n);
237 nested.resize_with(n, Option::default);
238 nested[*i] = Some(Self::from((path, sender, receiver)));
239 *self = Self::IndexNode { tx, rx, nested };
240 } else {
241 *self = Self::WildcardNode {
242 tx,
243 rx,
244 nested: Some(Box::new(Self::from((path, sender, receiver)))),
245 };
246 }
247 true
248 }
249 Self::IndexNode {
250 ref mut tx,
251 ref mut rx,
252 ref mut nested,
253 } => match (&tx, &rx, path) {
254 (None, None, []) => {
255 *tx = Some(sender);
256 *rx = receiver;
257 true
258 }
259 (_, _, [Some(i), path @ ..]) => {
260 let cap = i.saturating_add(1);
261 if nested.len() < cap {
262 nested.resize_with(cap, Option::default);
263 }
264 let nested = &mut nested[*i];
265 if let Some(nested) = nested {
266 nested.insert(path, sender, receiver)
267 } else {
268 *nested = Some(Self::from((path, sender, receiver)));
269 true
270 }
271 }
272 _ => false,
273 },
274 Self::WildcardNode {
275 ref mut tx,
276 ref mut rx,
277 ref mut nested,
278 } => match (&tx, &rx, path) {
279 (None, None, []) => {
280 *tx = Some(sender);
281 *rx = receiver;
282 true
283 }
284 (_, _, [None, path @ ..]) => {
285 if let Some(nested) = nested {
286 nested.insert(path, sender, receiver)
287 } else {
288 *nested = Some(Box::new(Self::from((path, sender, receiver))));
289 true
290 }
291 }
292 _ => false,
293 },
294 }
295 }
296}
297
298pin_project! {
299 #[project = IncomingProj]
301 pub struct Incoming {
302 #[pin]
303 rx: Option<StreamReader<ReceiverStream<std::io::Result<Bytes>>, Bytes>>,
304 path: Arc<[usize]>,
305 index: Arc<std::sync::Mutex<IndexTrie>>,
306 io: Arc<JoinSet<()>>,
307 }
308}
309
310impl Index<Self> for Incoming {
311 #[instrument(level = "trace", skip(self), fields(path = ?self.path))]
312 fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
313 ensure!(!path.is_empty());
314 let path = if self.path.is_empty() {
315 Arc::from(path)
316 } else {
317 Arc::from([self.path.as_ref(), path].concat())
318 };
319 trace!("locking index trie");
320 let mut index = self
321 .index
322 .lock()
323 .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err.to_string()))?;
324 trace!(?path, "taking index subscription");
325 let rx = index
326 .take_rx(&path)
327 .map(|rx| StreamReader::new(ReceiverStream::new(rx)));
328 Ok(Self {
329 rx,
330 path,
331 index: Arc::clone(&self.index),
332 io: Arc::clone(&self.io),
333 })
334 }
335}
336
337impl AsyncRead for Incoming {
338 #[instrument(level = "trace", skip_all, fields(path = ?self.path), ret(level = "trace"))]
339 fn poll_read(
340 mut self: Pin<&mut Self>,
341 cx: &mut Context<'_>,
342 buf: &mut tokio::io::ReadBuf<'_>,
343 ) -> Poll<std::io::Result<()>> {
344 if buf.remaining() == 0 {
345 return Poll::Ready(Ok(()));
346 }
347 trace!("reading");
348 let this = self.as_mut().project();
349 let Some(rx) = this.rx.as_pin_mut() else {
350 trace!("reader is closed");
351 return Poll::Ready(Ok(()));
352 };
353 ready!(rx.poll_read(cx, buf))?;
354 trace!(buf = ?buf.filled(), "read buffer");
355 if buf.filled().is_empty() {
356 self.rx.take();
357 }
358 Poll::Ready(Ok(()))
359 }
360}
361
362pin_project! {
363 #[project = OutgoingProj]
365 pub struct Outgoing {
366 #[pin]
367 tx: PollSender<(Bytes, Bytes)>,
368 path: Arc<[usize]>,
369 path_buf: Bytes,
370 }
371}
372
373impl Index<Self> for Outgoing {
374 #[instrument(level = "trace", skip(self), fields(path = ?self.path))]
375 fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
376 ensure!(!path.is_empty());
377 let path: Arc<[usize]> = if self.path.is_empty() {
378 Arc::from(path)
379 } else {
380 Arc::from([self.path.as_ref(), path].concat())
381 };
382 let mut buf = BytesMut::with_capacity(path.len().saturating_add(5));
383 let n = u32::try_from(path.len())
384 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
385 trace!(n, "encoding path length");
386 Leb128Encoder.encode(n, &mut buf)?;
387 for p in path.as_ref() {
388 let p = u32::try_from(*p)
389 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
390 trace!(p, "encoding path element");
391 Leb128Encoder.encode(p, &mut buf)?;
392 }
393 Ok(Self {
394 tx: self.tx.clone(),
395 path,
396 path_buf: buf.freeze(),
397 })
398 }
399}
400
401impl AsyncWrite for Outgoing {
402 #[instrument(level = "trace", skip_all, fields(path = ?self.path, buf = format!("{buf:02x?}")), ret(level = "trace"))]
403 fn poll_write(
404 self: Pin<&mut Self>,
405 cx: &mut Context<'_>,
406 buf: &[u8],
407 ) -> Poll<std::io::Result<usize>> {
408 trace!("writing outgoing chunk");
409 let mut this = self.project();
410 ready!(this.tx.as_mut().poll_ready(cx))
411 .map_err(|err| std::io::Error::new(std::io::ErrorKind::BrokenPipe, err))?;
412 this.tx
413 .start_send((this.path_buf.clone(), Bytes::copy_from_slice(buf)))
414 .map_err(|err| std::io::Error::new(std::io::ErrorKind::BrokenPipe, err))?;
415 Poll::Ready(Ok(buf.len()))
416 }
417
418 #[instrument(level = "trace", skip_all, fields(path = ?self.path), ret(level = "trace"))]
419 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
420 Poll::Ready(Ok(()))
421 }
422
423 #[instrument(level = "trace", skip_all, fields(path = ?self.path), ret(level = "trace"))]
424 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
425 Poll::Ready(Ok(()))
426 }
427}
428
429#[instrument(level = "trace", skip_all, ret(level = "trace"))]
430async fn ingress(
431 mut rx: impl AsyncRead + Unpin,
432 index: &std::sync::Mutex<IndexTrie>,
433 param_tx: mpsc::Sender<std::io::Result<Bytes>>,
434) -> std::io::Result<()> {
435 loop {
436 trace!("reading path length");
437 let b = match rx.read_u8().await {
438 Ok(b) => b,
439 Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
440 Err(err) => return Err(err),
441 };
442 let n = AsyncReadExt::chain([b].as_slice(), &mut rx)
443 .read_u32_leb128()
444 .await?;
445 let n = n
446 .try_into()
447 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
448 trace!(n, "read path length");
449 let tx = if n == 0 {
450 ¶m_tx
451 } else {
452 let mut path = Vec::with_capacity(n);
453 for i in 0..n {
454 trace!(i, "reading path element");
455 let p = rx.read_u32_leb128().await?;
456 let p = usize::try_from(p)
457 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
458 path.push(p);
459 }
460 trace!(?path, "read path");
461
462 trace!("locking index trie");
463 let mut index = index
464 .lock()
465 .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err.to_string()))?;
466 &index.get_tx(&path).ok_or_else(|| {
467 std::io::Error::new(
468 std::io::ErrorKind::NotFound,
469 format!("`{path:?}` subscription not found"),
470 )
471 })?
472 };
473 trace!("reading data length");
474 let n = rx.read_u32_leb128().await?;
475 let n = n
476 .try_into()
477 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
478 trace!(n, "read data length");
479 let mut buf = BytesMut::with_capacity(n);
480 buf.put_bytes(0, n);
481 trace!("reading data");
482 rx.read_exact(&mut buf).await?;
483 trace!(?buf, "read data");
484 tx.send(Ok(buf.freeze())).await.map_err(|_| {
485 std::io::Error::new(std::io::ErrorKind::BrokenPipe, "stream receiver closed")
486 })?;
487 }
488}
489
490#[instrument(level = "trace", skip_all)]
491async fn egress(
492 mut tx: impl AsyncWrite + Unpin,
493 mut rx: mpsc::Receiver<(Bytes, Bytes)>,
494) -> std::io::Result<()> {
495 let mut buf = BytesMut::with_capacity(5);
496 trace!("waiting for next frame");
497 while let Some((path, data)) = rx.recv().await {
498 let data_len = u32::try_from(data.len())
499 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
500 buf.clear();
501 Leb128Encoder.encode(data_len, &mut buf)?;
502 let mut frame = path.chain(&mut buf).chain(data);
503 trace!(?frame, "writing egress frame");
504 tx.write_all_buf(&mut frame).await?;
505 }
506 trace!("shutting down outgoing stream");
507 tx.shutdown().await
508}
509
510pub trait ConnHandler<Rx, Tx> {
515 fn on_ingress(rx: Rx, res: std::io::Result<()>) -> impl Future<Output = ()> + Send {
517 _ = rx;
518 if let Err(err) = res {
519 error!(?err, "ingress failed");
520 } else {
521 debug!("ingress successfully complete");
522 }
523 async {}
524 }
525
526 fn on_egress(tx: Tx, res: std::io::Result<()>) -> impl Future<Output = ()> + Send {
528 _ = tx;
529 if let Err(err) = res {
530 error!(?err, "egress failed");
531 } else {
532 debug!("egress successfully complete");
533 }
534 async {}
535 }
536}
537
538impl<Rx, Tx> ConnHandler<Rx, Tx> for () {}
539
540pub(crate) struct Conn {
542 rx: Incoming,
543 tx: Outgoing,
544}
545
546impl Conn {
547 fn new<H, Rx, Tx, P>(mut rx: Rx, mut tx: Tx, paths: impl IntoIterator<Item = P>) -> Self
549 where
550 Rx: AsyncRead + Unpin + Send + 'static,
551 Tx: AsyncWrite + Unpin + Send + 'static,
552 H: ConnHandler<Rx, Tx>,
553 P: AsRef<[Option<usize>]>,
554 {
555 let index = Arc::new(std::sync::Mutex::new(paths.into_iter().collect()));
556 let (rx_tx, rx_rx) = mpsc::channel(128);
557 let mut rx_io = JoinSet::new();
558 let span = Span::current();
559 rx_io.spawn({
560 let index = Arc::clone(&index);
561 async move {
562 let res = ingress(&mut rx, &index, rx_tx).await;
563 H::on_ingress(rx, res).await;
564 let Ok(mut index) = index.lock() else {
565 error!("failed to lock index trie");
566 return;
567 };
568 trace!("shutting down index trie");
569 index.close_tx();
570 }
571 .instrument(span.clone())
572 });
573 let (tx_tx, tx_rx) = mpsc::channel(128);
574 tokio::spawn(
575 async {
576 let res = egress(&mut tx, tx_rx).await;
577 H::on_egress(tx, res).await;
578 }
579 .instrument(span.clone()),
580 );
581 Conn {
582 tx: Outgoing {
583 tx: PollSender::new(tx_tx),
584 path: Arc::from([]),
585 path_buf: Bytes::from_static(&[0]),
586 },
587 rx: Incoming {
588 rx: Some(StreamReader::new(ReceiverStream::new(rx_rx))),
589 path: Arc::from([]),
590 index: Arc::clone(&index),
591 io: Arc::new(rx_io),
592 },
593 }
594 }
595}