Skip to main content

vox_types/
calls.rs

1use std::{future::Future, pin::Pin, sync::Arc};
2
3use crate::{
4    ClientCallOutcome, ClientContext, ClientMiddleware, ClientRequest, ConnectionId, Extensions,
5    MaybeSend, MaybeSendFuture, MaybeSync, Metadata, RequestCall, RequestId, RequestResponse,
6    SelfRef, ServiceDescriptor, VoxError,
7};
8
9/// A boxed future that is `Send` on native targets and `!Send` on wasm32.
10pub type BoxFut<'a, T> = Pin<Box<dyn MaybeSendFuture<Output = T> + 'a>>;
11
12/// Result type for one caller-visible RPC call: either a tracked response or an error.
13///
14/// The tracked value is the wire-level [`RequestResponse`] that resolved the
15/// current request attempt for that call.
16pub type CallResult = Result<crate::WithTracker<SelfRef<RequestResponse<'static>>>, VoxError>;
17
18// As a recap, a service defined like so:
19//
20// #[vox::service]
21// trait Hash {
22//   async fn hash(&self, payload: &[u8]) -> Result<&[u8], E>;
23// }
24//
25// Would expand to the following caller:
26//
27// impl HashClient {
28//   async fn hash(&self, payload: &[u8]) -> Result<SelfRef<&[u8]>, VoxError<E>>;
29// }
30//
31// Would expand to a service trait (what users implement):
32//
33// trait Hash {
34//   async fn hash(&self, call: impl Call<&[u8], E>, payload: &[u8]);
35// }
36//
37// And a HashDispatcher<S: Hash> that implements Handler<R: ReplySink>:
38// it deserializes args, constructs an ErasedCall<T, E> from the ReplySink,
39// and routes to the appropriate method by method ID.
40//
41// For owned success returns, generated methods return values directly and
42// the dispatcher sends replies on their behalf.
43//
44// HashDispatcher<S> implements Handler<R>, and can be stored as
45// Box<dyn Handler<R>> to erase both S and the service type.
46//
47// Why impl Call in HashServer? So that the server can reply with something
48// _borrowed_ from its own stack frame.
49//
50// For example:
51//
52// impl Hash for MyHasher {
53//   async fn hash(&self, call: impl Call<&[u8], E>, payload: &[u8]) {
54//     let result: [u8; 16] = compute_hash(payload);
55//     call.ok(&result).await;
56//   }
57// }
58//
59// Call's public API is:
60//
61// trait Call<T, E> {
62//   async fn reply(self, result: Result<T, E>);
63//   async fn ok(self, value: T) { self.reply(Ok(value)).await }
64//   async fn err(self, error: E) { self.reply(Err(error)).await }
65// }
66//
67// If a Call is dropped before reply/ok/err is called, the caller will
68// receive a VoxError::Cancelled error. This is to ensure that the caller
69// is always notified, even if the handler panics or otherwise fails to
70// reply.
71
72/// Represents an in-progress API-level call as seen by a server handler.
73///
74/// A `Call` is handed to a [`Handler`] implementation for one incoming
75/// request attempt. It provides the mechanism for sending the terminal
76/// response for that attempt back to the caller. The response can be sent
77/// via [`Call::reply`], [`Call::ok`], or [`Call::err`].
78///
79/// In the retry model, one logical operation may span multiple request
80/// attempts over time, but each `Call` value corresponds to exactly one
81/// request attempt currently being handled.
82///
83/// # Cancellation
84///
85/// If a `Call` is dropped without a reply being sent, the caller will
86/// automatically receive a [`VoxError::Cancelled`] error. This guarantees
87/// that the caller is always notified, even if the handler panics or
88/// otherwise fails to produce a reply.
89///
90/// # Type Parameters
91///
92/// - `T`: The success value type of the response.
93/// - `E`: The error value type of the response.
94pub trait Call<'wire, T, E>: MaybeSend
95where
96    T: facet::Facet<'wire> + MaybeSend,
97    E: facet::Facet<'wire> + MaybeSend,
98{
99    /// Send the terminal response for this request attempt, consuming this `Call`.
100    fn reply(self, result: Result<T, E>) -> impl std::future::Future<Output = ()> + MaybeSend;
101
102    /// Send a successful response for this request attempt, consuming this `Call`.
103    ///
104    /// Equivalent to `self.reply(Ok(value)).await`.
105    fn ok(self, value: T) -> impl std::future::Future<Output = ()> + MaybeSend
106    where
107        Self: Sized,
108    {
109        self.reply(Ok(value))
110    }
111
112    /// Send an error response for this request attempt, consuming this `Call`.
113    ///
114    /// Equivalent to `self.reply(Err(error)).await`.
115    fn err(self, error: E) -> impl std::future::Future<Output = ()> + MaybeSend
116    where
117        Self: Sized,
118    {
119        self.reply(Err(error))
120    }
121}
122
123/// Sink for sending the terminal response for one request attempt.
124///
125/// Implemented by the session driver. Provides backpressure: `send_reply`
126/// awaits until the transport can accept the response before serializing it.
127///
128/// # Cancellation
129///
130/// If the `ReplySink` is dropped without `send_reply` being called, the caller
131/// will automatically receive a [`crate::VoxError::Cancelled`] error.
132pub trait ReplySink: MaybeSend + MaybeSync + 'static {
133    /// Send the terminal response for this request attempt, consuming the sink.
134    /// Any error that happens during `send_reply` must set a flag in the driver
135    /// for it to resolve the attempt as failed.
136    ///
137    /// This cannot return a `Result` because we cannot trust callers to deal
138    /// with it, and they cannot try sending a second response anyway.
139    ///
140    /// Do not spawn a task to send the error because it too, might fail.
141    fn send_reply(
142        self,
143        response: RequestResponse<'_>,
144    ) -> impl std::future::Future<Output = ()> + MaybeSend;
145
146    /// Send an error response for this request attempt, consuming the sink.
147    ///
148    /// This is a convenience method used by generated dispatchers when
149    /// deserialization fails or the method ID is unknown.
150    fn send_error<E: for<'a> facet::Facet<'a> + MaybeSend>(
151        self,
152        error: VoxError<E>,
153    ) -> impl std::future::Future<Output = ()> + MaybeSend
154    where
155        Self: Sized,
156    {
157        use crate::{Payload, RequestResponse};
158        // Wire format is always Result<T, VoxError<E>>. We don't know T here,
159        // but postcard encodes () as zero bytes, so Result<(), VoxError<E>>
160        // produces the same Err variant encoding as any Result<T, VoxError<E>>.
161        async move {
162            let wire: Result<(), VoxError<E>> = Err(error);
163            self.send_reply(RequestResponse {
164                ret: Payload::outgoing(&wire),
165                metadata: Default::default(),
166                schemas: Default::default(),
167            })
168            .await;
169        }
170    }
171
172    /// Send an error response using the full wire shape `Result<T, VoxError<E>>`.
173    ///
174    /// This preserves the method's real `Ok` type for schema extraction.
175    fn send_typed_error<'wire, T, E>(
176        self,
177        error: VoxError<E>,
178    ) -> impl std::future::Future<Output = ()> + MaybeSend
179    where
180        Self: Sized,
181        T: facet::Facet<'wire> + MaybeSend,
182        E: facet::Facet<'wire> + MaybeSend,
183    {
184        use crate::{Payload, RequestResponse};
185        async move {
186            let wire: Result<T, VoxError<E>> = Err(error);
187            let ptr = facet::PtrConst::new((&wire as *const Result<T, VoxError<E>>).cast::<u8>());
188            let shape = <Result<T, VoxError<E>> as facet::Facet<'wire>>::SHAPE;
189            let ret = unsafe { Payload::outgoing_unchecked(ptr, shape) };
190            self.send_reply(RequestResponse {
191                ret,
192                metadata: Default::default(),
193                schemas: Default::default(),
194            })
195            .await;
196        }
197    }
198
199    /// Return a channel binder for binding Tx/Rx handles in deserialized args.
200    ///
201    /// Returns `None` by default. The driver's `ReplySink` implementation
202    /// overrides this to provide actual channel binding.
203    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
204        None
205    }
206
207    /// Return the wire-level request identifier for this reply sink when available.
208    fn request_id(&self) -> Option<RequestId> {
209        None
210    }
211
212    /// Return the virtual connection identifier for this reply sink when available.
213    fn connection_id(&self) -> Option<ConnectionId> {
214        None
215    }
216}
217
218/// Type-erased handler for incoming service calls.
219///
220/// Implemented (by the macro-generated dispatch code) for server-side types.
221/// Takes a fully decoded [`RequestCall`](crate::RequestCall) — one wire-level
222/// request attempt already parsed from the connection — and a [`ReplySink`]
223/// through which the terminal response for that attempt is sent.
224///
225/// The dispatch impl decodes the args, routes by [`crate::MethodId`], and
226/// invokes the appropriate typed [`Call`]-based method on the concrete server type.
227///
228/// Generated clients hold an [`ErasedCaller`] and use it to start API-level
229/// calls. The caller serializes the outgoing [`RequestCall`] (with borrowed
230/// args), registers a pending response slot for that request attempt, and
231/// awaits the response from the peer.
232pub trait Caller: Clone + MaybeSend + MaybeSync + 'static {
233    /// Start one outgoing request attempt for an API-level call and wait for
234    /// its response.
235    ///
236    /// Returns the wire-level response paired with the `SchemaRecvTracker` that
237    /// was active when the response was received, for schema-aware
238    /// deserialization.
239    fn call<'a>(
240        &'a self,
241        call: RequestCall<'a>,
242    ) -> impl Future<Output = CallResult> + MaybeSend + 'a;
243
244    /// Resolve when the underlying connection closes.
245    ///
246    /// Runtime-backed callers can override this to expose connection liveness.
247    /// The default implementation never resolves.
248    fn closed(&self) -> BoxFut<'_, ()> {
249        Box::pin(std::future::pending())
250    }
251
252    /// Return whether the underlying connection is still considered connected.
253    ///
254    /// Runtime-backed callers can override this to provide eager liveness
255    /// checks. The default implementation assumes the connection is live.
256    fn is_connected(&self) -> bool {
257        true
258    }
259
260    /// Return a channel binder for binding Tx/Rx handles in args before sending.
261    ///
262    /// Returns `None` by default. The driver's `Caller` implementation
263    /// overrides this to provide actual channel binding.
264    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
265        None
266    }
267}
268
269trait ErasedCallerDyn: MaybeSend + MaybeSync + 'static {
270    fn call<'a>(&'a self, call: RequestCall<'a>) -> BoxFut<'a, CallResult>;
271
272    fn closed(&self) -> BoxFut<'_, ()>;
273
274    fn is_connected(&self) -> bool;
275
276    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder>;
277}
278
279impl<C: Caller> ErasedCallerDyn for C {
280    fn call<'a>(&'a self, call: RequestCall<'a>) -> BoxFut<'a, CallResult> {
281        Box::pin(Caller::call(self, call))
282    }
283
284    fn closed(&self) -> BoxFut<'_, ()> {
285        Caller::closed(self)
286    }
287
288    fn is_connected(&self) -> bool {
289        Caller::is_connected(self)
290    }
291
292    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
293        Caller::channel_binder(self)
294    }
295}
296
297/// Type-erased [`Caller`] wrapper used by generated clients.
298#[derive(Clone)]
299pub struct ErasedCaller {
300    inner: Arc<dyn ErasedCallerDyn>,
301    service: Option<&'static ServiceDescriptor>,
302    middlewares: Vec<Arc<dyn ClientMiddleware>>,
303}
304
305impl ErasedCaller {
306    pub fn new<C: Caller>(caller: C) -> Self {
307        Self {
308            inner: Arc::new(caller),
309            service: None,
310            middlewares: vec![],
311        }
312    }
313
314    pub fn with_middleware(
315        mut self,
316        service: &'static ServiceDescriptor,
317        middleware: impl ClientMiddleware,
318    ) -> Self {
319        if let Some(existing_service) = self.service {
320            assert_eq!(
321                existing_service.service_name, service.service_name,
322                "ErasedCaller middleware service mismatch"
323            );
324        } else {
325            self.service = Some(service);
326        }
327        self.middlewares.push(Arc::new(middleware));
328        self
329    }
330}
331
332impl Caller for ErasedCaller {
333    async fn call<'a>(&'a self, mut call: RequestCall<'a>) -> CallResult {
334        let Some(service) = self.service else {
335            return self.inner.call(call).await;
336        };
337
338        let extensions = Extensions::new();
339        let method = service.by_id(call.method_id);
340        let context = ClientContext::new(method, call.method_id, &extensions);
341        let mut owned_metadata = crate::client_middleware::OwnedMetadata::default();
342
343        if !self.middlewares.is_empty() {
344            for middleware in &self.middlewares {
345                let mut request = ClientRequest::new(&mut call, &mut owned_metadata);
346                middleware.pre(&context, &mut request).await;
347            }
348        }
349
350        let result = self.inner.call(call).await;
351        if !self.middlewares.is_empty() {
352            let outcome = match &result {
353                Ok(_) => ClientCallOutcome::Response,
354                Err(error) => ClientCallOutcome::Error(error),
355            };
356            for middleware in self.middlewares.iter().rev() {
357                middleware.post(&context, outcome).await;
358            }
359        }
360        result
361    }
362
363    fn closed(&self) -> BoxFut<'_, ()> {
364        self.inner.closed()
365    }
366
367    fn is_connected(&self) -> bool {
368        self.inner.is_connected()
369    }
370
371    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
372        self.inner.channel_binder()
373    }
374}
375
376pub trait Handler<R: ReplySink>: MaybeSend + MaybeSync + 'static {
377    /// Return the static retry policy for a method ID served by this handler.
378    fn retry_policy(&self, _method_id: crate::MethodId) -> crate::RetryPolicy {
379        crate::RetryPolicy::VOLATILE
380    }
381
382    /// Return whether the method's argument shape contains any channels.
383    fn args_have_channels(&self, _method_id: crate::MethodId) -> bool {
384        false
385    }
386
387    /// Return the canonical wire response shape for a method, if known.
388    ///
389    /// This is the full wire type `Result<T, VoxError<E>>`, not the
390    /// user-facing return type `T` or `Result<T, E>`.
391    fn response_wire_shape(&self, _method_id: crate::MethodId) -> Option<&'static facet::Shape> {
392        None
393    }
394
395    /// Dispatch an incoming call to the appropriate method implementation.
396    fn handle(
397        &self,
398        call: SelfRef<crate::RequestCall<'static>>,
399        reply: R,
400        schemas: std::sync::Arc<crate::SchemaRecvTracker>,
401    ) -> impl std::future::Future<Output = ()> + MaybeSend + '_;
402}
403
404impl<R: ReplySink> Handler<R> for () {
405    async fn handle(
406        &self,
407        _call: SelfRef<crate::RequestCall<'static>>,
408        _reply: R,
409        _schemas: std::sync::Arc<crate::SchemaRecvTracker>,
410    ) {
411    }
412}
413
414/// A decoded response value paired with response metadata.
415///
416/// This helper is available for lower-level callers that need both the
417/// decoded value and metadata together. Generated Rust client methods do
418/// not expose response metadata in their return types.
419pub struct ResponseParts<'a, T> {
420    /// The decoded return value.
421    pub ret: T,
422    /// Metadata attached to the response by the server.
423    pub metadata: Metadata<'a>,
424}
425
426impl<'a, T> std::ops::Deref for ResponseParts<'a, T> {
427    type Target = T;
428    fn deref(&self) -> &T {
429        &self.ret
430    }
431}
432
433/// Concrete [`Call`] implementation backed by a [`ReplySink`].
434///
435/// Constructed by the dispatcher and handed to the server method.
436/// When the server calls [`Call::reply`], the result is serialized and
437/// sent through the sink.
438pub struct SinkCall<R: ReplySink> {
439    reply: R,
440}
441
442impl<R: ReplySink> SinkCall<R> {
443    pub fn new(reply: R) -> Self {
444        Self { reply }
445    }
446}
447
448impl<'wire, T, E, R> Call<'wire, T, E> for SinkCall<R>
449where
450    T: facet::Facet<'wire> + MaybeSend,
451    E: facet::Facet<'wire> + MaybeSend,
452    R: ReplySink,
453{
454    async fn reply(self, result: Result<T, E>) {
455        use crate::{Payload, RequestResponse};
456        let wire: Result<T, crate::VoxError<E>> = result.map_err(crate::VoxError::User);
457        let ptr =
458            facet::PtrConst::new((&wire as *const Result<T, crate::VoxError<E>>).cast::<u8>());
459        let shape = <Result<T, crate::VoxError<E>> as facet::Facet<'wire>>::SHAPE;
460        // SAFETY: `wire` lives until `send_reply(...).await` completes in this function,
461        // and `shape` matches the pointed value exactly.
462        let ret = unsafe { Payload::outgoing_unchecked(ptr, shape) };
463        self.reply
464            .send_reply(RequestResponse {
465                ret,
466                metadata: Default::default(),
467                schemas: Default::default(),
468            })
469            .await;
470    }
471}
472
473#[cfg(test)]
474mod tests {
475    use std::sync::{Arc, Mutex};
476
477    use crate::{MaybeSend, Metadata, Payload, RequestCall, RequestResponse};
478
479    use super::{Call, CallResult, Caller, Handler, ReplySink, ResponseParts};
480
481    struct RecordingCall<T, E> {
482        observed: Arc<Mutex<Option<Result<T, E>>>>,
483    }
484
485    impl<'wire, T, E> Call<'wire, T, E> for RecordingCall<T, E>
486    where
487        T: facet::Facet<'wire> + MaybeSend + Send + 'static,
488        E: facet::Facet<'wire> + MaybeSend + Send + 'static,
489    {
490        async fn reply(self, result: Result<T, E>) {
491            let mut guard = self.observed.lock().expect("recording mutex poisoned");
492            *guard = Some(result);
493        }
494    }
495
496    struct RecordingReplySink {
497        saw_send_reply: Arc<Mutex<bool>>,
498        saw_outgoing_payload: Arc<Mutex<bool>>,
499    }
500
501    impl ReplySink for RecordingReplySink {
502        async fn send_reply(self, response: RequestResponse<'_>) {
503            let mut saw_send_reply = self
504                .saw_send_reply
505                .lock()
506                .expect("send-reply mutex poisoned");
507            *saw_send_reply = true;
508
509            let mut saw_outgoing = self
510                .saw_outgoing_payload
511                .lock()
512                .expect("payload-kind mutex poisoned");
513            *saw_outgoing = matches!(response.ret, Payload::Value { .. });
514        }
515    }
516
517    #[derive(Clone)]
518    struct NoopCaller;
519
520    impl Caller for NoopCaller {
521        async fn call<'a>(&'a self, _call: RequestCall<'a>) -> CallResult {
522            unreachable!("NoopCaller::call is not used by this test")
523        }
524    }
525
526    #[tokio::test]
527    async fn call_ok_and_err_route_through_reply() {
528        let observed_ok: Arc<Mutex<Option<Result<u32, &'static str>>>> = Arc::new(Mutex::new(None));
529        RecordingCall {
530            observed: Arc::clone(&observed_ok),
531        }
532        .ok(7)
533        .await;
534        assert!(matches!(
535            *observed_ok.lock().expect("ok mutex poisoned"),
536            Some(Ok(7))
537        ));
538
539        let observed_err: Arc<Mutex<Option<Result<u32, &'static str>>>> =
540            Arc::new(Mutex::new(None));
541        RecordingCall {
542            observed: Arc::clone(&observed_err),
543        }
544        .err("boom")
545        .await;
546        assert!(matches!(
547            *observed_err.lock().expect("err mutex poisoned"),
548            Some(Err("boom"))
549        ));
550    }
551
552    #[tokio::test]
553    async fn reply_sink_send_error_uses_outgoing_payload_and_reply_path() {
554        let saw_send_reply = Arc::new(Mutex::new(false));
555        let saw_outgoing_payload = Arc::new(Mutex::new(false));
556        let sink = RecordingReplySink {
557            saw_send_reply: Arc::clone(&saw_send_reply),
558            saw_outgoing_payload: Arc::clone(&saw_outgoing_payload),
559        };
560
561        sink.send_error(crate::VoxError::<String>::Cancelled).await;
562
563        assert!(*saw_send_reply.lock().expect("send-reply mutex poisoned"));
564        assert!(
565            *saw_outgoing_payload
566                .lock()
567                .expect("payload-kind mutex poisoned")
568        );
569    }
570
571    #[tokio::test]
572    async fn reply_sink_send_typed_error_preserves_ok_shape() {
573        use crate::{
574            SchemaKind, TypeRef, VariantPayload, VoxError, build_registry, extract_schemas,
575        };
576
577        struct ShapeReplySink {
578            observed_root: Arc<Mutex<Option<TypeRef>>>,
579        }
580
581        impl ReplySink for ShapeReplySink {
582            async fn send_reply(self, response: RequestResponse<'_>) {
583                let Payload::Value { shape, .. } = response.ret else {
584                    panic!("typed error should use outgoing payload");
585                };
586                let extracted = extract_schemas(shape).expect("response shape should extract");
587                *self
588                    .observed_root
589                    .lock()
590                    .expect("observed-root mutex poisoned") = Some(extracted.root);
591            }
592        }
593
594        let observed_root = Arc::new(Mutex::new(None));
595        ShapeReplySink {
596            observed_root: Arc::clone(&observed_root),
597        }
598        .send_typed_error::<(String, i32), String>(VoxError::Cancelled)
599        .await;
600
601        let root = observed_root
602            .lock()
603            .expect("observed-root mutex poisoned")
604            .clone()
605            .expect("typed error should record a root");
606        let extracted =
607            extract_schemas(<Result<(String, i32), VoxError<String>> as facet::Facet>::SHAPE)
608                .expect("expected result shape should extract");
609        let registry = build_registry(&extracted.schemas);
610        let root_kind = root.resolve_kind(&registry).expect("root should resolve");
611        let SchemaKind::Enum { variants, .. } = root_kind else {
612            panic!("expected result enum root");
613        };
614        let ok_variant = variants
615            .iter()
616            .find(|variant| variant.name == "Ok")
617            .expect("Result should have Ok variant");
618        let VariantPayload::Newtype { type_ref } = &ok_variant.payload else {
619            panic!("Ok variant should be newtype");
620        };
621        match type_ref
622            .resolve_kind(&registry)
623            .expect("Ok payload should resolve")
624        {
625            SchemaKind::Tuple { elements } => {
626                assert_eq!(elements.len(), 2, "Ok tuple should have two elements");
627            }
628            other => panic!("expected Ok payload to be tuple, got {other:?}"),
629        }
630    }
631
632    #[tokio::test]
633    async fn unit_handler_is_noop() {
634        let req = crate::SelfRef::owning(
635            crate::Backing::Boxed(Box::<[u8]>::default()),
636            RequestCall {
637                method_id: crate::MethodId(1),
638                metadata: Metadata::default(),
639                args: Payload::PostcardBytes(&[]),
640                schemas: Default::default(),
641            },
642        );
643        ().handle(
644            req,
645            RecordingReplySink {
646                saw_send_reply: Arc::new(Mutex::new(false)),
647                saw_outgoing_payload: Arc::new(Mutex::new(false)),
648            },
649            Arc::new(crate::SchemaRecvTracker::new()),
650        )
651        .await;
652    }
653
654    #[test]
655    fn response_parts_deref_exposes_ret() {
656        let parts = ResponseParts {
657            ret: 42_u32,
658            metadata: Metadata::default(),
659        };
660        assert_eq!(*parts, 42);
661    }
662
663    #[test]
664    fn default_channel_binder_accessor_for_caller_returns_none() {
665        let caller = NoopCaller;
666        assert!(caller.channel_binder().is_none());
667    }
668
669    #[test]
670    fn default_channel_binder_accessor_for_reply_sink_returns_none() {
671        let sink = RecordingReplySink {
672            saw_send_reply: Arc::new(Mutex::new(false)),
673            saw_outgoing_payload: Arc::new(Mutex::new(false)),
674        };
675        assert!(sink.channel_binder().is_none());
676    }
677}