Skip to main content

vox_types/
channel.rs

1use std::marker::PhantomData;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::sync::Mutex;
5use std::sync::atomic::{AtomicBool, Ordering};
6
7use facet::Facet;
8use facet_core::PtrConst;
9use moire::sync::{Notify, Semaphore, mpsc};
10
11use crate::ChannelId;
12use crate::{Backing, ChannelClose, ChannelItem, ChannelReset, Metadata, Payload, SelfRef};
13
14// ---------------------------------------------------------------------------
15// Thread-local channel binder — set during deserialization so TryFrom impls
16// can bind channels immediately.
17// ---------------------------------------------------------------------------
18
19std::thread_local! {
20    static CHANNEL_BINDER: std::cell::RefCell<Option<&'static dyn ChannelBinder>> =
21        const { std::cell::RefCell::new(None) };
22}
23
24/// Set the thread-local channel binder for the duration of `f`.
25///
26/// Any `Tx<T>` or `Rx<T>` deserialized (via `TryFrom<ChannelId>`) during `f`
27/// will be bound through this binder.
28pub fn with_channel_binder<R>(binder: &dyn ChannelBinder, f: impl FnOnce() -> R) -> R {
29    let _guard = set_channel_binder(binder);
30    f()
31}
32
33/// Set the thread-local channel binder, returning a guard that restores
34/// the previous value on drop.
35///
36/// Prefer this over [`with_channel_binder`] when the code that runs under
37/// the binder needs to return borrowed data (closures can't return borrows
38/// from captures).
39pub fn set_channel_binder(binder: &dyn ChannelBinder) -> ChannelBinderGuard<'_> {
40    // SAFETY: we restore the previous value on drop (via ChannelBinderGuard),
41    // so the binder reference doesn't escape the guard's lifetime.
42    #[allow(unsafe_code)]
43    let static_ref: &'static dyn ChannelBinder = unsafe { std::mem::transmute(binder) };
44    let prev = CHANNEL_BINDER.with(|cell| cell.borrow_mut().replace(static_ref));
45    ChannelBinderGuard {
46        prev,
47        _lifetime: std::marker::PhantomData,
48    }
49}
50
51/// RAII guard that restores the previous thread-local channel binder on drop.
52pub struct ChannelBinderGuard<'a> {
53    prev: Option<&'static dyn ChannelBinder>,
54    _lifetime: std::marker::PhantomData<&'a dyn ChannelBinder>,
55}
56
57impl Drop for ChannelBinderGuard<'_> {
58    fn drop(&mut self) {
59        CHANNEL_BINDER.with(|cell| {
60            *cell.borrow_mut() = self.prev.take();
61        });
62    }
63}
64
65// r[impl rpc.channel.pair]
66/// The binding stored in a channel core — either a sink or a receiver, never both.
67pub enum ChannelBinding {
68    Sink(BoundChannelSink),
69    Receiver(BoundChannelReceiver),
70}
71
72pub trait ChannelLiveness: crate::MaybeSend + crate::MaybeSync + 'static {}
73
74impl<T: crate::MaybeSend + crate::MaybeSync + 'static> ChannelLiveness for T {}
75
76pub type ChannelLivenessHandle = Arc<dyn ChannelLiveness>;
77
78pub trait ChannelCreditReplenisher: crate::MaybeSend + crate::MaybeSync + 'static {
79    fn on_item_consumed(&self);
80}
81
82pub type ChannelCreditReplenisherHandle = Arc<dyn ChannelCreditReplenisher>;
83
84#[derive(Clone)]
85pub struct BoundChannelSink {
86    pub sink: Arc<dyn ChannelSink>,
87    pub liveness: Option<ChannelLivenessHandle>,
88}
89
90pub struct BoundChannelReceiver {
91    pub receiver: mpsc::Receiver<IncomingChannelMessage>,
92    pub liveness: Option<ChannelLivenessHandle>,
93    pub replenisher: Option<ChannelCreditReplenisherHandle>,
94}
95
96struct LogicalReceiverState {
97    generation: u64,
98    liveness: Option<ChannelLivenessHandle>,
99    sender: Option<mpsc::Sender<LogicalIncomingChannelMessage>>,
100    receiver: Option<mpsc::Receiver<LogicalIncomingChannelMessage>>,
101}
102
103// r[impl rpc.channel.pair]
104/// Shared state between a `Tx`/`Rx` pair created by `channel()`.
105///
106/// Contains a `Mutex<Option<ChannelBinding>>` that is written once during
107/// binding and read/taken by the paired handle. The mutex is only locked
108/// during binding (once) and on first use by the paired handle (once).
109pub struct ChannelCore {
110    binding: Mutex<Option<ChannelBinding>>,
111    logical_receiver: Mutex<Option<LogicalReceiverState>>,
112    binding_changed: Notify,
113}
114
115impl ChannelCore {
116    fn new() -> Self {
117        Self {
118            binding: Mutex::new(None),
119            logical_receiver: Mutex::new(None),
120            binding_changed: Notify::new("vox_types.channel.binding_changed"),
121        }
122    }
123
124    /// Store or replace a binding in the core.
125    pub fn set_binding(&self, binding: ChannelBinding) {
126        let mut guard = self.binding.lock().expect("channel core mutex poisoned");
127        *guard = Some(binding);
128        self.binding_changed.notify_waiters();
129    }
130
131    /// Clone the sink from the core (for Tx reading the sink).
132    /// Returns None if no sink has been set or if the binding is a Receiver.
133    pub fn get_sink(&self) -> Option<Arc<dyn ChannelSink>> {
134        let guard = self.binding.lock().expect("channel core mutex poisoned");
135        match guard.as_ref() {
136            Some(ChannelBinding::Sink(bound)) => Some(bound.sink.clone()),
137            _ => None,
138        }
139    }
140
141    /// Take the receiver out of the core (for Rx on first recv).
142    /// Returns None if no receiver has been set or if it was already taken.
143    pub fn take_receiver(&self) -> Option<BoundChannelReceiver> {
144        let mut guard = self.binding.lock().expect("channel core mutex poisoned");
145        match guard.take() {
146            Some(ChannelBinding::Receiver(bound)) => Some(bound),
147            other => {
148                // Put it back if it wasn't a receiver
149                *guard = other;
150                None
151            }
152        }
153    }
154
155    pub fn bind_retryable_receiver(self: &Arc<Self>, bound: BoundChannelReceiver) {
156        #[cfg(not(target_arch = "wasm32"))]
157        if tokio::runtime::Handle::try_current().is_err() {
158            self.set_binding(ChannelBinding::Receiver(bound));
159            return;
160        }
161
162        let mut guard = self
163            .logical_receiver
164            .lock()
165            .expect("channel core logical receiver mutex poisoned");
166        let state = guard.get_or_insert_with(|| {
167            let (tx, rx) = mpsc::channel("vox_types.channel.logical_receiver", 64);
168            LogicalReceiverState {
169                generation: 0,
170                liveness: None,
171                sender: Some(tx),
172                receiver: Some(rx),
173            }
174        });
175        state.generation = state.generation.wrapping_add(1);
176        state.liveness = bound.liveness.clone();
177        let generation = state.generation;
178
179        let Some(sender) = state.sender.clone() else {
180            return;
181        };
182
183        self.binding_changed.notify_waiters();
184
185        drop(guard);
186        let core = Arc::clone(self);
187
188        moire::task::spawn(async move {
189            let mut receiver = bound.receiver;
190            let replenisher = bound.replenisher.clone();
191            while let Some(msg) = receiver.recv().await {
192                let is_current_generation = {
193                    let guard = core
194                        .logical_receiver
195                        .lock()
196                        .expect("channel core logical receiver mutex poisoned");
197                    guard
198                        .as_ref()
199                        .map(|state| state.generation == generation)
200                        .unwrap_or(false)
201                };
202                if !is_current_generation {
203                    return;
204                }
205                let forwarded = LogicalIncomingChannelMessage {
206                    msg,
207                    replenisher: replenisher.clone(),
208                };
209                if sender.send(forwarded).await.is_err() {
210                    return;
211                }
212            }
213        });
214    }
215
216    pub fn take_logical_receiver(
217        &self,
218    ) -> Option<(
219        mpsc::Receiver<LogicalIncomingChannelMessage>,
220        Option<ChannelLivenessHandle>,
221    )> {
222        self.logical_receiver
223            .lock()
224            .expect("channel core logical receiver mutex poisoned")
225            .as_mut()
226            .and_then(|state| {
227                state
228                    .receiver
229                    .take()
230                    .map(|receiver| (receiver, state.liveness.clone()))
231            })
232    }
233
234    pub fn finish_retry_binding(&self) {
235        let mut guard = self
236            .logical_receiver
237            .lock()
238            .expect("channel core logical receiver mutex poisoned");
239        if let Some(state) = guard.as_mut() {
240            if let Some(sender) = state.sender.as_ref() {
241                let close = SelfRef::owning(
242                    Backing::Boxed(Box::<[u8]>::default()),
243                    ChannelClose {
244                        metadata: Metadata::default(),
245                    },
246                );
247                let _ = sender.try_send(LogicalIncomingChannelMessage {
248                    msg: IncomingChannelMessage::Close(close),
249                    replenisher: None,
250                });
251            }
252            state.sender.take();
253        }
254        *guard = None;
255        let mut guard = self.binding.lock().expect("channel core mutex poisoned");
256        *guard = None;
257        self.binding_changed.notify_waiters();
258    }
259}
260
261/// Slot for the shared channel core, accessible via facet reflection.
262#[derive(Facet)]
263#[facet(opaque)]
264pub(crate) struct CoreSlot {
265    pub(crate) inner: Option<Arc<ChannelCore>>,
266}
267
268impl CoreSlot {
269    pub(crate) fn empty() -> Self {
270        Self { inner: None }
271    }
272}
273
274// r[impl rpc.channel.pair]
275/// Create a channel pair with shared state.
276///
277/// Both ends hold an `Arc` reference to the same `ChannelCore`. The framework
278/// binds the handle that appears in args or return values, and the paired
279/// handle reads or takes the binding from the shared core.
280pub fn channel<T>() -> (Tx<T>, Rx<T>) {
281    let core = Arc::new(ChannelCore::new());
282    (Tx::paired(core.clone()), Rx::paired(core))
283}
284
285/// Runtime sink implemented by the session driver.
286///
287/// The contract is strict: successful completion means the item has gone
288/// through the conduit to the link commit boundary.
289pub trait ChannelSink: crate::MaybeSend + crate::MaybeSync + 'static {
290    fn send_payload<'payload>(
291        &self,
292        payload: Payload<'payload>,
293    ) -> Pin<Box<dyn crate::MaybeSendFuture<Output = Result<(), TxError>> + 'payload>>;
294
295    fn close_channel(
296        &self,
297        metadata: Metadata,
298    ) -> Pin<Box<dyn crate::MaybeSendFuture<Output = Result<(), TxError>> + 'static>>;
299
300    /// Synchronous drop-time close signal.
301    ///
302    /// This is used by `Tx::drop` to notify the runtime immediately without
303    /// spawning detached tasks. Implementations should enqueue a close intent
304    /// to their runtime/driver if possible.
305    fn close_channel_on_drop(&self) {}
306}
307
308// r[impl rpc.flow-control.credit]
309// r[impl rpc.flow-control.credit.exhaustion]
310/// A [`ChannelSink`] wrapper that enforces credit-based flow control.
311///
312/// Each `send_payload` acquires one permit from the semaphore, blocking if
313/// credit is zero. The semaphore is shared with the driver so that incoming
314/// `GrantCredit` messages can add permits via [`CreditSink::credit`].
315pub struct CreditSink<S: ChannelSink> {
316    inner: S,
317    credit: Arc<Semaphore>,
318}
319
320impl<S: ChannelSink> CreditSink<S> {
321    // r[impl rpc.flow-control.credit.initial]
322    // r[impl rpc.flow-control.credit.initial.zero]
323    /// Wrap `inner` with `initial_credit` permits (the const generic `N`).
324    pub fn new(inner: S, initial_credit: u32) -> Self {
325        Self {
326            inner,
327            credit: Arc::new(Semaphore::new(
328                "vox_types.channel.credit",
329                initial_credit as usize,
330            )),
331        }
332    }
333
334    /// Returns the credit semaphore. The driver holds a clone so
335    /// `GrantCredit` messages can call `add_permits`.
336    pub fn credit(&self) -> &Arc<Semaphore> {
337        &self.credit
338    }
339}
340
341impl<S: ChannelSink> ChannelSink for CreditSink<S> {
342    fn send_payload<'payload>(
343        &self,
344        payload: Payload<'payload>,
345    ) -> Pin<Box<dyn crate::MaybeSendFuture<Output = Result<(), TxError>> + 'payload>> {
346        let credit = self.credit.clone();
347        let fut = self.inner.send_payload(payload);
348        Box::pin(async move {
349            let permit = credit
350                .acquire_owned()
351                .await
352                .map_err(|_| TxError::Transport("channel credit semaphore closed".into()))?;
353            std::mem::forget(permit);
354            fut.await
355        })
356    }
357
358    fn close_channel(
359        &self,
360        metadata: Metadata,
361    ) -> Pin<Box<dyn crate::MaybeSendFuture<Output = Result<(), TxError>> + 'static>> {
362        // Close does not consume credit — it's a control message.
363        self.inner.close_channel(metadata)
364    }
365
366    fn close_channel_on_drop(&self) {
367        self.inner.close_channel_on_drop();
368    }
369}
370
371/// Message delivered to an `Rx` by the driver.
372pub enum IncomingChannelMessage {
373    Item(SelfRef<ChannelItem<'static>>),
374    Close(SelfRef<ChannelClose<'static>>),
375    Reset(SelfRef<ChannelReset<'static>>),
376}
377
378pub struct LogicalIncomingChannelMessage {
379    pub msg: IncomingChannelMessage,
380    pub replenisher: Option<ChannelCreditReplenisherHandle>,
381}
382
383/// Sender-side runtime slot.
384#[derive(Facet)]
385#[facet(opaque)]
386pub(crate) struct SinkSlot {
387    pub(crate) inner: Option<Arc<dyn ChannelSink>>,
388}
389
390impl SinkSlot {
391    pub(crate) fn empty() -> Self {
392        Self { inner: None }
393    }
394}
395
396/// Opaque liveness retention slot for bound channel handles.
397#[derive(Facet)]
398#[facet(opaque)]
399pub(crate) struct LivenessSlot {
400    pub(crate) inner: Option<ChannelLivenessHandle>,
401}
402
403impl LivenessSlot {
404    pub(crate) fn empty() -> Self {
405        Self { inner: None }
406    }
407}
408
409/// Receiver-side runtime slot.
410#[derive(Facet)]
411#[facet(opaque)]
412pub(crate) struct ReceiverSlot {
413    pub(crate) inner: Option<mpsc::Receiver<IncomingChannelMessage>>,
414}
415
416impl ReceiverSlot {
417    pub(crate) fn empty() -> Self {
418        Self { inner: None }
419    }
420}
421
422#[derive(Facet)]
423#[facet(opaque)]
424pub(crate) struct LogicalReceiverSlot {
425    pub(crate) inner: Option<mpsc::Receiver<LogicalIncomingChannelMessage>>,
426}
427
428impl LogicalReceiverSlot {
429    pub(crate) fn empty() -> Self {
430        Self { inner: None }
431    }
432}
433
434/// Receiver-side credit replenishment slot.
435#[derive(Facet)]
436#[facet(opaque)]
437pub(crate) struct ReplenisherSlot {
438    pub(crate) inner: Option<ChannelCreditReplenisherHandle>,
439}
440
441impl ReplenisherSlot {
442    pub(crate) fn empty() -> Self {
443        Self { inner: None }
444    }
445}
446
447/// Sender handle: "I send". The holder of a `Tx<T>` sends items of type `T`.
448///
449/// In method args, the handler holds it (handler sends → caller).
450///
451/// Wire encoding is always unit (`()`), with channel IDs carried exclusively
452/// in `Message::Request.channels`.
453// r[impl rpc.channel]
454// r[impl rpc.channel.direction]
455// r[impl rpc.channel.payload-encoding]
456#[derive(Facet)]
457#[facet(proxy = crate::ChannelId)]
458pub struct Tx<T> {
459    pub(crate) channel_id: ChannelId,
460    pub(crate) sink: SinkSlot,
461    pub(crate) core: CoreSlot,
462    pub(crate) liveness: LivenessSlot,
463    #[facet(opaque)]
464    closed: AtomicBool,
465    #[facet(opaque)]
466    _marker: PhantomData<T>,
467}
468
469impl<T> Tx<T> {
470    /// Create a standalone unbound Tx (used by deserialization).
471    pub fn unbound() -> Self {
472        Self {
473            channel_id: ChannelId::RESERVED,
474            sink: SinkSlot::empty(),
475            core: CoreSlot::empty(),
476            liveness: LivenessSlot::empty(),
477            closed: AtomicBool::new(false),
478            _marker: PhantomData,
479        }
480    }
481
482    /// Create a Tx that is part of a `channel()` pair.
483    fn paired(core: Arc<ChannelCore>) -> Self {
484        Self {
485            channel_id: ChannelId::RESERVED,
486            sink: SinkSlot::empty(),
487            core: CoreSlot { inner: Some(core) },
488            liveness: LivenessSlot::empty(),
489            closed: AtomicBool::new(false),
490            _marker: PhantomData,
491        }
492    }
493
494    pub fn is_bound(&self) -> bool {
495        if self.sink.inner.is_some() {
496            return true;
497        }
498        if let Some(core) = &self.core.inner {
499            return core.get_sink().is_some();
500        }
501        false
502    }
503
504    /// Check if this Tx is part of a channel() pair (has a shared core).
505    pub fn has_core(&self) -> bool {
506        self.core.inner.is_some()
507    }
508
509    // r[impl rpc.channel.pair.tx-read]
510    fn resolve_sink_now(&self) -> Option<Arc<dyn ChannelSink>> {
511        // Fast path: local slot (standalone/callee-side handle)
512        if let Some(sink) = &self.sink.inner {
513            return Some(sink.clone());
514        }
515        // Slow path: read from shared core (paired handle)
516        if let Some(core) = &self.core.inner
517            && let Some(sink) = core.get_sink()
518        {
519            return Some(sink);
520        }
521        None
522    }
523
524    pub async fn send<'value>(&self, value: T) -> Result<(), TxError>
525    where
526        T: Facet<'value>,
527    {
528        let sink = if let Some(sink) = self.resolve_sink_now() {
529            sink
530        } else if let Some(core) = &self.core.inner {
531            loop {
532                let notified = core.binding_changed.notified();
533                if let Some(sink) = self.resolve_sink_now() {
534                    break sink;
535                }
536                notified.await;
537            }
538        } else {
539            return Err(TxError::Unbound);
540        };
541        let ptr = PtrConst::new((&value as *const T).cast::<u8>());
542        // SAFETY: `value` is explicitly dropped only after `await`, so the pointer
543        // remains valid for the whole send operation.
544        let payload = unsafe { Payload::outgoing_unchecked(ptr, T::SHAPE) };
545        let result = sink.send_payload(payload).await;
546        drop(value);
547        result
548    }
549
550    // r[impl rpc.channel.lifecycle]
551    pub async fn close<'value>(&self, metadata: Metadata<'value>) -> Result<(), TxError> {
552        self.closed.store(true, Ordering::Release);
553        let sink = if let Some(sink) = self.resolve_sink_now() {
554            sink
555        } else if let Some(core) = &self.core.inner {
556            loop {
557                let notified = core.binding_changed.notified();
558                if let Some(sink) = self.resolve_sink_now() {
559                    break sink;
560                }
561                notified.await;
562            }
563        } else {
564            return Err(TxError::Unbound);
565        };
566        sink.close_channel(metadata).await
567    }
568
569    #[doc(hidden)]
570    pub fn bind(&mut self, sink: Arc<dyn ChannelSink>) {
571        self.bind_with_liveness(sink, None);
572    }
573
574    #[doc(hidden)]
575    pub fn bind_with_liveness(
576        &mut self,
577        sink: Arc<dyn ChannelSink>,
578        liveness: Option<ChannelLivenessHandle>,
579    ) {
580        self.sink.inner = Some(sink);
581        self.liveness.inner = liveness;
582    }
583
584    #[doc(hidden)]
585    pub fn finish_retry_binding(&self) {
586        if let Some(core) = &self.core.inner {
587            core.finish_retry_binding();
588        }
589    }
590}
591
592impl<T> Drop for Tx<T> {
593    fn drop(&mut self) {
594        if self.closed.swap(true, Ordering::AcqRel) {
595            return;
596        }
597
598        let sink = if let Some(sink) = &self.sink.inner {
599            Some(sink.clone())
600        } else if let Some(core) = &self.core.inner {
601            core.get_sink()
602        } else {
603            None
604        };
605
606        let Some(sink) = sink else {
607            return;
608        };
609
610        // Synchronous signal into the runtime/driver; no detached async work here.
611        sink.close_channel_on_drop();
612    }
613}
614
615impl<T> TryFrom<&Tx<T>> for ChannelId {
616    type Error = String;
617
618    fn try_from(value: &Tx<T>) -> Result<Self, Self::Error> {
619        // Case 1: Caller passes Tx in args (callee sends, caller receives).
620        // Allocate a channel ID and store the receiver binding in the shared
621        // core so the caller's paired Rx can pick it up.
622        CHANNEL_BINDER.with(|cell| {
623            let borrow = cell.borrow();
624            let Some(binder) = *borrow else {
625                return Err("serializing Tx requires an active ChannelBinder".to_string());
626            };
627            let (channel_id, bound) = binder.create_rx();
628            if let Some(core) = &value.core.inner {
629                core.bind_retryable_receiver(bound);
630            }
631            Ok(channel_id)
632        })
633    }
634}
635
636impl<T> TryFrom<ChannelId> for Tx<T> {
637    type Error = String;
638
639    fn try_from(channel_id: ChannelId) -> Result<Self, Self::Error> {
640        let mut tx = Self::unbound();
641        tx.channel_id = channel_id;
642
643        CHANNEL_BINDER.with(|cell| {
644            let Some(binder) = *cell.borrow() else {
645                return Err("deserializing Tx requires an active ChannelBinder".to_string());
646            };
647            let sink = binder.bind_tx(channel_id);
648            let liveness = binder.channel_liveness();
649            tx.bind_with_liveness(sink, liveness);
650            Ok(())
651        })?;
652
653        Ok(tx)
654    }
655}
656
657/// Error when sending on a `Tx`.
658#[derive(Debug)]
659pub enum TxError {
660    Unbound,
661    Transport(String),
662}
663
664impl std::fmt::Display for TxError {
665    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
666        match self {
667            Self::Unbound => write!(f, "channel is not bound"),
668            Self::Transport(msg) => write!(f, "transport error: {msg}"),
669        }
670    }
671}
672
673impl std::error::Error for TxError {}
674
675/// Receiver handle: "I receive". The holder of an `Rx<T>` receives items of type `T`.
676///
677/// In method args, the handler holds it (handler receives ← caller).
678///
679/// Channel IDs are serialized inline in the postcard payload.
680#[derive(Facet)]
681#[facet(proxy = crate::ChannelId)]
682pub struct Rx<T> {
683    pub(crate) channel_id: ChannelId,
684    pub(crate) receiver: ReceiverSlot,
685    pub(crate) logical_receiver: LogicalReceiverSlot,
686    pub(crate) core: CoreSlot,
687    pub(crate) liveness: LivenessSlot,
688    pub(crate) replenisher: ReplenisherSlot,
689    #[facet(opaque)]
690    _marker: PhantomData<T>,
691}
692
693impl<T> Rx<T> {
694    /// Create a standalone unbound Rx (used by deserialization).
695    pub fn unbound() -> Self {
696        Self {
697            channel_id: ChannelId::RESERVED,
698            receiver: ReceiverSlot::empty(),
699            logical_receiver: LogicalReceiverSlot::empty(),
700            core: CoreSlot::empty(),
701            liveness: LivenessSlot::empty(),
702            replenisher: ReplenisherSlot::empty(),
703            _marker: PhantomData,
704        }
705    }
706
707    /// Create an Rx that is part of a `channel()` pair.
708    fn paired(core: Arc<ChannelCore>) -> Self {
709        Self {
710            channel_id: ChannelId::RESERVED,
711            receiver: ReceiverSlot::empty(),
712            logical_receiver: LogicalReceiverSlot::empty(),
713            core: CoreSlot { inner: Some(core) },
714            liveness: LivenessSlot::empty(),
715            replenisher: ReplenisherSlot::empty(),
716            _marker: PhantomData,
717        }
718    }
719
720    pub fn is_bound(&self) -> bool {
721        self.receiver.inner.is_some()
722    }
723
724    /// Check if this Rx is part of a channel() pair (has a shared core).
725    pub fn has_core(&self) -> bool {
726        self.core.inner.is_some()
727    }
728
729    // r[impl rpc.channel.pair.rx-take]
730    pub async fn recv(&mut self) -> Result<Option<SelfRef<T>>, RxError>
731    where
732        T: Facet<'static>,
733    {
734        loop {
735            if self.logical_receiver.inner.is_none()
736                && let Some(core) = &self.core.inner
737                && let Some((receiver, liveness)) = core.take_logical_receiver()
738            {
739                self.logical_receiver.inner = Some(receiver);
740                self.liveness.inner = liveness;
741            }
742
743            if let Some(receiver) = self.logical_receiver.inner.as_mut() {
744                match receiver.recv().await {
745                    Some(LogicalIncomingChannelMessage {
746                        msg: IncomingChannelMessage::Close(_),
747                        ..
748                    })
749                    | None => return Ok(None),
750                    Some(LogicalIncomingChannelMessage {
751                        msg: IncomingChannelMessage::Reset(_),
752                        ..
753                    }) => return Err(RxError::Reset),
754                    Some(LogicalIncomingChannelMessage {
755                        msg: IncomingChannelMessage::Item(msg),
756                        replenisher,
757                    }) => {
758                        let value = msg
759                            .try_repack(|item, _backing_bytes| {
760                                let Payload::PostcardBytes(bytes) = item.item else {
761                                    return Err(RxError::Protocol(
762                                        "incoming channel item payload was not Incoming".into(),
763                                    ));
764                                };
765                                vox_postcard::from_slice_borrowed(bytes)
766                                    .map_err(RxError::Deserialize)
767                            })
768                            .map(Some);
769                        if value.is_ok()
770                            && let Some(replenisher) = replenisher.as_ref()
771                        {
772                            replenisher.on_item_consumed();
773                        }
774                        return value;
775                    }
776                }
777            }
778
779            if self.receiver.inner.is_none()
780                && let Some(core) = &self.core.inner
781                && let Some(bound) = core.take_receiver()
782            {
783                self.receiver.inner = Some(bound.receiver);
784                self.liveness.inner = bound.liveness;
785                self.replenisher.inner = bound.replenisher;
786            }
787
788            if let Some(receiver) = self.receiver.inner.as_mut() {
789                return match receiver.recv().await {
790                    Some(IncomingChannelMessage::Close(_)) | None => Ok(None),
791                    Some(IncomingChannelMessage::Reset(_)) => Err(RxError::Reset),
792                    Some(IncomingChannelMessage::Item(msg)) => {
793                        let value = msg
794                            .try_repack(|item, _backing_bytes| {
795                                let Payload::PostcardBytes(bytes) = item.item else {
796                                    return Err(RxError::Protocol(
797                                        "incoming channel item payload was not Incoming".into(),
798                                    ));
799                                };
800                                vox_postcard::from_slice_borrowed(bytes)
801                                    .map_err(RxError::Deserialize)
802                            })
803                            .map(Some);
804                        if value.is_ok()
805                            && let Some(replenisher) = &self.replenisher.inner
806                        {
807                            replenisher.on_item_consumed();
808                        }
809                        value
810                    }
811                };
812            }
813
814            let Some(core) = &self.core.inner else {
815                return Err(RxError::Unbound);
816            };
817            core.binding_changed.notified().await;
818        }
819    }
820
821    #[doc(hidden)]
822    pub fn bind(&mut self, receiver: mpsc::Receiver<IncomingChannelMessage>) {
823        self.bind_with_liveness(receiver, None);
824    }
825
826    #[doc(hidden)]
827    pub fn bind_with_liveness(
828        &mut self,
829        receiver: mpsc::Receiver<IncomingChannelMessage>,
830        liveness: Option<ChannelLivenessHandle>,
831    ) {
832        self.receiver.inner = Some(receiver);
833        self.logical_receiver.inner = None;
834        self.liveness.inner = liveness;
835        self.replenisher.inner = None;
836    }
837}
838
839impl<T> TryFrom<&Rx<T>> for ChannelId {
840    type Error = String;
841
842    fn try_from(value: &Rx<T>) -> Result<Self, Self::Error> {
843        // Case 2: Caller passes Rx in args (callee receives, caller sends).
844        // Allocate a channel ID and store the sink binding in the shared
845        // core so the caller's paired Tx can pick it up.
846        CHANNEL_BINDER.with(|cell| {
847            let borrow = cell.borrow();
848            let Some(binder) = *borrow else {
849                return Err("serializing Rx requires an active ChannelBinder".to_string());
850            };
851            let (channel_id, sink) = binder.create_tx();
852            let liveness = binder.channel_liveness();
853            if let Some(core) = &value.core.inner {
854                core.set_binding(ChannelBinding::Sink(BoundChannelSink { sink, liveness }));
855            }
856            Ok(channel_id)
857        })
858    }
859}
860
861impl<T> TryFrom<ChannelId> for Rx<T> {
862    type Error = String;
863
864    fn try_from(channel_id: ChannelId) -> Result<Self, Self::Error> {
865        let mut rx = Self::unbound();
866        rx.channel_id = channel_id;
867
868        CHANNEL_BINDER.with(|cell| {
869            let Some(binder) = *cell.borrow() else {
870                return Err("deserializing Rx requires an active ChannelBinder".to_string());
871            };
872            let bound = binder.register_rx(channel_id);
873            rx.receiver.inner = Some(bound.receiver);
874            rx.liveness.inner = bound.liveness;
875            rx.replenisher.inner = bound.replenisher;
876            Ok(())
877        })?;
878
879        Ok(rx)
880    }
881}
882
883/// Error when receiving from an `Rx`.
884#[derive(Debug)]
885pub enum RxError {
886    Unbound,
887    Reset,
888    Deserialize(vox_postcard::error::DeserializeError),
889    Protocol(String),
890}
891
892impl std::fmt::Display for RxError {
893    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
894        match self {
895            Self::Unbound => write!(f, "channel is not bound"),
896            Self::Reset => write!(f, "channel reset by peer"),
897            Self::Deserialize(e) => write!(f, "deserialize error: {e}"),
898            Self::Protocol(msg) => write!(f, "protocol error: {msg}"),
899        }
900    }
901}
902
903impl std::error::Error for RxError {}
904
905/// Check if a shape represents a `Tx` channel.
906pub fn is_tx(shape: &facet_core::Shape) -> bool {
907    shape.decl_id == Tx::<()>::SHAPE.decl_id
908}
909
910/// Check if a shape represents an `Rx` channel.
911pub fn is_rx(shape: &facet_core::Shape) -> bool {
912    shape.decl_id == Rx::<()>::SHAPE.decl_id
913}
914
915/// Check if a shape represents any channel type (`Tx` or `Rx`).
916pub fn is_channel(shape: &facet_core::Shape) -> bool {
917    is_tx(shape) || is_rx(shape)
918}
919
920pub trait ChannelBinder: crate::MaybeSend + crate::MaybeSync {
921    /// Allocate a channel ID and create a sink for sending items.
922    ///
923    fn create_tx(&self) -> (ChannelId, Arc<dyn ChannelSink>);
924
925    /// Allocate a channel ID, register it for routing, and return a receiver.
926    fn create_rx(&self) -> (ChannelId, BoundChannelReceiver);
927
928    /// Create a sink for a known channel ID (callee side).
929    ///
930    /// The channel ID comes from `Request.channels`.
931    fn bind_tx(&self, channel_id: ChannelId) -> Arc<dyn ChannelSink>;
932
933    /// Register an inbound channel by ID and return the receiver (callee side).
934    ///
935    /// The channel ID comes from `Request.channels`.
936    fn register_rx(&self, channel_id: ChannelId) -> BoundChannelReceiver;
937
938    /// Optional opaque handle that keeps the underlying session/connection alive
939    /// for the lifetime of any bound channel handle.
940    fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
941        None
942    }
943}
944
945#[cfg(test)]
946mod tests {
947    use super::*;
948    use crate::{Backing, ChannelClose, ChannelItem, ChannelReset, Metadata, SelfRef};
949    use std::sync::atomic::{AtomicUsize, Ordering};
950
951    struct CountingSink {
952        send_calls: AtomicUsize,
953        close_calls: AtomicUsize,
954        close_on_drop_calls: AtomicUsize,
955    }
956
957    impl CountingSink {
958        fn new() -> Self {
959            Self {
960                send_calls: AtomicUsize::new(0),
961                close_calls: AtomicUsize::new(0),
962                close_on_drop_calls: AtomicUsize::new(0),
963            }
964        }
965    }
966
967    impl ChannelSink for CountingSink {
968        fn send_payload<'payload>(
969            &self,
970            _payload: Payload<'payload>,
971        ) -> Pin<Box<dyn crate::MaybeSendFuture<Output = Result<(), TxError>> + 'payload>> {
972            self.send_calls.fetch_add(1, Ordering::AcqRel);
973            Box::pin(async { Ok(()) })
974        }
975
976        fn close_channel(
977            &self,
978            _metadata: Metadata,
979        ) -> Pin<Box<dyn crate::MaybeSendFuture<Output = Result<(), TxError>> + 'static>> {
980            self.close_calls.fetch_add(1, Ordering::AcqRel);
981            Box::pin(async { Ok(()) })
982        }
983
984        fn close_channel_on_drop(&self) {
985            self.close_on_drop_calls.fetch_add(1, Ordering::AcqRel);
986        }
987    }
988
989    struct CountingReplenisher {
990        calls: AtomicUsize,
991    }
992
993    impl CountingReplenisher {
994        fn new() -> Self {
995            Self {
996                calls: AtomicUsize::new(0),
997            }
998        }
999    }
1000
1001    impl ChannelCreditReplenisher for CountingReplenisher {
1002        fn on_item_consumed(&self) {
1003            self.calls.fetch_add(1, Ordering::AcqRel);
1004        }
1005    }
1006
1007    #[tokio::test]
1008    async fn tx_close_does_not_emit_drop_close_after_explicit_close() {
1009        let sink_impl = Arc::new(CountingSink::new());
1010        let sink: Arc<dyn ChannelSink> = sink_impl.clone();
1011
1012        let mut tx = Tx::<u32>::unbound();
1013        tx.bind(sink);
1014        tx.close(Metadata::default())
1015            .await
1016            .expect("close should succeed");
1017        drop(tx);
1018
1019        assert_eq!(sink_impl.close_calls.load(Ordering::Acquire), 1);
1020        assert_eq!(sink_impl.close_on_drop_calls.load(Ordering::Acquire), 0);
1021    }
1022
1023    #[test]
1024    fn tx_drop_emits_close_on_drop_for_bound_sink() {
1025        let sink_impl = Arc::new(CountingSink::new());
1026        let sink: Arc<dyn ChannelSink> = sink_impl.clone();
1027
1028        let mut tx = Tx::<u32>::unbound();
1029        tx.bind(sink);
1030        drop(tx);
1031
1032        assert_eq!(sink_impl.close_on_drop_calls.load(Ordering::Acquire), 1);
1033    }
1034
1035    #[test]
1036    fn tx_drop_emits_close_on_drop_for_paired_core_binding() {
1037        let sink_impl = Arc::new(CountingSink::new());
1038        let sink: Arc<dyn ChannelSink> = sink_impl.clone();
1039
1040        let (tx, _rx) = channel::<u32>();
1041        let core = tx.core.inner.as_ref().expect("paired tx should have core");
1042        core.set_binding(ChannelBinding::Sink(BoundChannelSink {
1043            sink,
1044            liveness: None,
1045        }));
1046        drop(tx);
1047
1048        assert_eq!(sink_impl.close_on_drop_calls.load(Ordering::Acquire), 1);
1049    }
1050
1051    #[tokio::test]
1052    async fn rx_recv_returns_unbound_when_not_bound() {
1053        let mut rx = Rx::<u32>::unbound();
1054        let err = match rx.recv().await {
1055            Ok(_) => panic!("unbound rx should fail"),
1056            Err(err) => err,
1057        };
1058        assert!(matches!(err, RxError::Unbound));
1059    }
1060
1061    #[tokio::test]
1062    async fn rx_recv_returns_none_on_close() {
1063        let (tx, rx_inner) = mpsc::channel("vox_types.channel.test.rx1", 1);
1064        let mut rx = Rx::<u32>::unbound();
1065        rx.bind(rx_inner);
1066
1067        let close = SelfRef::owning(
1068            Backing::Boxed(Box::<[u8]>::default()),
1069            ChannelClose {
1070                metadata: Metadata::default(),
1071            },
1072        );
1073        tx.send(IncomingChannelMessage::Close(close))
1074            .await
1075            .expect("send close");
1076
1077        assert!(rx.recv().await.expect("recv should succeed").is_none());
1078    }
1079
1080    #[tokio::test]
1081    async fn rx_recv_returns_reset_error() {
1082        let (tx, rx_inner) = mpsc::channel("vox_types.channel.test.rx2", 1);
1083        let mut rx = Rx::<u32>::unbound();
1084        rx.bind(rx_inner);
1085
1086        let reset = SelfRef::owning(
1087            Backing::Boxed(Box::<[u8]>::default()),
1088            ChannelReset {
1089                metadata: Metadata::default(),
1090            },
1091        );
1092        tx.send(IncomingChannelMessage::Reset(reset))
1093            .await
1094            .expect("send reset");
1095
1096        let err = match rx.recv().await {
1097            Ok(_) => panic!("reset should be surfaced as error"),
1098            Err(err) => err,
1099        };
1100        assert!(matches!(err, RxError::Reset));
1101    }
1102
1103    #[tokio::test]
1104    async fn rx_recv_rejects_outgoing_payload_variant_as_protocol_error() {
1105        static VALUE: u32 = 42;
1106
1107        let (tx, rx_inner) = mpsc::channel("vox_types.channel.test.rx3", 1);
1108        let mut rx = Rx::<u32>::unbound();
1109        rx.bind(rx_inner);
1110
1111        let item = SelfRef::owning(
1112            Backing::Boxed(Box::<[u8]>::default()),
1113            ChannelItem {
1114                item: Payload::outgoing(&VALUE),
1115            },
1116        );
1117        tx.send(IncomingChannelMessage::Item(item))
1118            .await
1119            .expect("send item");
1120
1121        let err = match rx.recv().await {
1122            Ok(_) => panic!("outgoing payload should be protocol error"),
1123            Err(err) => err,
1124        };
1125        assert!(matches!(err, RxError::Protocol(_)));
1126    }
1127
1128    #[tokio::test]
1129    async fn rx_recv_notifies_replenisher_after_consuming_an_item() {
1130        let (tx, rx_inner) = mpsc::channel("vox_types.channel.test.rx4", 1);
1131        let replenisher = Arc::new(CountingReplenisher::new());
1132        let mut rx = Rx::<u32>::unbound();
1133        rx.bind(rx_inner);
1134        rx.replenisher.inner = Some(replenisher.clone());
1135
1136        let encoded = vox_postcard::to_vec(&123_u32).expect("serialize test item");
1137        let item = SelfRef::owning(
1138            Backing::Boxed(Box::<[u8]>::default()),
1139            ChannelItem {
1140                item: Payload::PostcardBytes(Box::leak(encoded.into_boxed_slice())),
1141            },
1142        );
1143        tx.send(IncomingChannelMessage::Item(item))
1144            .await
1145            .expect("send item");
1146
1147        let value = rx
1148            .recv()
1149            .await
1150            .expect("recv should succeed")
1151            .expect("expected item");
1152        assert_eq!(*value.get(), 123_u32);
1153        assert_eq!(replenisher.calls.load(Ordering::Acquire), 1);
1154    }
1155
1156    #[tokio::test]
1157    async fn rx_recv_logical_receiver_decodes_items_and_notifies_replenisher() {
1158        let (tx, rx_inner) = mpsc::channel("vox_types.channel.test.rx5", 1);
1159        let replenisher = Arc::new(CountingReplenisher::new());
1160        let core = Arc::new(ChannelCore::new());
1161        core.bind_retryable_receiver(BoundChannelReceiver {
1162            receiver: rx_inner,
1163            liveness: None,
1164            replenisher: Some(replenisher.clone()),
1165        });
1166
1167        let mut rx = Rx::<u32>::paired(core);
1168
1169        let encoded = vox_postcard::to_vec(&321_u32).expect("serialize test item");
1170        let item = SelfRef::owning(
1171            Backing::Boxed(Box::<[u8]>::default()),
1172            ChannelItem {
1173                item: Payload::PostcardBytes(Box::leak(encoded.into_boxed_slice())),
1174            },
1175        );
1176        tx.send(IncomingChannelMessage::Item(item))
1177            .await
1178            .expect("send item");
1179
1180        let value = rx
1181            .recv()
1182            .await
1183            .expect("recv should succeed")
1184            .expect("expected item");
1185        assert_eq!(*value.get(), 321_u32);
1186        assert_eq!(replenisher.calls.load(Ordering::Acquire), 1);
1187    }
1188
1189    // ========================================================================
1190    // Channel binding through ser/deser
1191    // ========================================================================
1192
1193    /// A test binder that tracks allocations and bindings.
1194    struct TestBinder {
1195        next_id: std::sync::Mutex<u64>,
1196    }
1197
1198    impl TestBinder {
1199        fn new() -> Self {
1200            Self {
1201                next_id: std::sync::Mutex::new(100),
1202            }
1203        }
1204
1205        fn alloc_id(&self) -> ChannelId {
1206            let mut guard = self.next_id.lock().unwrap();
1207            let id = *guard;
1208            *guard += 2;
1209            ChannelId(id)
1210        }
1211    }
1212
1213    impl ChannelBinder for TestBinder {
1214        fn create_tx(&self) -> (ChannelId, Arc<dyn ChannelSink>) {
1215            (self.alloc_id(), Arc::new(CountingSink::new()))
1216        }
1217
1218        fn create_rx(&self) -> (ChannelId, BoundChannelReceiver) {
1219            let (tx, rx) = mpsc::channel("vox_types.channel.test.bind_retryable1", 8);
1220            // Keep the sender alive by leaking it — test only.
1221            std::mem::forget(tx);
1222            (
1223                self.alloc_id(),
1224                BoundChannelReceiver {
1225                    receiver: rx,
1226                    liveness: None,
1227                    replenisher: None,
1228                },
1229            )
1230        }
1231
1232        fn bind_tx(&self, _channel_id: ChannelId) -> Arc<dyn ChannelSink> {
1233            Arc::new(CountingSink::new())
1234        }
1235
1236        fn register_rx(&self, _channel_id: ChannelId) -> BoundChannelReceiver {
1237            let (_tx, rx) = mpsc::channel("vox_types.channel.test.bind_retryable2", 8);
1238            BoundChannelReceiver {
1239                receiver: rx,
1240                liveness: None,
1241                replenisher: None,
1242            }
1243        }
1244    }
1245
1246    // Case 1: Caller passes Tx in args, keeps paired Rx.
1247    // Serializing the Tx allocates a channel ID via create_rx() and stores
1248    // the receiver in the shared logical core so the kept Rx can survive retries.
1249    #[tokio::test]
1250    async fn case1_serialize_tx_allocates_and_binds_paired_rx() {
1251        use facet::Facet;
1252
1253        #[derive(Facet)]
1254        struct Args {
1255            data: u32,
1256            tx: Tx<u32>,
1257        }
1258
1259        let (tx, rx) = channel::<u32>();
1260        let args = Args { data: 42, tx };
1261
1262        let binder = TestBinder::new();
1263        let bytes =
1264            with_channel_binder(&binder, || vox_postcard::to_vec(&args).expect("serialize"));
1265
1266        // The channel ID should be in the serialized bytes (after the u32 data field).
1267        assert!(!bytes.is_empty());
1268
1269        // The kept Rx should now have a receiver binding in the shared core.
1270        assert!(
1271            rx.core.inner.is_some(),
1272            "paired Rx should have a shared core"
1273        );
1274        let core = rx.core.inner.as_ref().unwrap();
1275        assert!(
1276            core.take_logical_receiver().is_some(),
1277            "core should have a logical receiver binding from create_rx()"
1278        );
1279    }
1280
1281    // Case 2: Caller passes Rx in args, keeps paired Tx.
1282    // Serializing the Rx allocates a channel ID via create_tx() and stores
1283    // the sink in the shared core so the kept Tx can use it.
1284    #[test]
1285    fn case2_serialize_rx_allocates_and_binds_paired_tx() {
1286        use facet::Facet;
1287
1288        #[derive(Facet)]
1289        struct Args {
1290            data: u32,
1291            rx: Rx<u32>,
1292        }
1293
1294        let (tx, rx) = channel::<u32>();
1295        let args = Args { data: 42, rx };
1296
1297        let binder = TestBinder::new();
1298        let bytes =
1299            with_channel_binder(&binder, || vox_postcard::to_vec(&args).expect("serialize"));
1300
1301        assert!(!bytes.is_empty());
1302
1303        // The kept Tx should now have a sink binding in the shared core.
1304        assert!(tx.core.inner.is_some());
1305        let core = tx.core.inner.as_ref().unwrap();
1306        assert!(
1307            core.get_sink().is_some(),
1308            "core should have a Sink binding from create_tx()"
1309        );
1310    }
1311
1312    // Case 3: Callee deserializes Tx from args.
1313    // The Tx is bound directly via bind_tx() during deserialization.
1314    #[test]
1315    fn case3_deserialize_tx_binds_via_binder() {
1316        use facet::Facet;
1317
1318        #[derive(Facet)]
1319        struct Args {
1320            data: u32,
1321            tx: Tx<u32>,
1322        }
1323
1324        // Simulate wire bytes: a u32 (42) followed by a channel ID (varint 7).
1325        let mut bytes = vox_postcard::to_vec(&42_u32).unwrap();
1326        bytes.extend_from_slice(&vox_postcard::to_vec(&ChannelId(7)).unwrap());
1327
1328        let binder = TestBinder::new();
1329        let args: Args = with_channel_binder(&binder, || {
1330            vox_postcard::from_slice(&bytes).expect("deserialize")
1331        });
1332
1333        assert_eq!(args.data, 42);
1334        assert_eq!(args.tx.channel_id, ChannelId(7));
1335        assert!(
1336            args.tx.is_bound(),
1337            "deserialized Tx should be bound via bind_tx()"
1338        );
1339    }
1340
1341    // Case 4: Callee deserializes Rx from args.
1342    // The Rx is bound directly via register_rx() during deserialization.
1343    #[test]
1344    fn case4_deserialize_rx_binds_via_binder() {
1345        use facet::Facet;
1346
1347        #[derive(Facet)]
1348        struct Args {
1349            data: u32,
1350            rx: Rx<u32>,
1351        }
1352
1353        // Simulate wire bytes: a u32 (42) followed by a channel ID (varint 7).
1354        let mut bytes = vox_postcard::to_vec(&42_u32).unwrap();
1355        bytes.extend_from_slice(&vox_postcard::to_vec(&ChannelId(7)).unwrap());
1356
1357        let binder = TestBinder::new();
1358        let args: Args = with_channel_binder(&binder, || {
1359            vox_postcard::from_slice(&bytes).expect("deserialize")
1360        });
1361
1362        assert_eq!(args.data, 42);
1363        assert_eq!(args.rx.channel_id, ChannelId(7));
1364        assert!(
1365            args.rx.is_bound(),
1366            "deserialized Rx should be bound via register_rx()"
1367        );
1368    }
1369
1370    // Round-trip: serialize with caller binder, deserialize with callee binder.
1371    // Verifies the channel ID allocated during serialization appears in the
1372    // deserialized handle.
1373    #[test]
1374    fn channel_id_round_trips_through_ser_deser() {
1375        use facet::Facet;
1376
1377        #[derive(Facet)]
1378        struct Args {
1379            tx: Tx<u32>,
1380        }
1381
1382        let (tx, _rx) = channel::<u32>();
1383        let args = Args { tx };
1384
1385        let caller_binder = TestBinder::new();
1386        let bytes = with_channel_binder(&caller_binder, || {
1387            vox_postcard::to_vec(&args).expect("serialize")
1388        });
1389
1390        let callee_binder = TestBinder::new();
1391        let deserialized: Args = with_channel_binder(&callee_binder, || {
1392            vox_postcard::from_slice(&bytes).expect("deserialize")
1393        });
1394
1395        // The caller binder starts at ID 100, so the deserialized Tx should have that ID.
1396        assert_eq!(deserialized.tx.channel_id, ChannelId(100));
1397        assert!(deserialized.tx.is_bound());
1398    }
1399}