1use std::{future::Future, pin::Pin, sync::Arc};
2
3use crate::{
4 ConnectionId, MaybeSend, MaybeSendFuture, MaybeSync, Metadata, MethodId, RequestCall,
5 RequestId, RequestResponse, SchemaRecvTracker, SelfRef, VoxError,
6};
7
8pub type BoxFut<'a, T> = Pin<Box<dyn MaybeSendFuture<Output = T> + 'a>>;
10
11pub type CallResult = Result<crate::WithTracker<SelfRef<RequestResponse<'static>>>, VoxError>;
16
17pub trait Call<'wire, T, E>: MaybeSend
94where
95 T: facet::Facet<'wire> + MaybeSend,
96 E: facet::Facet<'wire> + MaybeSend,
97{
98 fn reply(self, result: Result<T, E>) -> impl Future<Output = ()> + MaybeSend;
100
101 fn ok(self, value: T) -> impl Future<Output = ()> + MaybeSend
105 where
106 Self: Sized,
107 {
108 self.reply(Ok(value))
109 }
110
111 fn err(self, error: E) -> impl Future<Output = ()> + MaybeSend
115 where
116 Self: Sized,
117 {
118 self.reply(Err(error))
119 }
120}
121
122pub trait ReplySink: MaybeSend + MaybeSync + 'static {
132 fn send_reply(self, response: RequestResponse<'_>) -> impl Future<Output = ()> + MaybeSend;
141
142 fn send_error<E: for<'a> facet::Facet<'a> + MaybeSend>(
147 self,
148 error: VoxError<E>,
149 ) -> impl Future<Output = ()> + MaybeSend
150 where
151 Self: Sized,
152 {
153 use crate::{Payload, RequestResponse};
154 async move {
158 let wire: Result<(), VoxError<E>> = Err(error);
159 self.send_reply(RequestResponse {
160 ret: Payload::outgoing(&wire),
161 metadata: Default::default(),
162 schemas: Default::default(),
163 })
164 .await;
165 }
166 }
167
168 fn send_typed_error<'wire, T, E>(
172 self,
173 error: VoxError<E>,
174 ) -> impl Future<Output = ()> + MaybeSend
175 where
176 Self: Sized,
177 T: facet::Facet<'wire> + MaybeSend,
178 E: facet::Facet<'wire> + MaybeSend,
179 {
180 use crate::{Payload, RequestResponse};
181 async move {
182 let wire: Result<T, VoxError<E>> = Err(error);
183 let ptr = facet::PtrConst::new((&wire as *const Result<T, VoxError<E>>).cast::<u8>());
184 let shape = <Result<T, VoxError<E>> as facet::Facet<'wire>>::SHAPE;
185 let ret = unsafe { Payload::outgoing_unchecked(ptr, shape) };
186 self.send_reply(RequestResponse {
187 ret,
188 metadata: Default::default(),
189 schemas: Default::default(),
190 })
191 .await;
192 }
193 }
194
195 fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
200 None
201 }
202
203 fn request_id(&self) -> Option<RequestId> {
205 None
206 }
207
208 fn connection_id(&self) -> Option<ConnectionId> {
210 None
211 }
212}
213
214pub trait Handler<R: ReplySink>: MaybeSend + MaybeSync + 'static {
227 fn retry_policy(&self, _method_id: MethodId) -> crate::RetryPolicy {
229 crate::RetryPolicy::VOLATILE
230 }
231
232 fn args_have_channels(&self, _method_id: MethodId) -> bool {
234 false
235 }
236
237 fn response_wire_shape(&self, _method_id: MethodId) -> Option<&'static facet::Shape> {
242 None
243 }
244
245 fn handle(
247 &self,
248 call: SelfRef<RequestCall<'static>>,
249 reply: R,
250 schemas: Arc<SchemaRecvTracker>,
251 ) -> impl Future<Output = ()> + MaybeSend + '_;
252}
253
254impl<R: ReplySink> Handler<R> for () {
255 async fn handle(
256 &self,
257 _call: SelfRef<RequestCall<'static>>,
258 _reply: R,
259 _schemas: Arc<SchemaRecvTracker>,
260 ) {
261 }
262}
263
264pub struct ResponseParts<'a, T> {
270 pub ret: T,
272 pub metadata: Metadata<'a>,
274}
275
276impl<'a, T> std::ops::Deref for ResponseParts<'a, T> {
277 type Target = T;
278 fn deref(&self) -> &T {
279 &self.ret
280 }
281}
282
283pub struct SinkCall<R: ReplySink> {
289 reply: R,
290}
291
292impl<R: ReplySink> SinkCall<R> {
293 pub fn new(reply: R) -> Self {
294 Self { reply }
295 }
296}
297
298impl<'wire, T, E, R> Call<'wire, T, E> for SinkCall<R>
299where
300 T: facet::Facet<'wire> + MaybeSend,
301 E: facet::Facet<'wire> + MaybeSend,
302 R: ReplySink,
303{
304 async fn reply(self, result: Result<T, E>) {
305 use crate::{Payload, RequestResponse};
306 let wire: Result<T, VoxError<E>> = result.map_err(VoxError::User);
307 let ptr = facet::PtrConst::new((&wire as *const Result<T, VoxError<E>>).cast::<u8>());
308 let shape = <Result<T, VoxError<E>> as facet::Facet<'wire>>::SHAPE;
309 let ret = unsafe { Payload::outgoing_unchecked(ptr, shape) };
312
313 self.reply
314 .send_reply(RequestResponse {
315 ret,
316 metadata: Default::default(),
317 schemas: Default::default(),
318 })
319 .await;
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use std::sync::{Arc, Mutex};
326
327 use crate::{MaybeSend, Metadata, Payload, RequestCall, RequestResponse};
328
329 use super::{Call, Handler, ReplySink, ResponseParts};
330
331 struct RecordingCall<T, E> {
332 observed: Arc<Mutex<Option<Result<T, E>>>>,
333 }
334
335 impl<'wire, T, E> Call<'wire, T, E> for RecordingCall<T, E>
336 where
337 T: facet::Facet<'wire> + MaybeSend + Send + 'static,
338 E: facet::Facet<'wire> + MaybeSend + Send + 'static,
339 {
340 async fn reply(self, result: Result<T, E>) {
341 let mut guard = self.observed.lock().expect("recording mutex poisoned");
342 *guard = Some(result);
343 }
344 }
345
346 struct RecordingReplySink {
347 saw_send_reply: Arc<Mutex<bool>>,
348 saw_outgoing_payload: Arc<Mutex<bool>>,
349 }
350
351 impl ReplySink for RecordingReplySink {
352 async fn send_reply(self, response: RequestResponse<'_>) {
353 let mut saw_send_reply = self
354 .saw_send_reply
355 .lock()
356 .expect("send-reply mutex poisoned");
357 *saw_send_reply = true;
358
359 let mut saw_outgoing = self
360 .saw_outgoing_payload
361 .lock()
362 .expect("payload-kind mutex poisoned");
363 *saw_outgoing = matches!(response.ret, Payload::Value { .. });
364 }
365 }
366
367 #[tokio::test]
368 async fn call_ok_and_err_route_through_reply() {
369 let observed_ok: Arc<Mutex<Option<Result<u32, &'static str>>>> = Arc::new(Mutex::new(None));
370 RecordingCall {
371 observed: Arc::clone(&observed_ok),
372 }
373 .ok(7)
374 .await;
375 assert!(matches!(
376 *observed_ok.lock().expect("ok mutex poisoned"),
377 Some(Ok(7))
378 ));
379
380 let observed_err: Arc<Mutex<Option<Result<u32, &'static str>>>> =
381 Arc::new(Mutex::new(None));
382 RecordingCall {
383 observed: Arc::clone(&observed_err),
384 }
385 .err("boom")
386 .await;
387 assert!(matches!(
388 *observed_err.lock().expect("err mutex poisoned"),
389 Some(Err("boom"))
390 ));
391 }
392
393 #[tokio::test]
394 async fn reply_sink_send_error_uses_outgoing_payload_and_reply_path() {
395 let saw_send_reply = Arc::new(Mutex::new(false));
396 let saw_outgoing_payload = Arc::new(Mutex::new(false));
397 let sink = RecordingReplySink {
398 saw_send_reply: Arc::clone(&saw_send_reply),
399 saw_outgoing_payload: Arc::clone(&saw_outgoing_payload),
400 };
401
402 sink.send_error(crate::VoxError::<String>::Cancelled).await;
403
404 assert!(*saw_send_reply.lock().expect("send-reply mutex poisoned"));
405 assert!(
406 *saw_outgoing_payload
407 .lock()
408 .expect("payload-kind mutex poisoned")
409 );
410 }
411
412 #[tokio::test]
413 async fn reply_sink_send_typed_error_preserves_ok_shape() {
414 use crate::{
415 SchemaKind, TypeRef, VariantPayload, VoxError, build_registry, extract_schemas,
416 };
417
418 struct ShapeReplySink {
419 observed_root: Arc<Mutex<Option<TypeRef>>>,
420 }
421
422 impl ReplySink for ShapeReplySink {
423 async fn send_reply(self, response: RequestResponse<'_>) {
424 let Payload::Value { shape, .. } = response.ret else {
425 panic!("typed error should use outgoing payload");
426 };
427 let extracted = extract_schemas(shape).expect("response shape should extract");
428 *self
429 .observed_root
430 .lock()
431 .expect("observed-root mutex poisoned") = Some(extracted.root.clone());
432 }
433 }
434
435 let observed_root = Arc::new(Mutex::new(None));
436 ShapeReplySink {
437 observed_root: Arc::clone(&observed_root),
438 }
439 .send_typed_error::<(String, i32), String>(VoxError::Cancelled)
440 .await;
441
442 let root = observed_root
443 .lock()
444 .expect("observed-root mutex poisoned")
445 .clone()
446 .expect("typed error should record a root");
447 let extracted =
448 extract_schemas(<Result<(String, i32), VoxError<String>> as facet::Facet>::SHAPE)
449 .expect("expected result shape should extract");
450 let registry = build_registry(&extracted.schemas);
451 let root_kind = root.resolve_kind(®istry).expect("root should resolve");
452 let SchemaKind::Enum { variants, .. } = root_kind else {
453 panic!("expected result enum root");
454 };
455 let ok_variant = variants
456 .iter()
457 .find(|variant| variant.name == "Ok")
458 .expect("Result should have Ok variant");
459 let VariantPayload::Newtype { type_ref } = &ok_variant.payload else {
460 panic!("Ok variant should be newtype");
461 };
462 match type_ref
463 .resolve_kind(®istry)
464 .expect("Ok payload should resolve")
465 {
466 SchemaKind::Tuple { elements } => {
467 assert_eq!(elements.len(), 2, "Ok tuple should have two elements");
468 }
469 other => panic!("expected Ok payload to be tuple, got {other:?}"),
470 }
471 }
472
473 #[tokio::test]
474 async fn unit_handler_is_noop() {
475 let req = crate::SelfRef::owning(
476 crate::Backing::Boxed(Box::<[u8]>::default()),
477 RequestCall {
478 method_id: crate::MethodId(1),
479 metadata: Metadata::default(),
480 args: Payload::PostcardBytes(&[]),
481 schemas: Default::default(),
482 },
483 );
484 ().handle(
485 req,
486 RecordingReplySink {
487 saw_send_reply: Arc::new(Mutex::new(false)),
488 saw_outgoing_payload: Arc::new(Mutex::new(false)),
489 },
490 Arc::new(crate::SchemaRecvTracker::new()),
491 )
492 .await;
493 }
494
495 #[test]
496 fn response_parts_deref_exposes_ret() {
497 let parts = ResponseParts {
498 ret: 42_u32,
499 metadata: Metadata::default(),
500 };
501 assert_eq!(*parts, 42);
502 }
503
504 #[test]
505 fn default_channel_binder_accessor_for_reply_sink_returns_none() {
506 let sink = RecordingReplySink {
507 saw_send_reply: Arc::new(Mutex::new(false)),
508 saw_outgoing_payload: Arc::new(Mutex::new(false)),
509 };
510 assert!(sink.channel_binder().is_none());
511 }
512}