wrpc_transport_ants/
lib.rs

1//! wRPC NATS.io transport
2
3#![allow(clippy::type_complexity)]
4
5use core::future::Future;
6use core::iter::zip;
7use core::ops::{Deref, DerefMut};
8use core::pin::{pin, Pin};
9use core::sync::atomic::AtomicUsize;
10use core::task::{ready, Context, Poll};
11use core::{mem, str};
12
13use std::collections::HashMap;
14use std::sync::Arc;
15
16use anyhow::{anyhow, bail, ensure, Context as _};
17use bytes::{Buf as _, BufMut as _, Bytes, BytesMut};
18use futures::sink::SinkExt as _;
19use futures::{Stream, StreamExt};
20use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
21use tokio::select;
22use tokio::sync::{mpsc, oneshot, watch};
23use tokio::task::JoinSet;
24use tokio_stream::wrappers::ReceiverStream;
25use tokio_util::sync::PollSender;
26use tracing::{debug, error, instrument, trace, warn};
27use wrpc_transport::Index as _;
28
29pub const PROTOCOL: &str = "wrpc.0.0.1";
30
31fn spawn_async(fut: impl Future<Output = ()> + Send + 'static) {
32    match tokio::runtime::Handle::try_current() {
33        Ok(rt) => {
34            rt.spawn(fut);
35        }
36        Err(_) => match tokio::runtime::Runtime::new() {
37            Ok(rt) => {
38                rt.spawn(fut);
39            }
40            Err(err) => error!(?err, "failed to create a new Tokio runtime"),
41        },
42    }
43}
44
45fn new_inbox(inbox: &[u8]) -> Bytes {
46    let id = nuid::next();
47    let mut s = BytesMut::with_capacity(inbox.len().saturating_add(id.len()));
48    s.extend_from_slice(inbox);
49    s.extend_from_slice(id.as_bytes());
50    s.freeze()
51}
52
53#[must_use]
54#[inline]
55pub fn param_subject(prefix: &[u8]) -> Bytes {
56    let mut s = BytesMut::with_capacity(prefix.len().saturating_add(".params".len()));
57    s.extend_from_slice(prefix);
58    s.extend_from_slice(b".params");
59    s.freeze()
60}
61
62#[must_use]
63#[inline]
64pub fn result_subject(prefix: &[u8]) -> Bytes {
65    let mut s = BytesMut::with_capacity(prefix.len().saturating_add(".results".len()));
66    s.extend_from_slice(prefix);
67    s.extend_from_slice(b".results");
68    s.freeze()
69}
70
71#[must_use]
72#[inline]
73pub fn index_path(prefix: &[u8], path: &[usize]) -> Bytes {
74    let mut s = BytesMut::with_capacity(prefix.len().saturating_add(path.len().saturating_mul(2)));
75    if !prefix.is_empty() {
76        s.extend_from_slice(prefix);
77    }
78    for p in path {
79        if !s.is_empty() {
80            s.put_u8(b'.');
81        }
82        s.extend_from_slice(p.to_string().as_bytes());
83    }
84    s.freeze()
85}
86
87#[must_use]
88#[inline]
89pub fn subscribe_path(prefix: &[u8], path: &[Option<usize>]) -> Bytes {
90    let mut s = BytesMut::with_capacity(prefix.len().saturating_add(path.len().saturating_mul(2)));
91    if !prefix.is_empty() {
92        s.extend_from_slice(prefix);
93    }
94    for p in path {
95        if !s.is_empty() {
96            s.put_u8(b'.');
97        }
98        if let Some(p) = p {
99            s.extend_from_slice(p.to_string().as_bytes());
100        } else {
101            s.put_u8(b'*');
102        }
103    }
104    s.freeze()
105}
106
107#[must_use]
108#[inline]
109pub fn invocation_subject(prefix: &[u8], instance: &str, func: &str) -> Bytes {
110    let mut s = BytesMut::with_capacity(
111        3_usize
112            .saturating_add(prefix.len())
113            .saturating_add(PROTOCOL.len())
114            .saturating_add(instance.len())
115            .saturating_add(func.len()),
116    );
117    if !prefix.is_empty() {
118        s.extend_from_slice(prefix);
119        s.put_u8(b'.');
120    }
121    s.extend_from_slice(PROTOCOL.as_bytes());
122    s.put_u8(b'.');
123    if !instance.is_empty() {
124        s.extend_from_slice(instance.as_bytes());
125        s.put_u8(b'.');
126    }
127    s.extend_from_slice(func.as_bytes());
128    s.freeze()
129}
130
131fn corrupted_memory_error() -> std::io::Error {
132    std::io::Error::new(std::io::ErrorKind::Other, "corrupted memory state")
133}
134
135/// Transport subscriber
136pub struct Subscriber {
137    rx: ReceiverStream<Message>,
138    subject: Bytes,
139    commands: mpsc::Sender<Command>,
140    tasks: Arc<JoinSet<()>>,
141}
142
143impl Drop for Subscriber {
144    fn drop(&mut self) {
145        let commands = self.commands.clone();
146        let subject = mem::take(&mut self.subject);
147        let tasks = Arc::clone(&self.tasks);
148        spawn_async(async move {
149            trace!(?subject, "shutting down subscriber");
150            if let Err(err) = commands.send(Command::Unsubscribe(subject)).await {
151                warn!(?err, "failed to shutdown subscriber");
152            }
153            drop(tasks);
154        });
155    }
156}
157
158impl Deref for Subscriber {
159    type Target = ReceiverStream<Message>;
160
161    fn deref(&self) -> &Self::Target {
162        &self.rx
163    }
164}
165
166impl DerefMut for Subscriber {
167    fn deref_mut(&mut self) -> &mut Self::Target {
168        &mut self.rx
169    }
170}
171
172enum Command {
173    Subscribe(Bytes, mpsc::Sender<Message>),
174    Unsubscribe(Bytes),
175    Batch(Box<[Command]>),
176}
177
178/// Subset of [`ants::Message`](ants::Message) used by this crate
179pub struct Message {
180    reply: Bytes,
181    payload: Bytes,
182}
183
184#[derive(Clone, Debug)]
185pub struct Client {
186    nats: ants::Client,
187    prefix: Bytes,
188    inbox: Bytes,
189    queue_group: Bytes,
190    commands: mpsc::Sender<Command>,
191    sid: Arc<AtomicUsize>,
192    tasks: Arc<JoinSet<()>>,
193}
194
195impl Client {
196    fn next_sid(&self) -> Bytes {
197        Bytes::from(
198            self.sid
199                .fetch_add(1, core::sync::atomic::Ordering::Relaxed)
200                .to_string(),
201        )
202    }
203}
204
205#[derive(Default)]
206pub struct ClientBuilder {
207    prefix: Bytes,
208    queue_group: Bytes,
209}
210
211impl ClientBuilder {
212    pub async fn build(self, nats: ants::Client) -> anyhow::Result<Client> {
213        let id = nuid::next();
214        let mut subject = BytesMut::with_capacity(9_usize.saturating_add(id.len()));
215        subject.extend_from_slice(b"_INBOX.");
216        subject.extend_from_slice(id.as_bytes());
217        subject.put_u8(b'.');
218        let inbox = subject.clone().freeze();
219
220        subject.put_u8(b'>');
221        let (sub_tx, mut sub_rx) = mpsc::channel(8196);
222        nats.subscribe(subject, Bytes::default(), "0", sub_tx)
223            .await
224            .context("failed to subscribe on an inbox subject")?;
225
226        let mut tasks = JoinSet::new();
227        let (cmd_tx, mut cmd_rx) = mpsc::channel(8192);
228        tasks.spawn({
229            async move {
230                fn handle_command(subs: &mut HashMap<Bytes, mpsc::Sender<Message>>, cmd: Command) {
231                    match cmd {
232                        Command::Subscribe(s, tx) => {
233                            subs.insert(s, tx);
234                        }
235                        Command::Unsubscribe(s) => {
236                            subs.remove(&s);
237                        }
238                        Command::Batch(cmds) => {
239                            for cmd in cmds {
240                                handle_command(subs, cmd);
241                            }
242                        }
243                    }
244                }
245                async fn handle_message(
246                    subs: &mut HashMap<Bytes, mpsc::Sender<Message>>,
247                    ants::Message {
248                        subject,
249                        reply,
250                        payload,
251                        ..
252                    }: ants::Message,
253                ) {
254                    let Some(sub) = subs.get_mut(&subject) else {
255                        debug!(?subject, "drop message with no subscriber");
256                        return;
257                    };
258                    let Ok(sub) = sub.reserve().await else {
259                        debug!(?subject, "drop message with closed subscriber");
260                        subs.remove(&subject);
261                        return;
262                    };
263                    sub.send(Message { reply, payload });
264                }
265
266                let mut subs = HashMap::new();
267                loop {
268                    select! {
269                        Some(msg) = sub_rx.recv() => handle_message(&mut subs, msg).await,
270                        Some(cmd) = cmd_rx.recv() => handle_command(&mut subs, cmd),
271                    }
272                }
273            }
274        });
275        Ok(Client {
276            nats,
277            prefix: self.prefix,
278            inbox,
279            queue_group: self.queue_group,
280            commands: cmd_tx,
281            sid: Arc::new(AtomicUsize::new(1)),
282            tasks: Arc::new(tasks),
283        })
284    }
285}
286
287impl Client {
288    pub async fn new(nats: ants::Client) -> anyhow::Result<Self> {
289        ClientBuilder::default().build(nats).await
290    }
291}
292
293pub struct ByteSubscription(Subscriber);
294
295impl Stream for ByteSubscription {
296    type Item = std::io::Result<Bytes>;
297
298    #[instrument(level = "trace", skip_all)]
299    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
300        match self.0.poll_next_unpin(cx) {
301            Poll::Ready(Some(Message { payload, .. })) => Poll::Ready(Some(Ok(payload))),
302            Poll::Ready(None) => Poll::Ready(None),
303            Poll::Pending => Poll::Pending,
304        }
305    }
306}
307
308#[derive(Default)]
309enum IndexTrie {
310    #[default]
311    Empty,
312    Leaf(Subscriber),
313    IndexNode {
314        subscriber: Option<Subscriber>,
315        nested: Vec<Option<IndexTrie>>,
316    },
317    WildcardNode {
318        subscriber: Option<Subscriber>,
319        nested: Option<Box<IndexTrie>>,
320    },
321}
322
323impl<'a> From<(&'a [Option<usize>], Subscriber)> for IndexTrie {
324    fn from((path, sub): (&'a [Option<usize>], Subscriber)) -> Self {
325        match path {
326            [] => Self::Leaf(sub),
327            [None, path @ ..] => Self::WildcardNode {
328                subscriber: None,
329                nested: Some(Box::new(Self::from((path, sub)))),
330            },
331            [Some(i), path @ ..] => Self::IndexNode {
332                subscriber: None,
333                nested: {
334                    let n = i.saturating_add(1);
335                    let mut nested = Vec::with_capacity(n);
336                    nested.resize_with(n, Option::default);
337                    nested[*i] = Some(Self::from((path, sub)));
338                    nested
339                },
340            },
341        }
342    }
343}
344
345impl<P: AsRef<[Option<usize>]>> FromIterator<(P, Subscriber)> for IndexTrie {
346    fn from_iter<T: IntoIterator<Item = (P, Subscriber)>>(iter: T) -> Self {
347        let mut root = Self::Empty;
348        for (path, sub) in iter {
349            if !root.insert(path.as_ref(), sub) {
350                return Self::Empty;
351            }
352        }
353        root
354    }
355}
356
357impl IndexTrie {
358    #[inline]
359    fn is_empty(&self) -> bool {
360        matches!(self, IndexTrie::Empty)
361    }
362
363    #[instrument(level = "trace", skip_all)]
364    fn take(&mut self, path: &[usize]) -> Option<Subscriber> {
365        let Some((i, path)) = path.split_first() else {
366            return match mem::take(self) {
367                // TODO: Demux the subscription
368                //IndexTrie::WildcardNode { subscriber, nested } => {
369                //    if let Some(nested) = nested {
370                //        *self = IndexTrie::WildcardNode {
371                //            subscriber: None,
372                //            nested: Some(nested),
373                //        }
374                //    }
375                //    subscriber
376                //}
377                IndexTrie::Empty | IndexTrie::WildcardNode { .. } => None,
378                IndexTrie::Leaf(subscriber) => Some(subscriber),
379                IndexTrie::IndexNode { subscriber, nested } => {
380                    if !nested.is_empty() {
381                        *self = IndexTrie::IndexNode {
382                            subscriber: None,
383                            nested,
384                        }
385                    }
386                    subscriber
387                }
388            };
389        };
390        match self {
391            // TODO: Demux the subscription
392            //Self::WildcardNode { ref mut nested, .. } => {
393            //    nested.as_mut().and_then(|nested| nested.take(path))
394            //}
395            Self::Empty | Self::Leaf(..) | Self::WildcardNode { .. } => None,
396            Self::IndexNode { ref mut nested, .. } => nested
397                .get_mut(*i)
398                .and_then(|nested| nested.as_mut().and_then(|nested| nested.take(path))),
399        }
400    }
401
402    /// Inserts `sub` under a `path` - returns `false` if it failed and `true` if it succeeded.
403    /// Tree state after `false` is returned in undefined
404    #[instrument(level = "trace", skip_all)]
405    fn insert(&mut self, path: &[Option<usize>], sub: Subscriber) -> bool {
406        match self {
407            Self::Empty => {
408                *self = Self::from((path, sub));
409                true
410            }
411            Self::Leaf(..) => {
412                let Some((i, path)) = path.split_first() else {
413                    return false;
414                };
415                let Self::Leaf(subscriber) = mem::take(self) else {
416                    return false;
417                };
418                if let Some(i) = i {
419                    let n = i.saturating_add(1);
420                    let mut nested = Vec::with_capacity(n);
421                    nested.resize_with(n, Option::default);
422                    nested[*i] = Some(Self::from((path, sub)));
423                    *self = Self::IndexNode {
424                        subscriber: Some(subscriber),
425                        nested,
426                    };
427                } else {
428                    *self = Self::WildcardNode {
429                        subscriber: Some(subscriber),
430                        nested: Some(Box::new(Self::from((path, sub)))),
431                    };
432                }
433                true
434            }
435            Self::WildcardNode {
436                ref mut subscriber,
437                ref mut nested,
438            } => match (&subscriber, path) {
439                (None, []) => {
440                    *subscriber = Some(sub);
441                    true
442                }
443                (_, [None, path @ ..]) => {
444                    if let Some(nested) = nested {
445                        nested.insert(path, sub)
446                    } else {
447                        *nested = Some(Box::new(Self::from((path, sub))));
448                        true
449                    }
450                }
451                _ => false,
452            },
453            Self::IndexNode {
454                ref mut subscriber,
455                ref mut nested,
456            } => match (&subscriber, path) {
457                (None, []) => {
458                    *subscriber = Some(sub);
459                    true
460                }
461                (_, [Some(i), path @ ..]) => {
462                    let cap = i.saturating_add(1);
463                    if nested.len() < cap {
464                        nested.resize_with(cap, Option::default);
465                    }
466                    let nested = &mut nested[*i];
467                    if let Some(nested) = nested {
468                        nested.insert(path, sub)
469                    } else {
470                        *nested = Some(Self::from((path, sub)));
471                        true
472                    }
473                }
474                _ => false,
475            },
476        }
477    }
478}
479
480pub struct Reader {
481    buffer: Bytes,
482    incoming: Option<Subscriber>,
483    nested: Arc<std::sync::Mutex<IndexTrie>>,
484    path: Box<[usize]>,
485}
486
487impl wrpc_transport::Index<Self> for Reader {
488    #[instrument(level = "trace", skip(self))]
489    fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
490        ensure!(!path.is_empty());
491        trace!("locking index tree");
492        let mut nested = self
493            .nested
494            .lock()
495            .map_err(|err| anyhow!(err.to_string()).context("failed to lock map"))?;
496        trace!("taking index subscription");
497        let mut p = self.path.to_vec();
498        p.extend_from_slice(path);
499        let incoming = nested.take(&p);
500        Ok(Self {
501            buffer: Bytes::default(),
502            incoming,
503            nested: Arc::clone(&self.nested),
504            path: p.into_boxed_slice(),
505        })
506    }
507}
508
509impl AsyncRead for Reader {
510    #[instrument(level = "trace", skip_all, ret)]
511    fn poll_read(
512        mut self: Pin<&mut Self>,
513        cx: &mut Context<'_>,
514        buf: &mut ReadBuf<'_>,
515    ) -> Poll<std::io::Result<()>> {
516        let cap = buf.remaining();
517        if cap == 0 {
518            trace!("attempt to read empty buffer");
519            return Poll::Ready(Ok(()));
520        }
521
522        if !self.buffer.is_empty() {
523            if self.buffer.len() > cap {
524                trace!(cap, len = self.buffer.len(), "reading part of buffer");
525                buf.put_slice(&self.buffer.split_to(cap));
526            } else {
527                trace!(cap, len = self.buffer.len(), "reading full buffer");
528                buf.put_slice(&mem::take(&mut self.buffer));
529            }
530            return Poll::Ready(Ok(()));
531        }
532        let Some(incoming) = self.incoming.as_mut() else {
533            return Poll::Ready(Err(std::io::Error::new(
534                std::io::ErrorKind::NotFound,
535                format!("subscription not found for path {:?}", self.path),
536            )));
537        };
538        trace!("polling for next message");
539        match incoming.poll_next_unpin(cx) {
540            Poll::Ready(Some(Message { mut payload, .. })) => {
541                trace!(?payload, "received message");
542                if payload.is_empty() {
543                    trace!("received stream shutdown message");
544                    return Poll::Ready(Ok(()));
545                }
546                if payload.len() > cap {
547                    trace!(len = payload.len(), cap, "partially reading the message");
548                    buf.put_slice(&payload.split_to(cap));
549                    self.buffer = payload;
550                } else {
551                    trace!(len = payload.len(), cap, "filling the buffer with payload");
552                    buf.put_slice(&payload);
553                }
554                Poll::Ready(Ok(()))
555            }
556            Poll::Ready(None) => {
557                trace!("subscription finished");
558                Poll::Ready(Ok(()))
559            }
560            Poll::Pending => Poll::Pending,
561        }
562    }
563}
564
565#[derive(Clone)]
566pub struct SubjectWriter {
567    nats: PollSender<ants::Command>,
568    info: watch::Receiver<Arc<ants::protocol::InfoOptions>>,
569    tx: Bytes,
570    shutdown: bool,
571    tasks: Arc<JoinSet<()>>,
572}
573
574impl SubjectWriter {
575    fn new(
576        nats: PollSender<ants::Command>,
577        info: watch::Receiver<Arc<ants::protocol::InfoOptions>>,
578        tx: Bytes,
579        tasks: Arc<JoinSet<()>>,
580    ) -> Self {
581        Self {
582            nats,
583            info,
584            tx,
585            shutdown: false,
586            tasks,
587        }
588    }
589}
590
591impl wrpc_transport::Index<Self> for SubjectWriter {
592    #[instrument(level = "trace", skip(self))]
593    fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
594        ensure!(!path.is_empty());
595        let tx = index_path(&self.tx, path);
596        Ok(Self {
597            nats: self.nats.clone(),
598            info: self.info.clone(),
599            tx,
600            shutdown: false,
601            tasks: Arc::clone(&self.tasks),
602        })
603    }
604}
605
606impl AsyncWrite for SubjectWriter {
607    #[instrument(level = "trace", skip_all, ret, fields(subject = ?self.tx, buf = format!("{buf:02x?}")))]
608    fn poll_write(
609        mut self: Pin<&mut Self>,
610        cx: &mut Context<'_>,
611        mut buf: &[u8],
612    ) -> Poll<std::io::Result<usize>> {
613        trace!("polling for readiness");
614        match self.nats.poll_ready_unpin(cx) {
615            Poll::Pending => return Poll::Pending,
616            Poll::Ready(Err(..)) => return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())),
617            Poll::Ready(Ok(())) => {}
618        }
619        let max_payload = self.info.borrow().max_payload;
620        if max_payload == 0 {
621            return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into()));
622        }
623        let total = buf.len();
624        let cmd = if total < max_payload {
625            ants::Command::Publish {
626                subject: self.tx.clone(),
627                payload: Bytes::copy_from_slice(buf),
628                reply: Bytes::default(),
629                headers: Bytes::default(),
630            }
631        } else {
632            let mut cap = max_payload.saturating_div(total);
633            let rem = max_payload % total;
634            if rem > 0 {
635                cap = cap.saturating_add(1);
636            }
637            let mut cmds = Vec::with_capacity(cap);
638            while buf.len() > max_payload {
639                (buf, _) = buf.split_at(max_payload);
640                cmds.push(ants::Command::Publish {
641                    subject: self.tx.clone(),
642                    payload: Bytes::copy_from_slice(buf),
643                    reply: Bytes::default(),
644                    headers: Bytes::default(),
645                })
646            }
647            if !buf.is_empty() {
648                cmds.push(ants::Command::Publish {
649                    subject: self.tx.clone(),
650                    payload: Bytes::copy_from_slice(buf),
651                    reply: Bytes::default(),
652                    headers: Bytes::default(),
653                })
654            }
655            ants::Command::Batch(cmds.into_boxed_slice())
656        };
657        trace!("starting send");
658        match self.nats.start_send_unpin(cmd) {
659            Ok(()) => Poll::Ready(Ok(total)),
660            Err(..) => Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe))),
661        }
662    }
663
664    #[instrument(level = "trace", skip_all, ret, fields(subject = ?self.tx))]
665    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
666        trace!("polling for readiness");
667        match self.nats.poll_ready_unpin(cx) {
668            Poll::Pending => return Poll::Pending,
669            Poll::Ready(Err(..)) => return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())),
670            Poll::Ready(Ok(())) => {}
671        }
672        trace!("flushing");
673        match self.nats.start_send_unpin(ants::Command::Flush) {
674            Ok(()) => Poll::Ready(Ok(())),
675            Err(..) => Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe))),
676        }
677    }
678
679    #[instrument(level = "trace", skip_all, ret, fields(subject = ?self.tx))]
680    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
681        trace!("writing stream shutdown message");
682        ready!(self.as_mut().poll_write(cx, &[]))?;
683        self.shutdown = true;
684        Poll::Ready(Ok(()))
685    }
686}
687
688impl Drop for SubjectWriter {
689    fn drop(&mut self) {
690        if !self.shutdown {
691            let mut nats = self.nats.clone();
692            let subject = mem::take(&mut self.tx);
693            let tasks = Arc::clone(&self.tasks);
694            spawn_async(async move {
695                trace!("writing stream shutdown message");
696                if let Err(_) = nats
697                    .send(ants::Command::Publish {
698                        subject,
699                        reply: Bytes::default(),
700                        headers: Bytes::default(),
701                        payload: Bytes::default(),
702                    })
703                    .await
704                {
705                    warn!("failed to publish stream shutdown message");
706                }
707                drop(tasks);
708            });
709        }
710    }
711}
712
713#[derive(Default)]
714pub enum RootParamWriter {
715    #[default]
716    Corrupted,
717    Handshaking {
718        nats: ants::Client,
719        info: watch::Receiver<Arc<ants::protocol::InfoOptions>>,
720        sub: Subscriber,
721        indexed: std::sync::Mutex<Vec<(Vec<usize>, oneshot::Sender<SubjectWriter>)>>,
722        buffer: Bytes,
723        tasks: Arc<JoinSet<()>>,
724    },
725    Draining {
726        tx: SubjectWriter,
727        buffer: Bytes,
728    },
729    Active(SubjectWriter),
730}
731
732impl RootParamWriter {
733    fn new(
734        nats: ants::Client,
735        info: watch::Receiver<Arc<ants::protocol::InfoOptions>>,
736        sub: Subscriber,
737        buffer: Bytes,
738        tasks: Arc<JoinSet<()>>,
739    ) -> Self {
740        Self::Handshaking {
741            nats,
742            info,
743            sub,
744            indexed: std::sync::Mutex::default(),
745            buffer,
746            tasks,
747        }
748    }
749}
750
751impl RootParamWriter {
752    #[instrument(level = "trace", skip_all, ret)]
753    fn poll_active(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
754        match &mut *self {
755            Self::Corrupted => Poll::Ready(Err(corrupted_memory_error())),
756            Self::Handshaking { sub, .. } => {
757                trace!("polling for handshake response");
758                match sub.poll_next_unpin(cx) {
759                    Poll::Ready(Some(Message { reply: tx, .. })) => {
760                        if tx.is_empty() {
761                            return Poll::Ready(Err(std::io::Error::new(
762                                std::io::ErrorKind::InvalidInput,
763                                "peer did not specify a reply subject",
764                            )));
765                        }
766                        let Self::Handshaking {
767                            nats,
768                            info,
769                            indexed,
770                            buffer,
771                            tasks,
772                            ..
773                        } = mem::take(&mut *self)
774                        else {
775                            return Poll::Ready(Err(corrupted_memory_error()));
776                        };
777                        let tx = SubjectWriter::new(
778                            PollSender::new(nats.commands().clone()),
779                            info,
780                            param_subject(&tx),
781                            tasks,
782                        );
783                        let indexed = indexed.into_inner().map_err(|err| {
784                            std::io::Error::new(std::io::ErrorKind::Other, err.to_string())
785                        })?;
786                        for (path, tx_tx) in indexed {
787                            let tx = tx.index(&path).map_err(|err| {
788                                std::io::Error::new(std::io::ErrorKind::Other, err)
789                            })?;
790                            tx_tx.send(tx).map_err(|_| {
791                                std::io::Error::from(std::io::ErrorKind::BrokenPipe)
792                            })?;
793                        }
794                        trace!("handshake succeeded");
795                        if buffer.is_empty() {
796                            *self = Self::Active(tx);
797                            Poll::Ready(Ok(()))
798                        } else {
799                            *self = Self::Draining { tx, buffer };
800                            self.poll_active(cx)
801                        }
802                    }
803                    Poll::Ready(None) => {
804                        *self = Self::Corrupted;
805                        Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe)))
806                    }
807                    Poll::Pending => Poll::Pending,
808                }
809            }
810            Self::Draining { tx, buffer } => {
811                let mut tx = pin!(tx);
812                while !buffer.is_empty() {
813                    trace!(?tx.tx, "draining parameter buffer");
814                    match tx.as_mut().poll_write(cx, buffer) {
815                        Poll::Ready(Ok(n)) => {
816                            buffer.advance(n);
817                        }
818                        Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
819                        Poll::Pending => return Poll::Pending,
820                    }
821                }
822                let Self::Draining { tx, .. } = mem::take(&mut *self) else {
823                    return Poll::Ready(Err(corrupted_memory_error()));
824                };
825                trace!("parameter buffer draining succeeded");
826                *self = Self::Active(tx);
827                Poll::Ready(Ok(()))
828            }
829            Self::Active(..) => Poll::Ready(Ok(())),
830        }
831    }
832}
833
834impl wrpc_transport::Index<IndexedParamWriter> for RootParamWriter {
835    #[instrument(level = "trace", skip(self))]
836    fn index(&self, path: &[usize]) -> anyhow::Result<IndexedParamWriter> {
837        ensure!(!path.is_empty());
838        match self {
839            Self::Corrupted => Err(anyhow!(corrupted_memory_error())),
840            Self::Handshaking { indexed, .. } => {
841                let (tx_tx, tx_rx) = oneshot::channel();
842                let mut indexed = indexed.lock().map_err(|err| {
843                    std::io::Error::new(std::io::ErrorKind::Other, err.to_string())
844                })?;
845                indexed.push((path.to_vec(), tx_tx));
846                Ok(IndexedParamWriter::Handshaking {
847                    tx_rx,
848                    indexed: std::sync::Mutex::default(),
849                })
850            }
851            Self::Draining { tx, .. } | Self::Active(tx) => {
852                tx.index(path).map(IndexedParamWriter::Active)
853            }
854        }
855    }
856}
857
858impl AsyncWrite for RootParamWriter {
859    #[instrument(level = "trace", skip_all, ret, fields(buf = format!("{buf:02x?}")))]
860    fn poll_write(
861        mut self: Pin<&mut Self>,
862        cx: &mut Context<'_>,
863        buf: &[u8],
864    ) -> Poll<std::io::Result<usize>> {
865        match self.as_mut().poll_active(cx)? {
866            Poll::Ready(()) => {
867                let Self::Active(tx) = &mut *self else {
868                    return Poll::Ready(Err(corrupted_memory_error()));
869                };
870                trace!("writing buffer");
871                pin!(tx).poll_write(cx, buf)
872            }
873            Poll::Pending => Poll::Pending,
874        }
875    }
876
877    #[instrument(level = "trace", skip_all, ret)]
878    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
879        match self.as_mut().poll_active(cx)? {
880            Poll::Ready(()) => {
881                let Self::Active(tx) = &mut *self else {
882                    return Poll::Ready(Err(corrupted_memory_error()));
883                };
884                trace!("flushing");
885                pin!(tx).poll_flush(cx)
886            }
887            Poll::Pending => Poll::Pending,
888        }
889    }
890
891    #[instrument(level = "trace", skip_all, ret)]
892    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
893        match self.as_mut().poll_active(cx)? {
894            Poll::Ready(()) => {
895                let Self::Active(tx) = &mut *self else {
896                    return Poll::Ready(Err(corrupted_memory_error()));
897                };
898                trace!("shutting down");
899                pin!(tx).poll_shutdown(cx)
900            }
901            Poll::Pending => Poll::Pending,
902        }
903    }
904}
905
906#[derive(Default)]
907pub enum IndexedParamWriter {
908    #[default]
909    Corrupted,
910    Handshaking {
911        tx_rx: oneshot::Receiver<SubjectWriter>,
912        indexed: std::sync::Mutex<Vec<(Vec<usize>, oneshot::Sender<SubjectWriter>)>>,
913    },
914    Active(SubjectWriter),
915}
916
917impl IndexedParamWriter {
918    #[instrument(level = "trace", skip_all, ret)]
919    fn poll_active(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
920        match &mut *self {
921            Self::Corrupted => Poll::Ready(Err(corrupted_memory_error())),
922            Self::Handshaking { tx_rx, .. } => {
923                trace!("polling for handshake");
924                match pin!(tx_rx).poll(cx) {
925                    Poll::Ready(Ok(tx)) => {
926                        let Self::Handshaking { indexed, .. } = mem::take(&mut *self) else {
927                            return Poll::Ready(Err(corrupted_memory_error()));
928                        };
929                        let indexed = indexed.into_inner().map_err(|err| {
930                            std::io::Error::new(std::io::ErrorKind::Other, err.to_string())
931                        })?;
932                        for (path, tx_tx) in indexed {
933                            let tx = tx.index(&path).map_err(|err| {
934                                std::io::Error::new(std::io::ErrorKind::Other, err)
935                            })?;
936                            tx_tx.send(tx).map_err(|_| {
937                                std::io::Error::from(std::io::ErrorKind::BrokenPipe)
938                            })?;
939                        }
940                        *self = Self::Active(tx);
941                        Poll::Ready(Ok(()))
942                    }
943                    Poll::Ready(Err(..)) => Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())),
944                    Poll::Pending => Poll::Pending,
945                }
946            }
947            Self::Active(..) => Poll::Ready(Ok(())),
948        }
949    }
950}
951
952impl wrpc_transport::Index<Self> for IndexedParamWriter {
953    #[instrument(level = "trace", skip_all)]
954    fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
955        ensure!(!path.is_empty());
956        match self {
957            Self::Corrupted => Err(anyhow!(corrupted_memory_error())),
958            Self::Handshaking { indexed, .. } => {
959                let (tx_tx, tx_rx) = oneshot::channel();
960                let mut indexed = indexed.lock().map_err(|err| {
961                    std::io::Error::new(std::io::ErrorKind::Other, err.to_string())
962                })?;
963                indexed.push((path.to_vec(), tx_tx));
964                Ok(Self::Handshaking {
965                    tx_rx,
966                    indexed: std::sync::Mutex::default(),
967                })
968            }
969            Self::Active(tx) => tx.index(path).map(Self::Active),
970        }
971    }
972}
973
974impl AsyncWrite for IndexedParamWriter {
975    #[instrument(level = "trace", skip_all, ret, fields(buf = format!("{buf:02x?}")))]
976    fn poll_write(
977        mut self: Pin<&mut Self>,
978        cx: &mut Context<'_>,
979        buf: &[u8],
980    ) -> Poll<std::io::Result<usize>> {
981        match self.as_mut().poll_active(cx)? {
982            Poll::Ready(()) => {
983                let Self::Active(tx) = &mut *self else {
984                    return Poll::Ready(Err(corrupted_memory_error()));
985                };
986                trace!("writing buffer");
987                pin!(tx).poll_write(cx, buf)
988            }
989            Poll::Pending => Poll::Pending,
990        }
991    }
992
993    #[instrument(level = "trace", skip_all, ret)]
994    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
995        match self.as_mut().poll_active(cx)? {
996            Poll::Ready(()) => {
997                let Self::Active(tx) = &mut *self else {
998                    return Poll::Ready(Err(corrupted_memory_error()));
999                };
1000                trace!("flushing");
1001                pin!(tx).poll_flush(cx)
1002            }
1003            Poll::Pending => Poll::Pending,
1004        }
1005    }
1006
1007    #[instrument(level = "trace", skip_all, ret)]
1008    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1009        match self.as_mut().poll_active(cx)? {
1010            Poll::Ready(()) => {
1011                let Self::Active(tx) = &mut *self else {
1012                    return Poll::Ready(Err(corrupted_memory_error()));
1013                };
1014                trace!("shutting down");
1015                pin!(tx).poll_shutdown(cx)
1016            }
1017            Poll::Pending => Poll::Pending,
1018        }
1019    }
1020}
1021
1022pub enum ParamWriter {
1023    Root(RootParamWriter),
1024    Nested(IndexedParamWriter),
1025}
1026
1027impl wrpc_transport::Index<Self> for ParamWriter {
1028    fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
1029        ensure!(!path.is_empty());
1030        match self {
1031            ParamWriter::Root(w) => w.index(path),
1032            ParamWriter::Nested(w) => w.index(path),
1033        }
1034        .map(Self::Nested)
1035    }
1036}
1037
1038impl AsyncWrite for ParamWriter {
1039    #[instrument(level = "trace", skip_all, ret, fields(buf = format!("{buf:02x?}")))]
1040    fn poll_write(
1041        mut self: Pin<&mut Self>,
1042        cx: &mut Context<'_>,
1043        buf: &[u8],
1044    ) -> Poll<std::io::Result<usize>> {
1045        match &mut *self {
1046            ParamWriter::Root(w) => pin!(w).poll_write(cx, buf),
1047            ParamWriter::Nested(w) => pin!(w).poll_write(cx, buf),
1048        }
1049    }
1050
1051    #[instrument(level = "trace", skip_all, ret)]
1052    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1053        match &mut *self {
1054            ParamWriter::Root(w) => pin!(w).poll_flush(cx),
1055            ParamWriter::Nested(w) => pin!(w).poll_flush(cx),
1056        }
1057    }
1058
1059    #[instrument(level = "trace", skip_all, ret)]
1060    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1061        match &mut *self {
1062            ParamWriter::Root(w) => pin!(w).poll_shutdown(cx),
1063            ParamWriter::Nested(w) => pin!(w).poll_shutdown(cx),
1064        }
1065    }
1066}
1067
1068impl wrpc_transport::Invoke for Client {
1069    type Context = Bytes;
1070    type Outgoing = ParamWriter;
1071    type Incoming = Reader;
1072
1073    #[instrument(level = "trace", skip(self, paths, params), fields(params = format!("{params:02x?}")))]
1074    async fn invoke<P: AsRef<[Option<usize>]> + Send + Sync>(
1075        &self,
1076        headers: Self::Context,
1077        instance: &str,
1078        func: &str,
1079        mut params: Bytes,
1080        paths: impl AsRef<[P]> + Send,
1081    ) -> anyhow::Result<(Self::Outgoing, Self::Incoming)> {
1082        let paths = paths.as_ref();
1083        let mut cmds = Vec::with_capacity(paths.len().saturating_add(2));
1084
1085        let rx = new_inbox(&self.inbox);
1086        let (handshake_tx, handshake_rx) = mpsc::channel(1);
1087        cmds.push(Command::Subscribe(rx.clone(), handshake_tx));
1088
1089        let result = result_subject(&rx);
1090        let (result_tx, result_rx) = mpsc::channel(16);
1091        cmds.push(Command::Subscribe(result.clone(), result_tx));
1092
1093        let nested = paths.iter().map(|path| {
1094            let (tx, rx) = mpsc::channel(16);
1095            let subject = subscribe_path(&result, path.as_ref());
1096            cmds.push(Command::Subscribe(subject.clone(), tx));
1097            Subscriber {
1098                rx: ReceiverStream::new(rx),
1099                commands: self.commands.clone(),
1100                subject,
1101                tasks: Arc::clone(&self.tasks),
1102            }
1103        });
1104        let nested: IndexTrie = zip(paths.iter(), nested).collect();
1105        ensure!(
1106            paths.is_empty() == nested.is_empty(),
1107            "failed to construct subscription tree"
1108        );
1109
1110        self.commands
1111            .send(Command::Batch(cmds.into_boxed_slice()))
1112            .await
1113            .context("failed to subscribe")?;
1114
1115        let info = self
1116            .nats
1117            .server_info()
1118            .await
1119            .context("failed to get server info")?;
1120        let mut max_payload = info.borrow().max_payload;
1121        let param_tx = invocation_subject(&self.prefix, instance, func);
1122        if !headers.is_empty() {
1123            max_payload = max_payload.saturating_sub(headers.len());
1124        }
1125        trace!("publishing handshake");
1126        self.nats
1127            .publish(
1128                param_tx,
1129                rx.clone(),
1130                headers,
1131                params.split_to(max_payload.min(params.len())),
1132            )
1133            .await
1134            .context("failed to publish handshake")?;
1135        Ok((
1136            ParamWriter::Root(RootParamWriter::new(
1137                self.nats.clone(),
1138                info,
1139                Subscriber {
1140                    rx: ReceiverStream::new(handshake_rx),
1141                    commands: self.commands.clone(),
1142                    subject: rx,
1143                    tasks: Arc::clone(&self.tasks),
1144                },
1145                params,
1146                Arc::clone(&self.tasks),
1147            )),
1148            Reader {
1149                buffer: Bytes::default(),
1150                incoming: Some(Subscriber {
1151                    rx: ReceiverStream::new(result_rx),
1152                    commands: self.commands.clone(),
1153                    subject: result,
1154                    tasks: Arc::clone(&self.tasks),
1155                }),
1156                nested: Arc::new(std::sync::Mutex::new(nested)),
1157                path: Box::default(),
1158            },
1159        ))
1160    }
1161}
1162
1163async fn handle_message(
1164    nats: &ants::Client,
1165    rx: Bytes,
1166    commands: mpsc::Sender<Command>,
1167    ants::Message {
1168        reply: tx,
1169        payload,
1170        headers,
1171        ..
1172    }: ants::Message,
1173    paths: &[Box<[Option<usize>]>],
1174    tasks: Arc<JoinSet<()>>,
1175) -> anyhow::Result<(Bytes, SubjectWriter, Reader)> {
1176    if tx.is_empty() {
1177        bail!("peer did not specify a reply subject")
1178    }
1179
1180    let mut cmds = Vec::with_capacity(paths.len().saturating_add(1));
1181
1182    let param = Bytes::from(param_subject(&rx));
1183    let (param_tx, param_rx) = mpsc::channel(16);
1184    cmds.push(Command::Subscribe(param.clone(), param_tx));
1185
1186    let nested = paths.iter().map(|path| {
1187        let (tx, rx) = mpsc::channel(16);
1188        let subject = Bytes::from(subscribe_path(&param, path.as_ref()));
1189        cmds.push(Command::Subscribe(subject.clone(), tx));
1190        Subscriber {
1191            rx: ReceiverStream::new(rx),
1192            commands: commands.clone(),
1193            subject,
1194            tasks: Arc::clone(&tasks),
1195        }
1196    });
1197    let nested: IndexTrie = zip(paths.iter(), nested).collect();
1198    ensure!(
1199        paths.is_empty() == nested.is_empty(),
1200        "failed to construct subscription tree"
1201    );
1202
1203    commands
1204        .send(Command::Batch(cmds.into_boxed_slice()))
1205        .await
1206        .context("failed to subscribe")?;
1207
1208    trace!("publishing handshake response");
1209    nats.publish(tx.clone(), rx, Bytes::default(), Bytes::default())
1210        .await
1211        .context("failed to publish handshake accept")?;
1212    let info = nats
1213        .server_info()
1214        .await
1215        .context("failed to get server info")?;
1216    Ok((
1217        headers,
1218        SubjectWriter::new(
1219            PollSender::new(nats.commands().clone()),
1220            info,
1221            result_subject(&tx),
1222            Arc::clone(&tasks),
1223        ),
1224        Reader {
1225            buffer: payload,
1226            incoming: Some(Subscriber {
1227                rx: ReceiverStream::new(param_rx),
1228                commands,
1229                subject: param,
1230                tasks,
1231            }),
1232            nested: Arc::new(std::sync::Mutex::new(nested)),
1233            path: Box::default(),
1234        },
1235    ))
1236}
1237
1238impl wrpc_transport::Serve for Client {
1239    type Context = Bytes;
1240    type Outgoing = SubjectWriter;
1241    type Incoming = Reader;
1242
1243    #[instrument(level = "trace", skip(self, paths))]
1244    async fn serve(
1245        &self,
1246        instance: &str,
1247        func: &str,
1248        paths: impl Into<Arc<[Box<[Option<usize>]>]>> + Send,
1249    ) -> anyhow::Result<
1250        impl Stream<Item = anyhow::Result<(Self::Context, Self::Outgoing, Self::Incoming)>> + 'static,
1251    > {
1252        let subject = invocation_subject(&self.prefix, instance, func);
1253        debug!(?subject, "subscribing on invocation subject");
1254        let (sub_tx, sub_rx) = mpsc::channel(256);
1255        self.nats
1256            .subscribe(subject, self.queue_group.clone(), self.next_sid(), sub_tx)
1257            .await?;
1258        let nats = self.nats.clone();
1259        let paths = paths.into();
1260        let commands = self.commands.clone();
1261        let inbox = self.inbox.clone();
1262        let tasks = Arc::clone(&self.tasks);
1263        Ok(ReceiverStream::new(sub_rx).then(move |msg| {
1264            let tasks = Arc::clone(&tasks);
1265            let nats = nats.clone();
1266            let paths = Arc::clone(&paths);
1267            let commands = commands.clone();
1268            let rx = new_inbox(&inbox);
1269            async move { handle_message(&nats, rx, commands, msg, &paths, tasks).await }
1270        }))
1271    }
1272}