1use std::{
2 collections::BTreeMap,
3 pin::Pin,
4 sync::{
5 Arc, Weak,
6 atomic::{AtomicU64, Ordering},
7 },
8};
9
10use moire::sync::{Semaphore, SyncMutex};
11use tokio::sync::watch;
12
13use moire::task::FutureExt as _;
14use vox_types::{
15 BoxFut, CallResult, ChannelBinder, ChannelBody, ChannelClose, ChannelCreditReplenisher,
16 ChannelCreditReplenisherHandle, ChannelId, ChannelItem, ChannelLivenessHandle, ChannelMessage,
17 ChannelRetryMode, ChannelSink, CreditSink, Handler, IdAllocator, IncomingChannelMessage,
18 MaybeSend, MaybeSync, Payload, ReplySink, RequestBody, RequestCall, RequestId, RequestMessage,
19 RequestResponse, SelfRef, TxError, VoxError, ensure_operation_id, metadata_channel_retry_mode,
20 metadata_operation_id,
21};
22
23use crate::session::{
24 ConnectionHandle, ConnectionMessage, ConnectionSender, DropControlRequest, FailureDisposition,
25};
26use crate::{InMemoryOperationStore, OperationStore};
27use moire::sync::mpsc;
28use vox_types::{OperationId, PostcardPayload, SchemaHash, TypeRef};
29
30struct PendingResponse {
36 msg: SelfRef<RequestMessage<'static>>,
37 schemas: Arc<vox_types::SchemaRecvTracker>,
38}
39
40type ResponseSlot = moire::sync::oneshot::Sender<PendingResponse>;
41
42struct InFlightHandler {
43 handle: moire::task::JoinHandle<()>,
44 method_id: vox_types::MethodId,
45 retry: vox_types::RetryPolicy,
46 has_channels: bool,
47 operation_id: Option<OperationId>,
48}
49
50struct LiveOperationTracker {
60 live: HashMap<OperationId, LiveOperation>,
62 request_to_operation: HashMap<RequestId, OperationId>,
64}
65
66struct LiveOperation {
67 method_id: vox_types::MethodId,
68 args_hash: u64,
69 owner_request_id: RequestId,
70 waiters: Vec<RequestId>,
71 retry: vox_types::RetryPolicy,
72}
73
74enum AdmitResult {
75 Start,
77 Attached,
79 Conflict,
81}
82
83impl LiveOperationTracker {
84 fn new() -> Self {
85 Self {
86 live: HashMap::new(),
87 request_to_operation: HashMap::new(),
88 }
89 }
90
91 fn admit(
92 &mut self,
93 operation_id: OperationId,
94 method_id: vox_types::MethodId,
95 args: &[u8],
96 retry: vox_types::RetryPolicy,
97 request_id: RequestId,
98 ) -> AdmitResult {
99 use std::hash::{Hash, Hasher};
100 let args_hash = {
101 let mut h = std::collections::hash_map::DefaultHasher::new();
102 method_id.hash(&mut h);
103 args.hash(&mut h);
104 h.finish()
105 };
106 let live_operations = self.live.len();
107
108 if let Some(live) = self.live.get_mut(&operation_id) {
109 if live.method_id != method_id || live.args_hash != args_hash {
110 let request_bindings = self.request_to_operation.len();
111 tracing::trace!(
112 %operation_id,
113 %request_id,
114 ?method_id,
115 live_operations,
116 request_bindings,
117 "live operation conflict"
118 );
119 return AdmitResult::Conflict;
120 }
121 live.waiters.push(request_id);
122 self.request_to_operation.insert(request_id, operation_id);
123 let waiters = live.waiters.len();
124 let request_bindings = self.request_to_operation.len();
125 tracing::trace!(
126 %operation_id,
127 %request_id,
128 ?method_id,
129 waiters,
130 live_operations,
131 request_bindings,
132 "live operation attached"
133 );
134 return AdmitResult::Attached;
135 }
136
137 self.live.insert(
138 operation_id,
139 LiveOperation {
140 method_id,
141 args_hash,
142 owner_request_id: request_id,
143 waiters: vec![request_id],
144 retry,
145 },
146 );
147 self.request_to_operation.insert(request_id, operation_id);
148 let live_operations = self.live.len();
149 let request_bindings = self.request_to_operation.len();
150 tracing::trace!(
151 %operation_id,
152 %request_id,
153 ?method_id,
154 live_operations,
155 request_bindings,
156 "live operation admitted"
157 );
158 AdmitResult::Start
159 }
160
161 fn seal(&mut self, operation_id: OperationId) -> Vec<RequestId> {
163 if let Some(live) = self.live.remove(&operation_id) {
164 for waiter in &live.waiters {
165 self.request_to_operation.remove(waiter);
166 }
167 let waiters = live.waiters.len();
168 let live_operations = self.live.len();
169 let request_bindings = self.request_to_operation.len();
170 tracing::trace!(
171 %operation_id,
172 waiters,
173 live_operations,
174 request_bindings,
175 "live operation sealed"
176 );
177 live.waiters
178 } else {
179 vec![]
180 }
181 }
182
183 fn release(&mut self, operation_id: OperationId) -> Option<LiveOperation> {
185 if let Some(live) = self.live.remove(&operation_id) {
186 for waiter in &live.waiters {
187 self.request_to_operation.remove(waiter);
188 }
189 let waiters = live.waiters.len();
190 let live_operations = self.live.len();
191 let request_bindings = self.request_to_operation.len();
192 tracing::trace!(
193 %operation_id,
194 waiters,
195 live_operations,
196 request_bindings,
197 "live operation released"
198 );
199 Some(live)
200 } else {
201 None
202 }
203 }
204
205 fn cancel(&mut self, request_id: RequestId) -> CancelResult {
207 let Some(&operation_id) = self.request_to_operation.get(&request_id) else {
208 return CancelResult::NotFound;
209 };
210 let live_operations = self.live.len();
211 let Some(live) = self.live.get_mut(&operation_id) else {
212 self.request_to_operation.remove(&request_id);
213 return CancelResult::NotFound;
214 };
215
216 if live.retry.persist {
217 if live.owner_request_id == request_id {
219 return CancelResult::NotFound; }
221 live.waiters.retain(|w| *w != request_id);
222 self.request_to_operation.remove(&request_id);
223 let waiters = live.waiters.len();
224 let request_bindings = self.request_to_operation.len();
225 tracing::trace!(
226 %operation_id,
227 %request_id,
228 waiters,
229 live_operations,
230 request_bindings,
231 "live operation detached waiter"
232 );
233 CancelResult::Detached
234 } else {
235 let live = self.live.remove(&operation_id).unwrap();
237 for waiter in &live.waiters {
238 self.request_to_operation.remove(waiter);
239 }
240 let waiters = live.waiters.len();
241 let live_operations = self.live.len();
242 let request_bindings = self.request_to_operation.len();
243 tracing::trace!(
244 %operation_id,
245 %request_id,
246 waiters,
247 live_operations,
248 request_bindings,
249 "live operation aborted"
250 );
251 CancelResult::Abort {
252 owner_request_id: live.owner_request_id,
253 waiters: live.waiters,
254 }
255 }
256 }
257}
258
259enum CancelResult {
260 NotFound,
261 Detached,
262 Abort {
263 owner_request_id: RequestId,
264 waiters: Vec<RequestId>,
265 },
266}
267
268use std::collections::HashMap;
269
270struct DriverShared {
275 pending_responses: SyncMutex<BTreeMap<RequestId, ResponseSlot>>,
276 request_ids: SyncMutex<IdAllocator<RequestId>>,
277 next_operation_id: AtomicU64,
278 operations: Arc<dyn OperationStore>,
279 channel_ids: SyncMutex<IdAllocator<ChannelId>>,
280 channel_senders: SyncMutex<BTreeMap<ChannelId, mpsc::Sender<IncomingChannelMessage>>>,
282 channel_buffers: SyncMutex<BTreeMap<ChannelId, Vec<IncomingChannelMessage>>>,
289 channel_credits: SyncMutex<BTreeMap<ChannelId, Arc<Semaphore>>>,
292 stale_close_channels: SyncMutex<std::collections::HashSet<ChannelId>>,
297}
298
299struct CallerDropGuard {
300 control_tx: mpsc::UnboundedSender<DropControlRequest>,
301 request: DropControlRequest,
302}
303
304impl Drop for CallerDropGuard {
305 fn drop(&mut self) {
306 let _ = self.control_tx.send(self.request);
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::{DriverChannelCreditReplenisher, DriverLocalControl};
313 use vox_types::{ChannelCreditReplenisher, ChannelId};
314
315 #[tokio::test]
316 async fn replenisher_batches_at_half_the_initial_window() {
317 let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher");
318 let replenisher = DriverChannelCreditReplenisher::new(ChannelId(7), 16, tx);
319
320 for _ in 0..7 {
321 replenisher.on_item_consumed();
322 }
323 assert!(
324 tokio::time::timeout(std::time::Duration::from_millis(20), rx.recv())
325 .await
326 .is_err(),
327 "should not emit credit before reaching the batch threshold"
328 );
329
330 replenisher.on_item_consumed();
331 let Some(DriverLocalControl::GrantCredit {
332 channel_id,
333 additional,
334 }) = rx.recv().await
335 else {
336 panic!("expected batched credit grant");
337 };
338 assert_eq!(channel_id, ChannelId(7));
339 assert_eq!(additional, 8);
340 }
341
342 #[tokio::test]
343 async fn replenisher_grants_one_by_one_for_single_credit_windows() {
344 let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher.single");
345 let replenisher = DriverChannelCreditReplenisher::new(ChannelId(9), 1, tx);
346
347 replenisher.on_item_consumed();
348 let Some(DriverLocalControl::GrantCredit {
349 channel_id,
350 additional,
351 }) = rx.recv().await
352 else {
353 panic!("expected immediate credit grant");
354 };
355 assert_eq!(channel_id, ChannelId(9));
356 assert_eq!(additional, 1);
357 }
358}
359
360pub struct DriverReplySink {
368 sender: Option<ConnectionSender>,
369 request_id: RequestId,
370 method_id: vox_types::MethodId,
371 retry: vox_types::RetryPolicy,
372 operation_id: Option<OperationId>,
373 operations: Option<Arc<dyn OperationStore>>,
374 live_operations: Option<Arc<SyncMutex<LiveOperationTracker>>>,
375 binder: DriverChannelBinder,
376}
377
378async fn replay_sealed_response(
384 sender: ConnectionSender,
385 request_id: RequestId,
386 method_id: vox_types::MethodId,
387 encoded_response: &[u8],
388 root_type: TypeRef,
389 operations: &dyn OperationStore,
390) -> Result<(), ()> {
391 let mut response: RequestResponse<'_> =
392 vox_postcard::from_slice_borrowed(encoded_response).map_err(|_| ())?;
393 sender.prepare_replay_schemas(request_id, method_id, &root_type, operations, &mut response);
394 sender.send_response(request_id, response).await
395}
396
397fn extract_root_type_ref(schemas_cbor: &vox_types::CborPayload) -> TypeRef {
399 if schemas_cbor.is_empty() {
400 return TypeRef::concrete(SchemaHash(0));
401 }
402 let payload =
403 vox_types::SchemaPayload::from_cbor(&schemas_cbor.0).expect("schema CBOR must be valid");
404 payload.root
405}
406
407fn incoming_args_bytes<'a>(call: &'a RequestCall<'a>) -> &'a [u8] {
408 match &call.args {
409 Payload::PostcardBytes(bytes) => bytes,
410 Payload::Value { .. } => {
411 panic!("incoming request payload should always be decoded as incoming bytes")
412 }
413 }
414}
415
416impl ReplySink for DriverReplySink {
417 async fn send_reply(mut self, response: RequestResponse<'_>) {
418 let sender = self
419 .sender
420 .take()
421 .expect("unreachable: send_reply takes self by value");
422
423 vox_types::dlog!(
424 "[driver] send_reply: conn={:?} req={:?} method={:?} payload={} operation_id={:?}",
425 sender.connection_id(),
426 self.request_id,
427 self.method_id,
428 match &response.ret {
429 Payload::Value { .. } => "Value",
430 Payload::PostcardBytes(_) => "PostcardBytes",
431 },
432 self.operation_id
433 );
434
435 if let Payload::Value { shape, .. } = &response.ret
436 && let Ok(extracted) = vox_types::extract_schemas(shape)
437 {
438 vox_types::dlog!(
439 "[schema] driver send_reply: method={:?} root={:?}",
440 self.method_id,
441 extracted.root
442 );
443 }
444
445 if let (Some(operation_id), Some(operations)) = (self.operation_id, self.operations.take())
446 {
447 let mut response = response;
448 sender.prepare_response_for_method(self.request_id, self.method_id, &mut response);
449
450 let root_type = extract_root_type_ref(&response.schemas);
452
453 let schemas_for_wire = std::mem::take(&mut response.schemas);
455 let encoded_for_store = PostcardPayload(
456 vox_postcard::to_vec(&response).expect("serialize operation response for store"),
457 );
458 response.schemas = schemas_for_wire;
459
460 vox_types::dlog!(
462 "[driver] send_reply wire send: conn={:?} req={:?} method={:?} schemas={}",
463 sender.connection_id(),
464 self.request_id,
465 self.method_id,
466 response.schemas.0.len()
467 );
468 if let Err(_e) = sender.send_response(self.request_id, response).await {
469 sender.mark_failure(self.request_id, FailureDisposition::Cancelled);
470 }
471
472 let registry = sender.schema_registry();
474 operations.seal(operation_id, &encoded_for_store, &root_type, ®istry);
475
476 let waiters = self
478 .live_operations
479 .as_ref()
480 .map(|lo| lo.lock().seal(operation_id))
481 .unwrap_or_default();
482 for waiter in waiters {
483 if waiter == self.request_id {
484 continue;
485 }
486 if replay_sealed_response(
487 sender.clone(),
488 waiter,
489 self.method_id,
490 encoded_for_store.as_bytes(),
491 root_type.clone(),
492 operations.as_ref(),
493 )
494 .await
495 .is_err()
496 {
497 sender.mark_failure(waiter, FailureDisposition::Cancelled);
498 }
499 }
500 } else {
501 vox_types::dlog!(
502 "[driver] send_reply direct send: conn={:?} req={:?} method={:?}",
503 sender.connection_id(),
504 self.request_id,
505 self.method_id
506 );
507 if let Err(_e) = sender
508 .send_response_for_method(self.request_id, self.method_id, response)
509 .await
510 {
511 sender.mark_failure(self.request_id, FailureDisposition::Cancelled);
512 }
513 }
514 }
515
516 fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
517 Some(&self.binder)
518 }
519
520 fn request_id(&self) -> Option<RequestId> {
521 Some(self.request_id)
522 }
523
524 fn connection_id(&self) -> Option<vox_types::ConnectionId> {
525 self.sender.as_ref().map(|sender| sender.connection_id())
526 }
527}
528
529impl Drop for DriverReplySink {
531 fn drop(&mut self) {
532 if let Some(sender) = self.sender.take() {
533 let disposition = if self.retry.persist {
534 FailureDisposition::Indeterminate
535 } else {
536 FailureDisposition::Cancelled
537 };
538
539 if let Some(operation_id) = self.operation_id {
540 if let Some(live_ops) = self.live_operations.take()
546 && let Some(live) = live_ops.lock().release(operation_id)
547 {
548 for waiter in live.waiters {
549 sender.mark_failure(waiter, disposition);
550 }
551 return;
552 }
553 }
554
555 sender.mark_failure(self.request_id, disposition);
556 }
557 }
558}
559
560pub struct DriverChannelSink {
568 sender: ConnectionSender,
569 channel_id: ChannelId,
570 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
571}
572
573impl ChannelSink for DriverChannelSink {
574 fn send_payload<'payload>(
575 &self,
576 payload: Payload<'payload>,
577 ) -> Pin<Box<dyn vox_types::MaybeSendFuture<Output = Result<(), TxError>> + 'payload>> {
578 let sender = self.sender.clone();
579 let channel_id = self.channel_id;
580 Box::pin(async move {
581 sender
582 .send(ConnectionMessage::Channel(ChannelMessage {
583 id: channel_id,
584 body: ChannelBody::Item(ChannelItem { item: payload }),
585 }))
586 .await
587 .map_err(|()| TxError::Transport("connection closed".into()))
588 })
589 }
590
591 fn close_channel(
592 &self,
593 _metadata: vox_types::Metadata,
594 ) -> Pin<Box<dyn vox_types::MaybeSendFuture<Output = Result<(), TxError>> + 'static>> {
595 let sender = self.sender.clone();
599 let channel_id = self.channel_id;
600 Box::pin(async move {
601 sender
602 .send(ConnectionMessage::Channel(ChannelMessage {
603 id: channel_id,
604 body: ChannelBody::Close(ChannelClose {
605 metadata: Default::default(),
606 }),
607 }))
608 .await
609 .map_err(|()| TxError::Transport("connection closed".into()))
610 })
611 }
612
613 fn close_channel_on_drop(&self) {
614 let _ = self
615 .local_control_tx
616 .send(DriverLocalControl::CloseChannel {
617 channel_id: self.channel_id,
618 });
619 }
620}
621
622pub trait ErasedHandler: MaybeSend + MaybeSync + 'static {
627 fn retry_policy(&self, method_id: vox_types::MethodId) -> vox_types::RetryPolicy {
628 let _ = method_id;
629 vox_types::RetryPolicy::VOLATILE
630 }
631
632 fn args_have_channels(&self, method_id: vox_types::MethodId) -> bool {
633 let _ = method_id;
634 false
635 }
636
637 fn response_wire_shape(&self, method_id: vox_types::MethodId) -> Option<&'static facet::Shape> {
638 let _ = method_id;
639 None
640 }
641
642 fn handle_erased(
643 &self,
644 call: SelfRef<RequestCall<'static>>,
645 reply: DriverReplySink,
646 schemas: std::sync::Arc<vox_types::SchemaRecvTracker>,
647 ) -> BoxFut<'_, ()>;
648}
649
650impl<H: Handler<DriverReplySink>> ErasedHandler for H {
651 fn retry_policy(&self, method_id: vox_types::MethodId) -> vox_types::RetryPolicy {
652 Handler::retry_policy(self, method_id)
653 }
654
655 fn args_have_channels(&self, method_id: vox_types::MethodId) -> bool {
656 Handler::args_have_channels(self, method_id)
657 }
658
659 fn response_wire_shape(&self, method_id: vox_types::MethodId) -> Option<&'static facet::Shape> {
660 Handler::response_wire_shape(self, method_id)
661 }
662
663 fn handle_erased(
664 &self,
665 call: SelfRef<RequestCall<'static>>,
666 reply: DriverReplySink,
667 schemas: std::sync::Arc<vox_types::SchemaRecvTracker>,
668 ) -> BoxFut<'_, ()> {
669 Box::pin(Handler::handle(self, call, reply, schemas))
670 }
671}
672
673impl Handler<DriverReplySink> for Box<dyn ErasedHandler> {
674 fn retry_policy(&self, method_id: vox_types::MethodId) -> vox_types::RetryPolicy {
675 (**self).retry_policy(method_id)
676 }
677
678 fn args_have_channels(&self, method_id: vox_types::MethodId) -> bool {
679 (**self).args_have_channels(method_id)
680 }
681
682 fn response_wire_shape(&self, method_id: vox_types::MethodId) -> Option<&'static facet::Shape> {
683 (**self).response_wire_shape(method_id)
684 }
685
686 async fn handle(
687 &self,
688 call: SelfRef<RequestCall<'static>>,
689 reply: DriverReplySink,
690 schemas: std::sync::Arc<vox_types::SchemaRecvTracker>,
691 ) {
692 (**self).handle_erased(call, reply, schemas).await
693 }
694}
695
696#[must_use = "Dropping this caller may close the connection if it is the last caller."]
702#[derive(Clone)]
703pub struct Caller {
704 inner: Arc<DriverCaller>,
705 service: Option<&'static vox_types::ServiceDescriptor>,
706 middlewares: Vec<Arc<dyn vox_types::ClientMiddleware>>,
707}
708
709impl Caller {
710 pub fn new(driver: DriverCaller) -> Self {
712 Self {
713 inner: Arc::new(driver),
714 service: None,
715 middlewares: vec![],
716 }
717 }
718
719 #[cfg(test)]
721 pub(crate) fn driver(&self) -> &DriverCaller {
722 &self.inner
723 }
724
725 pub fn with_middleware(
727 mut self,
728 service: &'static vox_types::ServiceDescriptor,
729 middleware: impl vox_types::ClientMiddleware,
730 ) -> Self {
731 if let Some(existing_service) = self.service {
732 assert_eq!(
733 existing_service.service_name, service.service_name,
734 "Caller middleware service mismatch"
735 );
736 } else {
737 self.service = Some(service);
738 }
739 self.middlewares.push(Arc::new(middleware));
740 self
741 }
742
743 pub async fn call(&self, mut call: RequestCall<'_>) -> CallResult {
746 use vox_types::{
747 ClientCallOutcome, ClientContext, ClientRequest, Extensions, OwnedMetadata,
748 };
749
750 let Some(service) = self.service else {
751 return self.inner.call_inner(call).await;
752 };
753
754 let extensions = Extensions::new();
755 let method = service.by_id(call.method_id);
756 let context = ClientContext::new(method, call.method_id, &extensions);
757 let mut owned_metadata = OwnedMetadata::default();
758
759 if !self.middlewares.is_empty() {
760 for middleware in &self.middlewares {
761 let mut request = ClientRequest::new(&mut call, &mut owned_metadata);
762 middleware.pre(&context, &mut request).await;
763 }
764 }
765
766 let result = self.inner.call_inner(call).await;
767 if !self.middlewares.is_empty() {
768 let outcome = match &result {
769 Ok(_) => ClientCallOutcome::Response,
770 Err(error) => ClientCallOutcome::Error(error),
771 };
772 for middleware in self.middlewares.iter().rev() {
773 middleware.post(&context, outcome).await;
774 }
775 }
776 result
777 }
778
779 pub async fn closed(&self) {
781 if *self.inner.closed_rx.borrow() {
782 return;
783 }
784 let mut rx = self.inner.closed_rx.clone();
785 while rx.changed().await.is_ok() {
786 if *rx.borrow() {
787 return;
788 }
789 }
790 }
791
792 pub fn is_connected(&self) -> bool {
794 !*self.inner.closed_rx.borrow()
795 }
796
797 pub fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
799 Some(self.inner.as_ref())
800 }
801}
802
803pub trait FromVoxSession {
809 const SERVICE_NAME: &'static str;
812
813 fn from_vox_session(
814 caller: Caller,
815 session_handle: Option<crate::session::SessionHandle>,
816 ) -> Self;
817}
818
819#[must_use = "Dropping NoopClient may close the connection if it is the last caller."]
824#[derive(Clone)]
825pub struct NoopClient {
826 pub caller: Caller,
828 pub session: Option<crate::session::SessionHandle>,
830}
831
832impl FromVoxSession for NoopClient {
833 const SERVICE_NAME: &'static str = "Noop";
834
835 fn from_vox_session(caller: Caller, session: Option<crate::session::SessionHandle>) -> Self {
836 Self { caller, session }
837 }
838}
839
840#[derive(Clone)]
841struct DriverChannelBinder {
842 sender: ConnectionSender,
843 shared: Arc<DriverShared>,
844 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
845 drop_guard: Option<Arc<CallerDropGuard>>,
846}
847
848const DEFAULT_CHANNEL_CREDIT: u32 = 16;
850
851fn register_rx_channel_impl(
852 shared: &Arc<DriverShared>,
853 channel_id: ChannelId,
854 queue_name: &'static str,
855 liveness: Option<ChannelLivenessHandle>,
856 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
857) -> vox_types::BoundChannelReceiver {
858 let (tx, rx) = mpsc::channel(queue_name, 64);
859
860 let mut terminal_buffered = false;
861 {
862 let mut senders = shared.channel_senders.lock();
863
864 senders.insert(channel_id, tx.clone());
873
874 let buffered = shared.channel_buffers.lock().remove(&channel_id);
875 if let Some(buffered) = buffered {
876 for msg in buffered {
877 let is_terminal = matches!(
878 msg,
879 IncomingChannelMessage::Close(_) | IncomingChannelMessage::Reset(_)
880 );
881 let _ = tx.try_send(msg);
882 if is_terminal {
883 terminal_buffered = true;
884 break;
885 }
886 }
887 }
888
889 if terminal_buffered {
890 senders.remove(&channel_id);
891 }
892 }
893
894 if terminal_buffered {
895 shared.channel_credits.lock().remove(&channel_id);
896 return vox_types::BoundChannelReceiver {
897 receiver: rx,
898 liveness,
899 replenisher: None,
900 };
901 }
902
903 vox_types::BoundChannelReceiver {
904 receiver: rx,
905 liveness,
906 replenisher: Some(Arc::new(DriverChannelCreditReplenisher::new(
907 channel_id,
908 DEFAULT_CHANNEL_CREDIT,
909 local_control_tx,
910 )) as ChannelCreditReplenisherHandle),
911 }
912}
913
914impl DriverChannelBinder {
915 fn create_tx_channel(&self) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
916 let channel_id = self.shared.channel_ids.lock().alloc();
917 let inner = DriverChannelSink {
918 sender: self.sender.clone(),
919 channel_id,
920 local_control_tx: self.local_control_tx.clone(),
921 };
922 let sink = Arc::new(CreditSink::new(inner, DEFAULT_CHANNEL_CREDIT));
923 self.shared
924 .channel_credits
925 .lock()
926 .insert(channel_id, Arc::clone(sink.credit()));
927 (channel_id, sink)
928 }
929
930 fn register_rx_channel(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
931 register_rx_channel_impl(
932 &self.shared,
933 channel_id,
934 "driver.register_rx_channel",
935 self.channel_liveness(),
936 self.local_control_tx.clone(),
937 )
938 }
939}
940
941impl ChannelBinder for DriverChannelBinder {
942 fn create_tx(&self) -> (ChannelId, Arc<dyn ChannelSink>) {
943 let (id, sink) = self.create_tx_channel();
944 (id, sink as Arc<dyn ChannelSink>)
945 }
946
947 fn create_rx(&self) -> (ChannelId, vox_types::BoundChannelReceiver) {
948 let channel_id = self.shared.channel_ids.lock().alloc();
949 let rx = self.register_rx_channel(channel_id);
950 (channel_id, rx)
951 }
952
953 fn bind_tx(&self, channel_id: ChannelId) -> Arc<dyn ChannelSink> {
954 let inner = DriverChannelSink {
955 sender: self.sender.clone(),
956 channel_id,
957 local_control_tx: self.local_control_tx.clone(),
958 };
959 let sink = Arc::new(CreditSink::new(inner, DEFAULT_CHANNEL_CREDIT));
960 self.shared
961 .channel_credits
962 .lock()
963 .insert(channel_id, Arc::clone(sink.credit()));
964 sink
965 }
966
967 fn register_rx(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
968 self.register_rx_channel(channel_id)
969 }
970
971 fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
972 self.drop_guard
973 .as_ref()
974 .map(|guard| guard.clone() as ChannelLivenessHandle)
975 }
976}
977
978#[derive(Clone)]
982pub struct DriverCaller {
983 sender: ConnectionSender,
984 shared: Arc<DriverShared>,
985 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
986 closed_rx: watch::Receiver<bool>,
987 resumed_rx: watch::Receiver<u64>,
988 resume_processed_rx: watch::Receiver<u64>,
989 peer_supports_retry: bool,
990 _drop_guard: Option<Arc<CallerDropGuard>>,
991}
992
993impl DriverCaller {
994 pub fn create_tx_channel(&self) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
999 let channel_id = self.shared.channel_ids.lock().alloc();
1000 let inner = DriverChannelSink {
1001 sender: self.sender.clone(),
1002 channel_id,
1003 local_control_tx: self.local_control_tx.clone(),
1004 };
1005 let sink = Arc::new(CreditSink::new(inner, DEFAULT_CHANNEL_CREDIT));
1006 self.shared
1007 .channel_credits
1008 .lock()
1009 .insert(channel_id, Arc::clone(sink.credit()));
1010 (channel_id, sink)
1011 }
1012
1013 #[cfg(test)]
1018 pub(crate) fn connection_sender(&self) -> &ConnectionSender {
1019 &self.sender
1020 }
1021
1022 pub fn register_rx_channel(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
1027 register_rx_channel_impl(
1028 &self.shared,
1029 channel_id,
1030 "driver.caller.register_rx_channel",
1031 self.channel_liveness(),
1032 self.local_control_tx.clone(),
1033 )
1034 }
1035}
1036
1037impl ChannelBinder for DriverCaller {
1038 fn create_tx(&self) -> (ChannelId, Arc<dyn ChannelSink>) {
1039 let (id, sink) = self.create_tx_channel();
1040 (id, sink as Arc<dyn ChannelSink>)
1041 }
1042
1043 fn create_rx(&self) -> (ChannelId, vox_types::BoundChannelReceiver) {
1044 let channel_id = self.shared.channel_ids.lock().alloc();
1045 let rx = self.register_rx_channel(channel_id);
1046 (channel_id, rx)
1047 }
1048
1049 fn bind_tx(&self, channel_id: ChannelId) -> Arc<dyn ChannelSink> {
1050 let inner = DriverChannelSink {
1051 sender: self.sender.clone(),
1052 channel_id,
1053 local_control_tx: self.local_control_tx.clone(),
1054 };
1055 let sink = Arc::new(CreditSink::new(inner, DEFAULT_CHANNEL_CREDIT));
1056 self.shared
1057 .channel_credits
1058 .lock()
1059 .insert(channel_id, Arc::clone(sink.credit()));
1060 sink
1061 }
1062
1063 fn register_rx(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
1064 self.register_rx_channel(channel_id)
1065 }
1066
1067 fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
1068 self._drop_guard
1069 .as_ref()
1070 .map(|guard| guard.clone() as ChannelLivenessHandle)
1071 }
1072}
1073
1074impl DriverCaller {
1075 async fn call_inner(&self, mut call: RequestCall<'_>) -> CallResult {
1077 if self.peer_supports_retry {
1078 let operation_id = OperationId(
1079 self.shared
1080 .next_operation_id
1081 .fetch_add(1, Ordering::Relaxed),
1082 );
1083 ensure_operation_id(&mut call.metadata, operation_id);
1084 }
1085
1086 let req_id = self.shared.request_ids.lock().alloc();
1088
1089 let (tx, rx) = moire::sync::oneshot::channel("driver.response");
1092 self.shared.pending_responses.lock().insert(req_id, tx);
1093
1094 if self
1102 .sender
1103 .send_with_binder(
1104 ConnectionMessage::Request(RequestMessage {
1105 id: req_id,
1106 body: RequestBody::Call(RequestCall {
1107 method_id: call.method_id,
1108 args: call.args.reborrow(),
1109 metadata: call.metadata.clone(),
1110 schemas: Default::default(),
1111 }),
1112 }),
1113 Some(self),
1114 )
1115 .await
1116 .is_err()
1117 {
1118 self.shared.pending_responses.lock().remove(&req_id);
1119 return Err(VoxError::SendFailed);
1120 }
1121
1122 let mut resumed_rx = self.resumed_rx.clone();
1123 let mut seen_resume_generation = *resumed_rx.borrow();
1124 let mut resume_processed_rx = self.resume_processed_rx.clone();
1125 let mut closed_rx = self.closed_rx.clone();
1126 let mut response = std::pin::pin!(rx.named("awaiting_response"));
1127
1128 let pending: PendingResponse = loop {
1129 tokio::select! {
1130 result = &mut response => {
1131 match result {
1132 Ok(pending) => break pending,
1133 Err(_) => {
1134 return Err(VoxError::ConnectionClosed);
1135 }
1136 }
1137 }
1138 changed = resumed_rx.changed(), if self.peer_supports_retry => {
1139 vox_types::dlog!("[CALLER] resumed_rx fired");
1140 if changed.is_err() {
1141 self.shared.pending_responses.lock().remove(&req_id);
1142 return Err(VoxError::SessionShutdown);
1143 }
1144 let generation = *resumed_rx.borrow();
1145 if generation == seen_resume_generation {
1146 continue;
1147 }
1148 seen_resume_generation = generation;
1149 while *resume_processed_rx.borrow() < generation {
1150 if resume_processed_rx.changed().await.is_err() {
1151 self.shared.pending_responses.lock().remove(&req_id);
1152 return Err(VoxError::SessionShutdown);
1153 }
1154 }
1155 match metadata_channel_retry_mode(&call.metadata) {
1156 ChannelRetryMode::NonIdem => {
1157 self.shared.pending_responses.lock().remove(&req_id);
1158 return Err(VoxError::Indeterminate);
1159 }
1160 ChannelRetryMode::Idem | ChannelRetryMode::None => {}
1161 }
1162 let _ = self.sender.send_with_binder(
1166 ConnectionMessage::Request(RequestMessage {
1167 id: req_id,
1168 body: RequestBody::Call(RequestCall {
1169 method_id: call.method_id,
1170 args: call.args.reborrow(),
1171 metadata: call.metadata.clone(),
1172 schemas: Default::default(),
1173 }),
1174 }),
1175 Some(self),
1176 ).await;
1177 }
1178 changed = closed_rx.changed() => {
1179 vox_types::dlog!("[CALLER] closed_rx fired, value={}", *closed_rx.borrow());
1180 if changed.is_err() || *closed_rx.borrow() {
1181 self.shared.pending_responses.lock().remove(&req_id);
1182 return Err(VoxError::ConnectionClosed);
1183 }
1184 }
1185 }
1186 };
1187
1188 let PendingResponse {
1190 msg: response_msg,
1191 schemas: response_schemas,
1192 } = pending;
1193 let response = response_msg.map(|m| match m.body {
1194 RequestBody::Response(r) => r,
1195 _ => unreachable!("pending_responses only gets Response variants"),
1196 });
1197
1198 Ok(vox_types::WithTracker {
1199 value: response,
1200 tracker: response_schemas,
1201 })
1202 }
1203}
1204
1205pub struct Driver<H: Handler<DriverReplySink>> {
1212 sender: ConnectionSender,
1213 rx: mpsc::Receiver<crate::session::RecvMessage>,
1214 failures_rx: mpsc::UnboundedReceiver<(RequestId, FailureDisposition)>,
1215 closed_rx: watch::Receiver<bool>,
1216 resumed_rx: watch::Receiver<u64>,
1217 resume_processed_tx: watch::Sender<u64>,
1218 peer_supports_retry: bool,
1219 local_control_rx: mpsc::UnboundedReceiver<DriverLocalControl>,
1220 handler: Arc<H>,
1221 shared: Arc<DriverShared>,
1222 in_flight_handlers: BTreeMap<RequestId, InFlightHandler>,
1225 live_operations: Arc<SyncMutex<LiveOperationTracker>>,
1228 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
1229 drop_control_seed: Option<mpsc::UnboundedSender<DropControlRequest>>,
1230 drop_control_request: DropControlRequest,
1231 drop_guard: SyncMutex<Option<Weak<CallerDropGuard>>>,
1232}
1233
1234enum DriverLocalControl {
1235 CloseChannel {
1236 channel_id: ChannelId,
1237 },
1238 GrantCredit {
1239 channel_id: ChannelId,
1240 additional: u32,
1241 },
1242 HandlerCompleted {
1243 request_id: RequestId,
1244 },
1245}
1246
1247struct DriverChannelCreditReplenisher {
1248 channel_id: ChannelId,
1249 threshold: u32,
1250 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
1251 pending: std::sync::Mutex<u32>,
1252}
1253
1254impl DriverChannelCreditReplenisher {
1255 fn new(
1256 channel_id: ChannelId,
1257 initial_credit: u32,
1258 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
1259 ) -> Self {
1260 Self {
1261 channel_id,
1262 threshold: (initial_credit / 2).max(1),
1263 local_control_tx,
1264 pending: std::sync::Mutex::new(0),
1265 }
1266 }
1267}
1268
1269impl ChannelCreditReplenisher for DriverChannelCreditReplenisher {
1270 fn on_item_consumed(&self) {
1271 let mut pending = self.pending.lock().expect("pending credit mutex poisoned");
1272 *pending += 1;
1273 if *pending < self.threshold {
1274 return;
1275 }
1276
1277 let additional = *pending;
1278 *pending = 0;
1279 let _ = self.local_control_tx.send(DriverLocalControl::GrantCredit {
1280 channel_id: self.channel_id,
1281 additional,
1282 });
1283 }
1284}
1285
1286impl<H: Handler<DriverReplySink>> Driver<H> {
1287 fn close_all_channel_runtime_state(&self) {
1288 let mut credits = self.shared.channel_credits.lock();
1289 for semaphore in credits.values() {
1290 semaphore.close();
1291 }
1292 let mut stale = self.shared.stale_close_channels.lock();
1295 stale.extend(credits.keys().copied());
1296 credits.clear();
1297 drop(credits);
1298
1299 self.shared.channel_senders.lock().clear();
1300 self.shared.channel_buffers.lock().clear();
1301 }
1302
1303 fn close_outbound_channel(&self, channel_id: ChannelId) {
1304 if let Some(semaphore) = self.shared.channel_credits.lock().remove(&channel_id) {
1305 semaphore.close();
1306 }
1307 }
1308
1309 fn abort_channel_handlers(&mut self) {
1310 for in_flight in self.in_flight_handlers.values() {
1311 if in_flight.has_channels {
1312 if let Some(operation_id) = in_flight.operation_id {
1313 self.shared.operations.remove(operation_id);
1314 self.live_operations.lock().release(operation_id);
1315 }
1316 in_flight.handle.abort();
1317 }
1318 }
1319 }
1320
1321 pub fn new(handle: ConnectionHandle, handler: H) -> Self {
1322 Self::with_operation_store(handle, handler, Arc::new(InMemoryOperationStore::default()))
1323 }
1324
1325 pub fn with_operation_store(
1326 handle: ConnectionHandle,
1327 handler: H,
1328 operation_store: Arc<dyn OperationStore>,
1329 ) -> Self {
1330 let conn_id = handle.connection_id();
1331 let ConnectionHandle {
1332 sender,
1333 rx,
1334 failures_rx,
1335 control_tx,
1336 closed_rx,
1337 resumed_rx,
1338 parity,
1339 peer_supports_retry,
1340 } = handle;
1341 let drop_control_request = DropControlRequest::Close(conn_id);
1342 let (local_control_tx, local_control_rx) = mpsc::unbounded_channel("driver.local_control");
1343 let (resume_processed_tx, _resume_processed_rx) = watch::channel(0_u64);
1344 Self {
1345 sender,
1346 rx,
1347 failures_rx,
1348 closed_rx,
1349 resumed_rx,
1350 resume_processed_tx,
1351 peer_supports_retry,
1352 local_control_rx,
1353 handler: Arc::new(handler),
1354 shared: Arc::new(DriverShared {
1355 pending_responses: SyncMutex::new("driver.pending_responses", BTreeMap::new()),
1356 request_ids: SyncMutex::new("driver.request_ids", IdAllocator::new(parity)),
1357 next_operation_id: AtomicU64::new(1),
1358 operations: operation_store,
1359 channel_ids: SyncMutex::new("driver.channel_ids", IdAllocator::new(parity)),
1360 channel_senders: SyncMutex::new("driver.channel_senders", BTreeMap::new()),
1361 channel_buffers: SyncMutex::new("driver.channel_buffers", BTreeMap::new()),
1362 channel_credits: SyncMutex::new("driver.channel_credits", BTreeMap::new()),
1363 stale_close_channels: SyncMutex::new(
1364 "driver.stale_close_channels",
1365 std::collections::HashSet::new(),
1366 ),
1367 }),
1368 in_flight_handlers: BTreeMap::new(),
1369 live_operations: Arc::new(SyncMutex::new(
1370 "driver.live_operations",
1371 LiveOperationTracker::new(),
1372 )),
1373 local_control_tx,
1374 drop_control_seed: control_tx,
1375 drop_control_request,
1376 drop_guard: SyncMutex::new("driver.drop_guard", None),
1377 }
1378 }
1379
1380 fn existing_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
1386 self.drop_guard.lock().as_ref().and_then(Weak::upgrade)
1387 }
1388
1389 fn connection_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
1390 if let Some(existing) = self.existing_drop_guard() {
1391 Some(existing)
1392 } else if let Some(seed) = &self.drop_control_seed {
1393 let mut guard = self.drop_guard.lock();
1394 if let Some(existing) = guard.as_ref().and_then(Weak::upgrade) {
1395 Some(existing)
1396 } else {
1397 let arc = Arc::new(CallerDropGuard {
1398 control_tx: seed.clone(),
1399 request: self.drop_control_request,
1400 });
1401 *guard = Some(Arc::downgrade(&arc));
1402 Some(arc)
1403 }
1404 } else {
1405 None
1406 }
1407 }
1408
1409 pub fn caller(&self) -> DriverCaller {
1410 let drop_guard = self.connection_drop_guard();
1411 DriverCaller {
1412 sender: self.sender.clone(),
1413 shared: Arc::clone(&self.shared),
1414 local_control_tx: self.local_control_tx.clone(),
1415 closed_rx: self.closed_rx.clone(),
1416 resumed_rx: self.resumed_rx.clone(),
1417 resume_processed_rx: self.resume_processed_tx.subscribe(),
1418 peer_supports_retry: self.peer_supports_retry,
1419 _drop_guard: drop_guard,
1420 }
1421 }
1422
1423 fn internal_binder(&self) -> DriverChannelBinder {
1424 DriverChannelBinder {
1425 sender: self.sender.clone(),
1426 shared: Arc::clone(&self.shared),
1427 local_control_tx: self.local_control_tx.clone(),
1428 drop_guard: self.existing_drop_guard(),
1429 }
1430 }
1431
1432 pub async fn run(&mut self) {
1437 let mut resumed_rx = self.resumed_rx.clone();
1438 let mut seen_resume_generation = *resumed_rx.borrow();
1439 loop {
1440 tracing::trace!("driver select loop top");
1441 tokio::select! {
1442 biased;
1443 changed = resumed_rx.changed() => {
1444 if changed.is_err() {
1445 break;
1446 }
1447 let generation = *resumed_rx.borrow();
1448 if generation != seen_resume_generation {
1449 seen_resume_generation = generation;
1450 self.close_all_channel_runtime_state();
1451 self.abort_channel_handlers();
1452 let _ = self.resume_processed_tx.send(generation);
1453 }
1454 }
1455 recv = self.rx.recv() => {
1456 match recv {
1457 Some(recv) => {
1458 self.handle_recv(recv);
1459 }
1460 None => {
1461 tracing::trace!("driver rx closed, exiting loop");
1462 break;
1463 }
1464 }
1465 }
1466 Some((req_id, disposition)) = self.failures_rx.recv() => {
1467 tracing::trace!(%req_id, ?disposition, "failures_rx fired");
1468 let in_flight_found = self.in_flight_handlers.contains_key(&req_id);
1469 let in_flight_method_id =
1470 self.in_flight_handlers.get(&req_id).map(|in_flight| in_flight.method_id);
1471 let reply_disposition = self
1472 .in_flight_handlers
1473 .get(&req_id)
1474 .map(|in_flight| {
1475 if in_flight.has_channels && !in_flight.retry.idem {
1476 Some(FailureDisposition::Indeterminate)
1477 } else if in_flight.has_channels && in_flight.retry.idem {
1478 None
1479 } else {
1480 Some(disposition)
1481 }
1482 })
1483 .unwrap_or(Some(disposition));
1484 tracing::trace!(%req_id, in_flight_found, ?reply_disposition, "failures_rx computed disposition");
1485 self.in_flight_handlers.remove(&req_id);
1487 tracing::trace!(%req_id, in_flight = self.in_flight_handlers.len(), "handler removed on failure");
1488 let had_pending = self.shared.pending_responses.lock().remove(&req_id).is_some();
1489 tracing::trace!(%req_id, had_pending, "failures_rx checked pending_responses");
1490 if !had_pending {
1491 let Some(reply_disposition) = reply_disposition else {
1492 tracing::trace!(%req_id, "failures_rx: no reply_disposition, skipping");
1493 continue;
1494 };
1495 tracing::trace!(%req_id, ?reply_disposition, "failures_rx: sending error response");
1496 let vox_error = match reply_disposition {
1497 FailureDisposition::Cancelled => VoxError::Cancelled,
1498 FailureDisposition::Indeterminate => VoxError::Indeterminate,
1499 };
1500 if let Some(method_id) = in_flight_method_id
1501 && let Some(response_shape) = self.handler.response_wire_shape(method_id)
1502 && let Ok(extracted) = vox_types::extract_schemas(response_shape)
1503 {
1504 let registry = vox_types::build_registry(&extracted.schemas);
1505 let error: Result<(), VoxError<core::convert::Infallible>> =
1506 Err(vox_error);
1507 let encoded = vox_postcard::to_vec(&error)
1508 .expect("serialize runtime-generated error response");
1509 let mut response = RequestResponse {
1510 ret: Payload::PostcardBytes(Box::leak(encoded.into_boxed_slice())),
1511 metadata: Default::default(),
1512 schemas: Default::default(),
1513 };
1514 self.sender.prepare_response_from_source(
1515 req_id,
1516 method_id,
1517 &extracted.root,
1518 ®istry,
1519 &mut response,
1520 );
1521 let _ = self.sender.send_response(req_id, response).await;
1522 } else {
1523 let error: Result<(), VoxError<core::convert::Infallible>> =
1524 Err(vox_error);
1525 let _ = self.sender.send_response(req_id, RequestResponse {
1526 ret: Payload::outgoing(&error),
1527 metadata: Default::default(),
1528 schemas: Default::default(),
1529 }).await;
1530 }
1531 tracing::trace!(%req_id, "failures_rx: error response sent");
1532 }
1533 }
1534 Some(ctrl) = self.local_control_rx.recv() => {
1535 self.handle_local_control(ctrl).await;
1536 }
1537 }
1538 }
1539
1540 for (_, in_flight) in std::mem::take(&mut self.in_flight_handlers) {
1541 if !in_flight.retry.persist {
1542 in_flight.handle.abort();
1543 }
1544 }
1545 self.shared.pending_responses.lock().clear();
1546
1547 self.close_all_channel_runtime_state();
1551 }
1552
1553 async fn handle_local_control(&mut self, control: DriverLocalControl) {
1554 match control {
1555 DriverLocalControl::CloseChannel { channel_id } => {
1556 if self.shared.stale_close_channels.lock().remove(&channel_id) {
1561 tracing::trace!(%channel_id, "suppressing ChannelClose for stale channel");
1562 return;
1563 }
1564 let _ = self
1565 .sender
1566 .send(ConnectionMessage::Channel(ChannelMessage {
1567 id: channel_id,
1568 body: ChannelBody::Close(ChannelClose {
1569 metadata: Default::default(),
1570 }),
1571 }))
1572 .await;
1573 }
1574 DriverLocalControl::GrantCredit {
1575 channel_id,
1576 additional,
1577 } => {
1578 let _ = self
1579 .sender
1580 .send(ConnectionMessage::Channel(ChannelMessage {
1581 id: channel_id,
1582 body: ChannelBody::GrantCredit(vox_types::ChannelGrantCredit {
1583 additional,
1584 }),
1585 }))
1586 .await;
1587 }
1588 DriverLocalControl::HandlerCompleted { request_id } => {
1589 let removed = self.in_flight_handlers.remove(&request_id).is_some();
1590 tracing::trace!(
1591 %request_id,
1592 removed,
1593 in_flight = self.in_flight_handlers.len(),
1594 "handler completion processed"
1595 );
1596 }
1597 }
1598 }
1599
1600 fn handle_recv(&mut self, recv: crate::session::RecvMessage) {
1601 let crate::session::RecvMessage { schemas, msg } = recv;
1602 let msg_ref = msg.get();
1603 let is_request = matches!(msg_ref, ConnectionMessage::Request(_));
1604 if is_request {
1605 if let ConnectionMessage::Request(req) = msg_ref {
1606 vox_types::dlog!(
1607 "[driver] handle_recv request: conn={:?} req={:?} body={} method={:?}",
1608 self.sender.connection_id(),
1609 req.id,
1610 match &req.body {
1611 RequestBody::Call(_) => "Call",
1612 RequestBody::Response(_) => "Response",
1613 RequestBody::Cancel(_) => "Cancel",
1614 },
1615 match &req.body {
1616 RequestBody::Call(call) => Some(call.method_id),
1617 RequestBody::Response(_) | RequestBody::Cancel(_) => None,
1618 }
1619 );
1620 match &req.body {
1621 RequestBody::Call(call) => tracing::trace!(
1622 conn_id = self.sender.connection_id().0,
1623 req_id = req.id.0,
1624 method_id = call.method_id.0,
1625 "driver received call"
1626 ),
1627 RequestBody::Response(_) => tracing::trace!(
1628 conn_id = self.sender.connection_id().0,
1629 req_id = req.id.0,
1630 "driver received response message"
1631 ),
1632 RequestBody::Cancel(_) => tracing::trace!(
1633 conn_id = self.sender.connection_id().0,
1634 req_id = req.id.0,
1635 "driver received cancel message"
1636 ),
1637 }
1638 }
1639 let msg = msg.map(|m| match m {
1640 ConnectionMessage::Request(r) => r,
1641 _ => unreachable!(),
1642 });
1643 self.handle_request(msg, schemas);
1644 } else {
1645 let msg = msg.map(|m| match m {
1646 ConnectionMessage::Channel(c) => c,
1647 _ => unreachable!(),
1648 });
1649 self.handle_channel(msg);
1650 }
1651 }
1652
1653 fn handle_request(
1654 &mut self,
1655 msg: SelfRef<RequestMessage<'static>>,
1656 schemas: Arc<vox_types::SchemaRecvTracker>,
1657 ) {
1658 let msg_ref = msg.get();
1659 let req_id = msg_ref.id;
1660 let is_call = matches!(&msg_ref.body, RequestBody::Call(_));
1661 let is_response = matches!(&msg_ref.body, RequestBody::Response(_));
1662 let is_cancel = matches!(&msg_ref.body, RequestBody::Cancel(_));
1663
1664 if is_call {
1665 let method_id = match &msg_ref.body {
1666 RequestBody::Call(call) => call.method_id,
1667 _ => unreachable!(),
1668 };
1669 vox_types::dlog!(
1670 "[driver] inbound call: conn={:?} req={:?} method={:?}",
1671 self.sender.connection_id(),
1672 req_id,
1673 method_id
1674 );
1675 let call = msg.map(|m| match m.body {
1678 RequestBody::Call(c) => c,
1679 _ => unreachable!(),
1680 });
1681 let call_ref = call.get();
1682 let handler = Arc::clone(&self.handler);
1683 let retry = handler.retry_policy(call_ref.method_id);
1684 let operation_id = metadata_operation_id(&call_ref.metadata).filter(|_| !retry.idem);
1686 let method_id = call_ref.method_id;
1687
1688 if let Some(operation_id) = operation_id {
1689 let admit = self.live_operations.lock().admit(
1691 operation_id,
1692 call_ref.method_id,
1693 incoming_args_bytes(call_ref),
1694 retry,
1695 req_id,
1696 );
1697 match admit {
1698 AdmitResult::Attached => return,
1699 AdmitResult::Conflict => {
1700 let sender = self.sender.clone();
1701 moire::task::spawn(
1702 async move {
1703 let error: Result<(), VoxError<core::convert::Infallible>> =
1704 Err(VoxError::InvalidPayload("operation ID conflict".into()));
1705 let _ = sender
1706 .send_response(
1707 req_id,
1708 RequestResponse {
1709 ret: Payload::outgoing(&error),
1710 metadata: Default::default(),
1711 schemas: Default::default(),
1712 },
1713 )
1714 .await;
1715 }
1716 .named("operation_reject"),
1717 );
1718 return;
1719 }
1720 AdmitResult::Start => {}
1721 }
1722
1723 match self.shared.operations.lookup(operation_id) {
1725 crate::OperationState::Sealed => {
1726 if let Some(sealed) = self.shared.operations.get_sealed(operation_id) {
1728 let sender = self.sender.clone();
1729 let method_id = call_ref.method_id;
1730 let operations = Arc::clone(&self.shared.operations);
1731 self.live_operations.lock().seal(operation_id);
1733 moire::task::spawn(
1734 async move {
1735 if replay_sealed_response(
1736 sender.clone(),
1737 req_id,
1738 method_id,
1739 sealed.response.as_bytes(),
1740 sealed.root_type,
1741 operations.as_ref(),
1742 )
1743 .await
1744 .is_err()
1745 {
1746 sender.mark_failure(req_id, FailureDisposition::Cancelled);
1747 }
1748 }
1749 .named("operation_replay"),
1750 );
1751 return;
1752 }
1753 }
1754 crate::OperationState::Admitted => {
1755 self.live_operations.lock().seal(operation_id);
1757 let sender = self.sender.clone();
1758 moire::task::spawn(
1759 async move {
1760 let error: Result<(), VoxError<core::convert::Infallible>> =
1761 Err(VoxError::Indeterminate);
1762 let _ = sender
1763 .send_response(
1764 req_id,
1765 RequestResponse {
1766 ret: Payload::outgoing(&error),
1767 metadata: Default::default(),
1768 schemas: Default::default(),
1769 },
1770 )
1771 .await;
1772 }
1773 .named("operation_indeterminate"),
1774 );
1775 return;
1776 }
1777 crate::OperationState::Unknown => {
1778 if !retry.idem {
1781 self.shared.operations.admit(operation_id);
1782 }
1783 }
1784 }
1785 }
1786 let reply = DriverReplySink {
1787 sender: Some(self.sender.clone()),
1788 request_id: req_id,
1789 method_id: call_ref.method_id,
1790 retry,
1791 operation_id,
1792 operations: operation_id.map(|_| Arc::clone(&self.shared.operations)),
1793 live_operations: operation_id.map(|_| Arc::clone(&self.live_operations)),
1794 binder: self.internal_binder(),
1795 };
1796 let has_channels = handler.args_have_channels(call_ref.method_id);
1797 let local_control_tx = self.local_control_tx.clone();
1798 let join_handle = moire::task::spawn(
1799 async move {
1800 vox_types::dlog!(
1801 "[driver] handler start: req={:?} method={:?}",
1802 req_id,
1803 method_id
1804 );
1805 handler.handle(call, reply, schemas).await;
1806 vox_types::dlog!(
1807 "[driver] handler done: req={:?} method={:?}",
1808 req_id,
1809 method_id
1810 );
1811 let _ = local_control_tx
1812 .send(DriverLocalControl::HandlerCompleted { request_id: req_id });
1813 }
1814 .named("handler"),
1815 );
1816 self.in_flight_handlers.insert(
1817 req_id,
1818 InFlightHandler {
1819 handle: join_handle,
1820 method_id,
1821 retry,
1822 has_channels,
1823 operation_id,
1824 },
1825 );
1826 tracing::trace!(%req_id, in_flight = self.in_flight_handlers.len(), "handler inserted");
1827 } else if is_response {
1828 vox_types::dlog!(
1830 "[driver] inbound response: conn={:?} req={:?}",
1831 self.sender.connection_id(),
1832 req_id
1833 );
1834 tracing::trace!(%req_id, "driver received response");
1835 if let Some(tx) = self.shared.pending_responses.lock().remove(&req_id) {
1836 vox_types::dlog!("[driver] routing response to waiter: req={:?}", req_id);
1837 tracing::trace!(%req_id, "routing response to pending oneshot");
1838 let _: Result<(), _> = tx.send(PendingResponse { msg, schemas });
1839 } else {
1840 vox_types::dlog!("[driver] dropped unmatched response: req={:?}", req_id);
1841 tracing::trace!(%req_id, "no pending response slot for this req_id");
1842 }
1843 } else if is_cancel {
1844 vox_types::dlog!(
1845 "[driver] inbound cancel: conn={:?} req={:?}",
1846 self.sender.connection_id(),
1847 req_id
1848 );
1849 tracing::trace!(%req_id, in_flight = self.in_flight_handlers.contains_key(&req_id), "received cancel");
1852 match self.live_operations.lock().cancel(req_id) {
1853 CancelResult::NotFound => {
1854 let should_abort = self
1855 .in_flight_handlers
1856 .get(&req_id)
1857 .map(|in_flight| !in_flight.retry.persist)
1858 .unwrap_or(false);
1859 tracing::trace!(%req_id, should_abort, "cancel: not in live operations");
1860 if should_abort && let Some(in_flight) = self.in_flight_handlers.remove(&req_id)
1861 {
1862 tracing::trace!(%req_id, "aborting handler");
1863 in_flight.handle.abort();
1864 tracing::trace!(%req_id, in_flight = self.in_flight_handlers.len(), "handler removed on cancel");
1865 }
1866 }
1867 CancelResult::Detached => {}
1868 CancelResult::Abort {
1869 owner_request_id,
1870 waiters,
1871 } => {
1872 if let Some(in_flight) = self.in_flight_handlers.remove(&owner_request_id) {
1873 if let Some(op_id) = in_flight.operation_id {
1874 self.shared.operations.remove(op_id);
1875 }
1876 in_flight.handle.abort();
1877 tracing::trace!(%owner_request_id, in_flight = self.in_flight_handlers.len(), "owner handler removed on abort");
1878 }
1879 for waiter in waiters {
1880 self.sender
1881 .mark_failure(waiter, FailureDisposition::Cancelled);
1882 }
1883 }
1884 }
1885 }
1888 }
1889
1890 fn handle_channel(&mut self, msg: SelfRef<ChannelMessage<'static>>) {
1891 let msg_ref = msg.get();
1892 let chan_id = msg_ref.id;
1893
1894 let sender = self.shared.channel_senders.lock().get(&chan_id).cloned();
1897
1898 match &msg_ref.body {
1899 ChannelBody::Item(_item) => {
1901 if let Some(tx) = &sender {
1902 tracing::trace!(
1903 conn_id = self.sender.connection_id().0,
1904 channel_id = chan_id.0,
1905 registered = true,
1906 "driver received channel item"
1907 );
1908 let item = msg.map(|m| match m.body {
1909 ChannelBody::Item(item) => item,
1910 _ => unreachable!(),
1911 });
1912 let _ = tx.try_send(IncomingChannelMessage::Item(item));
1914 } else {
1915 tracing::trace!(
1916 conn_id = self.sender.connection_id().0,
1917 channel_id = chan_id.0,
1918 registered = false,
1919 "driver buffered channel item before registration"
1920 );
1921 let item = msg.map(|m| match m.body {
1923 ChannelBody::Item(item) => item,
1924 _ => unreachable!(),
1925 });
1926 self.shared
1927 .channel_buffers
1928 .lock()
1929 .entry(chan_id)
1930 .or_default()
1931 .push(IncomingChannelMessage::Item(item));
1932 }
1933 }
1934 ChannelBody::Close(_close) => {
1936 if let Some(tx) = &sender {
1937 tracing::trace!(
1938 conn_id = self.sender.connection_id().0,
1939 channel_id = chan_id.0,
1940 registered = true,
1941 "driver received channel close"
1942 );
1943 let close = msg.map(|m| match m.body {
1944 ChannelBody::Close(close) => close,
1945 _ => unreachable!(),
1946 });
1947 let _ = tx.try_send(IncomingChannelMessage::Close(close));
1948 } else {
1949 tracing::trace!(
1950 conn_id = self.sender.connection_id().0,
1951 channel_id = chan_id.0,
1952 registered = false,
1953 "driver buffered channel close before registration"
1954 );
1955 let close = msg.map(|m| match m.body {
1957 ChannelBody::Close(close) => close,
1958 _ => unreachable!(),
1959 });
1960 self.shared
1961 .channel_buffers
1962 .lock()
1963 .entry(chan_id)
1964 .or_default()
1965 .push(IncomingChannelMessage::Close(close));
1966 }
1967 self.shared.channel_senders.lock().remove(&chan_id);
1968 self.close_outbound_channel(chan_id);
1969 }
1970 ChannelBody::Reset(_reset) => {
1972 if let Some(tx) = &sender {
1973 tracing::trace!(
1974 conn_id = self.sender.connection_id().0,
1975 channel_id = chan_id.0,
1976 registered = true,
1977 "driver received channel reset"
1978 );
1979 let reset = msg.map(|m| match m.body {
1980 ChannelBody::Reset(reset) => reset,
1981 _ => unreachable!(),
1982 });
1983 let _ = tx.try_send(IncomingChannelMessage::Reset(reset));
1984 } else {
1985 tracing::trace!(
1986 conn_id = self.sender.connection_id().0,
1987 channel_id = chan_id.0,
1988 registered = false,
1989 "driver buffered channel reset before registration"
1990 );
1991 let reset = msg.map(|m| match m.body {
1993 ChannelBody::Reset(reset) => reset,
1994 _ => unreachable!(),
1995 });
1996 self.shared
1997 .channel_buffers
1998 .lock()
1999 .entry(chan_id)
2000 .or_default()
2001 .push(IncomingChannelMessage::Reset(reset));
2002 }
2003 self.shared.channel_senders.lock().remove(&chan_id);
2004 self.close_outbound_channel(chan_id);
2005 }
2006 ChannelBody::GrantCredit(grant) => {
2009 tracing::trace!(
2010 conn_id = self.sender.connection_id().0,
2011 channel_id = chan_id.0,
2012 additional = grant.additional,
2013 "driver received channel credit"
2014 );
2015 if let Some(semaphore) = self.shared.channel_credits.lock().get(&chan_id) {
2016 semaphore.add_permits(grant.additional as usize);
2017 }
2018 }
2019 }
2020 }
2021}