1use std::{
2 collections::BTreeMap,
3 pin::Pin,
4 sync::{
5 Arc, Weak,
6 atomic::{AtomicU64, Ordering},
7 },
8};
9
10use moire::sync::{Semaphore, SyncMutex};
11use tokio::sync::watch;
12
13use moire::task::FutureExt as _;
14use vox_types::{
15 BoxFut, CallResult, Caller, ChannelBinder, ChannelBody, ChannelClose, ChannelCreditReplenisher,
16 ChannelCreditReplenisherHandle, ChannelId, ChannelItem, ChannelLivenessHandle, ChannelMessage,
17 ChannelRetryMode, ChannelSink, CreditSink, Handler, IdAllocator, IncomingChannelMessage,
18 Payload, ReplySink, RequestBody, RequestCall, RequestId, RequestMessage, RequestResponse,
19 SelfRef, TxError, VoxError, ensure_operation_id, metadata_channel_retry_mode,
20 metadata_operation_id,
21};
22
23use crate::session::{
24 ConnectionHandle, ConnectionMessage, ConnectionSender, DropControlRequest, FailureDisposition,
25};
26use crate::{InMemoryOperationStore, OperationStore};
27use moire::sync::mpsc;
28use vox_types::{OperationId, PostcardPayload, SchemaHash, TypeRef};
29
30struct PendingResponse {
36 msg: SelfRef<RequestMessage<'static>>,
37 schemas: Arc<vox_types::SchemaRecvTracker>,
38}
39
40type ResponseSlot = moire::sync::oneshot::Sender<PendingResponse>;
41
42struct InFlightHandler {
43 handle: moire::task::JoinHandle<()>,
44 method_id: vox_types::MethodId,
45 retry: vox_types::RetryPolicy,
46 has_channels: bool,
47 operation_id: Option<OperationId>,
48}
49
50struct LiveOperationTracker {
60 live: HashMap<OperationId, LiveOperation>,
62 request_to_operation: HashMap<RequestId, OperationId>,
64}
65
66struct LiveOperation {
67 method_id: vox_types::MethodId,
68 args_hash: u64,
69 owner_request_id: RequestId,
70 waiters: Vec<RequestId>,
71 retry: vox_types::RetryPolicy,
72}
73
74enum AdmitResult {
75 Start,
77 Attached,
79 Conflict,
81}
82
83impl LiveOperationTracker {
84 fn new() -> Self {
85 Self {
86 live: HashMap::new(),
87 request_to_operation: HashMap::new(),
88 }
89 }
90
91 fn admit(
92 &mut self,
93 operation_id: OperationId,
94 method_id: vox_types::MethodId,
95 args: &[u8],
96 retry: vox_types::RetryPolicy,
97 request_id: RequestId,
98 ) -> AdmitResult {
99 use std::hash::{Hash, Hasher};
100 let args_hash = {
101 let mut h = std::collections::hash_map::DefaultHasher::new();
102 method_id.hash(&mut h);
103 args.hash(&mut h);
104 h.finish()
105 };
106 let live_operations = self.live.len();
107
108 if let Some(live) = self.live.get_mut(&operation_id) {
109 if live.method_id != method_id || live.args_hash != args_hash {
110 let request_bindings = self.request_to_operation.len();
111 tracing::trace!(
112 %operation_id,
113 %request_id,
114 ?method_id,
115 live_operations,
116 request_bindings,
117 "live operation conflict"
118 );
119 return AdmitResult::Conflict;
120 }
121 live.waiters.push(request_id);
122 self.request_to_operation.insert(request_id, operation_id);
123 let waiters = live.waiters.len();
124 let request_bindings = self.request_to_operation.len();
125 tracing::trace!(
126 %operation_id,
127 %request_id,
128 ?method_id,
129 waiters,
130 live_operations,
131 request_bindings,
132 "live operation attached"
133 );
134 return AdmitResult::Attached;
135 }
136
137 self.live.insert(
138 operation_id,
139 LiveOperation {
140 method_id,
141 args_hash,
142 owner_request_id: request_id,
143 waiters: vec![request_id],
144 retry,
145 },
146 );
147 self.request_to_operation.insert(request_id, operation_id);
148 let live_operations = self.live.len();
149 let request_bindings = self.request_to_operation.len();
150 tracing::trace!(
151 %operation_id,
152 %request_id,
153 ?method_id,
154 live_operations,
155 request_bindings,
156 "live operation admitted"
157 );
158 AdmitResult::Start
159 }
160
161 fn seal(&mut self, operation_id: OperationId) -> Vec<RequestId> {
163 if let Some(live) = self.live.remove(&operation_id) {
164 for waiter in &live.waiters {
165 self.request_to_operation.remove(waiter);
166 }
167 let waiters = live.waiters.len();
168 let live_operations = self.live.len();
169 let request_bindings = self.request_to_operation.len();
170 tracing::trace!(
171 %operation_id,
172 waiters,
173 live_operations,
174 request_bindings,
175 "live operation sealed"
176 );
177 live.waiters
178 } else {
179 vec![]
180 }
181 }
182
183 fn release(&mut self, operation_id: OperationId) -> Option<LiveOperation> {
185 if let Some(live) = self.live.remove(&operation_id) {
186 for waiter in &live.waiters {
187 self.request_to_operation.remove(waiter);
188 }
189 let waiters = live.waiters.len();
190 let live_operations = self.live.len();
191 let request_bindings = self.request_to_operation.len();
192 tracing::trace!(
193 %operation_id,
194 waiters,
195 live_operations,
196 request_bindings,
197 "live operation released"
198 );
199 Some(live)
200 } else {
201 None
202 }
203 }
204
205 fn cancel(&mut self, request_id: RequestId) -> CancelResult {
207 let Some(&operation_id) = self.request_to_operation.get(&request_id) else {
208 return CancelResult::NotFound;
209 };
210 let live_operations = self.live.len();
211 let Some(live) = self.live.get_mut(&operation_id) else {
212 self.request_to_operation.remove(&request_id);
213 return CancelResult::NotFound;
214 };
215
216 if live.retry.persist {
217 if live.owner_request_id == request_id {
219 return CancelResult::NotFound; }
221 live.waiters.retain(|w| *w != request_id);
222 self.request_to_operation.remove(&request_id);
223 let waiters = live.waiters.len();
224 let request_bindings = self.request_to_operation.len();
225 tracing::trace!(
226 %operation_id,
227 %request_id,
228 waiters,
229 live_operations,
230 request_bindings,
231 "live operation detached waiter"
232 );
233 CancelResult::Detached
234 } else {
235 let live = self.live.remove(&operation_id).unwrap();
237 for waiter in &live.waiters {
238 self.request_to_operation.remove(waiter);
239 }
240 let waiters = live.waiters.len();
241 let live_operations = self.live.len();
242 let request_bindings = self.request_to_operation.len();
243 tracing::trace!(
244 %operation_id,
245 %request_id,
246 waiters,
247 live_operations,
248 request_bindings,
249 "live operation aborted"
250 );
251 CancelResult::Abort {
252 owner_request_id: live.owner_request_id,
253 waiters: live.waiters,
254 }
255 }
256 }
257}
258
259enum CancelResult {
260 NotFound,
261 Detached,
262 Abort {
263 owner_request_id: RequestId,
264 waiters: Vec<RequestId>,
265 },
266}
267
268use std::collections::HashMap;
269
270struct DriverShared {
275 pending_responses: SyncMutex<BTreeMap<RequestId, ResponseSlot>>,
276 request_ids: SyncMutex<IdAllocator<RequestId>>,
277 next_operation_id: AtomicU64,
278 operations: Arc<dyn OperationStore>,
279 channel_ids: SyncMutex<IdAllocator<ChannelId>>,
280 channel_senders: SyncMutex<BTreeMap<ChannelId, mpsc::Sender<IncomingChannelMessage>>>,
282 channel_buffers: SyncMutex<BTreeMap<ChannelId, Vec<IncomingChannelMessage>>>,
289 channel_credits: SyncMutex<BTreeMap<ChannelId, Arc<Semaphore>>>,
292 stale_close_channels: SyncMutex<std::collections::HashSet<ChannelId>>,
297}
298
299struct CallerDropGuard {
300 control_tx: mpsc::UnboundedSender<DropControlRequest>,
301 request: DropControlRequest,
302}
303
304impl Drop for CallerDropGuard {
305 fn drop(&mut self) {
306 let _ = self.control_tx.send(self.request);
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::{DriverChannelCreditReplenisher, DriverLocalControl};
313 use tokio::sync::mpsc::error::TryRecvError;
314 use vox_types::{ChannelCreditReplenisher, ChannelId};
315
316 #[test]
317 fn replenisher_batches_at_half_the_initial_window() {
318 let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher");
319 let replenisher = DriverChannelCreditReplenisher::new(ChannelId(7), 16, tx);
320
321 for _ in 0..7 {
322 replenisher.on_item_consumed();
323 }
324 assert!(
325 matches!(rx.try_recv(), Err(TryRecvError::Empty)),
326 "should not emit credit before reaching the batch threshold"
327 );
328
329 replenisher.on_item_consumed();
330 let Ok(DriverLocalControl::GrantCredit {
331 channel_id,
332 additional,
333 }) = rx.try_recv()
334 else {
335 panic!("expected batched credit grant");
336 };
337 assert_eq!(channel_id, ChannelId(7));
338 assert_eq!(additional, 8);
339 }
340
341 #[test]
342 fn replenisher_grants_one_by_one_for_single_credit_windows() {
343 let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher.single");
344 let replenisher = DriverChannelCreditReplenisher::new(ChannelId(9), 1, tx);
345
346 replenisher.on_item_consumed();
347 let Ok(DriverLocalControl::GrantCredit {
348 channel_id,
349 additional,
350 }) = rx.try_recv()
351 else {
352 panic!("expected immediate credit grant");
353 };
354 assert_eq!(channel_id, ChannelId(9));
355 assert_eq!(additional, 1);
356 }
357}
358
359pub struct DriverReplySink {
367 sender: Option<ConnectionSender>,
368 request_id: RequestId,
369 method_id: vox_types::MethodId,
370 retry: vox_types::RetryPolicy,
371 operation_id: Option<OperationId>,
372 operations: Option<Arc<dyn OperationStore>>,
373 live_operations: Option<Arc<SyncMutex<LiveOperationTracker>>>,
374 binder: DriverChannelBinder,
375}
376
377async fn replay_sealed_response(
383 sender: ConnectionSender,
384 request_id: RequestId,
385 method_id: vox_types::MethodId,
386 encoded_response: &[u8],
387 root_type: TypeRef,
388 operations: &dyn OperationStore,
389) -> Result<(), ()> {
390 let mut response: RequestResponse<'_> =
391 vox_postcard::from_slice_borrowed(encoded_response).map_err(|_| ())?;
392 sender.prepare_replay_schemas(request_id, method_id, &root_type, operations, &mut response);
393 sender.send_response(request_id, response).await
394}
395
396fn extract_root_type_ref(schemas_cbor: &vox_types::CborPayload) -> TypeRef {
398 if schemas_cbor.is_empty() {
399 return TypeRef::concrete(SchemaHash(0));
400 }
401 let payload =
402 vox_types::SchemaPayload::from_cbor(&schemas_cbor.0).expect("schema CBOR must be valid");
403 payload.root
404}
405
406fn incoming_args_bytes<'a>(call: &'a RequestCall<'a>) -> &'a [u8] {
407 match &call.args {
408 Payload::PostcardBytes(bytes) => bytes,
409 Payload::Value { .. } => {
410 panic!("incoming request payload should always be decoded as incoming bytes")
411 }
412 }
413}
414
415impl ReplySink for DriverReplySink {
416 async fn send_reply(mut self, response: RequestResponse<'_>) {
417 let sender = self
418 .sender
419 .take()
420 .expect("unreachable: send_reply takes self by value");
421
422 vox_types::dlog!(
423 "[driver] send_reply: conn={:?} req={:?} method={:?} payload={} operation_id={:?}",
424 sender.connection_id(),
425 self.request_id,
426 self.method_id,
427 match &response.ret {
428 Payload::Value { .. } => "Value",
429 Payload::PostcardBytes(_) => "PostcardBytes",
430 },
431 self.operation_id
432 );
433
434 if let Payload::Value { shape, .. } = &response.ret
435 && let Ok(extracted) = vox_types::extract_schemas(shape)
436 {
437 vox_types::dlog!(
438 "[schema] driver send_reply: method={:?} root={:?}",
439 self.method_id,
440 extracted.root
441 );
442 }
443
444 if let (Some(operation_id), Some(operations)) = (self.operation_id, self.operations.take())
445 {
446 let mut response = response;
447 sender.prepare_response_for_method(self.request_id, self.method_id, &mut response);
448
449 let root_type = extract_root_type_ref(&response.schemas);
451
452 let schemas_for_wire = std::mem::take(&mut response.schemas);
454 let encoded_for_store = PostcardPayload(
455 vox_postcard::to_vec(&response).expect("serialize operation response for store"),
456 );
457 response.schemas = schemas_for_wire;
458
459 vox_types::dlog!(
461 "[driver] send_reply wire send: conn={:?} req={:?} method={:?} schemas={}",
462 sender.connection_id(),
463 self.request_id,
464 self.method_id,
465 response.schemas.0.len()
466 );
467 if let Err(_e) = sender.send_response(self.request_id, response).await {
468 sender.mark_failure(self.request_id, FailureDisposition::Cancelled);
469 }
470
471 let registry = sender.schema_registry();
473 operations.seal(operation_id, &encoded_for_store, &root_type, ®istry);
474
475 let waiters = self
477 .live_operations
478 .as_ref()
479 .map(|lo| lo.lock().seal(operation_id))
480 .unwrap_or_default();
481 for waiter in waiters {
482 if waiter == self.request_id {
483 continue;
484 }
485 if replay_sealed_response(
486 sender.clone(),
487 waiter,
488 self.method_id,
489 encoded_for_store.as_bytes(),
490 root_type.clone(),
491 operations.as_ref(),
492 )
493 .await
494 .is_err()
495 {
496 sender.mark_failure(waiter, FailureDisposition::Cancelled);
497 }
498 }
499 } else {
500 vox_types::dlog!(
501 "[driver] send_reply direct send: conn={:?} req={:?} method={:?}",
502 sender.connection_id(),
503 self.request_id,
504 self.method_id
505 );
506 if let Err(_e) = sender
507 .send_response_for_method(self.request_id, self.method_id, response)
508 .await
509 {
510 sender.mark_failure(self.request_id, FailureDisposition::Cancelled);
511 }
512 }
513 }
514
515 fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
516 Some(&self.binder)
517 }
518
519 fn request_id(&self) -> Option<RequestId> {
520 Some(self.request_id)
521 }
522
523 fn connection_id(&self) -> Option<vox_types::ConnectionId> {
524 self.sender.as_ref().map(|sender| sender.connection_id())
525 }
526}
527
528impl Drop for DriverReplySink {
530 fn drop(&mut self) {
531 if let Some(sender) = self.sender.take() {
532 let disposition = if self.retry.persist {
533 FailureDisposition::Indeterminate
534 } else {
535 FailureDisposition::Cancelled
536 };
537
538 if let Some(operation_id) = self.operation_id {
539 if let Some(live_ops) = self.live_operations.take()
545 && let Some(live) = live_ops.lock().release(operation_id)
546 {
547 for waiter in live.waiters {
548 sender.mark_failure(waiter, disposition);
549 }
550 return;
551 }
552 }
553
554 sender.mark_failure(self.request_id, disposition);
555 }
556 }
557}
558
559pub struct DriverChannelSink {
567 sender: ConnectionSender,
568 channel_id: ChannelId,
569 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
570}
571
572impl ChannelSink for DriverChannelSink {
573 fn send_payload<'payload>(
574 &self,
575 payload: Payload<'payload>,
576 ) -> Pin<Box<dyn vox_types::MaybeSendFuture<Output = Result<(), TxError>> + 'payload>> {
577 let sender = self.sender.clone();
578 let channel_id = self.channel_id;
579 Box::pin(async move {
580 sender
581 .send(ConnectionMessage::Channel(ChannelMessage {
582 id: channel_id,
583 body: ChannelBody::Item(ChannelItem { item: payload }),
584 }))
585 .await
586 .map_err(|()| TxError::Transport("connection closed".into()))
587 })
588 }
589
590 fn close_channel(
591 &self,
592 _metadata: vox_types::Metadata,
593 ) -> Pin<Box<dyn vox_types::MaybeSendFuture<Output = Result<(), TxError>> + 'static>> {
594 let sender = self.sender.clone();
598 let channel_id = self.channel_id;
599 Box::pin(async move {
600 sender
601 .send(ConnectionMessage::Channel(ChannelMessage {
602 id: channel_id,
603 body: ChannelBody::Close(ChannelClose {
604 metadata: Default::default(),
605 }),
606 }))
607 .await
608 .map_err(|()| TxError::Transport("connection closed".into()))
609 })
610 }
611
612 fn close_channel_on_drop(&self) {
613 let _ = self
614 .local_control_tx
615 .send(DriverLocalControl::CloseChannel {
616 channel_id: self.channel_id,
617 });
618 }
619}
620
621#[must_use = "Dropping NoopCaller may close the connection if it is the last caller."]
625#[derive(Clone)]
626pub struct NoopCaller(#[allow(dead_code)] DriverCaller);
627
628impl From<DriverCaller> for NoopCaller {
629 fn from(caller: DriverCaller) -> Self {
630 Self(caller)
631 }
632}
633
634#[derive(Clone)]
635struct DriverChannelBinder {
636 sender: ConnectionSender,
637 shared: Arc<DriverShared>,
638 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
639 drop_guard: Option<Arc<CallerDropGuard>>,
640}
641
642const DEFAULT_CHANNEL_CREDIT: u32 = 16;
644
645fn register_rx_channel_impl(
646 shared: &Arc<DriverShared>,
647 channel_id: ChannelId,
648 queue_name: &'static str,
649 liveness: Option<ChannelLivenessHandle>,
650 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
651) -> vox_types::BoundChannelReceiver {
652 let (tx, rx) = mpsc::channel(queue_name, 64);
653
654 let mut terminal_buffered = false;
655 {
656 let mut senders = shared.channel_senders.lock();
657
658 senders.insert(channel_id, tx.clone());
667
668 let buffered = shared.channel_buffers.lock().remove(&channel_id);
669 if let Some(buffered) = buffered {
670 for msg in buffered {
671 let is_terminal = matches!(
672 msg,
673 IncomingChannelMessage::Close(_) | IncomingChannelMessage::Reset(_)
674 );
675 let _ = tx.try_send(msg);
676 if is_terminal {
677 terminal_buffered = true;
678 break;
679 }
680 }
681 }
682
683 if terminal_buffered {
684 senders.remove(&channel_id);
685 }
686 }
687
688 if terminal_buffered {
689 shared.channel_credits.lock().remove(&channel_id);
690 return vox_types::BoundChannelReceiver {
691 receiver: rx,
692 liveness,
693 replenisher: None,
694 };
695 }
696
697 vox_types::BoundChannelReceiver {
698 receiver: rx,
699 liveness,
700 replenisher: Some(Arc::new(DriverChannelCreditReplenisher::new(
701 channel_id,
702 DEFAULT_CHANNEL_CREDIT,
703 local_control_tx,
704 )) as ChannelCreditReplenisherHandle),
705 }
706}
707
708impl DriverChannelBinder {
709 fn create_tx_channel(&self) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
710 let channel_id = self.shared.channel_ids.lock().alloc();
711 let inner = DriverChannelSink {
712 sender: self.sender.clone(),
713 channel_id,
714 local_control_tx: self.local_control_tx.clone(),
715 };
716 let sink = Arc::new(CreditSink::new(inner, DEFAULT_CHANNEL_CREDIT));
717 self.shared
718 .channel_credits
719 .lock()
720 .insert(channel_id, Arc::clone(sink.credit()));
721 (channel_id, sink)
722 }
723
724 fn register_rx_channel(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
725 register_rx_channel_impl(
726 &self.shared,
727 channel_id,
728 "driver.register_rx_channel",
729 self.channel_liveness(),
730 self.local_control_tx.clone(),
731 )
732 }
733}
734
735impl ChannelBinder for DriverChannelBinder {
736 fn create_tx(&self) -> (ChannelId, Arc<dyn ChannelSink>) {
737 let (id, sink) = self.create_tx_channel();
738 (id, sink as Arc<dyn ChannelSink>)
739 }
740
741 fn create_rx(&self) -> (ChannelId, vox_types::BoundChannelReceiver) {
742 let channel_id = self.shared.channel_ids.lock().alloc();
743 let rx = self.register_rx_channel(channel_id);
744 (channel_id, rx)
745 }
746
747 fn bind_tx(&self, channel_id: ChannelId) -> Arc<dyn ChannelSink> {
748 let inner = DriverChannelSink {
749 sender: self.sender.clone(),
750 channel_id,
751 local_control_tx: self.local_control_tx.clone(),
752 };
753 let sink = Arc::new(CreditSink::new(inner, DEFAULT_CHANNEL_CREDIT));
754 self.shared
755 .channel_credits
756 .lock()
757 .insert(channel_id, Arc::clone(sink.credit()));
758 sink
759 }
760
761 fn register_rx(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
762 self.register_rx_channel(channel_id)
763 }
764
765 fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
766 self.drop_guard
767 .as_ref()
768 .map(|guard| guard.clone() as ChannelLivenessHandle)
769 }
770}
771
772#[derive(Clone)]
776pub struct DriverCaller {
777 sender: ConnectionSender,
778 shared: Arc<DriverShared>,
779 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
780 closed_rx: watch::Receiver<bool>,
781 resumed_rx: watch::Receiver<u64>,
782 resume_processed_rx: watch::Receiver<u64>,
783 peer_supports_retry: bool,
784 _drop_guard: Option<Arc<CallerDropGuard>>,
785}
786
787impl DriverCaller {
788 pub fn create_tx_channel(&self) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
793 let channel_id = self.shared.channel_ids.lock().alloc();
794 let inner = DriverChannelSink {
795 sender: self.sender.clone(),
796 channel_id,
797 local_control_tx: self.local_control_tx.clone(),
798 };
799 let sink = Arc::new(CreditSink::new(inner, DEFAULT_CHANNEL_CREDIT));
800 self.shared
801 .channel_credits
802 .lock()
803 .insert(channel_id, Arc::clone(sink.credit()));
804 (channel_id, sink)
805 }
806
807 #[cfg(test)]
812 pub(crate) fn connection_sender(&self) -> &ConnectionSender {
813 &self.sender
814 }
815
816 pub fn register_rx_channel(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
821 register_rx_channel_impl(
822 &self.shared,
823 channel_id,
824 "driver.caller.register_rx_channel",
825 self.channel_liveness(),
826 self.local_control_tx.clone(),
827 )
828 }
829}
830
831impl ChannelBinder for DriverCaller {
832 fn create_tx(&self) -> (ChannelId, Arc<dyn ChannelSink>) {
833 let (id, sink) = self.create_tx_channel();
834 (id, sink as Arc<dyn ChannelSink>)
835 }
836
837 fn create_rx(&self) -> (ChannelId, vox_types::BoundChannelReceiver) {
838 let channel_id = self.shared.channel_ids.lock().alloc();
839 let rx = self.register_rx_channel(channel_id);
840 (channel_id, rx)
841 }
842
843 fn bind_tx(&self, channel_id: ChannelId) -> Arc<dyn ChannelSink> {
844 let inner = DriverChannelSink {
845 sender: self.sender.clone(),
846 channel_id,
847 local_control_tx: self.local_control_tx.clone(),
848 };
849 let sink = Arc::new(CreditSink::new(inner, DEFAULT_CHANNEL_CREDIT));
850 self.shared
851 .channel_credits
852 .lock()
853 .insert(channel_id, Arc::clone(sink.credit()));
854 sink
855 }
856
857 fn register_rx(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
858 self.register_rx_channel(channel_id)
859 }
860
861 fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
862 self._drop_guard
863 .as_ref()
864 .map(|guard| guard.clone() as ChannelLivenessHandle)
865 }
866}
867
868impl Caller for DriverCaller {
869 async fn call<'a>(&'a self, mut call: RequestCall<'a>) -> CallResult {
870 if self.peer_supports_retry {
871 let operation_id = OperationId(
872 self.shared
873 .next_operation_id
874 .fetch_add(1, Ordering::Relaxed),
875 );
876 ensure_operation_id(&mut call.metadata, operation_id);
877 }
878
879 let req_id = self.shared.request_ids.lock().alloc();
881
882 let (tx, rx) = moire::sync::oneshot::channel("driver.response");
885 self.shared.pending_responses.lock().insert(req_id, tx);
886
887 if self
895 .sender
896 .send_with_binder(
897 ConnectionMessage::Request(RequestMessage {
898 id: req_id,
899 body: RequestBody::Call(RequestCall {
900 method_id: call.method_id,
901 args: call.args.reborrow(),
902 metadata: call.metadata.clone(),
903 schemas: Default::default(),
904 }),
905 }),
906 Some(self),
907 )
908 .await
909 .is_err()
910 {
911 self.shared.pending_responses.lock().remove(&req_id);
912 return Err(VoxError::SendFailed);
913 }
914
915 let mut resumed_rx = self.resumed_rx.clone();
916 let mut seen_resume_generation = *resumed_rx.borrow();
917 let mut resume_processed_rx = self.resume_processed_rx.clone();
918 let mut closed_rx = self.closed_rx.clone();
919 let mut response = std::pin::pin!(rx.named("awaiting_response"));
920
921 let pending: PendingResponse = loop {
922 tokio::select! {
923 result = &mut response => {
924 match result {
925 Ok(pending) => break pending,
926 Err(_) => {
927 return Err(VoxError::ConnectionClosed);
928 }
929 }
930 }
931 changed = resumed_rx.changed(), if self.peer_supports_retry => {
932 vox_types::dlog!("[CALLER] resumed_rx fired");
933 if changed.is_err() {
934 self.shared.pending_responses.lock().remove(&req_id);
935 return Err(VoxError::SessionShutdown);
936 }
937 let generation = *resumed_rx.borrow();
938 if generation == seen_resume_generation {
939 continue;
940 }
941 seen_resume_generation = generation;
942 while *resume_processed_rx.borrow() < generation {
943 if resume_processed_rx.changed().await.is_err() {
944 self.shared.pending_responses.lock().remove(&req_id);
945 return Err(VoxError::SessionShutdown);
946 }
947 }
948 match metadata_channel_retry_mode(&call.metadata) {
949 ChannelRetryMode::NonIdem => {
950 self.shared.pending_responses.lock().remove(&req_id);
951 return Err(VoxError::Indeterminate);
952 }
953 ChannelRetryMode::Idem | ChannelRetryMode::None => {}
954 }
955 let _ = self.sender.send_with_binder(
959 ConnectionMessage::Request(RequestMessage {
960 id: req_id,
961 body: RequestBody::Call(RequestCall {
962 method_id: call.method_id,
963 args: call.args.reborrow(),
964 metadata: call.metadata.clone(),
965 schemas: Default::default(),
966 }),
967 }),
968 Some(self),
969 ).await;
970 }
971 changed = closed_rx.changed() => {
972 vox_types::dlog!("[CALLER] closed_rx fired, value={}", *closed_rx.borrow());
973 if changed.is_err() || *closed_rx.borrow() {
974 self.shared.pending_responses.lock().remove(&req_id);
975 return Err(VoxError::ConnectionClosed);
976 }
977 }
978 }
979 };
980
981 let PendingResponse {
983 msg: response_msg,
984 schemas: response_schemas,
985 } = pending;
986 let response = response_msg.map(|m| match m.body {
987 RequestBody::Response(r) => r,
988 _ => unreachable!("pending_responses only gets Response variants"),
989 });
990
991 Ok(vox_types::WithTracker {
992 value: response,
993 tracker: response_schemas,
994 })
995 }
996
997 fn closed(&self) -> BoxFut<'_, ()> {
998 Box::pin(async move {
999 if *self.closed_rx.borrow() {
1000 return;
1001 }
1002 let mut rx = self.closed_rx.clone();
1003 while rx.changed().await.is_ok() {
1004 if *rx.borrow() {
1005 return;
1006 }
1007 }
1008 })
1009 }
1010
1011 fn is_connected(&self) -> bool {
1012 !*self.closed_rx.borrow()
1013 }
1014
1015 fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
1016 Some(self)
1017 }
1018}
1019
1020pub struct Driver<H: Handler<DriverReplySink>> {
1027 sender: ConnectionSender,
1028 rx: mpsc::Receiver<crate::session::RecvMessage>,
1029 failures_rx: mpsc::UnboundedReceiver<(RequestId, FailureDisposition)>,
1030 closed_rx: watch::Receiver<bool>,
1031 resumed_rx: watch::Receiver<u64>,
1032 resume_processed_tx: watch::Sender<u64>,
1033 peer_supports_retry: bool,
1034 local_control_rx: mpsc::UnboundedReceiver<DriverLocalControl>,
1035 handler: Arc<H>,
1036 shared: Arc<DriverShared>,
1037 in_flight_handlers: BTreeMap<RequestId, InFlightHandler>,
1040 live_operations: Arc<SyncMutex<LiveOperationTracker>>,
1043 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
1044 drop_control_seed: Option<mpsc::UnboundedSender<DropControlRequest>>,
1045 drop_control_request: DropControlRequest,
1046 drop_guard: SyncMutex<Option<Weak<CallerDropGuard>>>,
1047}
1048
1049enum DriverLocalControl {
1050 CloseChannel {
1051 channel_id: ChannelId,
1052 },
1053 GrantCredit {
1054 channel_id: ChannelId,
1055 additional: u32,
1056 },
1057}
1058
1059struct DriverChannelCreditReplenisher {
1060 channel_id: ChannelId,
1061 threshold: u32,
1062 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
1063 pending: std::sync::Mutex<u32>,
1064}
1065
1066impl DriverChannelCreditReplenisher {
1067 fn new(
1068 channel_id: ChannelId,
1069 initial_credit: u32,
1070 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
1071 ) -> Self {
1072 Self {
1073 channel_id,
1074 threshold: (initial_credit / 2).max(1),
1075 local_control_tx,
1076 pending: std::sync::Mutex::new(0),
1077 }
1078 }
1079}
1080
1081impl ChannelCreditReplenisher for DriverChannelCreditReplenisher {
1082 fn on_item_consumed(&self) {
1083 let mut pending = self.pending.lock().expect("pending credit mutex poisoned");
1084 *pending += 1;
1085 if *pending < self.threshold {
1086 return;
1087 }
1088
1089 let additional = *pending;
1090 *pending = 0;
1091 let _ = self.local_control_tx.send(DriverLocalControl::GrantCredit {
1092 channel_id: self.channel_id,
1093 additional,
1094 });
1095 }
1096}
1097
1098impl<H: Handler<DriverReplySink>> Driver<H> {
1099 fn close_all_channel_runtime_state(&self) {
1100 let mut credits = self.shared.channel_credits.lock();
1101 for semaphore in credits.values() {
1102 semaphore.close();
1103 }
1104 let mut stale = self.shared.stale_close_channels.lock();
1107 stale.extend(credits.keys().copied());
1108 credits.clear();
1109 drop(credits);
1110
1111 self.shared.channel_senders.lock().clear();
1112 self.shared.channel_buffers.lock().clear();
1113 }
1114
1115 fn close_outbound_channel(&self, channel_id: ChannelId) {
1116 if let Some(semaphore) = self.shared.channel_credits.lock().remove(&channel_id) {
1117 semaphore.close();
1118 }
1119 }
1120
1121 fn abort_channel_handlers(&mut self) {
1122 for (_req_id, in_flight) in &self.in_flight_handlers {
1123 if in_flight.has_channels {
1124 if let Some(operation_id) = in_flight.operation_id {
1125 self.shared.operations.remove(operation_id);
1126 self.live_operations.lock().release(operation_id);
1127 }
1128 in_flight.handle.abort();
1129 }
1130 }
1131 }
1132
1133 pub fn new(handle: ConnectionHandle, handler: H) -> Self {
1134 Self::with_operation_store(handle, handler, Arc::new(InMemoryOperationStore::default()))
1135 }
1136
1137 pub fn with_operation_store(
1138 handle: ConnectionHandle,
1139 handler: H,
1140 operation_store: Arc<dyn OperationStore>,
1141 ) -> Self {
1142 let conn_id = handle.connection_id();
1143 let ConnectionHandle {
1144 sender,
1145 rx,
1146 failures_rx,
1147 control_tx,
1148 closed_rx,
1149 resumed_rx,
1150 parity,
1151 peer_supports_retry,
1152 } = handle;
1153 let drop_control_request = DropControlRequest::Close(conn_id);
1154 let (local_control_tx, local_control_rx) = mpsc::unbounded_channel("driver.local_control");
1155 let (resume_processed_tx, _resume_processed_rx) = watch::channel(0_u64);
1156 Self {
1157 sender,
1158 rx,
1159 failures_rx,
1160 closed_rx,
1161 resumed_rx,
1162 resume_processed_tx,
1163 peer_supports_retry,
1164 local_control_rx,
1165 handler: Arc::new(handler),
1166 shared: Arc::new(DriverShared {
1167 pending_responses: SyncMutex::new("driver.pending_responses", BTreeMap::new()),
1168 request_ids: SyncMutex::new("driver.request_ids", IdAllocator::new(parity)),
1169 next_operation_id: AtomicU64::new(1),
1170 operations: operation_store,
1171 channel_ids: SyncMutex::new("driver.channel_ids", IdAllocator::new(parity)),
1172 channel_senders: SyncMutex::new("driver.channel_senders", BTreeMap::new()),
1173 channel_buffers: SyncMutex::new("driver.channel_buffers", BTreeMap::new()),
1174 channel_credits: SyncMutex::new("driver.channel_credits", BTreeMap::new()),
1175 stale_close_channels: SyncMutex::new(
1176 "driver.stale_close_channels",
1177 std::collections::HashSet::new(),
1178 ),
1179 }),
1180 in_flight_handlers: BTreeMap::new(),
1181 live_operations: Arc::new(SyncMutex::new(
1182 "driver.live_operations",
1183 LiveOperationTracker::new(),
1184 )),
1185 local_control_tx,
1186 drop_control_seed: control_tx,
1187 drop_control_request,
1188 drop_guard: SyncMutex::new("driver.drop_guard", None),
1189 }
1190 }
1191
1192 fn existing_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
1198 self.drop_guard.lock().as_ref().and_then(Weak::upgrade)
1199 }
1200
1201 fn connection_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
1202 if let Some(existing) = self.existing_drop_guard() {
1203 Some(existing)
1204 } else if let Some(seed) = &self.drop_control_seed {
1205 let mut guard = self.drop_guard.lock();
1206 if let Some(existing) = guard.as_ref().and_then(Weak::upgrade) {
1207 Some(existing)
1208 } else {
1209 let arc = Arc::new(CallerDropGuard {
1210 control_tx: seed.clone(),
1211 request: self.drop_control_request,
1212 });
1213 *guard = Some(Arc::downgrade(&arc));
1214 Some(arc)
1215 }
1216 } else {
1217 None
1218 }
1219 }
1220
1221 pub fn caller(&self) -> DriverCaller {
1222 let drop_guard = self.connection_drop_guard();
1223 DriverCaller {
1224 sender: self.sender.clone(),
1225 shared: Arc::clone(&self.shared),
1226 local_control_tx: self.local_control_tx.clone(),
1227 closed_rx: self.closed_rx.clone(),
1228 resumed_rx: self.resumed_rx.clone(),
1229 resume_processed_rx: self.resume_processed_tx.subscribe(),
1230 peer_supports_retry: self.peer_supports_retry,
1231 _drop_guard: drop_guard,
1232 }
1233 }
1234
1235 fn internal_binder(&self) -> DriverChannelBinder {
1236 DriverChannelBinder {
1237 sender: self.sender.clone(),
1238 shared: Arc::clone(&self.shared),
1239 local_control_tx: self.local_control_tx.clone(),
1240 drop_guard: self.existing_drop_guard(),
1241 }
1242 }
1243
1244 pub async fn run(&mut self) {
1249 let mut resumed_rx = self.resumed_rx.clone();
1250 let mut seen_resume_generation = *resumed_rx.borrow();
1251 loop {
1252 tracing::trace!("driver select loop top");
1253 tokio::select! {
1254 biased;
1255 changed = resumed_rx.changed() => {
1256 if changed.is_err() {
1257 break;
1258 }
1259 let generation = *resumed_rx.borrow();
1260 if generation != seen_resume_generation {
1261 seen_resume_generation = generation;
1262 self.close_all_channel_runtime_state();
1263 self.abort_channel_handlers();
1264 let _ = self.resume_processed_tx.send(generation);
1265 }
1266 }
1267 recv = self.rx.recv() => {
1268 match recv {
1269 Some(recv) => {
1270 self.handle_recv(recv);
1271 }
1272 None => {
1273 tracing::trace!("driver rx closed, exiting loop");
1274 break;
1275 }
1276 }
1277 }
1278 Some((req_id, disposition)) = self.failures_rx.recv() => {
1279 tracing::trace!(%req_id, ?disposition, "failures_rx fired");
1280 let in_flight_found = self.in_flight_handlers.contains_key(&req_id);
1281 let in_flight_method_id =
1282 self.in_flight_handlers.get(&req_id).map(|in_flight| in_flight.method_id);
1283 let reply_disposition = self
1284 .in_flight_handlers
1285 .get(&req_id)
1286 .map(|in_flight| {
1287 if in_flight.has_channels && !in_flight.retry.idem {
1288 Some(FailureDisposition::Indeterminate)
1289 } else if in_flight.has_channels && in_flight.retry.idem {
1290 None
1291 } else {
1292 Some(disposition)
1293 }
1294 })
1295 .unwrap_or(Some(disposition));
1296 tracing::trace!(%req_id, in_flight_found, ?reply_disposition, "failures_rx computed disposition");
1297 self.in_flight_handlers.remove(&req_id);
1299 let had_pending = self.shared.pending_responses.lock().remove(&req_id).is_some();
1300 tracing::trace!(%req_id, had_pending, "failures_rx checked pending_responses");
1301 if !had_pending {
1302 let Some(reply_disposition) = reply_disposition else {
1303 tracing::trace!(%req_id, "failures_rx: no reply_disposition, skipping");
1304 continue;
1305 };
1306 tracing::trace!(%req_id, ?reply_disposition, "failures_rx: sending error response");
1307 let vox_error = match reply_disposition {
1308 FailureDisposition::Cancelled => VoxError::Cancelled,
1309 FailureDisposition::Indeterminate => VoxError::Indeterminate,
1310 };
1311 if let Some(method_id) = in_flight_method_id
1312 && let Some(response_shape) = self.handler.response_wire_shape(method_id)
1313 && let Ok(extracted) = vox_types::extract_schemas(response_shape)
1314 {
1315 let registry = vox_types::build_registry(&extracted.schemas);
1316 let error: Result<(), VoxError<core::convert::Infallible>> =
1317 Err(vox_error);
1318 let encoded = vox_postcard::to_vec(&error)
1319 .expect("serialize runtime-generated error response");
1320 let mut response = RequestResponse {
1321 ret: Payload::PostcardBytes(Box::leak(encoded.into_boxed_slice())),
1322 metadata: Default::default(),
1323 schemas: Default::default(),
1324 };
1325 self.sender.prepare_response_from_source(
1326 req_id,
1327 method_id,
1328 &extracted.root,
1329 ®istry,
1330 &mut response,
1331 );
1332 let _ = self.sender.send_response(req_id, response).await;
1333 } else {
1334 let error: Result<(), VoxError<core::convert::Infallible>> =
1335 Err(vox_error);
1336 let _ = self.sender.send_response(req_id, RequestResponse {
1337 ret: Payload::outgoing(&error),
1338 metadata: Default::default(),
1339 schemas: Default::default(),
1340 }).await;
1341 }
1342 tracing::trace!(%req_id, "failures_rx: error response sent");
1343 }
1344 }
1345 Some(ctrl) = self.local_control_rx.recv() => {
1346 self.handle_local_control(ctrl).await;
1347 }
1348 }
1349 }
1350
1351 for (_, in_flight) in std::mem::take(&mut self.in_flight_handlers) {
1352 if !in_flight.retry.persist {
1353 in_flight.handle.abort();
1354 }
1355 }
1356 self.shared.pending_responses.lock().clear();
1357
1358 self.close_all_channel_runtime_state();
1362 }
1363
1364 async fn handle_local_control(&mut self, control: DriverLocalControl) {
1365 match control {
1366 DriverLocalControl::CloseChannel { channel_id } => {
1367 if self.shared.stale_close_channels.lock().remove(&channel_id) {
1372 tracing::trace!(%channel_id, "suppressing ChannelClose for stale channel");
1373 return;
1374 }
1375 let _ = self
1376 .sender
1377 .send(ConnectionMessage::Channel(ChannelMessage {
1378 id: channel_id,
1379 body: ChannelBody::Close(ChannelClose {
1380 metadata: Default::default(),
1381 }),
1382 }))
1383 .await;
1384 }
1385 DriverLocalControl::GrantCredit {
1386 channel_id,
1387 additional,
1388 } => {
1389 let _ = self
1390 .sender
1391 .send(ConnectionMessage::Channel(ChannelMessage {
1392 id: channel_id,
1393 body: ChannelBody::GrantCredit(vox_types::ChannelGrantCredit {
1394 additional,
1395 }),
1396 }))
1397 .await;
1398 }
1399 }
1400 }
1401
1402 fn handle_recv(&mut self, recv: crate::session::RecvMessage) {
1403 let crate::session::RecvMessage { schemas, msg } = recv;
1404 let is_request = matches!(&*msg, ConnectionMessage::Request(_));
1405 if is_request {
1406 if let ConnectionMessage::Request(req) = &*msg {
1407 vox_types::dlog!(
1408 "[driver] handle_recv request: conn={:?} req={:?} body={} method={:?}",
1409 self.sender.connection_id(),
1410 req.id,
1411 match &req.body {
1412 RequestBody::Call(_) => "Call",
1413 RequestBody::Response(_) => "Response",
1414 RequestBody::Cancel(_) => "Cancel",
1415 },
1416 match &req.body {
1417 RequestBody::Call(call) => Some(call.method_id),
1418 RequestBody::Response(_) | RequestBody::Cancel(_) => None,
1419 }
1420 );
1421 match &req.body {
1422 RequestBody::Call(call) => tracing::trace!(
1423 conn_id = self.sender.connection_id().0,
1424 req_id = req.id.0,
1425 method_id = call.method_id.0,
1426 "driver received call"
1427 ),
1428 RequestBody::Response(_) => tracing::trace!(
1429 conn_id = self.sender.connection_id().0,
1430 req_id = req.id.0,
1431 "driver received response message"
1432 ),
1433 RequestBody::Cancel(_) => tracing::trace!(
1434 conn_id = self.sender.connection_id().0,
1435 req_id = req.id.0,
1436 "driver received cancel message"
1437 ),
1438 }
1439 }
1440 let msg = msg.map(|m| match m {
1441 ConnectionMessage::Request(r) => r,
1442 _ => unreachable!(),
1443 });
1444 self.handle_request(msg, schemas);
1445 } else {
1446 let msg = msg.map(|m| match m {
1447 ConnectionMessage::Channel(c) => c,
1448 _ => unreachable!(),
1449 });
1450 self.handle_channel(msg);
1451 }
1452 }
1453
1454 fn handle_request(
1455 &mut self,
1456 msg: SelfRef<RequestMessage<'static>>,
1457 schemas: Arc<vox_types::SchemaRecvTracker>,
1458 ) {
1459 let req_id = msg.id;
1460 let is_call = matches!(&msg.body, RequestBody::Call(_));
1461 let is_response = matches!(&msg.body, RequestBody::Response(_));
1462 let is_cancel = matches!(&msg.body, RequestBody::Cancel(_));
1463
1464 if is_call {
1465 let method_id = match &msg.body {
1466 RequestBody::Call(call) => call.method_id,
1467 _ => unreachable!(),
1468 };
1469 vox_types::dlog!(
1470 "[driver] inbound call: conn={:?} req={:?} method={:?}",
1471 self.sender.connection_id(),
1472 req_id,
1473 method_id
1474 );
1475 let call = msg.map(|m| match m.body {
1478 RequestBody::Call(c) => c,
1479 _ => unreachable!(),
1480 });
1481 let handler = Arc::clone(&self.handler);
1482 let retry = handler.retry_policy(call.method_id);
1483 let operation_id = metadata_operation_id(&call.metadata);
1484 let method_id = call.method_id;
1485
1486 if let Some(operation_id) = operation_id {
1487 let admit = self.live_operations.lock().admit(
1489 operation_id,
1490 call.method_id,
1491 incoming_args_bytes(&call),
1492 retry,
1493 req_id,
1494 );
1495 match admit {
1496 AdmitResult::Attached => return,
1497 AdmitResult::Conflict => {
1498 let sender = self.sender.clone();
1499 moire::task::spawn(
1500 async move {
1501 let error: Result<(), VoxError<core::convert::Infallible>> =
1502 Err(VoxError::InvalidPayload("operation ID conflict".into()));
1503 let _ = sender
1504 .send_response(
1505 req_id,
1506 RequestResponse {
1507 ret: Payload::outgoing(&error),
1508 metadata: Default::default(),
1509 schemas: Default::default(),
1510 },
1511 )
1512 .await;
1513 }
1514 .named("operation_reject"),
1515 );
1516 return;
1517 }
1518 AdmitResult::Start => {}
1519 }
1520
1521 match self.shared.operations.lookup(operation_id) {
1523 crate::OperationState::Sealed => {
1524 if let Some(sealed) = self.shared.operations.get_sealed(operation_id) {
1526 let sender = self.sender.clone();
1527 let method_id = call.method_id;
1528 let operations = Arc::clone(&self.shared.operations);
1529 self.live_operations.lock().seal(operation_id);
1531 moire::task::spawn(
1532 async move {
1533 if replay_sealed_response(
1534 sender.clone(),
1535 req_id,
1536 method_id,
1537 sealed.response.as_bytes(),
1538 sealed.root_type,
1539 operations.as_ref(),
1540 )
1541 .await
1542 .is_err()
1543 {
1544 sender.mark_failure(req_id, FailureDisposition::Cancelled);
1545 }
1546 }
1547 .named("operation_replay"),
1548 );
1549 return;
1550 }
1551 }
1552 crate::OperationState::Admitted => {
1553 self.live_operations.lock().seal(operation_id);
1555 let sender = self.sender.clone();
1556 moire::task::spawn(
1557 async move {
1558 let error: Result<(), VoxError<core::convert::Infallible>> =
1559 Err(VoxError::Indeterminate);
1560 let _ = sender
1561 .send_response(
1562 req_id,
1563 RequestResponse {
1564 ret: Payload::outgoing(&error),
1565 metadata: Default::default(),
1566 schemas: Default::default(),
1567 },
1568 )
1569 .await;
1570 }
1571 .named("operation_indeterminate"),
1572 );
1573 return;
1574 }
1575 crate::OperationState::Unknown => {
1576 if !retry.idem {
1579 self.shared.operations.admit(operation_id);
1580 }
1581 }
1582 }
1583 }
1584 let reply = DriverReplySink {
1585 sender: Some(self.sender.clone()),
1586 request_id: req_id,
1587 method_id: call.method_id,
1588 retry,
1589 operation_id,
1590 operations: operation_id.map(|_| Arc::clone(&self.shared.operations)),
1591 live_operations: operation_id.map(|_| Arc::clone(&self.live_operations)),
1592 binder: self.internal_binder(),
1593 };
1594 let has_channels = handler.args_have_channels(call.method_id);
1595 let join_handle = moire::task::spawn(
1596 async move {
1597 vox_types::dlog!(
1598 "[driver] handler start: req={:?} method={:?}",
1599 req_id,
1600 method_id
1601 );
1602 handler.handle(call, reply, schemas).await;
1603 vox_types::dlog!(
1604 "[driver] handler done: req={:?} method={:?}",
1605 req_id,
1606 method_id
1607 );
1608 }
1609 .named("handler"),
1610 );
1611 self.in_flight_handlers.insert(
1612 req_id,
1613 InFlightHandler {
1614 handle: join_handle,
1615 method_id,
1616 retry,
1617 has_channels,
1618 operation_id,
1619 },
1620 );
1621 } else if is_response {
1622 vox_types::dlog!(
1624 "[driver] inbound response: conn={:?} req={:?}",
1625 self.sender.connection_id(),
1626 req_id
1627 );
1628 tracing::trace!(%req_id, "driver received response");
1629 if let Some(tx) = self.shared.pending_responses.lock().remove(&req_id) {
1630 vox_types::dlog!("[driver] routing response to waiter: req={:?}", req_id);
1631 tracing::trace!(%req_id, "routing response to pending oneshot");
1632 let _: Result<(), _> = tx.send(PendingResponse { msg, schemas });
1633 } else {
1634 vox_types::dlog!("[driver] dropped unmatched response: req={:?}", req_id);
1635 tracing::trace!(%req_id, "no pending response slot for this req_id");
1636 }
1637 } else if is_cancel {
1638 vox_types::dlog!(
1639 "[driver] inbound cancel: conn={:?} req={:?}",
1640 self.sender.connection_id(),
1641 req_id
1642 );
1643 tracing::trace!(%req_id, in_flight = self.in_flight_handlers.contains_key(&req_id), "received cancel");
1646 match self.live_operations.lock().cancel(req_id) {
1647 CancelResult::NotFound => {
1648 let should_abort = self
1649 .in_flight_handlers
1650 .get(&req_id)
1651 .map(|in_flight| !in_flight.retry.persist)
1652 .unwrap_or(false);
1653 tracing::trace!(%req_id, should_abort, "cancel: not in live operations");
1654 if should_abort && let Some(in_flight) = self.in_flight_handlers.remove(&req_id)
1655 {
1656 tracing::trace!(%req_id, "aborting handler");
1657 in_flight.handle.abort();
1658 }
1659 }
1660 CancelResult::Detached => {}
1661 CancelResult::Abort {
1662 owner_request_id,
1663 waiters,
1664 } => {
1665 if let Some(in_flight) = self.in_flight_handlers.remove(&owner_request_id) {
1666 if let Some(op_id) = in_flight.operation_id {
1667 self.shared.operations.remove(op_id);
1668 }
1669 in_flight.handle.abort();
1670 }
1671 for waiter in waiters {
1672 self.sender
1673 .mark_failure(waiter, FailureDisposition::Cancelled);
1674 }
1675 }
1676 }
1677 }
1680 }
1681
1682 fn handle_channel(&mut self, msg: SelfRef<ChannelMessage<'static>>) {
1683 let chan_id = msg.id;
1684
1685 let sender = self.shared.channel_senders.lock().get(&chan_id).cloned();
1688
1689 match &msg.body {
1690 ChannelBody::Item(_item) => {
1692 if let Some(tx) = &sender {
1693 tracing::trace!(
1694 conn_id = self.sender.connection_id().0,
1695 channel_id = chan_id.0,
1696 registered = true,
1697 "driver received channel item"
1698 );
1699 let item = msg.map(|m| match m.body {
1700 ChannelBody::Item(item) => item,
1701 _ => unreachable!(),
1702 });
1703 let _ = tx.try_send(IncomingChannelMessage::Item(item));
1705 } else {
1706 tracing::trace!(
1707 conn_id = self.sender.connection_id().0,
1708 channel_id = chan_id.0,
1709 registered = false,
1710 "driver buffered channel item before registration"
1711 );
1712 let item = msg.map(|m| match m.body {
1714 ChannelBody::Item(item) => item,
1715 _ => unreachable!(),
1716 });
1717 self.shared
1718 .channel_buffers
1719 .lock()
1720 .entry(chan_id)
1721 .or_default()
1722 .push(IncomingChannelMessage::Item(item));
1723 }
1724 }
1725 ChannelBody::Close(_close) => {
1727 if let Some(tx) = &sender {
1728 tracing::trace!(
1729 conn_id = self.sender.connection_id().0,
1730 channel_id = chan_id.0,
1731 registered = true,
1732 "driver received channel close"
1733 );
1734 let close = msg.map(|m| match m.body {
1735 ChannelBody::Close(close) => close,
1736 _ => unreachable!(),
1737 });
1738 let _ = tx.try_send(IncomingChannelMessage::Close(close));
1739 } else {
1740 tracing::trace!(
1741 conn_id = self.sender.connection_id().0,
1742 channel_id = chan_id.0,
1743 registered = false,
1744 "driver buffered channel close before registration"
1745 );
1746 let close = msg.map(|m| match m.body {
1748 ChannelBody::Close(close) => close,
1749 _ => unreachable!(),
1750 });
1751 self.shared
1752 .channel_buffers
1753 .lock()
1754 .entry(chan_id)
1755 .or_default()
1756 .push(IncomingChannelMessage::Close(close));
1757 }
1758 self.shared.channel_senders.lock().remove(&chan_id);
1759 self.close_outbound_channel(chan_id);
1760 }
1761 ChannelBody::Reset(_reset) => {
1763 if let Some(tx) = &sender {
1764 tracing::trace!(
1765 conn_id = self.sender.connection_id().0,
1766 channel_id = chan_id.0,
1767 registered = true,
1768 "driver received channel reset"
1769 );
1770 let reset = msg.map(|m| match m.body {
1771 ChannelBody::Reset(reset) => reset,
1772 _ => unreachable!(),
1773 });
1774 let _ = tx.try_send(IncomingChannelMessage::Reset(reset));
1775 } else {
1776 tracing::trace!(
1777 conn_id = self.sender.connection_id().0,
1778 channel_id = chan_id.0,
1779 registered = false,
1780 "driver buffered channel reset before registration"
1781 );
1782 let reset = msg.map(|m| match m.body {
1784 ChannelBody::Reset(reset) => reset,
1785 _ => unreachable!(),
1786 });
1787 self.shared
1788 .channel_buffers
1789 .lock()
1790 .entry(chan_id)
1791 .or_default()
1792 .push(IncomingChannelMessage::Reset(reset));
1793 }
1794 self.shared.channel_senders.lock().remove(&chan_id);
1795 self.close_outbound_channel(chan_id);
1796 }
1797 ChannelBody::GrantCredit(grant) => {
1800 tracing::trace!(
1801 conn_id = self.sender.connection_id().0,
1802 channel_id = chan_id.0,
1803 additional = grant.additional,
1804 "driver received channel credit"
1805 );
1806 if let Some(semaphore) = self.shared.channel_credits.lock().get(&chan_id) {
1807 semaphore.add_permits(grant.additional as usize);
1808 }
1809 }
1810 }
1811 }
1812}