Skip to main content

vox_types/
calls.rs

1use std::{future::Future, pin::Pin, sync::Arc};
2
3use crate::{
4    ClientCallOutcome, ClientContext, ClientMiddleware, ClientRequest, Extensions, MaybeSend,
5    MaybeSendFuture, MaybeSync, Metadata, RequestCall, RequestResponse, VoxError, SelfRef,
6    ServiceDescriptor,
7};
8
9/// A boxed future that is `Send` on native targets and `!Send` on wasm32.
10pub type BoxFut<'a, T> = Pin<Box<dyn MaybeSendFuture<Output = T> + 'a>>;
11
12/// Result type for one caller-visible RPC call: either a tracked response or an error.
13///
14/// The tracked value is the wire-level [`RequestResponse`] that resolved the
15/// current request attempt for that call.
16pub type CallResult = Result<crate::WithTracker<SelfRef<RequestResponse<'static>>>, VoxError>;
17
18// As a recap, a service defined like so:
19//
20// #[vox::service]
21// trait Hash {
22//   async fn hash(&self, payload: &[u8]) -> Result<&[u8], E>;
23// }
24//
25// Would expand to the following caller:
26//
27// impl HashClient {
28//   async fn hash(&self, payload: &[u8]) -> Result<SelfRef<&[u8]>, VoxError<E>>;
29// }
30//
31// Would expand to a service trait (what users implement):
32//
33// trait Hash {
34//   async fn hash(&self, call: impl Call<&[u8], E>, payload: &[u8]);
35// }
36//
37// And a HashDispatcher<S: Hash> that implements Handler<R: ReplySink>:
38// it deserializes args, constructs an ErasedCall<T, E> from the ReplySink,
39// and routes to the appropriate method by method ID.
40//
41// For owned success returns, generated methods return values directly and
42// the dispatcher sends replies on their behalf.
43//
44// HashDispatcher<S> implements Handler<R>, and can be stored as
45// Box<dyn Handler<R>> to erase both S and the service type.
46//
47// Why impl Call in HashServer? So that the server can reply with something
48// _borrowed_ from its own stack frame.
49//
50// For example:
51//
52// impl Hash for MyHasher {
53//   async fn hash(&self, call: impl Call<&[u8], E>, payload: &[u8]) {
54//     let result: [u8; 16] = compute_hash(payload);
55//     call.ok(&result).await;
56//   }
57// }
58//
59// Call's public API is:
60//
61// trait Call<T, E> {
62//   async fn reply(self, result: Result<T, E>);
63//   async fn ok(self, value: T) { self.reply(Ok(value)).await }
64//   async fn err(self, error: E) { self.reply(Err(error)).await }
65// }
66//
67// If a Call is dropped before reply/ok/err is called, the caller will
68// receive a VoxError::Cancelled error. This is to ensure that the caller
69// is always notified, even if the handler panics or otherwise fails to
70// reply.
71
72/// Represents an in-progress API-level call as seen by a server handler.
73///
74/// A `Call` is handed to a [`Handler`] implementation for one incoming
75/// request attempt. It provides the mechanism for sending the terminal
76/// response for that attempt back to the caller. The response can be sent
77/// via [`Call::reply`], [`Call::ok`], or [`Call::err`].
78///
79/// In the retry model, one logical operation may span multiple request
80/// attempts over time, but each `Call` value corresponds to exactly one
81/// request attempt currently being handled.
82///
83/// # Cancellation
84///
85/// If a `Call` is dropped without a reply being sent, the caller will
86/// automatically receive a [`VoxError::Cancelled`] error. This guarantees
87/// that the caller is always notified, even if the handler panics or
88/// otherwise fails to produce a reply.
89///
90/// # Type Parameters
91///
92/// - `T`: The success value type of the response.
93/// - `E`: The error value type of the response.
94pub trait Call<'wire, T, E>: MaybeSend
95where
96    T: facet::Facet<'wire> + MaybeSend,
97    E: facet::Facet<'wire> + MaybeSend,
98{
99    /// Send the terminal response for this request attempt, consuming this `Call`.
100    fn reply(self, result: Result<T, E>) -> impl std::future::Future<Output = ()> + MaybeSend;
101
102    /// Send a successful response for this request attempt, consuming this `Call`.
103    ///
104    /// Equivalent to `self.reply(Ok(value)).await`.
105    fn ok(self, value: T) -> impl std::future::Future<Output = ()> + MaybeSend
106    where
107        Self: Sized,
108    {
109        self.reply(Ok(value))
110    }
111
112    /// Send an error response for this request attempt, consuming this `Call`.
113    ///
114    /// Equivalent to `self.reply(Err(error)).await`.
115    fn err(self, error: E) -> impl std::future::Future<Output = ()> + MaybeSend
116    where
117        Self: Sized,
118    {
119        self.reply(Err(error))
120    }
121}
122
123/// Sink for sending the terminal response for one request attempt.
124///
125/// Implemented by the session driver. Provides backpressure: `send_reply`
126/// awaits until the transport can accept the response before serializing it.
127///
128/// # Cancellation
129///
130/// If the `ReplySink` is dropped without `send_reply` being called, the caller
131/// will automatically receive a [`crate::VoxError::Cancelled`] error.
132pub trait ReplySink: MaybeSend + MaybeSync + 'static {
133    /// Send the terminal response for this request attempt, consuming the sink.
134    /// Any error that happens during `send_reply` must set a flag in the driver
135    /// for it to resolve the attempt as failed.
136    ///
137    /// This cannot return a `Result` because we cannot trust callers to deal
138    /// with it, and they cannot try sending a second response anyway.
139    ///
140    /// Do not spawn a task to send the error because it too, might fail.
141    fn send_reply(
142        self,
143        response: RequestResponse<'_>,
144    ) -> impl std::future::Future<Output = ()> + MaybeSend;
145
146    /// Send an error response for this request attempt, consuming the sink.
147    ///
148    /// This is a convenience method used by generated dispatchers when
149    /// deserialization fails or the method ID is unknown.
150    fn send_error<E: for<'a> facet::Facet<'a> + MaybeSend>(
151        self,
152        error: VoxError<E>,
153    ) -> impl std::future::Future<Output = ()> + MaybeSend
154    where
155        Self: Sized,
156    {
157        use crate::{Payload, RequestResponse};
158        // Wire format is always Result<T, VoxError<E>>. We don't know T here,
159        // but postcard encodes () as zero bytes, so Result<(), VoxError<E>>
160        // produces the same Err variant encoding as any Result<T, VoxError<E>>.
161        async move {
162            let wire: Result<(), VoxError<E>> = Err(error);
163            self.send_reply(RequestResponse {
164                ret: Payload::outgoing(&wire),
165                metadata: Default::default(),
166                schemas: Default::default(),
167            })
168            .await;
169        }
170    }
171
172    /// Send an error response using the full wire shape `Result<T, VoxError<E>>`.
173    ///
174    /// This preserves the method's real `Ok` type for schema extraction.
175    fn send_typed_error<'wire, T, E>(
176        self,
177        error: VoxError<E>,
178    ) -> impl std::future::Future<Output = ()> + MaybeSend
179    where
180        Self: Sized,
181        T: facet::Facet<'wire> + MaybeSend,
182        E: facet::Facet<'wire> + MaybeSend,
183    {
184        use crate::{Payload, RequestResponse};
185        async move {
186            let wire: Result<T, VoxError<E>> = Err(error);
187            let ptr = facet::PtrConst::new((&wire as *const Result<T, VoxError<E>>).cast::<u8>());
188            let shape = <Result<T, VoxError<E>> as facet::Facet<'wire>>::SHAPE;
189            let ret = unsafe { Payload::outgoing_unchecked(ptr, shape) };
190            self.send_reply(RequestResponse {
191                ret,
192                metadata: Default::default(),
193                schemas: Default::default(),
194            })
195            .await;
196        }
197    }
198
199    /// Return a channel binder for binding Tx/Rx handles in deserialized args.
200    ///
201    /// Returns `None` by default. The driver's `ReplySink` implementation
202    /// overrides this to provide actual channel binding.
203    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
204        None
205    }
206}
207
208/// Type-erased handler for incoming service calls.
209///
210/// Implemented (by the macro-generated dispatch code) for server-side types.
211/// Takes a fully decoded [`RequestCall`](crate::RequestCall) — one wire-level
212/// request attempt already parsed from the connection — and a [`ReplySink`]
213/// through which the terminal response for that attempt is sent.
214///
215/// The dispatch impl decodes the args, routes by [`crate::MethodId`], and
216/// invokes the appropriate typed [`Call`]-based method on the concrete server type.
217///
218/// Generated clients hold an [`ErasedCaller`] and use it to start API-level
219/// calls. The caller serializes the outgoing [`RequestCall`] (with borrowed
220/// args), registers a pending response slot for that request attempt, and
221/// awaits the response from the peer.
222pub trait Caller: Clone + MaybeSend + MaybeSync + 'static {
223    /// Start one outgoing request attempt for an API-level call and wait for
224    /// its response.
225    ///
226    /// Returns the wire-level response paired with the `SchemaRecvTracker` that
227    /// was active when the response was received, for schema-aware
228    /// deserialization.
229    fn call<'a>(
230        &'a self,
231        call: RequestCall<'a>,
232    ) -> impl Future<Output = CallResult> + MaybeSend + 'a;
233
234    /// Resolve when the underlying connection closes.
235    ///
236    /// Runtime-backed callers can override this to expose connection liveness.
237    /// The default implementation never resolves.
238    fn closed(&self) -> BoxFut<'_, ()> {
239        Box::pin(std::future::pending())
240    }
241
242    /// Return whether the underlying connection is still considered connected.
243    ///
244    /// Runtime-backed callers can override this to provide eager liveness
245    /// checks. The default implementation assumes the connection is live.
246    fn is_connected(&self) -> bool {
247        true
248    }
249
250    /// Return a channel binder for binding Tx/Rx handles in args before sending.
251    ///
252    /// Returns `None` by default. The driver's `Caller` implementation
253    /// overrides this to provide actual channel binding.
254    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
255        None
256    }
257}
258
259trait ErasedCallerDyn: MaybeSend + MaybeSync + 'static {
260    fn call<'a>(&'a self, call: RequestCall<'a>) -> BoxFut<'a, CallResult>;
261
262    fn closed(&self) -> BoxFut<'_, ()>;
263
264    fn is_connected(&self) -> bool;
265
266    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder>;
267}
268
269impl<C: Caller> ErasedCallerDyn for C {
270    fn call<'a>(&'a self, call: RequestCall<'a>) -> BoxFut<'a, CallResult> {
271        Box::pin(Caller::call(self, call))
272    }
273
274    fn closed(&self) -> BoxFut<'_, ()> {
275        Caller::closed(self)
276    }
277
278    fn is_connected(&self) -> bool {
279        Caller::is_connected(self)
280    }
281
282    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
283        Caller::channel_binder(self)
284    }
285}
286
287/// Type-erased [`Caller`] wrapper used by generated clients.
288#[derive(Clone)]
289pub struct ErasedCaller {
290    inner: Arc<dyn ErasedCallerDyn>,
291    service: Option<&'static ServiceDescriptor>,
292    middlewares: Vec<Arc<dyn ClientMiddleware>>,
293}
294
295impl ErasedCaller {
296    pub fn new<C: Caller>(caller: C) -> Self {
297        Self {
298            inner: Arc::new(caller),
299            service: None,
300            middlewares: vec![],
301        }
302    }
303
304    pub fn with_middleware(
305        mut self,
306        service: &'static ServiceDescriptor,
307        middleware: impl ClientMiddleware,
308    ) -> Self {
309        if let Some(existing_service) = self.service {
310            assert_eq!(
311                existing_service.service_name, service.service_name,
312                "ErasedCaller middleware service mismatch"
313            );
314        } else {
315            self.service = Some(service);
316        }
317        self.middlewares.push(Arc::new(middleware));
318        self
319    }
320}
321
322impl Caller for ErasedCaller {
323    async fn call<'a>(&'a self, mut call: RequestCall<'a>) -> CallResult {
324        let Some(service) = self.service else {
325            return self.inner.call(call).await;
326        };
327
328        let extensions = Extensions::new();
329        let method = service.by_id(call.method_id);
330        let context = ClientContext::new(method, call.method_id, &extensions);
331        let mut owned_metadata = crate::client_middleware::OwnedMetadata::default();
332
333        if !self.middlewares.is_empty() {
334            for middleware in &self.middlewares {
335                let mut request = ClientRequest::new(&mut call, &mut owned_metadata);
336                middleware.pre(&context, &mut request).await;
337            }
338        }
339
340        let result = self.inner.call(call).await;
341        if !self.middlewares.is_empty() {
342            let outcome = match &result {
343                Ok(_) => ClientCallOutcome::Response,
344                Err(error) => ClientCallOutcome::Error(error),
345            };
346            for middleware in self.middlewares.iter().rev() {
347                middleware.post(&context, outcome).await;
348            }
349        }
350        result
351    }
352
353    fn closed(&self) -> BoxFut<'_, ()> {
354        self.inner.closed()
355    }
356
357    fn is_connected(&self) -> bool {
358        self.inner.is_connected()
359    }
360
361    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
362        self.inner.channel_binder()
363    }
364}
365
366pub trait Handler<R: ReplySink>: MaybeSend + MaybeSync + 'static {
367    /// Return the static retry policy for a method ID served by this handler.
368    fn retry_policy(&self, _method_id: crate::MethodId) -> crate::RetryPolicy {
369        crate::RetryPolicy::VOLATILE
370    }
371
372    /// Return whether the method's argument shape contains any channels.
373    fn args_have_channels(&self, _method_id: crate::MethodId) -> bool {
374        false
375    }
376
377    /// Return the canonical wire response shape for a method, if known.
378    ///
379    /// This is the full wire type `Result<T, VoxError<E>>`, not the
380    /// user-facing return type `T` or `Result<T, E>`.
381    fn response_wire_shape(&self, _method_id: crate::MethodId) -> Option<&'static facet::Shape> {
382        None
383    }
384
385    /// Dispatch an incoming call to the appropriate method implementation.
386    fn handle(
387        &self,
388        call: SelfRef<crate::RequestCall<'static>>,
389        reply: R,
390        schemas: std::sync::Arc<crate::SchemaRecvTracker>,
391    ) -> impl std::future::Future<Output = ()> + MaybeSend + '_;
392}
393
394impl<R: ReplySink> Handler<R> for () {
395    async fn handle(
396        &self,
397        _call: SelfRef<crate::RequestCall<'static>>,
398        _reply: R,
399        _schemas: std::sync::Arc<crate::SchemaRecvTracker>,
400    ) {
401    }
402}
403
404/// A decoded response value paired with response metadata.
405///
406/// This helper is available for lower-level callers that need both the
407/// decoded value and metadata together. Generated Rust client methods do
408/// not expose response metadata in their return types.
409pub struct ResponseParts<'a, T> {
410    /// The decoded return value.
411    pub ret: T,
412    /// Metadata attached to the response by the server.
413    pub metadata: Metadata<'a>,
414}
415
416impl<'a, T> std::ops::Deref for ResponseParts<'a, T> {
417    type Target = T;
418    fn deref(&self) -> &T {
419        &self.ret
420    }
421}
422
423/// Concrete [`Call`] implementation backed by a [`ReplySink`].
424///
425/// Constructed by the dispatcher and handed to the server method.
426/// When the server calls [`Call::reply`], the result is serialized and
427/// sent through the sink.
428pub struct SinkCall<R: ReplySink> {
429    reply: R,
430}
431
432impl<R: ReplySink> SinkCall<R> {
433    pub fn new(reply: R) -> Self {
434        Self { reply }
435    }
436}
437
438impl<'wire, T, E, R> Call<'wire, T, E> for SinkCall<R>
439where
440    T: facet::Facet<'wire> + MaybeSend,
441    E: facet::Facet<'wire> + MaybeSend,
442    R: ReplySink,
443{
444    async fn reply(self, result: Result<T, E>) {
445        use crate::{Payload, RequestResponse};
446        let wire: Result<T, crate::VoxError<E>> = result.map_err(crate::VoxError::User);
447        let ptr =
448            facet::PtrConst::new((&wire as *const Result<T, crate::VoxError<E>>).cast::<u8>());
449        let shape = <Result<T, crate::VoxError<E>> as facet::Facet<'wire>>::SHAPE;
450        // SAFETY: `wire` lives until `send_reply(...).await` completes in this function,
451        // and `shape` matches the pointed value exactly.
452        let ret = unsafe { Payload::outgoing_unchecked(ptr, shape) };
453        self.reply
454            .send_reply(RequestResponse {
455                ret,
456                metadata: Default::default(),
457                schemas: Default::default(),
458            })
459            .await;
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use std::sync::{Arc, Mutex};
466
467    use crate::{MaybeSend, Metadata, Payload, RequestCall, RequestResponse};
468
469    use super::{Call, CallResult, Caller, Handler, ReplySink, ResponseParts};
470
471    struct RecordingCall<T, E> {
472        observed: Arc<Mutex<Option<Result<T, E>>>>,
473    }
474
475    impl<'wire, T, E> Call<'wire, T, E> for RecordingCall<T, E>
476    where
477        T: facet::Facet<'wire> + MaybeSend + Send + 'static,
478        E: facet::Facet<'wire> + MaybeSend + Send + 'static,
479    {
480        async fn reply(self, result: Result<T, E>) {
481            let mut guard = self.observed.lock().expect("recording mutex poisoned");
482            *guard = Some(result);
483        }
484    }
485
486    struct RecordingReplySink {
487        saw_send_reply: Arc<Mutex<bool>>,
488        saw_outgoing_payload: Arc<Mutex<bool>>,
489    }
490
491    impl ReplySink for RecordingReplySink {
492        async fn send_reply(self, response: RequestResponse<'_>) {
493            let mut saw_send_reply = self
494                .saw_send_reply
495                .lock()
496                .expect("send-reply mutex poisoned");
497            *saw_send_reply = true;
498
499            let mut saw_outgoing = self
500                .saw_outgoing_payload
501                .lock()
502                .expect("payload-kind mutex poisoned");
503            *saw_outgoing = matches!(response.ret, Payload::Value { .. });
504        }
505    }
506
507    #[derive(Clone)]
508    struct NoopCaller;
509
510    impl Caller for NoopCaller {
511        async fn call<'a>(&'a self, _call: RequestCall<'a>) -> CallResult {
512            unreachable!("NoopCaller::call is not used by this test")
513        }
514    }
515
516    #[tokio::test]
517    async fn call_ok_and_err_route_through_reply() {
518        let observed_ok: Arc<Mutex<Option<Result<u32, &'static str>>>> = Arc::new(Mutex::new(None));
519        RecordingCall {
520            observed: Arc::clone(&observed_ok),
521        }
522        .ok(7)
523        .await;
524        assert!(matches!(
525            *observed_ok.lock().expect("ok mutex poisoned"),
526            Some(Ok(7))
527        ));
528
529        let observed_err: Arc<Mutex<Option<Result<u32, &'static str>>>> =
530            Arc::new(Mutex::new(None));
531        RecordingCall {
532            observed: Arc::clone(&observed_err),
533        }
534        .err("boom")
535        .await;
536        assert!(matches!(
537            *observed_err.lock().expect("err mutex poisoned"),
538            Some(Err("boom"))
539        ));
540    }
541
542    #[tokio::test]
543    async fn reply_sink_send_error_uses_outgoing_payload_and_reply_path() {
544        let saw_send_reply = Arc::new(Mutex::new(false));
545        let saw_outgoing_payload = Arc::new(Mutex::new(false));
546        let sink = RecordingReplySink {
547            saw_send_reply: Arc::clone(&saw_send_reply),
548            saw_outgoing_payload: Arc::clone(&saw_outgoing_payload),
549        };
550
551        sink.send_error(crate::VoxError::<String>::Cancelled).await;
552
553        assert!(*saw_send_reply.lock().expect("send-reply mutex poisoned"));
554        assert!(
555            *saw_outgoing_payload
556                .lock()
557                .expect("payload-kind mutex poisoned")
558        );
559    }
560
561    #[tokio::test]
562    async fn reply_sink_send_typed_error_preserves_ok_shape() {
563        use crate::{
564            VoxError, SchemaKind, TypeRef, VariantPayload, build_registry, extract_schemas,
565        };
566
567        struct ShapeReplySink {
568            observed_root: Arc<Mutex<Option<TypeRef>>>,
569        }
570
571        impl ReplySink for ShapeReplySink {
572            async fn send_reply(self, response: RequestResponse<'_>) {
573                let Payload::Value { shape, .. } = response.ret else {
574                    panic!("typed error should use outgoing payload");
575                };
576                let extracted = extract_schemas(shape).expect("response shape should extract");
577                *self
578                    .observed_root
579                    .lock()
580                    .expect("observed-root mutex poisoned") = Some(extracted.root);
581            }
582        }
583
584        let observed_root = Arc::new(Mutex::new(None));
585        ShapeReplySink {
586            observed_root: Arc::clone(&observed_root),
587        }
588        .send_typed_error::<(String, i32), String>(VoxError::Cancelled)
589        .await;
590
591        let root = observed_root
592            .lock()
593            .expect("observed-root mutex poisoned")
594            .clone()
595            .expect("typed error should record a root");
596        let extracted =
597            extract_schemas(<Result<(String, i32), VoxError<String>> as facet::Facet>::SHAPE)
598                .expect("expected result shape should extract");
599        let registry = build_registry(&extracted.schemas);
600        let root_kind = root.resolve_kind(&registry).expect("root should resolve");
601        let SchemaKind::Enum { variants, .. } = root_kind else {
602            panic!("expected result enum root");
603        };
604        let ok_variant = variants
605            .iter()
606            .find(|variant| variant.name == "Ok")
607            .expect("Result should have Ok variant");
608        let VariantPayload::Newtype { type_ref } = &ok_variant.payload else {
609            panic!("Ok variant should be newtype");
610        };
611        match type_ref
612            .resolve_kind(&registry)
613            .expect("Ok payload should resolve")
614        {
615            SchemaKind::Tuple { elements } => {
616                assert_eq!(elements.len(), 2, "Ok tuple should have two elements");
617            }
618            other => panic!("expected Ok payload to be tuple, got {other:?}"),
619        }
620    }
621
622    #[tokio::test]
623    async fn unit_handler_is_noop() {
624        let req = crate::SelfRef::owning(
625            crate::Backing::Boxed(Box::<[u8]>::default()),
626            RequestCall {
627                method_id: crate::MethodId(1),
628                metadata: Metadata::default(),
629                args: Payload::PostcardBytes(&[]),
630                schemas: Default::default(),
631            },
632        );
633        ().handle(
634            req,
635            RecordingReplySink {
636                saw_send_reply: Arc::new(Mutex::new(false)),
637                saw_outgoing_payload: Arc::new(Mutex::new(false)),
638            },
639            Arc::new(crate::SchemaRecvTracker::new()),
640        )
641        .await;
642    }
643
644    #[test]
645    fn response_parts_deref_exposes_ret() {
646        let parts = ResponseParts {
647            ret: 42_u32,
648            metadata: Metadata::default(),
649        };
650        assert_eq!(*parts, 42);
651    }
652
653    #[test]
654    fn default_channel_binder_accessor_for_caller_returns_none() {
655        let caller = NoopCaller;
656        assert!(caller.channel_binder().is_none());
657    }
658
659    #[test]
660    fn default_channel_binder_accessor_for_reply_sink_returns_none() {
661        let sink = RecordingReplySink {
662            saw_send_reply: Arc::new(Mutex::new(false)),
663            saw_outgoing_payload: Arc::new(Mutex::new(false)),
664        };
665        assert!(sink.channel_binder().is_none());
666    }
667}