1#![allow(clippy::type_complexity)]
4
5use core::future::Future;
6use core::iter::zip;
7use core::ops::{Deref, DerefMut};
8use core::pin::{pin, Pin};
9use core::sync::atomic::AtomicUsize;
10use core::task::{ready, Context, Poll};
11use core::{mem, str};
12
13use std::collections::HashMap;
14use std::sync::Arc;
15
16use anyhow::{anyhow, bail, ensure, Context as _};
17use bytes::{Buf as _, BufMut as _, Bytes, BytesMut};
18use futures::sink::SinkExt as _;
19use futures::{Stream, StreamExt};
20use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
21use tokio::select;
22use tokio::sync::{mpsc, oneshot, watch};
23use tokio::task::JoinSet;
24use tokio_stream::wrappers::ReceiverStream;
25use tokio_util::sync::PollSender;
26use tracing::{debug, error, instrument, trace, warn};
27use wrpc_transport::Index as _;
28
29pub const PROTOCOL: &str = "wrpc.0.0.1";
30
31fn spawn_async(fut: impl Future<Output = ()> + Send + 'static) {
32 match tokio::runtime::Handle::try_current() {
33 Ok(rt) => {
34 rt.spawn(fut);
35 }
36 Err(_) => match tokio::runtime::Runtime::new() {
37 Ok(rt) => {
38 rt.spawn(fut);
39 }
40 Err(err) => error!(?err, "failed to create a new Tokio runtime"),
41 },
42 }
43}
44
45fn new_inbox(inbox: &[u8]) -> Bytes {
46 let id = nuid::next();
47 let mut s = BytesMut::with_capacity(inbox.len().saturating_add(id.len()));
48 s.extend_from_slice(inbox);
49 s.extend_from_slice(id.as_bytes());
50 s.freeze()
51}
52
53#[must_use]
54#[inline]
55pub fn param_subject(prefix: &[u8]) -> Bytes {
56 let mut s = BytesMut::with_capacity(prefix.len().saturating_add(".params".len()));
57 s.extend_from_slice(prefix);
58 s.extend_from_slice(b".params");
59 s.freeze()
60}
61
62#[must_use]
63#[inline]
64pub fn result_subject(prefix: &[u8]) -> Bytes {
65 let mut s = BytesMut::with_capacity(prefix.len().saturating_add(".results".len()));
66 s.extend_from_slice(prefix);
67 s.extend_from_slice(b".results");
68 s.freeze()
69}
70
71#[must_use]
72#[inline]
73pub fn index_path(prefix: &[u8], path: &[usize]) -> Bytes {
74 let mut s = BytesMut::with_capacity(prefix.len().saturating_add(path.len().saturating_mul(2)));
75 if !prefix.is_empty() {
76 s.extend_from_slice(prefix);
77 }
78 for p in path {
79 if !s.is_empty() {
80 s.put_u8(b'.');
81 }
82 s.extend_from_slice(p.to_string().as_bytes());
83 }
84 s.freeze()
85}
86
87#[must_use]
88#[inline]
89pub fn subscribe_path(prefix: &[u8], path: &[Option<usize>]) -> Bytes {
90 let mut s = BytesMut::with_capacity(prefix.len().saturating_add(path.len().saturating_mul(2)));
91 if !prefix.is_empty() {
92 s.extend_from_slice(prefix);
93 }
94 for p in path {
95 if !s.is_empty() {
96 s.put_u8(b'.');
97 }
98 if let Some(p) = p {
99 s.extend_from_slice(p.to_string().as_bytes());
100 } else {
101 s.put_u8(b'*');
102 }
103 }
104 s.freeze()
105}
106
107#[must_use]
108#[inline]
109pub fn invocation_subject(prefix: &[u8], instance: &str, func: &str) -> Bytes {
110 let mut s = BytesMut::with_capacity(
111 3_usize
112 .saturating_add(prefix.len())
113 .saturating_add(PROTOCOL.len())
114 .saturating_add(instance.len())
115 .saturating_add(func.len()),
116 );
117 if !prefix.is_empty() {
118 s.extend_from_slice(prefix);
119 s.put_u8(b'.');
120 }
121 s.extend_from_slice(PROTOCOL.as_bytes());
122 s.put_u8(b'.');
123 if !instance.is_empty() {
124 s.extend_from_slice(instance.as_bytes());
125 s.put_u8(b'.');
126 }
127 s.extend_from_slice(func.as_bytes());
128 s.freeze()
129}
130
131fn corrupted_memory_error() -> std::io::Error {
132 std::io::Error::new(std::io::ErrorKind::Other, "corrupted memory state")
133}
134
135pub struct Subscriber {
137 rx: ReceiverStream<Message>,
138 subject: Bytes,
139 commands: mpsc::Sender<Command>,
140 tasks: Arc<JoinSet<()>>,
141}
142
143impl Drop for Subscriber {
144 fn drop(&mut self) {
145 let commands = self.commands.clone();
146 let subject = mem::take(&mut self.subject);
147 let tasks = Arc::clone(&self.tasks);
148 spawn_async(async move {
149 trace!(?subject, "shutting down subscriber");
150 if let Err(err) = commands.send(Command::Unsubscribe(subject)).await {
151 warn!(?err, "failed to shutdown subscriber");
152 }
153 drop(tasks);
154 });
155 }
156}
157
158impl Deref for Subscriber {
159 type Target = ReceiverStream<Message>;
160
161 fn deref(&self) -> &Self::Target {
162 &self.rx
163 }
164}
165
166impl DerefMut for Subscriber {
167 fn deref_mut(&mut self) -> &mut Self::Target {
168 &mut self.rx
169 }
170}
171
172enum Command {
173 Subscribe(Bytes, mpsc::Sender<Message>),
174 Unsubscribe(Bytes),
175 Batch(Box<[Command]>),
176}
177
178pub struct Message {
180 reply: Bytes,
181 payload: Bytes,
182}
183
184#[derive(Clone, Debug)]
185pub struct Client {
186 nats: ants::Client,
187 prefix: Bytes,
188 inbox: Bytes,
189 queue_group: Bytes,
190 commands: mpsc::Sender<Command>,
191 sid: Arc<AtomicUsize>,
192 tasks: Arc<JoinSet<()>>,
193}
194
195impl Client {
196 fn next_sid(&self) -> Bytes {
197 Bytes::from(
198 self.sid
199 .fetch_add(1, core::sync::atomic::Ordering::Relaxed)
200 .to_string(),
201 )
202 }
203}
204
205#[derive(Default)]
206pub struct ClientBuilder {
207 prefix: Bytes,
208 queue_group: Bytes,
209}
210
211impl ClientBuilder {
212 pub async fn build(self, nats: ants::Client) -> anyhow::Result<Client> {
213 let id = nuid::next();
214 let mut subject = BytesMut::with_capacity(9_usize.saturating_add(id.len()));
215 subject.extend_from_slice(b"_INBOX.");
216 subject.extend_from_slice(id.as_bytes());
217 subject.put_u8(b'.');
218 let inbox = subject.clone().freeze();
219
220 subject.put_u8(b'>');
221 let (sub_tx, mut sub_rx) = mpsc::channel(8196);
222 nats.subscribe(subject, Bytes::default(), "0", sub_tx)
223 .await
224 .context("failed to subscribe on an inbox subject")?;
225
226 let mut tasks = JoinSet::new();
227 let (cmd_tx, mut cmd_rx) = mpsc::channel(8192);
228 tasks.spawn({
229 async move {
230 fn handle_command(subs: &mut HashMap<Bytes, mpsc::Sender<Message>>, cmd: Command) {
231 match cmd {
232 Command::Subscribe(s, tx) => {
233 subs.insert(s, tx);
234 }
235 Command::Unsubscribe(s) => {
236 subs.remove(&s);
237 }
238 Command::Batch(cmds) => {
239 for cmd in cmds {
240 handle_command(subs, cmd);
241 }
242 }
243 }
244 }
245 async fn handle_message(
246 subs: &mut HashMap<Bytes, mpsc::Sender<Message>>,
247 ants::Message {
248 subject,
249 reply,
250 payload,
251 ..
252 }: ants::Message,
253 ) {
254 let Some(sub) = subs.get_mut(&subject) else {
255 debug!(?subject, "drop message with no subscriber");
256 return;
257 };
258 let Ok(sub) = sub.reserve().await else {
259 debug!(?subject, "drop message with closed subscriber");
260 subs.remove(&subject);
261 return;
262 };
263 sub.send(Message { reply, payload });
264 }
265
266 let mut subs = HashMap::new();
267 loop {
268 select! {
269 Some(msg) = sub_rx.recv() => handle_message(&mut subs, msg).await,
270 Some(cmd) = cmd_rx.recv() => handle_command(&mut subs, cmd),
271 }
272 }
273 }
274 });
275 Ok(Client {
276 nats,
277 prefix: self.prefix,
278 inbox,
279 queue_group: self.queue_group,
280 commands: cmd_tx,
281 sid: Arc::new(AtomicUsize::new(1)),
282 tasks: Arc::new(tasks),
283 })
284 }
285}
286
287impl Client {
288 pub async fn new(nats: ants::Client) -> anyhow::Result<Self> {
289 ClientBuilder::default().build(nats).await
290 }
291}
292
293pub struct ByteSubscription(Subscriber);
294
295impl Stream for ByteSubscription {
296 type Item = std::io::Result<Bytes>;
297
298 #[instrument(level = "trace", skip_all)]
299 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
300 match self.0.poll_next_unpin(cx) {
301 Poll::Ready(Some(Message { payload, .. })) => Poll::Ready(Some(Ok(payload))),
302 Poll::Ready(None) => Poll::Ready(None),
303 Poll::Pending => Poll::Pending,
304 }
305 }
306}
307
308#[derive(Default)]
309enum IndexTrie {
310 #[default]
311 Empty,
312 Leaf(Subscriber),
313 IndexNode {
314 subscriber: Option<Subscriber>,
315 nested: Vec<Option<IndexTrie>>,
316 },
317 WildcardNode {
318 subscriber: Option<Subscriber>,
319 nested: Option<Box<IndexTrie>>,
320 },
321}
322
323impl<'a> From<(&'a [Option<usize>], Subscriber)> for IndexTrie {
324 fn from((path, sub): (&'a [Option<usize>], Subscriber)) -> Self {
325 match path {
326 [] => Self::Leaf(sub),
327 [None, path @ ..] => Self::WildcardNode {
328 subscriber: None,
329 nested: Some(Box::new(Self::from((path, sub)))),
330 },
331 [Some(i), path @ ..] => Self::IndexNode {
332 subscriber: None,
333 nested: {
334 let n = i.saturating_add(1);
335 let mut nested = Vec::with_capacity(n);
336 nested.resize_with(n, Option::default);
337 nested[*i] = Some(Self::from((path, sub)));
338 nested
339 },
340 },
341 }
342 }
343}
344
345impl<P: AsRef<[Option<usize>]>> FromIterator<(P, Subscriber)> for IndexTrie {
346 fn from_iter<T: IntoIterator<Item = (P, Subscriber)>>(iter: T) -> Self {
347 let mut root = Self::Empty;
348 for (path, sub) in iter {
349 if !root.insert(path.as_ref(), sub) {
350 return Self::Empty;
351 }
352 }
353 root
354 }
355}
356
357impl IndexTrie {
358 #[inline]
359 fn is_empty(&self) -> bool {
360 matches!(self, IndexTrie::Empty)
361 }
362
363 #[instrument(level = "trace", skip_all)]
364 fn take(&mut self, path: &[usize]) -> Option<Subscriber> {
365 let Some((i, path)) = path.split_first() else {
366 return match mem::take(self) {
367 IndexTrie::Empty | IndexTrie::WildcardNode { .. } => None,
378 IndexTrie::Leaf(subscriber) => Some(subscriber),
379 IndexTrie::IndexNode { subscriber, nested } => {
380 if !nested.is_empty() {
381 *self = IndexTrie::IndexNode {
382 subscriber: None,
383 nested,
384 }
385 }
386 subscriber
387 }
388 };
389 };
390 match self {
391 Self::Empty | Self::Leaf(..) | Self::WildcardNode { .. } => None,
396 Self::IndexNode { ref mut nested, .. } => nested
397 .get_mut(*i)
398 .and_then(|nested| nested.as_mut().and_then(|nested| nested.take(path))),
399 }
400 }
401
402 #[instrument(level = "trace", skip_all)]
405 fn insert(&mut self, path: &[Option<usize>], sub: Subscriber) -> bool {
406 match self {
407 Self::Empty => {
408 *self = Self::from((path, sub));
409 true
410 }
411 Self::Leaf(..) => {
412 let Some((i, path)) = path.split_first() else {
413 return false;
414 };
415 let Self::Leaf(subscriber) = mem::take(self) else {
416 return false;
417 };
418 if let Some(i) = i {
419 let n = i.saturating_add(1);
420 let mut nested = Vec::with_capacity(n);
421 nested.resize_with(n, Option::default);
422 nested[*i] = Some(Self::from((path, sub)));
423 *self = Self::IndexNode {
424 subscriber: Some(subscriber),
425 nested,
426 };
427 } else {
428 *self = Self::WildcardNode {
429 subscriber: Some(subscriber),
430 nested: Some(Box::new(Self::from((path, sub)))),
431 };
432 }
433 true
434 }
435 Self::WildcardNode {
436 ref mut subscriber,
437 ref mut nested,
438 } => match (&subscriber, path) {
439 (None, []) => {
440 *subscriber = Some(sub);
441 true
442 }
443 (_, [None, path @ ..]) => {
444 if let Some(nested) = nested {
445 nested.insert(path, sub)
446 } else {
447 *nested = Some(Box::new(Self::from((path, sub))));
448 true
449 }
450 }
451 _ => false,
452 },
453 Self::IndexNode {
454 ref mut subscriber,
455 ref mut nested,
456 } => match (&subscriber, path) {
457 (None, []) => {
458 *subscriber = Some(sub);
459 true
460 }
461 (_, [Some(i), path @ ..]) => {
462 let cap = i.saturating_add(1);
463 if nested.len() < cap {
464 nested.resize_with(cap, Option::default);
465 }
466 let nested = &mut nested[*i];
467 if let Some(nested) = nested {
468 nested.insert(path, sub)
469 } else {
470 *nested = Some(Self::from((path, sub)));
471 true
472 }
473 }
474 _ => false,
475 },
476 }
477 }
478}
479
480pub struct Reader {
481 buffer: Bytes,
482 incoming: Option<Subscriber>,
483 nested: Arc<std::sync::Mutex<IndexTrie>>,
484 path: Box<[usize]>,
485}
486
487impl wrpc_transport::Index<Self> for Reader {
488 #[instrument(level = "trace", skip(self))]
489 fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
490 ensure!(!path.is_empty());
491 trace!("locking index tree");
492 let mut nested = self
493 .nested
494 .lock()
495 .map_err(|err| anyhow!(err.to_string()).context("failed to lock map"))?;
496 trace!("taking index subscription");
497 let mut p = self.path.to_vec();
498 p.extend_from_slice(path);
499 let incoming = nested.take(&p);
500 Ok(Self {
501 buffer: Bytes::default(),
502 incoming,
503 nested: Arc::clone(&self.nested),
504 path: p.into_boxed_slice(),
505 })
506 }
507}
508
509impl AsyncRead for Reader {
510 #[instrument(level = "trace", skip_all, ret)]
511 fn poll_read(
512 mut self: Pin<&mut Self>,
513 cx: &mut Context<'_>,
514 buf: &mut ReadBuf<'_>,
515 ) -> Poll<std::io::Result<()>> {
516 let cap = buf.remaining();
517 if cap == 0 {
518 trace!("attempt to read empty buffer");
519 return Poll::Ready(Ok(()));
520 }
521
522 if !self.buffer.is_empty() {
523 if self.buffer.len() > cap {
524 trace!(cap, len = self.buffer.len(), "reading part of buffer");
525 buf.put_slice(&self.buffer.split_to(cap));
526 } else {
527 trace!(cap, len = self.buffer.len(), "reading full buffer");
528 buf.put_slice(&mem::take(&mut self.buffer));
529 }
530 return Poll::Ready(Ok(()));
531 }
532 let Some(incoming) = self.incoming.as_mut() else {
533 return Poll::Ready(Err(std::io::Error::new(
534 std::io::ErrorKind::NotFound,
535 format!("subscription not found for path {:?}", self.path),
536 )));
537 };
538 trace!("polling for next message");
539 match incoming.poll_next_unpin(cx) {
540 Poll::Ready(Some(Message { mut payload, .. })) => {
541 trace!(?payload, "received message");
542 if payload.is_empty() {
543 trace!("received stream shutdown message");
544 return Poll::Ready(Ok(()));
545 }
546 if payload.len() > cap {
547 trace!(len = payload.len(), cap, "partially reading the message");
548 buf.put_slice(&payload.split_to(cap));
549 self.buffer = payload;
550 } else {
551 trace!(len = payload.len(), cap, "filling the buffer with payload");
552 buf.put_slice(&payload);
553 }
554 Poll::Ready(Ok(()))
555 }
556 Poll::Ready(None) => {
557 trace!("subscription finished");
558 Poll::Ready(Ok(()))
559 }
560 Poll::Pending => Poll::Pending,
561 }
562 }
563}
564
565#[derive(Clone)]
566pub struct SubjectWriter {
567 nats: PollSender<ants::Command>,
568 info: watch::Receiver<Arc<ants::protocol::InfoOptions>>,
569 tx: Bytes,
570 shutdown: bool,
571 tasks: Arc<JoinSet<()>>,
572}
573
574impl SubjectWriter {
575 fn new(
576 nats: PollSender<ants::Command>,
577 info: watch::Receiver<Arc<ants::protocol::InfoOptions>>,
578 tx: Bytes,
579 tasks: Arc<JoinSet<()>>,
580 ) -> Self {
581 Self {
582 nats,
583 info,
584 tx,
585 shutdown: false,
586 tasks,
587 }
588 }
589}
590
591impl wrpc_transport::Index<Self> for SubjectWriter {
592 #[instrument(level = "trace", skip(self))]
593 fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
594 ensure!(!path.is_empty());
595 let tx = index_path(&self.tx, path);
596 Ok(Self {
597 nats: self.nats.clone(),
598 info: self.info.clone(),
599 tx,
600 shutdown: false,
601 tasks: Arc::clone(&self.tasks),
602 })
603 }
604}
605
606impl AsyncWrite for SubjectWriter {
607 #[instrument(level = "trace", skip_all, ret, fields(subject = ?self.tx, buf = format!("{buf:02x?}")))]
608 fn poll_write(
609 mut self: Pin<&mut Self>,
610 cx: &mut Context<'_>,
611 mut buf: &[u8],
612 ) -> Poll<std::io::Result<usize>> {
613 trace!("polling for readiness");
614 match self.nats.poll_ready_unpin(cx) {
615 Poll::Pending => return Poll::Pending,
616 Poll::Ready(Err(..)) => return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())),
617 Poll::Ready(Ok(())) => {}
618 }
619 let max_payload = self.info.borrow().max_payload;
620 if max_payload == 0 {
621 return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into()));
622 }
623 let total = buf.len();
624 let cmd = if total < max_payload {
625 ants::Command::Publish {
626 subject: self.tx.clone(),
627 payload: Bytes::copy_from_slice(buf),
628 reply: Bytes::default(),
629 headers: Bytes::default(),
630 }
631 } else {
632 let mut cap = max_payload.saturating_div(total);
633 let rem = max_payload % total;
634 if rem > 0 {
635 cap = cap.saturating_add(1);
636 }
637 let mut cmds = Vec::with_capacity(cap);
638 while buf.len() > max_payload {
639 (buf, _) = buf.split_at(max_payload);
640 cmds.push(ants::Command::Publish {
641 subject: self.tx.clone(),
642 payload: Bytes::copy_from_slice(buf),
643 reply: Bytes::default(),
644 headers: Bytes::default(),
645 })
646 }
647 if !buf.is_empty() {
648 cmds.push(ants::Command::Publish {
649 subject: self.tx.clone(),
650 payload: Bytes::copy_from_slice(buf),
651 reply: Bytes::default(),
652 headers: Bytes::default(),
653 })
654 }
655 ants::Command::Batch(cmds.into_boxed_slice())
656 };
657 trace!("starting send");
658 match self.nats.start_send_unpin(cmd) {
659 Ok(()) => Poll::Ready(Ok(total)),
660 Err(..) => Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe))),
661 }
662 }
663
664 #[instrument(level = "trace", skip_all, ret, fields(subject = ?self.tx))]
665 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
666 trace!("polling for readiness");
667 match self.nats.poll_ready_unpin(cx) {
668 Poll::Pending => return Poll::Pending,
669 Poll::Ready(Err(..)) => return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())),
670 Poll::Ready(Ok(())) => {}
671 }
672 trace!("flushing");
673 match self.nats.start_send_unpin(ants::Command::Flush) {
674 Ok(()) => Poll::Ready(Ok(())),
675 Err(..) => Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe))),
676 }
677 }
678
679 #[instrument(level = "trace", skip_all, ret, fields(subject = ?self.tx))]
680 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
681 trace!("writing stream shutdown message");
682 ready!(self.as_mut().poll_write(cx, &[]))?;
683 self.shutdown = true;
684 Poll::Ready(Ok(()))
685 }
686}
687
688impl Drop for SubjectWriter {
689 fn drop(&mut self) {
690 if !self.shutdown {
691 let mut nats = self.nats.clone();
692 let subject = mem::take(&mut self.tx);
693 let tasks = Arc::clone(&self.tasks);
694 spawn_async(async move {
695 trace!("writing stream shutdown message");
696 if let Err(_) = nats
697 .send(ants::Command::Publish {
698 subject,
699 reply: Bytes::default(),
700 headers: Bytes::default(),
701 payload: Bytes::default(),
702 })
703 .await
704 {
705 warn!("failed to publish stream shutdown message");
706 }
707 drop(tasks);
708 });
709 }
710 }
711}
712
713#[derive(Default)]
714pub enum RootParamWriter {
715 #[default]
716 Corrupted,
717 Handshaking {
718 nats: ants::Client,
719 info: watch::Receiver<Arc<ants::protocol::InfoOptions>>,
720 sub: Subscriber,
721 indexed: std::sync::Mutex<Vec<(Vec<usize>, oneshot::Sender<SubjectWriter>)>>,
722 buffer: Bytes,
723 tasks: Arc<JoinSet<()>>,
724 },
725 Draining {
726 tx: SubjectWriter,
727 buffer: Bytes,
728 },
729 Active(SubjectWriter),
730}
731
732impl RootParamWriter {
733 fn new(
734 nats: ants::Client,
735 info: watch::Receiver<Arc<ants::protocol::InfoOptions>>,
736 sub: Subscriber,
737 buffer: Bytes,
738 tasks: Arc<JoinSet<()>>,
739 ) -> Self {
740 Self::Handshaking {
741 nats,
742 info,
743 sub,
744 indexed: std::sync::Mutex::default(),
745 buffer,
746 tasks,
747 }
748 }
749}
750
751impl RootParamWriter {
752 #[instrument(level = "trace", skip_all, ret)]
753 fn poll_active(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
754 match &mut *self {
755 Self::Corrupted => Poll::Ready(Err(corrupted_memory_error())),
756 Self::Handshaking { sub, .. } => {
757 trace!("polling for handshake response");
758 match sub.poll_next_unpin(cx) {
759 Poll::Ready(Some(Message { reply: tx, .. })) => {
760 if tx.is_empty() {
761 return Poll::Ready(Err(std::io::Error::new(
762 std::io::ErrorKind::InvalidInput,
763 "peer did not specify a reply subject",
764 )));
765 }
766 let Self::Handshaking {
767 nats,
768 info,
769 indexed,
770 buffer,
771 tasks,
772 ..
773 } = mem::take(&mut *self)
774 else {
775 return Poll::Ready(Err(corrupted_memory_error()));
776 };
777 let tx = SubjectWriter::new(
778 PollSender::new(nats.commands().clone()),
779 info,
780 param_subject(&tx),
781 tasks,
782 );
783 let indexed = indexed.into_inner().map_err(|err| {
784 std::io::Error::new(std::io::ErrorKind::Other, err.to_string())
785 })?;
786 for (path, tx_tx) in indexed {
787 let tx = tx.index(&path).map_err(|err| {
788 std::io::Error::new(std::io::ErrorKind::Other, err)
789 })?;
790 tx_tx.send(tx).map_err(|_| {
791 std::io::Error::from(std::io::ErrorKind::BrokenPipe)
792 })?;
793 }
794 trace!("handshake succeeded");
795 if buffer.is_empty() {
796 *self = Self::Active(tx);
797 Poll::Ready(Ok(()))
798 } else {
799 *self = Self::Draining { tx, buffer };
800 self.poll_active(cx)
801 }
802 }
803 Poll::Ready(None) => {
804 *self = Self::Corrupted;
805 Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe)))
806 }
807 Poll::Pending => Poll::Pending,
808 }
809 }
810 Self::Draining { tx, buffer } => {
811 let mut tx = pin!(tx);
812 while !buffer.is_empty() {
813 trace!(?tx.tx, "draining parameter buffer");
814 match tx.as_mut().poll_write(cx, buffer) {
815 Poll::Ready(Ok(n)) => {
816 buffer.advance(n);
817 }
818 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
819 Poll::Pending => return Poll::Pending,
820 }
821 }
822 let Self::Draining { tx, .. } = mem::take(&mut *self) else {
823 return Poll::Ready(Err(corrupted_memory_error()));
824 };
825 trace!("parameter buffer draining succeeded");
826 *self = Self::Active(tx);
827 Poll::Ready(Ok(()))
828 }
829 Self::Active(..) => Poll::Ready(Ok(())),
830 }
831 }
832}
833
834impl wrpc_transport::Index<IndexedParamWriter> for RootParamWriter {
835 #[instrument(level = "trace", skip(self))]
836 fn index(&self, path: &[usize]) -> anyhow::Result<IndexedParamWriter> {
837 ensure!(!path.is_empty());
838 match self {
839 Self::Corrupted => Err(anyhow!(corrupted_memory_error())),
840 Self::Handshaking { indexed, .. } => {
841 let (tx_tx, tx_rx) = oneshot::channel();
842 let mut indexed = indexed.lock().map_err(|err| {
843 std::io::Error::new(std::io::ErrorKind::Other, err.to_string())
844 })?;
845 indexed.push((path.to_vec(), tx_tx));
846 Ok(IndexedParamWriter::Handshaking {
847 tx_rx,
848 indexed: std::sync::Mutex::default(),
849 })
850 }
851 Self::Draining { tx, .. } | Self::Active(tx) => {
852 tx.index(path).map(IndexedParamWriter::Active)
853 }
854 }
855 }
856}
857
858impl AsyncWrite for RootParamWriter {
859 #[instrument(level = "trace", skip_all, ret, fields(buf = format!("{buf:02x?}")))]
860 fn poll_write(
861 mut self: Pin<&mut Self>,
862 cx: &mut Context<'_>,
863 buf: &[u8],
864 ) -> Poll<std::io::Result<usize>> {
865 match self.as_mut().poll_active(cx)? {
866 Poll::Ready(()) => {
867 let Self::Active(tx) = &mut *self else {
868 return Poll::Ready(Err(corrupted_memory_error()));
869 };
870 trace!("writing buffer");
871 pin!(tx).poll_write(cx, buf)
872 }
873 Poll::Pending => Poll::Pending,
874 }
875 }
876
877 #[instrument(level = "trace", skip_all, ret)]
878 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
879 match self.as_mut().poll_active(cx)? {
880 Poll::Ready(()) => {
881 let Self::Active(tx) = &mut *self else {
882 return Poll::Ready(Err(corrupted_memory_error()));
883 };
884 trace!("flushing");
885 pin!(tx).poll_flush(cx)
886 }
887 Poll::Pending => Poll::Pending,
888 }
889 }
890
891 #[instrument(level = "trace", skip_all, ret)]
892 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
893 match self.as_mut().poll_active(cx)? {
894 Poll::Ready(()) => {
895 let Self::Active(tx) = &mut *self else {
896 return Poll::Ready(Err(corrupted_memory_error()));
897 };
898 trace!("shutting down");
899 pin!(tx).poll_shutdown(cx)
900 }
901 Poll::Pending => Poll::Pending,
902 }
903 }
904}
905
906#[derive(Default)]
907pub enum IndexedParamWriter {
908 #[default]
909 Corrupted,
910 Handshaking {
911 tx_rx: oneshot::Receiver<SubjectWriter>,
912 indexed: std::sync::Mutex<Vec<(Vec<usize>, oneshot::Sender<SubjectWriter>)>>,
913 },
914 Active(SubjectWriter),
915}
916
917impl IndexedParamWriter {
918 #[instrument(level = "trace", skip_all, ret)]
919 fn poll_active(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
920 match &mut *self {
921 Self::Corrupted => Poll::Ready(Err(corrupted_memory_error())),
922 Self::Handshaking { tx_rx, .. } => {
923 trace!("polling for handshake");
924 match pin!(tx_rx).poll(cx) {
925 Poll::Ready(Ok(tx)) => {
926 let Self::Handshaking { indexed, .. } = mem::take(&mut *self) else {
927 return Poll::Ready(Err(corrupted_memory_error()));
928 };
929 let indexed = indexed.into_inner().map_err(|err| {
930 std::io::Error::new(std::io::ErrorKind::Other, err.to_string())
931 })?;
932 for (path, tx_tx) in indexed {
933 let tx = tx.index(&path).map_err(|err| {
934 std::io::Error::new(std::io::ErrorKind::Other, err)
935 })?;
936 tx_tx.send(tx).map_err(|_| {
937 std::io::Error::from(std::io::ErrorKind::BrokenPipe)
938 })?;
939 }
940 *self = Self::Active(tx);
941 Poll::Ready(Ok(()))
942 }
943 Poll::Ready(Err(..)) => Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())),
944 Poll::Pending => Poll::Pending,
945 }
946 }
947 Self::Active(..) => Poll::Ready(Ok(())),
948 }
949 }
950}
951
952impl wrpc_transport::Index<Self> for IndexedParamWriter {
953 #[instrument(level = "trace", skip_all)]
954 fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
955 ensure!(!path.is_empty());
956 match self {
957 Self::Corrupted => Err(anyhow!(corrupted_memory_error())),
958 Self::Handshaking { indexed, .. } => {
959 let (tx_tx, tx_rx) = oneshot::channel();
960 let mut indexed = indexed.lock().map_err(|err| {
961 std::io::Error::new(std::io::ErrorKind::Other, err.to_string())
962 })?;
963 indexed.push((path.to_vec(), tx_tx));
964 Ok(Self::Handshaking {
965 tx_rx,
966 indexed: std::sync::Mutex::default(),
967 })
968 }
969 Self::Active(tx) => tx.index(path).map(Self::Active),
970 }
971 }
972}
973
974impl AsyncWrite for IndexedParamWriter {
975 #[instrument(level = "trace", skip_all, ret, fields(buf = format!("{buf:02x?}")))]
976 fn poll_write(
977 mut self: Pin<&mut Self>,
978 cx: &mut Context<'_>,
979 buf: &[u8],
980 ) -> Poll<std::io::Result<usize>> {
981 match self.as_mut().poll_active(cx)? {
982 Poll::Ready(()) => {
983 let Self::Active(tx) = &mut *self else {
984 return Poll::Ready(Err(corrupted_memory_error()));
985 };
986 trace!("writing buffer");
987 pin!(tx).poll_write(cx, buf)
988 }
989 Poll::Pending => Poll::Pending,
990 }
991 }
992
993 #[instrument(level = "trace", skip_all, ret)]
994 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
995 match self.as_mut().poll_active(cx)? {
996 Poll::Ready(()) => {
997 let Self::Active(tx) = &mut *self else {
998 return Poll::Ready(Err(corrupted_memory_error()));
999 };
1000 trace!("flushing");
1001 pin!(tx).poll_flush(cx)
1002 }
1003 Poll::Pending => Poll::Pending,
1004 }
1005 }
1006
1007 #[instrument(level = "trace", skip_all, ret)]
1008 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1009 match self.as_mut().poll_active(cx)? {
1010 Poll::Ready(()) => {
1011 let Self::Active(tx) = &mut *self else {
1012 return Poll::Ready(Err(corrupted_memory_error()));
1013 };
1014 trace!("shutting down");
1015 pin!(tx).poll_shutdown(cx)
1016 }
1017 Poll::Pending => Poll::Pending,
1018 }
1019 }
1020}
1021
1022pub enum ParamWriter {
1023 Root(RootParamWriter),
1024 Nested(IndexedParamWriter),
1025}
1026
1027impl wrpc_transport::Index<Self> for ParamWriter {
1028 fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
1029 ensure!(!path.is_empty());
1030 match self {
1031 ParamWriter::Root(w) => w.index(path),
1032 ParamWriter::Nested(w) => w.index(path),
1033 }
1034 .map(Self::Nested)
1035 }
1036}
1037
1038impl AsyncWrite for ParamWriter {
1039 #[instrument(level = "trace", skip_all, ret, fields(buf = format!("{buf:02x?}")))]
1040 fn poll_write(
1041 mut self: Pin<&mut Self>,
1042 cx: &mut Context<'_>,
1043 buf: &[u8],
1044 ) -> Poll<std::io::Result<usize>> {
1045 match &mut *self {
1046 ParamWriter::Root(w) => pin!(w).poll_write(cx, buf),
1047 ParamWriter::Nested(w) => pin!(w).poll_write(cx, buf),
1048 }
1049 }
1050
1051 #[instrument(level = "trace", skip_all, ret)]
1052 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1053 match &mut *self {
1054 ParamWriter::Root(w) => pin!(w).poll_flush(cx),
1055 ParamWriter::Nested(w) => pin!(w).poll_flush(cx),
1056 }
1057 }
1058
1059 #[instrument(level = "trace", skip_all, ret)]
1060 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1061 match &mut *self {
1062 ParamWriter::Root(w) => pin!(w).poll_shutdown(cx),
1063 ParamWriter::Nested(w) => pin!(w).poll_shutdown(cx),
1064 }
1065 }
1066}
1067
1068impl wrpc_transport::Invoke for Client {
1069 type Context = Bytes;
1070 type Outgoing = ParamWriter;
1071 type Incoming = Reader;
1072
1073 #[instrument(level = "trace", skip(self, paths, params), fields(params = format!("{params:02x?}")))]
1074 async fn invoke<P: AsRef<[Option<usize>]> + Send + Sync>(
1075 &self,
1076 headers: Self::Context,
1077 instance: &str,
1078 func: &str,
1079 mut params: Bytes,
1080 paths: impl AsRef<[P]> + Send,
1081 ) -> anyhow::Result<(Self::Outgoing, Self::Incoming)> {
1082 let paths = paths.as_ref();
1083 let mut cmds = Vec::with_capacity(paths.len().saturating_add(2));
1084
1085 let rx = new_inbox(&self.inbox);
1086 let (handshake_tx, handshake_rx) = mpsc::channel(1);
1087 cmds.push(Command::Subscribe(rx.clone(), handshake_tx));
1088
1089 let result = result_subject(&rx);
1090 let (result_tx, result_rx) = mpsc::channel(16);
1091 cmds.push(Command::Subscribe(result.clone(), result_tx));
1092
1093 let nested = paths.iter().map(|path| {
1094 let (tx, rx) = mpsc::channel(16);
1095 let subject = subscribe_path(&result, path.as_ref());
1096 cmds.push(Command::Subscribe(subject.clone(), tx));
1097 Subscriber {
1098 rx: ReceiverStream::new(rx),
1099 commands: self.commands.clone(),
1100 subject,
1101 tasks: Arc::clone(&self.tasks),
1102 }
1103 });
1104 let nested: IndexTrie = zip(paths.iter(), nested).collect();
1105 ensure!(
1106 paths.is_empty() == nested.is_empty(),
1107 "failed to construct subscription tree"
1108 );
1109
1110 self.commands
1111 .send(Command::Batch(cmds.into_boxed_slice()))
1112 .await
1113 .context("failed to subscribe")?;
1114
1115 let info = self
1116 .nats
1117 .server_info()
1118 .await
1119 .context("failed to get server info")?;
1120 let mut max_payload = info.borrow().max_payload;
1121 let param_tx = invocation_subject(&self.prefix, instance, func);
1122 if !headers.is_empty() {
1123 max_payload = max_payload.saturating_sub(headers.len());
1124 }
1125 trace!("publishing handshake");
1126 self.nats
1127 .publish(
1128 param_tx,
1129 rx.clone(),
1130 headers,
1131 params.split_to(max_payload.min(params.len())),
1132 )
1133 .await
1134 .context("failed to publish handshake")?;
1135 Ok((
1136 ParamWriter::Root(RootParamWriter::new(
1137 self.nats.clone(),
1138 info,
1139 Subscriber {
1140 rx: ReceiverStream::new(handshake_rx),
1141 commands: self.commands.clone(),
1142 subject: rx,
1143 tasks: Arc::clone(&self.tasks),
1144 },
1145 params,
1146 Arc::clone(&self.tasks),
1147 )),
1148 Reader {
1149 buffer: Bytes::default(),
1150 incoming: Some(Subscriber {
1151 rx: ReceiverStream::new(result_rx),
1152 commands: self.commands.clone(),
1153 subject: result,
1154 tasks: Arc::clone(&self.tasks),
1155 }),
1156 nested: Arc::new(std::sync::Mutex::new(nested)),
1157 path: Box::default(),
1158 },
1159 ))
1160 }
1161}
1162
1163async fn handle_message(
1164 nats: &ants::Client,
1165 rx: Bytes,
1166 commands: mpsc::Sender<Command>,
1167 ants::Message {
1168 reply: tx,
1169 payload,
1170 headers,
1171 ..
1172 }: ants::Message,
1173 paths: &[Box<[Option<usize>]>],
1174 tasks: Arc<JoinSet<()>>,
1175) -> anyhow::Result<(Bytes, SubjectWriter, Reader)> {
1176 if tx.is_empty() {
1177 bail!("peer did not specify a reply subject")
1178 }
1179
1180 let mut cmds = Vec::with_capacity(paths.len().saturating_add(1));
1181
1182 let param = Bytes::from(param_subject(&rx));
1183 let (param_tx, param_rx) = mpsc::channel(16);
1184 cmds.push(Command::Subscribe(param.clone(), param_tx));
1185
1186 let nested = paths.iter().map(|path| {
1187 let (tx, rx) = mpsc::channel(16);
1188 let subject = Bytes::from(subscribe_path(¶m, path.as_ref()));
1189 cmds.push(Command::Subscribe(subject.clone(), tx));
1190 Subscriber {
1191 rx: ReceiverStream::new(rx),
1192 commands: commands.clone(),
1193 subject,
1194 tasks: Arc::clone(&tasks),
1195 }
1196 });
1197 let nested: IndexTrie = zip(paths.iter(), nested).collect();
1198 ensure!(
1199 paths.is_empty() == nested.is_empty(),
1200 "failed to construct subscription tree"
1201 );
1202
1203 commands
1204 .send(Command::Batch(cmds.into_boxed_slice()))
1205 .await
1206 .context("failed to subscribe")?;
1207
1208 trace!("publishing handshake response");
1209 nats.publish(tx.clone(), rx, Bytes::default(), Bytes::default())
1210 .await
1211 .context("failed to publish handshake accept")?;
1212 let info = nats
1213 .server_info()
1214 .await
1215 .context("failed to get server info")?;
1216 Ok((
1217 headers,
1218 SubjectWriter::new(
1219 PollSender::new(nats.commands().clone()),
1220 info,
1221 result_subject(&tx),
1222 Arc::clone(&tasks),
1223 ),
1224 Reader {
1225 buffer: payload,
1226 incoming: Some(Subscriber {
1227 rx: ReceiverStream::new(param_rx),
1228 commands,
1229 subject: param,
1230 tasks,
1231 }),
1232 nested: Arc::new(std::sync::Mutex::new(nested)),
1233 path: Box::default(),
1234 },
1235 ))
1236}
1237
1238impl wrpc_transport::Serve for Client {
1239 type Context = Bytes;
1240 type Outgoing = SubjectWriter;
1241 type Incoming = Reader;
1242
1243 #[instrument(level = "trace", skip(self, paths))]
1244 async fn serve(
1245 &self,
1246 instance: &str,
1247 func: &str,
1248 paths: impl Into<Arc<[Box<[Option<usize>]>]>> + Send,
1249 ) -> anyhow::Result<
1250 impl Stream<Item = anyhow::Result<(Self::Context, Self::Outgoing, Self::Incoming)>> + 'static,
1251 > {
1252 let subject = invocation_subject(&self.prefix, instance, func);
1253 debug!(?subject, "subscribing on invocation subject");
1254 let (sub_tx, sub_rx) = mpsc::channel(256);
1255 self.nats
1256 .subscribe(subject, self.queue_group.clone(), self.next_sid(), sub_tx)
1257 .await?;
1258 let nats = self.nats.clone();
1259 let paths = paths.into();
1260 let commands = self.commands.clone();
1261 let inbox = self.inbox.clone();
1262 let tasks = Arc::clone(&self.tasks);
1263 Ok(ReceiverStream::new(sub_rx).then(move |msg| {
1264 let tasks = Arc::clone(&tasks);
1265 let nats = nats.clone();
1266 let paths = Arc::clone(&paths);
1267 let commands = commands.clone();
1268 let rx = new_inbox(&inbox);
1269 async move { handle_message(&nats, rx, commands, msg, &paths, tasks).await }
1270 }))
1271 }
1272}