Skip to main content

vox_types/
channel.rs

1use std::marker::PhantomData;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::sync::Mutex;
5use std::sync::atomic::{AtomicBool, Ordering};
6
7use facet::Facet;
8use facet_core::PtrConst;
9use moire::sync::{Notify, Semaphore, mpsc};
10
11use crate::ChannelId;
12use crate::{Backing, ChannelClose, ChannelItem, ChannelReset, Metadata, Payload, SelfRef};
13
14// ---------------------------------------------------------------------------
15// Thread-local channel binder — set during deserialization so TryFrom impls
16// can bind channels immediately.
17// ---------------------------------------------------------------------------
18
19std::thread_local! {
20    static CHANNEL_BINDER: std::cell::RefCell<Option<&'static dyn ChannelBinder>> =
21        const { std::cell::RefCell::new(None) };
22}
23
24/// Set the thread-local channel binder for the duration of `f`.
25///
26/// Any `Tx<T>` or `Rx<T>` deserialized (via `TryFrom<ChannelId>`) during `f`
27/// will be bound through this binder.
28pub fn with_channel_binder<R>(binder: &dyn ChannelBinder, f: impl FnOnce() -> R) -> R {
29    // SAFETY: we restore the previous value (always None in practice) on exit,
30    // so the binder reference doesn't escape the closure's lifetime.
31    #[allow(unsafe_code)]
32    let static_ref: &'static dyn ChannelBinder = unsafe { std::mem::transmute(binder) };
33    CHANNEL_BINDER.with(|cell| {
34        let prev = cell.borrow_mut().replace(static_ref);
35        let result = f();
36        *cell.borrow_mut() = prev;
37        result
38    })
39}
40
41// r[impl rpc.channel.pair]
42/// The binding stored in a channel core — either a sink or a receiver, never both.
43pub enum ChannelBinding {
44    Sink(BoundChannelSink),
45    Receiver(BoundChannelReceiver),
46}
47
48pub trait ChannelLiveness: crate::MaybeSend + crate::MaybeSync + 'static {}
49
50impl<T: crate::MaybeSend + crate::MaybeSync + 'static> ChannelLiveness for T {}
51
52pub type ChannelLivenessHandle = Arc<dyn ChannelLiveness>;
53
54pub trait ChannelCreditReplenisher: crate::MaybeSend + crate::MaybeSync + 'static {
55    fn on_item_consumed(&self);
56}
57
58pub type ChannelCreditReplenisherHandle = Arc<dyn ChannelCreditReplenisher>;
59
60#[derive(Clone)]
61pub struct BoundChannelSink {
62    pub sink: Arc<dyn ChannelSink>,
63    pub liveness: Option<ChannelLivenessHandle>,
64}
65
66pub struct BoundChannelReceiver {
67    pub receiver: mpsc::Receiver<IncomingChannelMessage>,
68    pub liveness: Option<ChannelLivenessHandle>,
69    pub replenisher: Option<ChannelCreditReplenisherHandle>,
70}
71
72struct LogicalReceiverState {
73    generation: u64,
74    liveness: Option<ChannelLivenessHandle>,
75    sender: Option<mpsc::Sender<LogicalIncomingChannelMessage>>,
76    receiver: Option<mpsc::Receiver<LogicalIncomingChannelMessage>>,
77}
78
79// r[impl rpc.channel.pair]
80/// Shared state between a `Tx`/`Rx` pair created by `channel()`.
81///
82/// Contains a `Mutex<Option<ChannelBinding>>` that is written once during
83/// binding and read/taken by the paired handle. The mutex is only locked
84/// during binding (once) and on first use by the paired handle (once).
85pub struct ChannelCore {
86    binding: Mutex<Option<ChannelBinding>>,
87    logical_receiver: Mutex<Option<LogicalReceiverState>>,
88    binding_changed: Notify,
89}
90
91impl ChannelCore {
92    fn new() -> Self {
93        Self {
94            binding: Mutex::new(None),
95            logical_receiver: Mutex::new(None),
96            binding_changed: Notify::new("vox_types.channel.binding_changed"),
97        }
98    }
99
100    /// Store or replace a binding in the core.
101    pub fn set_binding(&self, binding: ChannelBinding) {
102        let mut guard = self.binding.lock().expect("channel core mutex poisoned");
103        *guard = Some(binding);
104        self.binding_changed.notify_waiters();
105    }
106
107    /// Clone the sink from the core (for Tx reading the sink).
108    /// Returns None if no sink has been set or if the binding is a Receiver.
109    pub fn get_sink(&self) -> Option<Arc<dyn ChannelSink>> {
110        let guard = self.binding.lock().expect("channel core mutex poisoned");
111        match guard.as_ref() {
112            Some(ChannelBinding::Sink(bound)) => Some(bound.sink.clone()),
113            _ => None,
114        }
115    }
116
117    /// Take the receiver out of the core (for Rx on first recv).
118    /// Returns None if no receiver has been set or if it was already taken.
119    pub fn take_receiver(&self) -> Option<BoundChannelReceiver> {
120        let mut guard = self.binding.lock().expect("channel core mutex poisoned");
121        match guard.take() {
122            Some(ChannelBinding::Receiver(bound)) => Some(bound),
123            other => {
124                // Put it back if it wasn't a receiver
125                *guard = other;
126                None
127            }
128        }
129    }
130
131    pub fn bind_retryable_receiver(self: &Arc<Self>, bound: BoundChannelReceiver) {
132        #[cfg(not(target_arch = "wasm32"))]
133        if tokio::runtime::Handle::try_current().is_err() {
134            self.set_binding(ChannelBinding::Receiver(bound));
135            return;
136        }
137
138        let mut guard = self
139            .logical_receiver
140            .lock()
141            .expect("channel core logical receiver mutex poisoned");
142        let state = guard.get_or_insert_with(|| {
143            let (tx, rx) = mpsc::channel("vox_types.channel.logical_receiver", 64);
144            LogicalReceiverState {
145                generation: 0,
146                liveness: None,
147                sender: Some(tx),
148                receiver: Some(rx),
149            }
150        });
151        state.generation = state.generation.wrapping_add(1);
152        state.liveness = bound.liveness.clone();
153        let generation = state.generation;
154
155        let Some(sender) = state.sender.clone() else {
156            return;
157        };
158
159        self.binding_changed.notify_waiters();
160
161        drop(guard);
162        let core = Arc::clone(self);
163
164        moire::task::spawn(async move {
165            let mut receiver = bound.receiver;
166            let replenisher = bound.replenisher.clone();
167            while let Some(msg) = receiver.recv().await {
168                let is_current_generation = {
169                    let guard = core
170                        .logical_receiver
171                        .lock()
172                        .expect("channel core logical receiver mutex poisoned");
173                    guard
174                        .as_ref()
175                        .map(|state| state.generation == generation)
176                        .unwrap_or(false)
177                };
178                if !is_current_generation {
179                    return;
180                }
181                let forwarded = LogicalIncomingChannelMessage {
182                    msg,
183                    replenisher: replenisher.clone(),
184                };
185                if sender.send(forwarded).await.is_err() {
186                    return;
187                }
188            }
189        });
190    }
191
192    pub fn take_logical_receiver(
193        &self,
194    ) -> Option<(
195        mpsc::Receiver<LogicalIncomingChannelMessage>,
196        Option<ChannelLivenessHandle>,
197    )> {
198        self.logical_receiver
199            .lock()
200            .expect("channel core logical receiver mutex poisoned")
201            .as_mut()
202            .and_then(|state| {
203                state
204                    .receiver
205                    .take()
206                    .map(|receiver| (receiver, state.liveness.clone()))
207            })
208    }
209
210    pub fn finish_retry_binding(&self) {
211        let mut guard = self
212            .logical_receiver
213            .lock()
214            .expect("channel core logical receiver mutex poisoned");
215        if let Some(state) = guard.as_mut() {
216            if let Some(sender) = state.sender.as_ref() {
217                let close = SelfRef::owning(
218                    Backing::Boxed(Box::<[u8]>::default()),
219                    ChannelClose {
220                        metadata: Metadata::default(),
221                    },
222                );
223                let _ = sender.try_send(LogicalIncomingChannelMessage {
224                    msg: IncomingChannelMessage::Close(close),
225                    replenisher: None,
226                });
227            }
228            state.sender.take();
229        }
230        *guard = None;
231        let mut guard = self.binding.lock().expect("channel core mutex poisoned");
232        *guard = None;
233        self.binding_changed.notify_waiters();
234    }
235}
236
237/// Slot for the shared channel core, accessible via facet reflection.
238#[derive(Facet)]
239#[facet(opaque)]
240pub(crate) struct CoreSlot {
241    pub(crate) inner: Option<Arc<ChannelCore>>,
242}
243
244impl CoreSlot {
245    pub(crate) fn empty() -> Self {
246        Self { inner: None }
247    }
248}
249
250// r[impl rpc.channel.pair]
251/// Create a channel pair with shared state.
252///
253/// Both ends hold an `Arc` reference to the same `ChannelCore`. The framework
254/// binds the handle that appears in args or return values, and the paired
255/// handle reads or takes the binding from the shared core.
256pub fn channel<T>() -> (Tx<T>, Rx<T>) {
257    let core = Arc::new(ChannelCore::new());
258    (Tx::paired(core.clone()), Rx::paired(core))
259}
260
261/// Runtime sink implemented by the session driver.
262///
263/// The contract is strict: successful completion means the item has gone
264/// through the conduit to the link commit boundary.
265pub trait ChannelSink: crate::MaybeSend + crate::MaybeSync + 'static {
266    fn send_payload<'payload>(
267        &self,
268        payload: Payload<'payload>,
269    ) -> Pin<Box<dyn crate::MaybeSendFuture<Output = Result<(), TxError>> + 'payload>>;
270
271    fn close_channel(
272        &self,
273        metadata: Metadata,
274    ) -> Pin<Box<dyn crate::MaybeSendFuture<Output = Result<(), TxError>> + 'static>>;
275
276    /// Synchronous drop-time close signal.
277    ///
278    /// This is used by `Tx::drop` to notify the runtime immediately without
279    /// spawning detached tasks. Implementations should enqueue a close intent
280    /// to their runtime/driver if possible.
281    fn close_channel_on_drop(&self) {}
282}
283
284// r[impl rpc.flow-control.credit]
285// r[impl rpc.flow-control.credit.exhaustion]
286/// A [`ChannelSink`] wrapper that enforces credit-based flow control.
287///
288/// Each `send_payload` acquires one permit from the semaphore, blocking if
289/// credit is zero. The semaphore is shared with the driver so that incoming
290/// `GrantCredit` messages can add permits via [`CreditSink::credit`].
291pub struct CreditSink<S: ChannelSink> {
292    inner: S,
293    credit: Arc<Semaphore>,
294}
295
296impl<S: ChannelSink> CreditSink<S> {
297    // r[impl rpc.flow-control.credit.initial]
298    // r[impl rpc.flow-control.credit.initial.zero]
299    /// Wrap `inner` with `initial_credit` permits (the const generic `N`).
300    pub fn new(inner: S, initial_credit: u32) -> Self {
301        Self {
302            inner,
303            credit: Arc::new(Semaphore::new(
304                "vox_types.channel.credit",
305                initial_credit as usize,
306            )),
307        }
308    }
309
310    /// Returns the credit semaphore. The driver holds a clone so
311    /// `GrantCredit` messages can call `add_permits`.
312    pub fn credit(&self) -> &Arc<Semaphore> {
313        &self.credit
314    }
315}
316
317impl<S: ChannelSink> ChannelSink for CreditSink<S> {
318    fn send_payload<'payload>(
319        &self,
320        payload: Payload<'payload>,
321    ) -> Pin<Box<dyn crate::MaybeSendFuture<Output = Result<(), TxError>> + 'payload>> {
322        let credit = self.credit.clone();
323        let fut = self.inner.send_payload(payload);
324        Box::pin(async move {
325            let permit = credit
326                .acquire_owned()
327                .await
328                .map_err(|_| TxError::Transport("channel credit semaphore closed".into()))?;
329            std::mem::forget(permit);
330            fut.await
331        })
332    }
333
334    fn close_channel(
335        &self,
336        metadata: Metadata,
337    ) -> Pin<Box<dyn crate::MaybeSendFuture<Output = Result<(), TxError>> + 'static>> {
338        // Close does not consume credit — it's a control message.
339        self.inner.close_channel(metadata)
340    }
341
342    fn close_channel_on_drop(&self) {
343        self.inner.close_channel_on_drop();
344    }
345}
346
347/// Message delivered to an `Rx` by the driver.
348pub enum IncomingChannelMessage {
349    Item(SelfRef<ChannelItem<'static>>),
350    Close(SelfRef<ChannelClose<'static>>),
351    Reset(SelfRef<ChannelReset<'static>>),
352}
353
354pub struct LogicalIncomingChannelMessage {
355    pub msg: IncomingChannelMessage,
356    pub replenisher: Option<ChannelCreditReplenisherHandle>,
357}
358
359/// Sender-side runtime slot.
360#[derive(Facet)]
361#[facet(opaque)]
362pub(crate) struct SinkSlot {
363    pub(crate) inner: Option<Arc<dyn ChannelSink>>,
364}
365
366impl SinkSlot {
367    pub(crate) fn empty() -> Self {
368        Self { inner: None }
369    }
370}
371
372/// Opaque liveness retention slot for bound channel handles.
373#[derive(Facet)]
374#[facet(opaque)]
375pub(crate) struct LivenessSlot {
376    pub(crate) inner: Option<ChannelLivenessHandle>,
377}
378
379impl LivenessSlot {
380    pub(crate) fn empty() -> Self {
381        Self { inner: None }
382    }
383}
384
385/// Receiver-side runtime slot.
386#[derive(Facet)]
387#[facet(opaque)]
388pub(crate) struct ReceiverSlot {
389    pub(crate) inner: Option<mpsc::Receiver<IncomingChannelMessage>>,
390}
391
392impl ReceiverSlot {
393    pub(crate) fn empty() -> Self {
394        Self { inner: None }
395    }
396}
397
398#[derive(Facet)]
399#[facet(opaque)]
400pub(crate) struct LogicalReceiverSlot {
401    pub(crate) inner: Option<mpsc::Receiver<LogicalIncomingChannelMessage>>,
402}
403
404impl LogicalReceiverSlot {
405    pub(crate) fn empty() -> Self {
406        Self { inner: None }
407    }
408}
409
410/// Receiver-side credit replenishment slot.
411#[derive(Facet)]
412#[facet(opaque)]
413pub(crate) struct ReplenisherSlot {
414    pub(crate) inner: Option<ChannelCreditReplenisherHandle>,
415}
416
417impl ReplenisherSlot {
418    pub(crate) fn empty() -> Self {
419        Self { inner: None }
420    }
421}
422
423/// Sender handle: "I send". The holder of a `Tx<T>` sends items of type `T`.
424///
425/// In method args, the handler holds it (handler sends → caller).
426///
427/// Wire encoding is always unit (`()`), with channel IDs carried exclusively
428/// in `Message::Request.channels`.
429// r[impl rpc.channel]
430// r[impl rpc.channel.direction]
431// r[impl rpc.channel.payload-encoding]
432#[derive(Facet)]
433#[facet(proxy = crate::ChannelId)]
434pub struct Tx<T> {
435    pub(crate) channel_id: ChannelId,
436    pub(crate) sink: SinkSlot,
437    pub(crate) core: CoreSlot,
438    pub(crate) liveness: LivenessSlot,
439    #[facet(opaque)]
440    closed: AtomicBool,
441    #[facet(opaque)]
442    _marker: PhantomData<T>,
443}
444
445impl<T> Tx<T> {
446    /// Create a standalone unbound Tx (used by deserialization).
447    pub fn unbound() -> Self {
448        Self {
449            channel_id: ChannelId::RESERVED,
450            sink: SinkSlot::empty(),
451            core: CoreSlot::empty(),
452            liveness: LivenessSlot::empty(),
453            closed: AtomicBool::new(false),
454            _marker: PhantomData,
455        }
456    }
457
458    /// Create a Tx that is part of a `channel()` pair.
459    fn paired(core: Arc<ChannelCore>) -> Self {
460        Self {
461            channel_id: ChannelId::RESERVED,
462            sink: SinkSlot::empty(),
463            core: CoreSlot { inner: Some(core) },
464            liveness: LivenessSlot::empty(),
465            closed: AtomicBool::new(false),
466            _marker: PhantomData,
467        }
468    }
469
470    pub fn is_bound(&self) -> bool {
471        if self.sink.inner.is_some() {
472            return true;
473        }
474        if let Some(core) = &self.core.inner {
475            return core.get_sink().is_some();
476        }
477        false
478    }
479
480    /// Check if this Tx is part of a channel() pair (has a shared core).
481    pub fn has_core(&self) -> bool {
482        self.core.inner.is_some()
483    }
484
485    // r[impl rpc.channel.pair.tx-read]
486    fn resolve_sink_now(&self) -> Option<Arc<dyn ChannelSink>> {
487        // Fast path: local slot (standalone/callee-side handle)
488        if let Some(sink) = &self.sink.inner {
489            return Some(sink.clone());
490        }
491        // Slow path: read from shared core (paired handle)
492        if let Some(core) = &self.core.inner
493            && let Some(sink) = core.get_sink()
494        {
495            return Some(sink);
496        }
497        None
498    }
499
500    pub async fn send<'value>(&self, value: T) -> Result<(), TxError>
501    where
502        T: Facet<'value>,
503    {
504        let sink = if let Some(sink) = self.resolve_sink_now() {
505            sink
506        } else if let Some(core) = &self.core.inner {
507            loop {
508                let notified = core.binding_changed.notified();
509                if let Some(sink) = self.resolve_sink_now() {
510                    break sink;
511                }
512                notified.await;
513            }
514        } else {
515            return Err(TxError::Unbound);
516        };
517        let ptr = PtrConst::new((&value as *const T).cast::<u8>());
518        // SAFETY: `value` is explicitly dropped only after `await`, so the pointer
519        // remains valid for the whole send operation.
520        let payload = unsafe { Payload::outgoing_unchecked(ptr, T::SHAPE) };
521        let result = sink.send_payload(payload).await;
522        drop(value);
523        result
524    }
525
526    // r[impl rpc.channel.lifecycle]
527    pub async fn close<'value>(&self, metadata: Metadata<'value>) -> Result<(), TxError> {
528        self.closed.store(true, Ordering::Release);
529        let sink = if let Some(sink) = self.resolve_sink_now() {
530            sink
531        } else if let Some(core) = &self.core.inner {
532            loop {
533                let notified = core.binding_changed.notified();
534                if let Some(sink) = self.resolve_sink_now() {
535                    break sink;
536                }
537                notified.await;
538            }
539        } else {
540            return Err(TxError::Unbound);
541        };
542        sink.close_channel(metadata).await
543    }
544
545    #[doc(hidden)]
546    pub fn bind(&mut self, sink: Arc<dyn ChannelSink>) {
547        self.bind_with_liveness(sink, None);
548    }
549
550    #[doc(hidden)]
551    pub fn bind_with_liveness(
552        &mut self,
553        sink: Arc<dyn ChannelSink>,
554        liveness: Option<ChannelLivenessHandle>,
555    ) {
556        self.sink.inner = Some(sink);
557        self.liveness.inner = liveness;
558    }
559
560    #[doc(hidden)]
561    pub fn finish_retry_binding(&self) {
562        if let Some(core) = &self.core.inner {
563            core.finish_retry_binding();
564        }
565    }
566}
567
568impl<T> Drop for Tx<T> {
569    fn drop(&mut self) {
570        if self.closed.swap(true, Ordering::AcqRel) {
571            return;
572        }
573
574        let sink = if let Some(sink) = &self.sink.inner {
575            Some(sink.clone())
576        } else if let Some(core) = &self.core.inner {
577            core.get_sink()
578        } else {
579            None
580        };
581
582        let Some(sink) = sink else {
583            return;
584        };
585
586        // Synchronous signal into the runtime/driver; no detached async work here.
587        sink.close_channel_on_drop();
588    }
589}
590
591impl<T> TryFrom<&Tx<T>> for ChannelId {
592    type Error = String;
593
594    fn try_from(value: &Tx<T>) -> Result<Self, Self::Error> {
595        // Case 1: Caller passes Tx in args (callee sends, caller receives).
596        // Allocate a channel ID and store the receiver binding in the shared
597        // core so the caller's paired Rx can pick it up.
598        CHANNEL_BINDER.with(|cell| {
599            let borrow = cell.borrow();
600            let Some(binder) = *borrow else {
601                return Err("serializing Tx requires an active ChannelBinder".to_string());
602            };
603            let (channel_id, bound) = binder.create_rx();
604            if let Some(core) = &value.core.inner {
605                core.bind_retryable_receiver(bound);
606            }
607            Ok(channel_id)
608        })
609    }
610}
611
612impl<T> TryFrom<ChannelId> for Tx<T> {
613    type Error = String;
614
615    fn try_from(channel_id: ChannelId) -> Result<Self, Self::Error> {
616        let mut tx = Self::unbound();
617        tx.channel_id = channel_id;
618
619        CHANNEL_BINDER.with(|cell| {
620            let Some(binder) = *cell.borrow() else {
621                return Err("deserializing Tx requires an active ChannelBinder".to_string());
622            };
623            let sink = binder.bind_tx(channel_id);
624            let liveness = binder.channel_liveness();
625            tx.bind_with_liveness(sink, liveness);
626            Ok(())
627        })?;
628
629        Ok(tx)
630    }
631}
632
633/// Error when sending on a `Tx`.
634#[derive(Debug)]
635pub enum TxError {
636    Unbound,
637    Transport(String),
638}
639
640impl std::fmt::Display for TxError {
641    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
642        match self {
643            Self::Unbound => write!(f, "channel is not bound"),
644            Self::Transport(msg) => write!(f, "transport error: {msg}"),
645        }
646    }
647}
648
649impl std::error::Error for TxError {}
650
651/// Receiver handle: "I receive". The holder of an `Rx<T>` receives items of type `T`.
652///
653/// In method args, the handler holds it (handler receives ← caller).
654///
655/// Channel IDs are serialized inline in the postcard payload.
656#[derive(Facet)]
657#[facet(proxy = crate::ChannelId)]
658pub struct Rx<T> {
659    pub(crate) channel_id: ChannelId,
660    pub(crate) receiver: ReceiverSlot,
661    pub(crate) logical_receiver: LogicalReceiverSlot,
662    pub(crate) core: CoreSlot,
663    pub(crate) liveness: LivenessSlot,
664    pub(crate) replenisher: ReplenisherSlot,
665    #[facet(opaque)]
666    _marker: PhantomData<T>,
667}
668
669impl<T> Rx<T> {
670    /// Create a standalone unbound Rx (used by deserialization).
671    pub fn unbound() -> Self {
672        Self {
673            channel_id: ChannelId::RESERVED,
674            receiver: ReceiverSlot::empty(),
675            logical_receiver: LogicalReceiverSlot::empty(),
676            core: CoreSlot::empty(),
677            liveness: LivenessSlot::empty(),
678            replenisher: ReplenisherSlot::empty(),
679            _marker: PhantomData,
680        }
681    }
682
683    /// Create an Rx that is part of a `channel()` pair.
684    fn paired(core: Arc<ChannelCore>) -> Self {
685        Self {
686            channel_id: ChannelId::RESERVED,
687            receiver: ReceiverSlot::empty(),
688            logical_receiver: LogicalReceiverSlot::empty(),
689            core: CoreSlot { inner: Some(core) },
690            liveness: LivenessSlot::empty(),
691            replenisher: ReplenisherSlot::empty(),
692            _marker: PhantomData,
693        }
694    }
695
696    pub fn is_bound(&self) -> bool {
697        self.receiver.inner.is_some()
698    }
699
700    /// Check if this Rx is part of a channel() pair (has a shared core).
701    pub fn has_core(&self) -> bool {
702        self.core.inner.is_some()
703    }
704
705    // r[impl rpc.channel.pair.rx-take]
706    pub async fn recv(&mut self) -> Result<Option<SelfRef<T>>, RxError>
707    where
708        T: Facet<'static>,
709    {
710        loop {
711            if self.logical_receiver.inner.is_none()
712                && let Some(core) = &self.core.inner
713                && let Some((receiver, liveness)) = core.take_logical_receiver()
714            {
715                self.logical_receiver.inner = Some(receiver);
716                self.liveness.inner = liveness;
717            }
718
719            if let Some(receiver) = self.logical_receiver.inner.as_mut() {
720                match receiver.recv().await {
721                    Some(LogicalIncomingChannelMessage {
722                        msg: IncomingChannelMessage::Close(_),
723                        ..
724                    })
725                    | None => return Ok(None),
726                    Some(LogicalIncomingChannelMessage {
727                        msg: IncomingChannelMessage::Reset(_),
728                        ..
729                    }) => return Err(RxError::Reset),
730                    Some(LogicalIncomingChannelMessage {
731                        msg: IncomingChannelMessage::Item(msg),
732                        replenisher,
733                    }) => {
734                        let value = msg
735                            .try_repack(|item, _backing_bytes| {
736                                let Payload::PostcardBytes(bytes) = item.item else {
737                                    return Err(RxError::Protocol(
738                                        "incoming channel item payload was not Incoming".into(),
739                                    ));
740                                };
741                                vox_postcard::from_slice_borrowed(bytes)
742                                    .map_err(RxError::Deserialize)
743                            })
744                            .map(Some);
745                        if value.is_ok()
746                            && let Some(replenisher) = replenisher.as_ref()
747                        {
748                            replenisher.on_item_consumed();
749                        }
750                        return value;
751                    }
752                }
753            }
754
755            if self.receiver.inner.is_none()
756                && let Some(core) = &self.core.inner
757                && let Some(bound) = core.take_receiver()
758            {
759                self.receiver.inner = Some(bound.receiver);
760                self.liveness.inner = bound.liveness;
761                self.replenisher.inner = bound.replenisher;
762            }
763
764            if let Some(receiver) = self.receiver.inner.as_mut() {
765                return match receiver.recv().await {
766                    Some(IncomingChannelMessage::Close(_)) | None => Ok(None),
767                    Some(IncomingChannelMessage::Reset(_)) => Err(RxError::Reset),
768                    Some(IncomingChannelMessage::Item(msg)) => {
769                        let value = msg
770                            .try_repack(|item, _backing_bytes| {
771                                let Payload::PostcardBytes(bytes) = item.item else {
772                                    return Err(RxError::Protocol(
773                                        "incoming channel item payload was not Incoming".into(),
774                                    ));
775                                };
776                                vox_postcard::from_slice_borrowed(bytes)
777                                    .map_err(RxError::Deserialize)
778                            })
779                            .map(Some);
780                        if value.is_ok()
781                            && let Some(replenisher) = &self.replenisher.inner
782                        {
783                            replenisher.on_item_consumed();
784                        }
785                        value
786                    }
787                };
788            }
789
790            let Some(core) = &self.core.inner else {
791                return Err(RxError::Unbound);
792            };
793            core.binding_changed.notified().await;
794        }
795    }
796
797    #[doc(hidden)]
798    pub fn bind(&mut self, receiver: mpsc::Receiver<IncomingChannelMessage>) {
799        self.bind_with_liveness(receiver, None);
800    }
801
802    #[doc(hidden)]
803    pub fn bind_with_liveness(
804        &mut self,
805        receiver: mpsc::Receiver<IncomingChannelMessage>,
806        liveness: Option<ChannelLivenessHandle>,
807    ) {
808        self.receiver.inner = Some(receiver);
809        self.logical_receiver.inner = None;
810        self.liveness.inner = liveness;
811        self.replenisher.inner = None;
812    }
813}
814
815impl<T> TryFrom<&Rx<T>> for ChannelId {
816    type Error = String;
817
818    fn try_from(value: &Rx<T>) -> Result<Self, Self::Error> {
819        // Case 2: Caller passes Rx in args (callee receives, caller sends).
820        // Allocate a channel ID and store the sink binding in the shared
821        // core so the caller's paired Tx can pick it up.
822        CHANNEL_BINDER.with(|cell| {
823            let borrow = cell.borrow();
824            let Some(binder) = *borrow else {
825                return Err("serializing Rx requires an active ChannelBinder".to_string());
826            };
827            let (channel_id, sink) = binder.create_tx();
828            let liveness = binder.channel_liveness();
829            if let Some(core) = &value.core.inner {
830                core.set_binding(ChannelBinding::Sink(BoundChannelSink { sink, liveness }));
831            }
832            Ok(channel_id)
833        })
834    }
835}
836
837impl<T> TryFrom<ChannelId> for Rx<T> {
838    type Error = String;
839
840    fn try_from(channel_id: ChannelId) -> Result<Self, Self::Error> {
841        let mut rx = Self::unbound();
842        rx.channel_id = channel_id;
843
844        CHANNEL_BINDER.with(|cell| {
845            let Some(binder) = *cell.borrow() else {
846                return Err("deserializing Rx requires an active ChannelBinder".to_string());
847            };
848            let bound = binder.register_rx(channel_id);
849            rx.receiver.inner = Some(bound.receiver);
850            rx.liveness.inner = bound.liveness;
851            rx.replenisher.inner = bound.replenisher;
852            Ok(())
853        })?;
854
855        Ok(rx)
856    }
857}
858
859/// Error when receiving from an `Rx`.
860#[derive(Debug)]
861pub enum RxError {
862    Unbound,
863    Reset,
864    Deserialize(vox_postcard::error::DeserializeError),
865    Protocol(String),
866}
867
868impl std::fmt::Display for RxError {
869    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
870        match self {
871            Self::Unbound => write!(f, "channel is not bound"),
872            Self::Reset => write!(f, "channel reset by peer"),
873            Self::Deserialize(e) => write!(f, "deserialize error: {e}"),
874            Self::Protocol(msg) => write!(f, "protocol error: {msg}"),
875        }
876    }
877}
878
879impl std::error::Error for RxError {}
880
881/// Check if a shape represents a `Tx` channel.
882pub fn is_tx(shape: &facet_core::Shape) -> bool {
883    shape.decl_id == Tx::<()>::SHAPE.decl_id
884}
885
886/// Check if a shape represents an `Rx` channel.
887pub fn is_rx(shape: &facet_core::Shape) -> bool {
888    shape.decl_id == Rx::<()>::SHAPE.decl_id
889}
890
891/// Check if a shape represents any channel type (`Tx` or `Rx`).
892pub fn is_channel(shape: &facet_core::Shape) -> bool {
893    is_tx(shape) || is_rx(shape)
894}
895
896pub trait ChannelBinder: crate::MaybeSend + crate::MaybeSync {
897    /// Allocate a channel ID and create a sink for sending items.
898    ///
899    fn create_tx(&self) -> (ChannelId, Arc<dyn ChannelSink>);
900
901    /// Allocate a channel ID, register it for routing, and return a receiver.
902    fn create_rx(&self) -> (ChannelId, BoundChannelReceiver);
903
904    /// Create a sink for a known channel ID (callee side).
905    ///
906    /// The channel ID comes from `Request.channels`.
907    fn bind_tx(&self, channel_id: ChannelId) -> Arc<dyn ChannelSink>;
908
909    /// Register an inbound channel by ID and return the receiver (callee side).
910    ///
911    /// The channel ID comes from `Request.channels`.
912    fn register_rx(&self, channel_id: ChannelId) -> BoundChannelReceiver;
913
914    /// Optional opaque handle that keeps the underlying session/connection alive
915    /// for the lifetime of any bound channel handle.
916    fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
917        None
918    }
919}
920
921#[cfg(test)]
922mod tests {
923    use super::*;
924    use crate::{Backing, ChannelClose, ChannelItem, ChannelReset, Metadata, SelfRef};
925    use std::sync::atomic::{AtomicUsize, Ordering};
926
927    struct CountingSink {
928        send_calls: AtomicUsize,
929        close_calls: AtomicUsize,
930        close_on_drop_calls: AtomicUsize,
931    }
932
933    impl CountingSink {
934        fn new() -> Self {
935            Self {
936                send_calls: AtomicUsize::new(0),
937                close_calls: AtomicUsize::new(0),
938                close_on_drop_calls: AtomicUsize::new(0),
939            }
940        }
941    }
942
943    impl ChannelSink for CountingSink {
944        fn send_payload<'payload>(
945            &self,
946            _payload: Payload<'payload>,
947        ) -> Pin<Box<dyn crate::MaybeSendFuture<Output = Result<(), TxError>> + 'payload>> {
948            self.send_calls.fetch_add(1, Ordering::AcqRel);
949            Box::pin(async { Ok(()) })
950        }
951
952        fn close_channel(
953            &self,
954            _metadata: Metadata,
955        ) -> Pin<Box<dyn crate::MaybeSendFuture<Output = Result<(), TxError>> + 'static>> {
956            self.close_calls.fetch_add(1, Ordering::AcqRel);
957            Box::pin(async { Ok(()) })
958        }
959
960        fn close_channel_on_drop(&self) {
961            self.close_on_drop_calls.fetch_add(1, Ordering::AcqRel);
962        }
963    }
964
965    struct CountingReplenisher {
966        calls: AtomicUsize,
967    }
968
969    impl CountingReplenisher {
970        fn new() -> Self {
971            Self {
972                calls: AtomicUsize::new(0),
973            }
974        }
975    }
976
977    impl ChannelCreditReplenisher for CountingReplenisher {
978        fn on_item_consumed(&self) {
979            self.calls.fetch_add(1, Ordering::AcqRel);
980        }
981    }
982
983    #[tokio::test]
984    async fn tx_close_does_not_emit_drop_close_after_explicit_close() {
985        let sink_impl = Arc::new(CountingSink::new());
986        let sink: Arc<dyn ChannelSink> = sink_impl.clone();
987
988        let mut tx = Tx::<u32>::unbound();
989        tx.bind(sink);
990        tx.close(Metadata::default())
991            .await
992            .expect("close should succeed");
993        drop(tx);
994
995        assert_eq!(sink_impl.close_calls.load(Ordering::Acquire), 1);
996        assert_eq!(sink_impl.close_on_drop_calls.load(Ordering::Acquire), 0);
997    }
998
999    #[test]
1000    fn tx_drop_emits_close_on_drop_for_bound_sink() {
1001        let sink_impl = Arc::new(CountingSink::new());
1002        let sink: Arc<dyn ChannelSink> = sink_impl.clone();
1003
1004        let mut tx = Tx::<u32>::unbound();
1005        tx.bind(sink);
1006        drop(tx);
1007
1008        assert_eq!(sink_impl.close_on_drop_calls.load(Ordering::Acquire), 1);
1009    }
1010
1011    #[test]
1012    fn tx_drop_emits_close_on_drop_for_paired_core_binding() {
1013        let sink_impl = Arc::new(CountingSink::new());
1014        let sink: Arc<dyn ChannelSink> = sink_impl.clone();
1015
1016        let (tx, _rx) = channel::<u32>();
1017        let core = tx.core.inner.as_ref().expect("paired tx should have core");
1018        core.set_binding(ChannelBinding::Sink(BoundChannelSink {
1019            sink,
1020            liveness: None,
1021        }));
1022        drop(tx);
1023
1024        assert_eq!(sink_impl.close_on_drop_calls.load(Ordering::Acquire), 1);
1025    }
1026
1027    #[tokio::test]
1028    async fn rx_recv_returns_unbound_when_not_bound() {
1029        let mut rx = Rx::<u32>::unbound();
1030        let err = match rx.recv().await {
1031            Ok(_) => panic!("unbound rx should fail"),
1032            Err(err) => err,
1033        };
1034        assert!(matches!(err, RxError::Unbound));
1035    }
1036
1037    #[tokio::test]
1038    async fn rx_recv_returns_none_on_close() {
1039        let (tx, rx_inner) = mpsc::channel("vox_types.channel.test.rx1", 1);
1040        let mut rx = Rx::<u32>::unbound();
1041        rx.bind(rx_inner);
1042
1043        let close = SelfRef::owning(
1044            Backing::Boxed(Box::<[u8]>::default()),
1045            ChannelClose {
1046                metadata: Metadata::default(),
1047            },
1048        );
1049        tx.send(IncomingChannelMessage::Close(close))
1050            .await
1051            .expect("send close");
1052
1053        assert!(rx.recv().await.expect("recv should succeed").is_none());
1054    }
1055
1056    #[tokio::test]
1057    async fn rx_recv_returns_reset_error() {
1058        let (tx, rx_inner) = mpsc::channel("vox_types.channel.test.rx2", 1);
1059        let mut rx = Rx::<u32>::unbound();
1060        rx.bind(rx_inner);
1061
1062        let reset = SelfRef::owning(
1063            Backing::Boxed(Box::<[u8]>::default()),
1064            ChannelReset {
1065                metadata: Metadata::default(),
1066            },
1067        );
1068        tx.send(IncomingChannelMessage::Reset(reset))
1069            .await
1070            .expect("send reset");
1071
1072        let err = match rx.recv().await {
1073            Ok(_) => panic!("reset should be surfaced as error"),
1074            Err(err) => err,
1075        };
1076        assert!(matches!(err, RxError::Reset));
1077    }
1078
1079    #[tokio::test]
1080    async fn rx_recv_rejects_outgoing_payload_variant_as_protocol_error() {
1081        static VALUE: u32 = 42;
1082
1083        let (tx, rx_inner) = mpsc::channel("vox_types.channel.test.rx3", 1);
1084        let mut rx = Rx::<u32>::unbound();
1085        rx.bind(rx_inner);
1086
1087        let item = SelfRef::owning(
1088            Backing::Boxed(Box::<[u8]>::default()),
1089            ChannelItem {
1090                item: Payload::outgoing(&VALUE),
1091            },
1092        );
1093        tx.send(IncomingChannelMessage::Item(item))
1094            .await
1095            .expect("send item");
1096
1097        let err = match rx.recv().await {
1098            Ok(_) => panic!("outgoing payload should be protocol error"),
1099            Err(err) => err,
1100        };
1101        assert!(matches!(err, RxError::Protocol(_)));
1102    }
1103
1104    #[tokio::test]
1105    async fn rx_recv_notifies_replenisher_after_consuming_an_item() {
1106        let (tx, rx_inner) = mpsc::channel("vox_types.channel.test.rx4", 1);
1107        let replenisher = Arc::new(CountingReplenisher::new());
1108        let mut rx = Rx::<u32>::unbound();
1109        rx.bind(rx_inner);
1110        rx.replenisher.inner = Some(replenisher.clone());
1111
1112        let encoded = vox_postcard::to_vec(&123_u32).expect("serialize test item");
1113        let item = SelfRef::owning(
1114            Backing::Boxed(Box::<[u8]>::default()),
1115            ChannelItem {
1116                item: Payload::PostcardBytes(Box::leak(encoded.into_boxed_slice())),
1117            },
1118        );
1119        tx.send(IncomingChannelMessage::Item(item))
1120            .await
1121            .expect("send item");
1122
1123        let value = rx
1124            .recv()
1125            .await
1126            .expect("recv should succeed")
1127            .expect("expected item");
1128        assert_eq!(*value, 123_u32);
1129        assert_eq!(replenisher.calls.load(Ordering::Acquire), 1);
1130    }
1131
1132    #[tokio::test]
1133    async fn rx_recv_logical_receiver_decodes_items_and_notifies_replenisher() {
1134        let (tx, rx_inner) = mpsc::channel("vox_types.channel.test.rx5", 1);
1135        let replenisher = Arc::new(CountingReplenisher::new());
1136        let core = Arc::new(ChannelCore::new());
1137        core.bind_retryable_receiver(BoundChannelReceiver {
1138            receiver: rx_inner,
1139            liveness: None,
1140            replenisher: Some(replenisher.clone()),
1141        });
1142
1143        let mut rx = Rx::<u32>::paired(core);
1144
1145        let encoded = vox_postcard::to_vec(&321_u32).expect("serialize test item");
1146        let item = SelfRef::owning(
1147            Backing::Boxed(Box::<[u8]>::default()),
1148            ChannelItem {
1149                item: Payload::PostcardBytes(Box::leak(encoded.into_boxed_slice())),
1150            },
1151        );
1152        tx.send(IncomingChannelMessage::Item(item))
1153            .await
1154            .expect("send item");
1155
1156        let value = rx
1157            .recv()
1158            .await
1159            .expect("recv should succeed")
1160            .expect("expected item");
1161        assert_eq!(*value, 321_u32);
1162        assert_eq!(replenisher.calls.load(Ordering::Acquire), 1);
1163    }
1164
1165    // ========================================================================
1166    // Channel binding through ser/deser
1167    // ========================================================================
1168
1169    /// A test binder that tracks allocations and bindings.
1170    struct TestBinder {
1171        next_id: std::sync::Mutex<u64>,
1172    }
1173
1174    impl TestBinder {
1175        fn new() -> Self {
1176            Self {
1177                next_id: std::sync::Mutex::new(100),
1178            }
1179        }
1180
1181        fn alloc_id(&self) -> ChannelId {
1182            let mut guard = self.next_id.lock().unwrap();
1183            let id = *guard;
1184            *guard += 2;
1185            ChannelId(id)
1186        }
1187    }
1188
1189    impl ChannelBinder for TestBinder {
1190        fn create_tx(&self) -> (ChannelId, Arc<dyn ChannelSink>) {
1191            (self.alloc_id(), Arc::new(CountingSink::new()))
1192        }
1193
1194        fn create_rx(&self) -> (ChannelId, BoundChannelReceiver) {
1195            let (tx, rx) = mpsc::channel("vox_types.channel.test.bind_retryable1", 8);
1196            // Keep the sender alive by leaking it — test only.
1197            std::mem::forget(tx);
1198            (
1199                self.alloc_id(),
1200                BoundChannelReceiver {
1201                    receiver: rx,
1202                    liveness: None,
1203                    replenisher: None,
1204                },
1205            )
1206        }
1207
1208        fn bind_tx(&self, _channel_id: ChannelId) -> Arc<dyn ChannelSink> {
1209            Arc::new(CountingSink::new())
1210        }
1211
1212        fn register_rx(&self, _channel_id: ChannelId) -> BoundChannelReceiver {
1213            let (_tx, rx) = mpsc::channel("vox_types.channel.test.bind_retryable2", 8);
1214            BoundChannelReceiver {
1215                receiver: rx,
1216                liveness: None,
1217                replenisher: None,
1218            }
1219        }
1220    }
1221
1222    // Case 1: Caller passes Tx in args, keeps paired Rx.
1223    // Serializing the Tx allocates a channel ID via create_rx() and stores
1224    // the receiver in the shared logical core so the kept Rx can survive retries.
1225    #[tokio::test]
1226    async fn case1_serialize_tx_allocates_and_binds_paired_rx() {
1227        use facet::Facet;
1228
1229        #[derive(Facet)]
1230        struct Args {
1231            data: u32,
1232            tx: Tx<u32>,
1233        }
1234
1235        let (tx, rx) = channel::<u32>();
1236        let args = Args { data: 42, tx };
1237
1238        let binder = TestBinder::new();
1239        let bytes =
1240            with_channel_binder(&binder, || vox_postcard::to_vec(&args).expect("serialize"));
1241
1242        // The channel ID should be in the serialized bytes (after the u32 data field).
1243        assert!(!bytes.is_empty());
1244
1245        // The kept Rx should now have a receiver binding in the shared core.
1246        assert!(
1247            rx.core.inner.is_some(),
1248            "paired Rx should have a shared core"
1249        );
1250        let core = rx.core.inner.as_ref().unwrap();
1251        assert!(
1252            core.take_logical_receiver().is_some(),
1253            "core should have a logical receiver binding from create_rx()"
1254        );
1255    }
1256
1257    // Case 2: Caller passes Rx in args, keeps paired Tx.
1258    // Serializing the Rx allocates a channel ID via create_tx() and stores
1259    // the sink in the shared core so the kept Tx can use it.
1260    #[test]
1261    fn case2_serialize_rx_allocates_and_binds_paired_tx() {
1262        use facet::Facet;
1263
1264        #[derive(Facet)]
1265        struct Args {
1266            data: u32,
1267            rx: Rx<u32>,
1268        }
1269
1270        let (tx, rx) = channel::<u32>();
1271        let args = Args { data: 42, rx };
1272
1273        let binder = TestBinder::new();
1274        let bytes =
1275            with_channel_binder(&binder, || vox_postcard::to_vec(&args).expect("serialize"));
1276
1277        assert!(!bytes.is_empty());
1278
1279        // The kept Tx should now have a sink binding in the shared core.
1280        assert!(tx.core.inner.is_some());
1281        let core = tx.core.inner.as_ref().unwrap();
1282        assert!(
1283            core.get_sink().is_some(),
1284            "core should have a Sink binding from create_tx()"
1285        );
1286    }
1287
1288    // Case 3: Callee deserializes Tx from args.
1289    // The Tx is bound directly via bind_tx() during deserialization.
1290    #[test]
1291    fn case3_deserialize_tx_binds_via_binder() {
1292        use facet::Facet;
1293
1294        #[derive(Facet)]
1295        struct Args {
1296            data: u32,
1297            tx: Tx<u32>,
1298        }
1299
1300        // Simulate wire bytes: a u32 (42) followed by a channel ID (varint 7).
1301        let mut bytes = vox_postcard::to_vec(&42_u32).unwrap();
1302        bytes.extend_from_slice(&vox_postcard::to_vec(&ChannelId(7)).unwrap());
1303
1304        let binder = TestBinder::new();
1305        let args: Args = with_channel_binder(&binder, || {
1306            vox_postcard::from_slice(&bytes).expect("deserialize")
1307        });
1308
1309        assert_eq!(args.data, 42);
1310        assert_eq!(args.tx.channel_id, ChannelId(7));
1311        assert!(
1312            args.tx.is_bound(),
1313            "deserialized Tx should be bound via bind_tx()"
1314        );
1315    }
1316
1317    // Case 4: Callee deserializes Rx from args.
1318    // The Rx is bound directly via register_rx() during deserialization.
1319    #[test]
1320    fn case4_deserialize_rx_binds_via_binder() {
1321        use facet::Facet;
1322
1323        #[derive(Facet)]
1324        struct Args {
1325            data: u32,
1326            rx: Rx<u32>,
1327        }
1328
1329        // Simulate wire bytes: a u32 (42) followed by a channel ID (varint 7).
1330        let mut bytes = vox_postcard::to_vec(&42_u32).unwrap();
1331        bytes.extend_from_slice(&vox_postcard::to_vec(&ChannelId(7)).unwrap());
1332
1333        let binder = TestBinder::new();
1334        let args: Args = with_channel_binder(&binder, || {
1335            vox_postcard::from_slice(&bytes).expect("deserialize")
1336        });
1337
1338        assert_eq!(args.data, 42);
1339        assert_eq!(args.rx.channel_id, ChannelId(7));
1340        assert!(
1341            args.rx.is_bound(),
1342            "deserialized Rx should be bound via register_rx()"
1343        );
1344    }
1345
1346    // Round-trip: serialize with caller binder, deserialize with callee binder.
1347    // Verifies the channel ID allocated during serialization appears in the
1348    // deserialized handle.
1349    #[test]
1350    fn channel_id_round_trips_through_ser_deser() {
1351        use facet::Facet;
1352
1353        #[derive(Facet)]
1354        struct Args {
1355            tx: Tx<u32>,
1356        }
1357
1358        let (tx, _rx) = channel::<u32>();
1359        let args = Args { tx };
1360
1361        let caller_binder = TestBinder::new();
1362        let bytes = with_channel_binder(&caller_binder, || {
1363            vox_postcard::to_vec(&args).expect("serialize")
1364        });
1365
1366        let callee_binder = TestBinder::new();
1367        let deserialized: Args = with_channel_binder(&callee_binder, || {
1368            vox_postcard::from_slice(&bytes).expect("deserialize")
1369        });
1370
1371        // The caller binder starts at ID 100, so the deserialized Tx should have that ID.
1372        assert_eq!(deserialized.tx.channel_id, ChannelId(100));
1373        assert!(deserialized.tx.is_bound());
1374    }
1375}