1use std::{future::Future, pin::Pin, sync::Arc};
2
3use crate::{
4 ClientCallOutcome, ClientContext, ClientMiddleware, ClientRequest, Extensions, MaybeSend,
5 MaybeSendFuture, MaybeSync, Metadata, RequestCall, RequestResponse, VoxError, SelfRef,
6 ServiceDescriptor,
7};
8
9pub type BoxFut<'a, T> = Pin<Box<dyn MaybeSendFuture<Output = T> + 'a>>;
11
12pub type CallResult = Result<crate::WithTracker<SelfRef<RequestResponse<'static>>>, VoxError>;
17
18pub trait Call<'wire, T, E>: MaybeSend
95where
96 T: facet::Facet<'wire> + MaybeSend,
97 E: facet::Facet<'wire> + MaybeSend,
98{
99 fn reply(self, result: Result<T, E>) -> impl std::future::Future<Output = ()> + MaybeSend;
101
102 fn ok(self, value: T) -> impl std::future::Future<Output = ()> + MaybeSend
106 where
107 Self: Sized,
108 {
109 self.reply(Ok(value))
110 }
111
112 fn err(self, error: E) -> impl std::future::Future<Output = ()> + MaybeSend
116 where
117 Self: Sized,
118 {
119 self.reply(Err(error))
120 }
121}
122
123pub trait ReplySink: MaybeSend + MaybeSync + 'static {
133 fn send_reply(
142 self,
143 response: RequestResponse<'_>,
144 ) -> impl std::future::Future<Output = ()> + MaybeSend;
145
146 fn send_error<E: for<'a> facet::Facet<'a> + MaybeSend>(
151 self,
152 error: VoxError<E>,
153 ) -> impl std::future::Future<Output = ()> + MaybeSend
154 where
155 Self: Sized,
156 {
157 use crate::{Payload, RequestResponse};
158 async move {
162 let wire: Result<(), VoxError<E>> = Err(error);
163 self.send_reply(RequestResponse {
164 ret: Payload::outgoing(&wire),
165 metadata: Default::default(),
166 schemas: Default::default(),
167 })
168 .await;
169 }
170 }
171
172 fn send_typed_error<'wire, T, E>(
176 self,
177 error: VoxError<E>,
178 ) -> impl std::future::Future<Output = ()> + MaybeSend
179 where
180 Self: Sized,
181 T: facet::Facet<'wire> + MaybeSend,
182 E: facet::Facet<'wire> + MaybeSend,
183 {
184 use crate::{Payload, RequestResponse};
185 async move {
186 let wire: Result<T, VoxError<E>> = Err(error);
187 let ptr = facet::PtrConst::new((&wire as *const Result<T, VoxError<E>>).cast::<u8>());
188 let shape = <Result<T, VoxError<E>> as facet::Facet<'wire>>::SHAPE;
189 let ret = unsafe { Payload::outgoing_unchecked(ptr, shape) };
190 self.send_reply(RequestResponse {
191 ret,
192 metadata: Default::default(),
193 schemas: Default::default(),
194 })
195 .await;
196 }
197 }
198
199 fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
204 None
205 }
206}
207
208pub trait Caller: Clone + MaybeSend + MaybeSync + 'static {
223 fn call<'a>(
230 &'a self,
231 call: RequestCall<'a>,
232 ) -> impl Future<Output = CallResult> + MaybeSend + 'a;
233
234 fn closed(&self) -> BoxFut<'_, ()> {
239 Box::pin(std::future::pending())
240 }
241
242 fn is_connected(&self) -> bool {
247 true
248 }
249
250 fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
255 None
256 }
257}
258
259trait ErasedCallerDyn: MaybeSend + MaybeSync + 'static {
260 fn call<'a>(&'a self, call: RequestCall<'a>) -> BoxFut<'a, CallResult>;
261
262 fn closed(&self) -> BoxFut<'_, ()>;
263
264 fn is_connected(&self) -> bool;
265
266 fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder>;
267}
268
269impl<C: Caller> ErasedCallerDyn for C {
270 fn call<'a>(&'a self, call: RequestCall<'a>) -> BoxFut<'a, CallResult> {
271 Box::pin(Caller::call(self, call))
272 }
273
274 fn closed(&self) -> BoxFut<'_, ()> {
275 Caller::closed(self)
276 }
277
278 fn is_connected(&self) -> bool {
279 Caller::is_connected(self)
280 }
281
282 fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
283 Caller::channel_binder(self)
284 }
285}
286
287#[derive(Clone)]
289pub struct ErasedCaller {
290 inner: Arc<dyn ErasedCallerDyn>,
291 service: Option<&'static ServiceDescriptor>,
292 middlewares: Vec<Arc<dyn ClientMiddleware>>,
293}
294
295impl ErasedCaller {
296 pub fn new<C: Caller>(caller: C) -> Self {
297 Self {
298 inner: Arc::new(caller),
299 service: None,
300 middlewares: vec![],
301 }
302 }
303
304 pub fn with_middleware(
305 mut self,
306 service: &'static ServiceDescriptor,
307 middleware: impl ClientMiddleware,
308 ) -> Self {
309 if let Some(existing_service) = self.service {
310 assert_eq!(
311 existing_service.service_name, service.service_name,
312 "ErasedCaller middleware service mismatch"
313 );
314 } else {
315 self.service = Some(service);
316 }
317 self.middlewares.push(Arc::new(middleware));
318 self
319 }
320}
321
322impl Caller for ErasedCaller {
323 async fn call<'a>(&'a self, mut call: RequestCall<'a>) -> CallResult {
324 let Some(service) = self.service else {
325 return self.inner.call(call).await;
326 };
327
328 let extensions = Extensions::new();
329 let method = service.by_id(call.method_id);
330 let context = ClientContext::new(method, call.method_id, &extensions);
331 let mut owned_metadata = crate::client_middleware::OwnedMetadata::default();
332
333 if !self.middlewares.is_empty() {
334 for middleware in &self.middlewares {
335 let mut request = ClientRequest::new(&mut call, &mut owned_metadata);
336 middleware.pre(&context, &mut request).await;
337 }
338 }
339
340 let result = self.inner.call(call).await;
341 if !self.middlewares.is_empty() {
342 let outcome = match &result {
343 Ok(_) => ClientCallOutcome::Response,
344 Err(error) => ClientCallOutcome::Error(error),
345 };
346 for middleware in self.middlewares.iter().rev() {
347 middleware.post(&context, outcome).await;
348 }
349 }
350 result
351 }
352
353 fn closed(&self) -> BoxFut<'_, ()> {
354 self.inner.closed()
355 }
356
357 fn is_connected(&self) -> bool {
358 self.inner.is_connected()
359 }
360
361 fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
362 self.inner.channel_binder()
363 }
364}
365
366pub trait Handler<R: ReplySink>: MaybeSend + MaybeSync + 'static {
367 fn retry_policy(&self, _method_id: crate::MethodId) -> crate::RetryPolicy {
369 crate::RetryPolicy::VOLATILE
370 }
371
372 fn args_have_channels(&self, _method_id: crate::MethodId) -> bool {
374 false
375 }
376
377 fn response_wire_shape(&self, _method_id: crate::MethodId) -> Option<&'static facet::Shape> {
382 None
383 }
384
385 fn handle(
387 &self,
388 call: SelfRef<crate::RequestCall<'static>>,
389 reply: R,
390 schemas: std::sync::Arc<crate::SchemaRecvTracker>,
391 ) -> impl std::future::Future<Output = ()> + MaybeSend + '_;
392}
393
394impl<R: ReplySink> Handler<R> for () {
395 async fn handle(
396 &self,
397 _call: SelfRef<crate::RequestCall<'static>>,
398 _reply: R,
399 _schemas: std::sync::Arc<crate::SchemaRecvTracker>,
400 ) {
401 }
402}
403
404pub struct ResponseParts<'a, T> {
410 pub ret: T,
412 pub metadata: Metadata<'a>,
414}
415
416impl<'a, T> std::ops::Deref for ResponseParts<'a, T> {
417 type Target = T;
418 fn deref(&self) -> &T {
419 &self.ret
420 }
421}
422
423pub struct SinkCall<R: ReplySink> {
429 reply: R,
430}
431
432impl<R: ReplySink> SinkCall<R> {
433 pub fn new(reply: R) -> Self {
434 Self { reply }
435 }
436}
437
438impl<'wire, T, E, R> Call<'wire, T, E> for SinkCall<R>
439where
440 T: facet::Facet<'wire> + MaybeSend,
441 E: facet::Facet<'wire> + MaybeSend,
442 R: ReplySink,
443{
444 async fn reply(self, result: Result<T, E>) {
445 use crate::{Payload, RequestResponse};
446 let wire: Result<T, crate::VoxError<E>> = result.map_err(crate::VoxError::User);
447 let ptr =
448 facet::PtrConst::new((&wire as *const Result<T, crate::VoxError<E>>).cast::<u8>());
449 let shape = <Result<T, crate::VoxError<E>> as facet::Facet<'wire>>::SHAPE;
450 let ret = unsafe { Payload::outgoing_unchecked(ptr, shape) };
453 self.reply
454 .send_reply(RequestResponse {
455 ret,
456 metadata: Default::default(),
457 schemas: Default::default(),
458 })
459 .await;
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use std::sync::{Arc, Mutex};
466
467 use crate::{MaybeSend, Metadata, Payload, RequestCall, RequestResponse};
468
469 use super::{Call, CallResult, Caller, Handler, ReplySink, ResponseParts};
470
471 struct RecordingCall<T, E> {
472 observed: Arc<Mutex<Option<Result<T, E>>>>,
473 }
474
475 impl<'wire, T, E> Call<'wire, T, E> for RecordingCall<T, E>
476 where
477 T: facet::Facet<'wire> + MaybeSend + Send + 'static,
478 E: facet::Facet<'wire> + MaybeSend + Send + 'static,
479 {
480 async fn reply(self, result: Result<T, E>) {
481 let mut guard = self.observed.lock().expect("recording mutex poisoned");
482 *guard = Some(result);
483 }
484 }
485
486 struct RecordingReplySink {
487 saw_send_reply: Arc<Mutex<bool>>,
488 saw_outgoing_payload: Arc<Mutex<bool>>,
489 }
490
491 impl ReplySink for RecordingReplySink {
492 async fn send_reply(self, response: RequestResponse<'_>) {
493 let mut saw_send_reply = self
494 .saw_send_reply
495 .lock()
496 .expect("send-reply mutex poisoned");
497 *saw_send_reply = true;
498
499 let mut saw_outgoing = self
500 .saw_outgoing_payload
501 .lock()
502 .expect("payload-kind mutex poisoned");
503 *saw_outgoing = matches!(response.ret, Payload::Value { .. });
504 }
505 }
506
507 #[derive(Clone)]
508 struct NoopCaller;
509
510 impl Caller for NoopCaller {
511 async fn call<'a>(&'a self, _call: RequestCall<'a>) -> CallResult {
512 unreachable!("NoopCaller::call is not used by this test")
513 }
514 }
515
516 #[tokio::test]
517 async fn call_ok_and_err_route_through_reply() {
518 let observed_ok: Arc<Mutex<Option<Result<u32, &'static str>>>> = Arc::new(Mutex::new(None));
519 RecordingCall {
520 observed: Arc::clone(&observed_ok),
521 }
522 .ok(7)
523 .await;
524 assert!(matches!(
525 *observed_ok.lock().expect("ok mutex poisoned"),
526 Some(Ok(7))
527 ));
528
529 let observed_err: Arc<Mutex<Option<Result<u32, &'static str>>>> =
530 Arc::new(Mutex::new(None));
531 RecordingCall {
532 observed: Arc::clone(&observed_err),
533 }
534 .err("boom")
535 .await;
536 assert!(matches!(
537 *observed_err.lock().expect("err mutex poisoned"),
538 Some(Err("boom"))
539 ));
540 }
541
542 #[tokio::test]
543 async fn reply_sink_send_error_uses_outgoing_payload_and_reply_path() {
544 let saw_send_reply = Arc::new(Mutex::new(false));
545 let saw_outgoing_payload = Arc::new(Mutex::new(false));
546 let sink = RecordingReplySink {
547 saw_send_reply: Arc::clone(&saw_send_reply),
548 saw_outgoing_payload: Arc::clone(&saw_outgoing_payload),
549 };
550
551 sink.send_error(crate::VoxError::<String>::Cancelled).await;
552
553 assert!(*saw_send_reply.lock().expect("send-reply mutex poisoned"));
554 assert!(
555 *saw_outgoing_payload
556 .lock()
557 .expect("payload-kind mutex poisoned")
558 );
559 }
560
561 #[tokio::test]
562 async fn reply_sink_send_typed_error_preserves_ok_shape() {
563 use crate::{
564 VoxError, SchemaKind, TypeRef, VariantPayload, build_registry, extract_schemas,
565 };
566
567 struct ShapeReplySink {
568 observed_root: Arc<Mutex<Option<TypeRef>>>,
569 }
570
571 impl ReplySink for ShapeReplySink {
572 async fn send_reply(self, response: RequestResponse<'_>) {
573 let Payload::Value { shape, .. } = response.ret else {
574 panic!("typed error should use outgoing payload");
575 };
576 let extracted = extract_schemas(shape).expect("response shape should extract");
577 *self
578 .observed_root
579 .lock()
580 .expect("observed-root mutex poisoned") = Some(extracted.root);
581 }
582 }
583
584 let observed_root = Arc::new(Mutex::new(None));
585 ShapeReplySink {
586 observed_root: Arc::clone(&observed_root),
587 }
588 .send_typed_error::<(String, i32), String>(VoxError::Cancelled)
589 .await;
590
591 let root = observed_root
592 .lock()
593 .expect("observed-root mutex poisoned")
594 .clone()
595 .expect("typed error should record a root");
596 let extracted =
597 extract_schemas(<Result<(String, i32), VoxError<String>> as facet::Facet>::SHAPE)
598 .expect("expected result shape should extract");
599 let registry = build_registry(&extracted.schemas);
600 let root_kind = root.resolve_kind(®istry).expect("root should resolve");
601 let SchemaKind::Enum { variants, .. } = root_kind else {
602 panic!("expected result enum root");
603 };
604 let ok_variant = variants
605 .iter()
606 .find(|variant| variant.name == "Ok")
607 .expect("Result should have Ok variant");
608 let VariantPayload::Newtype { type_ref } = &ok_variant.payload else {
609 panic!("Ok variant should be newtype");
610 };
611 match type_ref
612 .resolve_kind(®istry)
613 .expect("Ok payload should resolve")
614 {
615 SchemaKind::Tuple { elements } => {
616 assert_eq!(elements.len(), 2, "Ok tuple should have two elements");
617 }
618 other => panic!("expected Ok payload to be tuple, got {other:?}"),
619 }
620 }
621
622 #[tokio::test]
623 async fn unit_handler_is_noop() {
624 let req = crate::SelfRef::owning(
625 crate::Backing::Boxed(Box::<[u8]>::default()),
626 RequestCall {
627 method_id: crate::MethodId(1),
628 metadata: Metadata::default(),
629 args: Payload::PostcardBytes(&[]),
630 schemas: Default::default(),
631 },
632 );
633 ().handle(
634 req,
635 RecordingReplySink {
636 saw_send_reply: Arc::new(Mutex::new(false)),
637 saw_outgoing_payload: Arc::new(Mutex::new(false)),
638 },
639 Arc::new(crate::SchemaRecvTracker::new()),
640 )
641 .await;
642 }
643
644 #[test]
645 fn response_parts_deref_exposes_ret() {
646 let parts = ResponseParts {
647 ret: 42_u32,
648 metadata: Metadata::default(),
649 };
650 assert_eq!(*parts, 42);
651 }
652
653 #[test]
654 fn default_channel_binder_accessor_for_caller_returns_none() {
655 let caller = NoopCaller;
656 assert!(caller.channel_binder().is_none());
657 }
658
659 #[test]
660 fn default_channel_binder_accessor_for_reply_sink_returns_none() {
661 let sink = RecordingReplySink {
662 saw_send_reply: Arc::new(Mutex::new(false)),
663 saw_outgoing_payload: Arc::new(Mutex::new(false)),
664 };
665 assert!(sink.channel_binder().is_none());
666 }
667}