1use std::{future::Future, pin::Pin, sync::Arc};
2
3use crate::{
4 ClientCallOutcome, ClientContext, ClientMiddleware, ClientRequest, ConnectionId, Extensions,
5 MaybeSend, MaybeSendFuture, MaybeSync, Metadata, RequestCall, RequestId, RequestResponse,
6 SelfRef, ServiceDescriptor, VoxError,
7};
8
9pub type BoxFut<'a, T> = Pin<Box<dyn MaybeSendFuture<Output = T> + 'a>>;
11
12pub type CallResult = Result<crate::WithTracker<SelfRef<RequestResponse<'static>>>, VoxError>;
17
18pub trait Call<'wire, T, E>: MaybeSend
95where
96 T: facet::Facet<'wire> + MaybeSend,
97 E: facet::Facet<'wire> + MaybeSend,
98{
99 fn reply(self, result: Result<T, E>) -> impl std::future::Future<Output = ()> + MaybeSend;
101
102 fn ok(self, value: T) -> impl std::future::Future<Output = ()> + MaybeSend
106 where
107 Self: Sized,
108 {
109 self.reply(Ok(value))
110 }
111
112 fn err(self, error: E) -> impl std::future::Future<Output = ()> + MaybeSend
116 where
117 Self: Sized,
118 {
119 self.reply(Err(error))
120 }
121}
122
123pub trait ReplySink: MaybeSend + MaybeSync + 'static {
133 fn send_reply(
142 self,
143 response: RequestResponse<'_>,
144 ) -> impl std::future::Future<Output = ()> + MaybeSend;
145
146 fn send_error<E: for<'a> facet::Facet<'a> + MaybeSend>(
151 self,
152 error: VoxError<E>,
153 ) -> impl std::future::Future<Output = ()> + MaybeSend
154 where
155 Self: Sized,
156 {
157 use crate::{Payload, RequestResponse};
158 async move {
162 let wire: Result<(), VoxError<E>> = Err(error);
163 self.send_reply(RequestResponse {
164 ret: Payload::outgoing(&wire),
165 metadata: Default::default(),
166 schemas: Default::default(),
167 })
168 .await;
169 }
170 }
171
172 fn send_typed_error<'wire, T, E>(
176 self,
177 error: VoxError<E>,
178 ) -> impl std::future::Future<Output = ()> + MaybeSend
179 where
180 Self: Sized,
181 T: facet::Facet<'wire> + MaybeSend,
182 E: facet::Facet<'wire> + MaybeSend,
183 {
184 use crate::{Payload, RequestResponse};
185 async move {
186 let wire: Result<T, VoxError<E>> = Err(error);
187 let ptr = facet::PtrConst::new((&wire as *const Result<T, VoxError<E>>).cast::<u8>());
188 let shape = <Result<T, VoxError<E>> as facet::Facet<'wire>>::SHAPE;
189 let ret = unsafe { Payload::outgoing_unchecked(ptr, shape) };
190 self.send_reply(RequestResponse {
191 ret,
192 metadata: Default::default(),
193 schemas: Default::default(),
194 })
195 .await;
196 }
197 }
198
199 fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
204 None
205 }
206
207 fn request_id(&self) -> Option<RequestId> {
209 None
210 }
211
212 fn connection_id(&self) -> Option<ConnectionId> {
214 None
215 }
216}
217
218pub trait Caller: Clone + MaybeSend + MaybeSync + 'static {
233 fn call<'a>(
240 &'a self,
241 call: RequestCall<'a>,
242 ) -> impl Future<Output = CallResult> + MaybeSend + 'a;
243
244 fn closed(&self) -> BoxFut<'_, ()> {
249 Box::pin(std::future::pending())
250 }
251
252 fn is_connected(&self) -> bool {
257 true
258 }
259
260 fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
265 None
266 }
267}
268
269trait ErasedCallerDyn: MaybeSend + MaybeSync + 'static {
270 fn call<'a>(&'a self, call: RequestCall<'a>) -> BoxFut<'a, CallResult>;
271
272 fn closed(&self) -> BoxFut<'_, ()>;
273
274 fn is_connected(&self) -> bool;
275
276 fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder>;
277}
278
279impl<C: Caller> ErasedCallerDyn for C {
280 fn call<'a>(&'a self, call: RequestCall<'a>) -> BoxFut<'a, CallResult> {
281 Box::pin(Caller::call(self, call))
282 }
283
284 fn closed(&self) -> BoxFut<'_, ()> {
285 Caller::closed(self)
286 }
287
288 fn is_connected(&self) -> bool {
289 Caller::is_connected(self)
290 }
291
292 fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
293 Caller::channel_binder(self)
294 }
295}
296
297#[derive(Clone)]
299pub struct ErasedCaller {
300 inner: Arc<dyn ErasedCallerDyn>,
301 service: Option<&'static ServiceDescriptor>,
302 middlewares: Vec<Arc<dyn ClientMiddleware>>,
303}
304
305impl ErasedCaller {
306 pub fn new<C: Caller>(caller: C) -> Self {
307 Self {
308 inner: Arc::new(caller),
309 service: None,
310 middlewares: vec![],
311 }
312 }
313
314 pub fn with_middleware(
315 mut self,
316 service: &'static ServiceDescriptor,
317 middleware: impl ClientMiddleware,
318 ) -> Self {
319 if let Some(existing_service) = self.service {
320 assert_eq!(
321 existing_service.service_name, service.service_name,
322 "ErasedCaller middleware service mismatch"
323 );
324 } else {
325 self.service = Some(service);
326 }
327 self.middlewares.push(Arc::new(middleware));
328 self
329 }
330}
331
332impl Caller for ErasedCaller {
333 async fn call<'a>(&'a self, mut call: RequestCall<'a>) -> CallResult {
334 let Some(service) = self.service else {
335 return self.inner.call(call).await;
336 };
337
338 let extensions = Extensions::new();
339 let method = service.by_id(call.method_id);
340 let context = ClientContext::new(method, call.method_id, &extensions);
341 let mut owned_metadata = crate::client_middleware::OwnedMetadata::default();
342
343 if !self.middlewares.is_empty() {
344 for middleware in &self.middlewares {
345 let mut request = ClientRequest::new(&mut call, &mut owned_metadata);
346 middleware.pre(&context, &mut request).await;
347 }
348 }
349
350 let result = self.inner.call(call).await;
351 if !self.middlewares.is_empty() {
352 let outcome = match &result {
353 Ok(_) => ClientCallOutcome::Response,
354 Err(error) => ClientCallOutcome::Error(error),
355 };
356 for middleware in self.middlewares.iter().rev() {
357 middleware.post(&context, outcome).await;
358 }
359 }
360 result
361 }
362
363 fn closed(&self) -> BoxFut<'_, ()> {
364 self.inner.closed()
365 }
366
367 fn is_connected(&self) -> bool {
368 self.inner.is_connected()
369 }
370
371 fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
372 self.inner.channel_binder()
373 }
374}
375
376pub trait Handler<R: ReplySink>: MaybeSend + MaybeSync + 'static {
377 fn retry_policy(&self, _method_id: crate::MethodId) -> crate::RetryPolicy {
379 crate::RetryPolicy::VOLATILE
380 }
381
382 fn args_have_channels(&self, _method_id: crate::MethodId) -> bool {
384 false
385 }
386
387 fn response_wire_shape(&self, _method_id: crate::MethodId) -> Option<&'static facet::Shape> {
392 None
393 }
394
395 fn handle(
397 &self,
398 call: SelfRef<crate::RequestCall<'static>>,
399 reply: R,
400 schemas: std::sync::Arc<crate::SchemaRecvTracker>,
401 ) -> impl std::future::Future<Output = ()> + MaybeSend + '_;
402}
403
404impl<R: ReplySink> Handler<R> for () {
405 async fn handle(
406 &self,
407 _call: SelfRef<crate::RequestCall<'static>>,
408 _reply: R,
409 _schemas: std::sync::Arc<crate::SchemaRecvTracker>,
410 ) {
411 }
412}
413
414pub struct ResponseParts<'a, T> {
420 pub ret: T,
422 pub metadata: Metadata<'a>,
424}
425
426impl<'a, T> std::ops::Deref for ResponseParts<'a, T> {
427 type Target = T;
428 fn deref(&self) -> &T {
429 &self.ret
430 }
431}
432
433pub struct SinkCall<R: ReplySink> {
439 reply: R,
440}
441
442impl<R: ReplySink> SinkCall<R> {
443 pub fn new(reply: R) -> Self {
444 Self { reply }
445 }
446}
447
448impl<'wire, T, E, R> Call<'wire, T, E> for SinkCall<R>
449where
450 T: facet::Facet<'wire> + MaybeSend,
451 E: facet::Facet<'wire> + MaybeSend,
452 R: ReplySink,
453{
454 async fn reply(self, result: Result<T, E>) {
455 use crate::{Payload, RequestResponse};
456 let wire: Result<T, crate::VoxError<E>> = result.map_err(crate::VoxError::User);
457 let ptr =
458 facet::PtrConst::new((&wire as *const Result<T, crate::VoxError<E>>).cast::<u8>());
459 let shape = <Result<T, crate::VoxError<E>> as facet::Facet<'wire>>::SHAPE;
460 let ret = unsafe { Payload::outgoing_unchecked(ptr, shape) };
463 self.reply
464 .send_reply(RequestResponse {
465 ret,
466 metadata: Default::default(),
467 schemas: Default::default(),
468 })
469 .await;
470 }
471}
472
473#[cfg(test)]
474mod tests {
475 use std::sync::{Arc, Mutex};
476
477 use crate::{MaybeSend, Metadata, Payload, RequestCall, RequestResponse};
478
479 use super::{Call, CallResult, Caller, Handler, ReplySink, ResponseParts};
480
481 struct RecordingCall<T, E> {
482 observed: Arc<Mutex<Option<Result<T, E>>>>,
483 }
484
485 impl<'wire, T, E> Call<'wire, T, E> for RecordingCall<T, E>
486 where
487 T: facet::Facet<'wire> + MaybeSend + Send + 'static,
488 E: facet::Facet<'wire> + MaybeSend + Send + 'static,
489 {
490 async fn reply(self, result: Result<T, E>) {
491 let mut guard = self.observed.lock().expect("recording mutex poisoned");
492 *guard = Some(result);
493 }
494 }
495
496 struct RecordingReplySink {
497 saw_send_reply: Arc<Mutex<bool>>,
498 saw_outgoing_payload: Arc<Mutex<bool>>,
499 }
500
501 impl ReplySink for RecordingReplySink {
502 async fn send_reply(self, response: RequestResponse<'_>) {
503 let mut saw_send_reply = self
504 .saw_send_reply
505 .lock()
506 .expect("send-reply mutex poisoned");
507 *saw_send_reply = true;
508
509 let mut saw_outgoing = self
510 .saw_outgoing_payload
511 .lock()
512 .expect("payload-kind mutex poisoned");
513 *saw_outgoing = matches!(response.ret, Payload::Value { .. });
514 }
515 }
516
517 #[derive(Clone)]
518 struct NoopCaller;
519
520 impl Caller for NoopCaller {
521 async fn call<'a>(&'a self, _call: RequestCall<'a>) -> CallResult {
522 unreachable!("NoopCaller::call is not used by this test")
523 }
524 }
525
526 #[tokio::test]
527 async fn call_ok_and_err_route_through_reply() {
528 let observed_ok: Arc<Mutex<Option<Result<u32, &'static str>>>> = Arc::new(Mutex::new(None));
529 RecordingCall {
530 observed: Arc::clone(&observed_ok),
531 }
532 .ok(7)
533 .await;
534 assert!(matches!(
535 *observed_ok.lock().expect("ok mutex poisoned"),
536 Some(Ok(7))
537 ));
538
539 let observed_err: Arc<Mutex<Option<Result<u32, &'static str>>>> =
540 Arc::new(Mutex::new(None));
541 RecordingCall {
542 observed: Arc::clone(&observed_err),
543 }
544 .err("boom")
545 .await;
546 assert!(matches!(
547 *observed_err.lock().expect("err mutex poisoned"),
548 Some(Err("boom"))
549 ));
550 }
551
552 #[tokio::test]
553 async fn reply_sink_send_error_uses_outgoing_payload_and_reply_path() {
554 let saw_send_reply = Arc::new(Mutex::new(false));
555 let saw_outgoing_payload = Arc::new(Mutex::new(false));
556 let sink = RecordingReplySink {
557 saw_send_reply: Arc::clone(&saw_send_reply),
558 saw_outgoing_payload: Arc::clone(&saw_outgoing_payload),
559 };
560
561 sink.send_error(crate::VoxError::<String>::Cancelled).await;
562
563 assert!(*saw_send_reply.lock().expect("send-reply mutex poisoned"));
564 assert!(
565 *saw_outgoing_payload
566 .lock()
567 .expect("payload-kind mutex poisoned")
568 );
569 }
570
571 #[tokio::test]
572 async fn reply_sink_send_typed_error_preserves_ok_shape() {
573 use crate::{
574 SchemaKind, TypeRef, VariantPayload, VoxError, build_registry, extract_schemas,
575 };
576
577 struct ShapeReplySink {
578 observed_root: Arc<Mutex<Option<TypeRef>>>,
579 }
580
581 impl ReplySink for ShapeReplySink {
582 async fn send_reply(self, response: RequestResponse<'_>) {
583 let Payload::Value { shape, .. } = response.ret else {
584 panic!("typed error should use outgoing payload");
585 };
586 let extracted = extract_schemas(shape).expect("response shape should extract");
587 *self
588 .observed_root
589 .lock()
590 .expect("observed-root mutex poisoned") = Some(extracted.root);
591 }
592 }
593
594 let observed_root = Arc::new(Mutex::new(None));
595 ShapeReplySink {
596 observed_root: Arc::clone(&observed_root),
597 }
598 .send_typed_error::<(String, i32), String>(VoxError::Cancelled)
599 .await;
600
601 let root = observed_root
602 .lock()
603 .expect("observed-root mutex poisoned")
604 .clone()
605 .expect("typed error should record a root");
606 let extracted =
607 extract_schemas(<Result<(String, i32), VoxError<String>> as facet::Facet>::SHAPE)
608 .expect("expected result shape should extract");
609 let registry = build_registry(&extracted.schemas);
610 let root_kind = root.resolve_kind(®istry).expect("root should resolve");
611 let SchemaKind::Enum { variants, .. } = root_kind else {
612 panic!("expected result enum root");
613 };
614 let ok_variant = variants
615 .iter()
616 .find(|variant| variant.name == "Ok")
617 .expect("Result should have Ok variant");
618 let VariantPayload::Newtype { type_ref } = &ok_variant.payload else {
619 panic!("Ok variant should be newtype");
620 };
621 match type_ref
622 .resolve_kind(®istry)
623 .expect("Ok payload should resolve")
624 {
625 SchemaKind::Tuple { elements } => {
626 assert_eq!(elements.len(), 2, "Ok tuple should have two elements");
627 }
628 other => panic!("expected Ok payload to be tuple, got {other:?}"),
629 }
630 }
631
632 #[tokio::test]
633 async fn unit_handler_is_noop() {
634 let req = crate::SelfRef::owning(
635 crate::Backing::Boxed(Box::<[u8]>::default()),
636 RequestCall {
637 method_id: crate::MethodId(1),
638 metadata: Metadata::default(),
639 args: Payload::PostcardBytes(&[]),
640 schemas: Default::default(),
641 },
642 );
643 ().handle(
644 req,
645 RecordingReplySink {
646 saw_send_reply: Arc::new(Mutex::new(false)),
647 saw_outgoing_payload: Arc::new(Mutex::new(false)),
648 },
649 Arc::new(crate::SchemaRecvTracker::new()),
650 )
651 .await;
652 }
653
654 #[test]
655 fn response_parts_deref_exposes_ret() {
656 let parts = ResponseParts {
657 ret: 42_u32,
658 metadata: Metadata::default(),
659 };
660 assert_eq!(*parts, 42);
661 }
662
663 #[test]
664 fn default_channel_binder_accessor_for_caller_returns_none() {
665 let caller = NoopCaller;
666 assert!(caller.channel_binder().is_none());
667 }
668
669 #[test]
670 fn default_channel_binder_accessor_for_reply_sink_returns_none() {
671 let sink = RecordingReplySink {
672 saw_send_reply: Arc::new(Mutex::new(false)),
673 saw_outgoing_payload: Arc::new(Mutex::new(false)),
674 };
675 assert!(sink.channel_binder().is_none());
676 }
677}