1use std::{
2 collections::{BTreeMap, HashMap, HashSet},
3 pin::Pin,
4 sync::{
5 Arc, Weak,
6 atomic::{AtomicU64, Ordering},
7 },
8};
9
10use vox_types::time::Instant;
11
12use futures_util::future::{AbortHandle, Abortable};
13use futures_util::stream::{FuturesUnordered, StreamExt as _};
14use moire::sync::{Semaphore, SyncMutex};
15use tokio::sync::watch;
16
17use moire::task::FutureExt as _;
18use vox_types::{
19 BoxFut, CallResult, ChannelBinder, ChannelBody, ChannelClose, ChannelCreditReplenisher,
20 ChannelCreditReplenisherHandle, ChannelEventContext, ChannelId, ChannelItem,
21 ChannelLivenessHandle, ChannelMailboxReceiver, ChannelMailboxSender, ChannelMessage,
22 ChannelRetryMode, ChannelSink, ConnectionId, CreditSink, Handler, IdAllocator,
23 IncomingChannelMessage, MaybeSend, MaybeSendFuture, MaybeSync, Payload, ReplySink, RequestBody,
24 RequestCall, RequestId, RequestMessage, RequestResponse, SelfRef, TrySendError, TxError,
25 VoxError, channel_mailbox, ensure_operation_id, metadata_channel_retry_mode,
26 metadata_operation_id,
27};
28use vox_types::{
29 ChannelCloseReason, ChannelDebugContext, ChannelDirection, ChannelEvent, ChannelResetReason,
30 ChannelSendOutcome, ChannelTrySendOutcome, DriverEvent, RpcOutcome, VoxObserverHandle,
31};
32use vox_types::{
33 ChannelDebugSnapshot, ChannelReceiverState, ConnectionCloseReason, ConnectionDebugSnapshot,
34 ConnectionDebugState, DriverTaskStatus, RequestDebugSnapshot, RequestDebugState,
35 VoxDebugSnapshot,
36};
37
38use crate::session::{
39 ConnectionHandle, ConnectionMessage, ConnectionSender, DropControlRequest, FailureDisposition,
40};
41use crate::{InMemoryOperationStore, OperationStore};
42use moire::sync::mpsc;
43use vox_types::{OperationId, PostcardPayload};
44
45struct PendingResponse {
51 msg: SelfRef<RequestMessage<'static>>,
52 schemas: Arc<vox_types::SchemaRecvTracker>,
53}
54
55type ResponseSlot = moire::sync::oneshot::Sender<PendingResponse>;
56
57struct InFlightHandler {
58 abort: AbortHandle,
63 method_id: vox_types::MethodId,
64 retry: vox_types::RetryPolicy,
65 operation_id: Option<OperationId>,
66}
67
68type HandlerFut = Abortable<Pin<Box<dyn MaybeSendFuture<Output = RequestId> + 'static>>>;
81
82#[derive(Clone, Copy, Debug)]
83enum ChannelRuntimeTeardown {
84 DropOnly,
85 ConnectionClosed(ConnectionCloseReason),
86}
87
88struct LiveOperationTracker {
98 live: HashMap<OperationId, LiveOperation>,
100 request_to_operation: HashMap<RequestId, OperationId>,
102}
103
104struct LiveOperation {
105 method_id: vox_types::MethodId,
106 args_hash: u64,
107 owner_request_id: RequestId,
108 waiters: Vec<RequestId>,
109 retry: vox_types::RetryPolicy,
110}
111
112enum AdmitResult {
113 Start,
115 Attached,
117 Conflict,
119}
120
121impl LiveOperationTracker {
122 fn new() -> Self {
123 Self {
124 live: HashMap::new(),
125 request_to_operation: HashMap::new(),
126 }
127 }
128
129 fn admit(
130 &mut self,
131 operation_id: OperationId,
132 method_id: vox_types::MethodId,
133 args: &[u8],
134 retry: vox_types::RetryPolicy,
135 request_id: RequestId,
136 ) -> AdmitResult {
137 use std::hash::{Hash, Hasher};
138 let args_hash = {
139 let mut h = std::collections::hash_map::DefaultHasher::new();
140 method_id.hash(&mut h);
141 args.hash(&mut h);
142 h.finish()
143 };
144 let live_operations = self.live.len();
145
146 if let Some(live) = self.live.get_mut(&operation_id) {
147 if live.method_id != method_id || live.args_hash != args_hash {
148 let request_bindings = self.request_to_operation.len();
149 tracing::trace!(
150 %operation_id,
151 %request_id,
152 ?method_id,
153 live_operations,
154 request_bindings,
155 "live operation conflict"
156 );
157 return AdmitResult::Conflict;
158 }
159 live.waiters.push(request_id);
160 self.request_to_operation.insert(request_id, operation_id);
161 let waiters = live.waiters.len();
162 let request_bindings = self.request_to_operation.len();
163 tracing::trace!(
164 %operation_id,
165 %request_id,
166 ?method_id,
167 waiters,
168 live_operations,
169 request_bindings,
170 "live operation attached"
171 );
172 return AdmitResult::Attached;
173 }
174
175 self.live.insert(
176 operation_id,
177 LiveOperation {
178 method_id,
179 args_hash,
180 owner_request_id: request_id,
181 waiters: vec![request_id],
182 retry,
183 },
184 );
185 self.request_to_operation.insert(request_id, operation_id);
186 let live_operations = self.live.len();
187 let request_bindings = self.request_to_operation.len();
188 tracing::trace!(
189 %operation_id,
190 %request_id,
191 ?method_id,
192 live_operations,
193 request_bindings,
194 "live operation admitted"
195 );
196 AdmitResult::Start
197 }
198
199 fn seal(&mut self, operation_id: OperationId) -> Vec<RequestId> {
201 if let Some(live) = self.live.remove(&operation_id) {
202 for waiter in &live.waiters {
203 self.request_to_operation.remove(waiter);
204 }
205 let waiters = live.waiters.len();
206 let live_operations = self.live.len();
207 let request_bindings = self.request_to_operation.len();
208 tracing::trace!(
209 %operation_id,
210 waiters,
211 live_operations,
212 request_bindings,
213 "live operation sealed"
214 );
215 live.waiters
216 } else {
217 vec![]
218 }
219 }
220
221 fn release(&mut self, operation_id: OperationId) -> Option<LiveOperation> {
223 if let Some(live) = self.live.remove(&operation_id) {
224 for waiter in &live.waiters {
225 self.request_to_operation.remove(waiter);
226 }
227 let waiters = live.waiters.len();
228 let live_operations = self.live.len();
229 let request_bindings = self.request_to_operation.len();
230 tracing::trace!(
231 %operation_id,
232 waiters,
233 live_operations,
234 request_bindings,
235 "live operation released"
236 );
237 Some(live)
238 } else {
239 None
240 }
241 }
242
243 fn cancel(&mut self, request_id: RequestId) -> CancelResult {
245 let Some(&operation_id) = self.request_to_operation.get(&request_id) else {
246 return CancelResult::NotFound;
247 };
248 let live_operations = self.live.len();
249 let Some(live) = self.live.get_mut(&operation_id) else {
250 self.request_to_operation.remove(&request_id);
251 return CancelResult::NotFound;
252 };
253
254 if live.retry.persist {
255 if live.owner_request_id == request_id {
257 return CancelResult::NotFound; }
259 live.waiters.retain(|w| *w != request_id);
260 self.request_to_operation.remove(&request_id);
261 let waiters = live.waiters.len();
262 let request_bindings = self.request_to_operation.len();
263 tracing::trace!(
264 %operation_id,
265 %request_id,
266 waiters,
267 live_operations,
268 request_bindings,
269 "live operation detached waiter"
270 );
271 CancelResult::Detached
272 } else {
273 let live = self.live.remove(&operation_id).unwrap();
275 for waiter in &live.waiters {
276 self.request_to_operation.remove(waiter);
277 }
278 let waiters = live.waiters.len();
279 let live_operations = self.live.len();
280 let request_bindings = self.request_to_operation.len();
281 tracing::trace!(
282 %operation_id,
283 %request_id,
284 waiters,
285 live_operations,
286 request_bindings,
287 "live operation aborted"
288 );
289 CancelResult::Abort {
290 owner_request_id: live.owner_request_id,
291 waiters: live.waiters,
292 }
293 }
294 }
295}
296
297enum CancelResult {
298 NotFound,
299 Detached,
300 Abort {
301 owner_request_id: RequestId,
302 waiters: Vec<RequestId>,
303 },
304}
305
306#[derive(Clone)]
307struct RequestRuntimeDebug {
308 method_id: vox_types::MethodId,
309 service: Option<&'static str>,
310 method: Option<&'static str>,
311 started_at: Instant,
312 state: RequestDebugState,
313 response_sender_blocked: Option<bool>,
314 associated_channels: Vec<ChannelId>,
315}
316
317impl RequestRuntimeDebug {
318 fn snapshot(&self, request_id: RequestId, now: Instant) -> RequestDebugSnapshot {
319 RequestDebugSnapshot {
320 request_id,
321 service: self.service,
322 method: self.method,
323 method_id: self.method_id,
324 age: now.saturating_duration_since(self.started_at),
325 state: self.state,
326 response_sender_blocked: self.response_sender_blocked,
327 associated_channels: self.associated_channels.clone(),
328 }
329 }
330}
331
332#[derive(Clone)]
333struct ChannelRuntimeDebug {
334 direction: ChannelDirection,
335 debug: Option<ChannelDebugContext>,
336 initial_credit: u32,
337 inbound_queue_len: usize,
338 inbound_queue_capacity: Option<usize>,
339 receiver_state: ChannelReceiverState,
340 last_item_sent_at: Option<Instant>,
341 last_item_received_at: Option<Instant>,
342 last_item_consumed_at: Option<Instant>,
343 last_credit_granted_at: Option<Instant>,
344 last_credit_received_at: Option<Instant>,
345 last_credit_granted_amount: Option<u32>,
346 last_credit_received_amount: Option<u32>,
347 pending_local_grant_credit: u32,
348 total_credit_granted: u64,
349 total_credit_received: u64,
350 sent: u64,
351 sends_started: u64,
352 sends_completed: u64,
353 sends_waited_for_credit: u64,
354 try_send_full_credit: u64,
355 try_send_full_runtime_queue: u64,
356 closed: u64,
357 reset: u64,
358 dropped: u64,
359 items_received: u64,
360 items_consumed: u64,
361 credit_granted: u64,
362 credit_received: u64,
363 close_reason: Option<ChannelCloseReason>,
364 reset_reason: Option<ChannelResetReason>,
365}
366
367impl ChannelRuntimeDebug {
368 fn new(
369 direction: ChannelDirection,
370 initial_credit: u32,
371 debug: Option<ChannelDebugContext>,
372 ) -> Self {
373 Self {
374 direction,
375 debug,
376 initial_credit,
377 inbound_queue_len: 0,
378 inbound_queue_capacity: match direction {
379 ChannelDirection::Rx => Some(initial_credit as usize),
380 ChannelDirection::Tx => None,
381 },
382 receiver_state: ChannelReceiverState::Present,
383 last_item_sent_at: None,
384 last_item_received_at: None,
385 last_item_consumed_at: None,
386 last_credit_granted_at: None,
387 last_credit_received_at: None,
388 last_credit_granted_amount: None,
389 last_credit_received_amount: None,
390 pending_local_grant_credit: 0,
391 total_credit_granted: 0,
392 total_credit_received: 0,
393 sent: 0,
394 sends_started: 0,
395 sends_completed: 0,
396 sends_waited_for_credit: 0,
397 try_send_full_credit: 0,
398 try_send_full_runtime_queue: 0,
399 closed: 0,
400 reset: 0,
401 dropped: 0,
402 items_received: 0,
403 items_consumed: 0,
404 credit_granted: 0,
405 credit_received: 0,
406 close_reason: None,
407 reset_reason: None,
408 }
409 }
410
411 fn merge_debug(&mut self, debug: Option<ChannelDebugContext>) {
412 if self.debug.is_none() {
413 self.debug = debug;
414 }
415 }
416
417 fn mark_item_received(&mut self, now: Instant) {
418 self.items_received = self.items_received.saturating_add(1);
419 self.inbound_queue_len = self.inbound_queue_len.saturating_add(1);
420 self.last_item_received_at = Some(now);
421 }
422
423 fn mark_closed(&mut self, reason: ChannelCloseReason) {
424 self.closed = self.closed.saturating_add(1);
425 self.close_reason = Some(reason);
426 self.receiver_state = ChannelReceiverState::Closed;
427 if reason == ChannelCloseReason::Dropped {
428 self.dropped = self.dropped.saturating_add(1);
429 self.receiver_state = ChannelReceiverState::Dropped;
430 }
431 }
432
433 fn mark_reset(&mut self, reason: ChannelResetReason) {
434 self.reset = self.reset.saturating_add(1);
435 self.reset_reason = Some(reason);
436 self.receiver_state = ChannelReceiverState::Reset;
437 }
438
439 fn mark_send_started(&mut self) {
440 self.sends_started = self.sends_started.saturating_add(1);
441 }
442
443 fn mark_send_waiting_for_credit(&mut self) {
444 self.sends_waited_for_credit = self.sends_waited_for_credit.saturating_add(1);
445 }
446
447 fn mark_send_finished(&mut self, outcome: ChannelSendOutcome, now: Instant) {
448 self.sends_completed = self.sends_completed.saturating_add(1);
449 if outcome == ChannelSendOutcome::Sent {
450 self.sent = self.sent.saturating_add(1);
451 self.last_item_sent_at = Some(now);
452 }
453 }
454
455 fn mark_try_send_outcome(&mut self, outcome: ChannelTrySendOutcome, now: Instant) {
456 match outcome {
457 ChannelTrySendOutcome::Sent => {
458 self.sent = self.sent.saturating_add(1);
459 self.last_item_sent_at = Some(now);
460 }
461 ChannelTrySendOutcome::FullCredit => {
462 self.try_send_full_credit = self.try_send_full_credit.saturating_add(1);
463 }
464 ChannelTrySendOutcome::FullRuntimeQueue => {
465 self.try_send_full_runtime_queue =
466 self.try_send_full_runtime_queue.saturating_add(1);
467 }
468 ChannelTrySendOutcome::Unbound | ChannelTrySendOutcome::Closed => {}
469 }
470 }
471
472 fn mark_item_consumed(&mut self, now: Instant) {
473 self.items_consumed = self.items_consumed.saturating_add(1);
474 self.inbound_queue_len = self.inbound_queue_len.saturating_sub(1);
475 self.last_item_consumed_at = Some(now);
476 }
477
478 fn mark_inbound_item_not_enqueued(&mut self) {
479 self.inbound_queue_len = self.inbound_queue_len.saturating_sub(1);
480 }
481
482 fn mark_credit_granted(&mut self, amount: u32, now: Instant) {
483 self.credit_granted = self.credit_granted.saturating_add(1);
484 self.total_credit_granted = self.total_credit_granted.saturating_add(amount as u64);
485 self.last_credit_granted_at = Some(now);
486 self.last_credit_granted_amount = Some(amount);
487 self.pending_local_grant_credit = 0;
488 }
489
490 fn mark_credit_received(&mut self, amount: u32, now: Instant) {
491 self.credit_received = self.credit_received.saturating_add(1);
492 self.total_credit_received = self.total_credit_received.saturating_add(amount as u64);
493 self.last_credit_received_at = Some(now);
494 self.last_credit_received_amount = Some(amount);
495 }
496
497 fn mark_receiver_dropped(&mut self) {
498 self.reset = self.reset.saturating_add(1);
499 self.reset_reason = Some(ChannelResetReason::ReceiverDropped);
500 self.receiver_state = ChannelReceiverState::Dropped;
501 self.dropped = self.dropped.saturating_add(1);
502 }
503
504 fn snapshot(
505 &self,
506 connection_id: ConnectionId,
507 channel_id: ChannelId,
508 available_send_credit: Option<u32>,
509 ) -> ChannelDebugSnapshot {
510 ChannelDebugSnapshot {
511 connection_id,
512 channel_id,
513 direction: self.direction,
514 debug: self.debug,
515 initial_credit: self.initial_credit,
516 available_send_credit,
517 inbound_queue_len: Some(self.inbound_queue_len),
518 inbound_queue_capacity: self.inbound_queue_capacity,
519 outbound_runtime_queue_len: None,
520 outbound_runtime_queue_capacity: None,
521 send_waiters_count: None,
522 receiver_state: self.receiver_state,
523 last_item_sent_at: self.last_item_sent_at,
524 last_item_received_at: self.last_item_received_at,
525 last_item_consumed_at: self.last_item_consumed_at,
526 last_credit_granted_at: self.last_credit_granted_at,
527 last_credit_received_at: self.last_credit_received_at,
528 last_credit_granted_amount: self.last_credit_granted_amount,
529 last_credit_received_amount: self.last_credit_received_amount,
530 pending_local_grant_credit: self.pending_local_grant_credit,
531 total_credit_granted: self.total_credit_granted,
532 total_credit_received: self.total_credit_received,
533 current_permit_count: available_send_credit,
534 zero_credit_with_blocked_senders: available_send_credit == Some(0)
535 && self.sends_waited_for_credit > 0,
536 sent: self.sent,
537 sends_started: self.sends_started,
538 sends_completed: self.sends_completed,
539 sends_waited_for_credit: self.sends_waited_for_credit,
540 try_send_full_credit: self.try_send_full_credit,
541 try_send_full_runtime_queue: self.try_send_full_runtime_queue,
542 closed: self.closed,
543 reset: self.reset,
544 dropped: self.dropped,
545 items_received: self.items_received,
546 items_consumed: self.items_consumed,
547 credit_granted: self.credit_granted,
548 credit_received: self.credit_received,
549 close_reason: self.close_reason,
550 reset_reason: self.reset_reason,
551 }
552 }
553}
554
555struct DriverShared {
560 connection_id: ConnectionId,
561 pending_responses: SyncMutex<BTreeMap<RequestId, ResponseSlot>>,
562 request_ids: SyncMutex<IdAllocator<RequestId>>,
563 next_operation_id: AtomicU64,
564 operations: Arc<dyn OperationStore>,
565 channel_ids: SyncMutex<IdAllocator<ChannelId>>,
566 channel_senders: SyncMutex<BTreeMap<ChannelId, ChannelMailboxSender<IncomingChannelMessage>>>,
568 channel_receivers:
571 SyncMutex<BTreeMap<ChannelId, ChannelMailboxReceiver<IncomingChannelMessage>>>,
572 channel_credits: SyncMutex<BTreeMap<ChannelId, Arc<Semaphore>>>,
575 channel_contexts: SyncMutex<BTreeMap<ChannelId, ChannelDebugContext>>,
577 request_debug: SyncMutex<BTreeMap<RequestId, RequestRuntimeDebug>>,
579 channel_debug: SyncMutex<BTreeMap<ChannelId, ChannelRuntimeDebug>>,
581 last_inbound_message_at: SyncMutex<Option<Instant>>,
582 last_outbound_message_at: SyncMutex<Option<Instant>>,
583 close_reason: SyncMutex<Option<ConnectionCloseReason>>,
584 terminal_channels: SyncMutex<HashSet<ChannelId>>,
588 stale_close_channels: SyncMutex<std::collections::HashSet<ChannelId>>,
593 local_initial_channel_credit: u32,
595 peer_initial_channel_credit: u32,
597 observer: Option<VoxObserverHandle>,
598}
599
600impl DriverShared {
601 fn remember_channel_context(
602 &self,
603 channel_id: ChannelId,
604 debug_context: Option<ChannelDebugContext>,
605 ) {
606 if let Some(debug_context) = debug_context.and_then(ChannelDebugContext::into_option) {
607 self.channel_contexts
608 .lock()
609 .insert(channel_id, debug_context);
610 if let Some(channel) = self.channel_debug.lock().get_mut(&channel_id) {
611 channel.debug = Some(debug_context);
612 }
613 }
614 }
615
616 fn channel_event_context(
617 &self,
618 channel_id: ChannelId,
619 debug_context: Option<ChannelDebugContext>,
620 ) -> ChannelEventContext {
621 let debug = debug_context
622 .and_then(ChannelDebugContext::into_option)
623 .or_else(|| self.channel_contexts.lock().get(&channel_id).copied());
624 ChannelEventContext {
625 connection_id: Some(self.connection_id),
626 channel_id,
627 debug,
628 }
629 }
630
631 fn emit_channel_event(
632 &self,
633 channel_id: ChannelId,
634 debug_context: Option<ChannelDebugContext>,
635 event: impl FnOnce(ChannelEventContext) -> ChannelEvent,
636 ) {
637 if let Some(observer) = &self.observer {
638 observer.channel_event(event(self.channel_event_context(channel_id, debug_context)));
639 }
640 }
641
642 fn observe_channel(
643 &self,
644 channel_id: ChannelId,
645 debug_context: Option<ChannelDebugContext>,
646 event: impl FnOnce(ChannelEventContext) -> ChannelEvent,
647 ) {
648 let event = event(self.channel_event_context(channel_id, debug_context));
649 self.record_channel_event(event);
650 if let Some(observer) = &self.observer {
651 observer.channel_event(event);
652 }
653 }
654
655 fn update_channel_debug(
656 &self,
657 channel: ChannelEventContext,
658 default_direction: ChannelDirection,
659 default_initial_credit: u32,
660 update: impl FnOnce(&mut ChannelRuntimeDebug),
661 ) {
662 let mut channels = self.channel_debug.lock();
663 let entry = channels.entry(channel.channel_id).or_insert_with(|| {
664 ChannelRuntimeDebug::new(default_direction, default_initial_credit, channel.debug)
665 });
666 entry.merge_debug(channel.debug);
667 update(entry);
668 }
669
670 fn update_existing_channel_debug(
671 &self,
672 channel_id: ChannelId,
673 update: impl FnOnce(&mut ChannelRuntimeDebug),
674 ) {
675 if let Some(channel) = self.channel_debug.lock().get_mut(&channel_id) {
676 update(channel);
677 }
678 }
679
680 fn record_channel_event(&self, event: ChannelEvent) {
681 let now = Instant::now();
682 match event {
683 ChannelEvent::Opened {
684 channel,
685 direction,
686 initial_credit,
687 } => {
688 self.channel_debug.lock().insert(
689 channel.channel_id,
690 ChannelRuntimeDebug::new(direction, initial_credit, channel.debug),
691 );
692 }
693 ChannelEvent::ItemReceived { channel } => {
694 self.update_channel_debug(channel, ChannelDirection::Rx, 0, |entry| {
695 entry.mark_item_received(now);
696 });
697 }
698 ChannelEvent::Closed { channel, reason } => {
699 self.update_channel_debug(channel, ChannelDirection::Rx, 0, |entry| {
700 entry.mark_closed(reason);
701 });
702 }
703 ChannelEvent::Reset { channel, reason } => {
704 self.update_channel_debug(channel, ChannelDirection::Rx, 0, |entry| {
705 entry.mark_reset(reason);
706 });
707 }
708 ChannelEvent::CreditGranted { channel, amount } => {
709 self.record_credit_granted_at(channel.channel_id, amount, now);
710 }
711 ChannelEvent::SendStarted { channel } => {
712 self.record_send_started(channel.channel_id);
713 }
714 ChannelEvent::SendWaitingForCredit { channel } => {
715 self.record_send_waiting_for_credit(channel.channel_id);
716 }
717 ChannelEvent::SendFinished {
718 channel, outcome, ..
719 } => {
720 self.record_send_finished(channel.channel_id, outcome);
721 }
722 ChannelEvent::TrySend { channel, outcome } => {
723 self.record_try_send_outcome(channel.channel_id, outcome);
724 }
725 ChannelEvent::ItemConsumed { channel } => {
726 self.record_item_consumed(channel.channel_id);
727 }
728 }
729 }
730
731 fn mark_inbound_progress(&self) {
732 *self.last_inbound_message_at.lock() = Some(Instant::now());
733 }
734
735 fn mark_outbound_progress(&self) {
736 *self.last_outbound_message_at.lock() = Some(Instant::now());
737 }
738
739 fn start_request(
740 &self,
741 request_id: RequestId,
742 method_id: vox_types::MethodId,
743 service: Option<&'static str>,
744 method: Option<&'static str>,
745 state: RequestDebugState,
746 ) {
747 self.request_debug.lock().insert(
748 request_id,
749 RequestRuntimeDebug {
750 method_id,
751 service,
752 method,
753 started_at: Instant::now(),
754 state,
755 response_sender_blocked: Some(false),
756 associated_channels: Vec::new(),
757 },
758 );
759 }
760
761 fn finish_request(&self, request_id: RequestId, state: RequestDebugState) {
762 if let Some(request) = self.request_debug.lock().get_mut(&request_id) {
763 request.state = state;
764 }
765 self.request_debug.lock().remove(&request_id);
766 }
767
768 fn record_send_started(&self, channel_id: ChannelId) {
769 self.update_existing_channel_debug(channel_id, ChannelRuntimeDebug::mark_send_started);
770 }
771
772 fn record_send_waiting_for_credit(&self, channel_id: ChannelId) {
773 self.update_existing_channel_debug(
774 channel_id,
775 ChannelRuntimeDebug::mark_send_waiting_for_credit,
776 );
777 }
778
779 fn record_send_finished(&self, channel_id: ChannelId, outcome: ChannelSendOutcome) {
780 let now = Instant::now();
781 self.update_existing_channel_debug(channel_id, |channel| {
782 channel.mark_send_finished(outcome, now);
783 });
784 }
785
786 fn record_try_send_outcome(&self, channel_id: ChannelId, outcome: ChannelTrySendOutcome) {
787 let now = Instant::now();
788 self.update_existing_channel_debug(channel_id, |channel| {
789 channel.mark_try_send_outcome(outcome, now);
790 });
791 }
792
793 fn record_item_consumed(&self, channel_id: ChannelId) {
794 let now = Instant::now();
795 self.update_existing_channel_debug(channel_id, |channel| {
796 channel.mark_item_consumed(now);
797 });
798 }
799
800 fn record_inbound_item_not_enqueued(&self, channel_id: ChannelId) {
801 self.update_existing_channel_debug(
802 channel_id,
803 ChannelRuntimeDebug::mark_inbound_item_not_enqueued,
804 );
805 }
806
807 fn record_pending_local_grant(&self, channel_id: ChannelId, pending: u32) {
808 self.update_existing_channel_debug(channel_id, |channel| {
809 channel.pending_local_grant_credit = pending;
810 });
811 }
812
813 fn record_credit_granted_at(&self, channel_id: ChannelId, amount: u32, now: Instant) {
814 self.update_existing_channel_debug(channel_id, |channel| {
815 channel.mark_credit_granted(amount, now);
816 });
817 }
818
819 fn record_credit_received(&self, channel_id: ChannelId, amount: u32) {
820 let now = Instant::now();
821 self.update_existing_channel_debug(channel_id, |channel| {
822 channel.mark_credit_received(amount, now);
823 });
824 }
825
826 fn record_receiver_dropped(&self, channel_id: ChannelId) {
827 self.update_existing_channel_debug(channel_id, ChannelRuntimeDebug::mark_receiver_dropped);
828 }
829
830 fn new_channel_mailbox(
831 &self,
832 ) -> (
833 ChannelMailboxSender<IncomingChannelMessage>,
834 ChannelMailboxReceiver<IncomingChannelMessage>,
835 ) {
836 channel_mailbox(
837 "driver.channel_mailbox",
838 self.local_initial_channel_credit as usize,
839 )
840 }
841
842 fn inbound_channel_sender(
843 &self,
844 channel_id: ChannelId,
845 ) -> ChannelMailboxSender<IncomingChannelMessage> {
846 let mut senders = self.channel_senders.lock();
847 if let Some(sender) = senders.get(&channel_id) {
848 return sender.clone();
849 }
850
851 let (sender, receiver) = self.new_channel_mailbox();
852 senders.insert(channel_id, sender.clone());
853 self.channel_receivers.lock().insert(channel_id, receiver);
854 sender
855 }
856
857 fn register_inbound_channel_receiver(
858 &self,
859 channel_id: ChannelId,
860 ) -> (ChannelMailboxReceiver<IncomingChannelMessage>, bool) {
861 let terminal = self.terminal_channels.lock().contains(&channel_id);
862 let mut senders = self.channel_senders.lock();
863 let mut receivers = self.channel_receivers.lock();
864
865 if let Some(receiver) = receivers.remove(&channel_id) {
866 return (receiver, terminal);
867 }
868
869 let (sender, receiver) = self.new_channel_mailbox();
870 if terminal {
871 drop(sender);
872 } else {
873 senders.insert(channel_id, sender);
874 }
875 (receiver, terminal)
876 }
877
878 fn debug_snapshot(
879 &self,
880 sender: &ConnectionSender,
881 state: ConnectionDebugState,
882 driver_task_status: DriverTaskStatus,
883 ) -> VoxDebugSnapshot {
884 let now = Instant::now();
885 let requests: Vec<_> = self
886 .request_debug
887 .lock()
888 .iter()
889 .map(|(request_id, request)| request.snapshot(*request_id, now))
890 .collect();
891 let credits = self.shared_channel_credit_snapshot();
892 let open_channels: Vec<_> = self
893 .channel_debug
894 .lock()
895 .iter()
896 .map(|(channel_id, channel)| {
897 channel.snapshot(
898 self.connection_id,
899 *channel_id,
900 credits.get(channel_id).copied().flatten(),
901 )
902 })
903 .collect();
904 let last_inbound_message_at = *self.last_inbound_message_at.lock();
905 let last_outbound_message_at = *self.last_outbound_message_at.lock();
906 let last_progress_at = match (last_inbound_message_at, last_outbound_message_at) {
907 (Some(inbound), Some(outbound)) => Some(inbound.max(outbound)),
908 (Some(inbound), None) => Some(inbound),
909 (None, Some(outbound)) => Some(outbound),
910 (None, None) => None,
911 };
912 let (outbound_queue_depth, outbound_queue_capacity) =
913 sender.sess_core.outbound_queue_stats();
914 VoxDebugSnapshot {
915 connections: vec![ConnectionDebugSnapshot {
916 connection_id: self.connection_id,
917 endpoint: None,
918 surface: None,
919 component: None,
920 state,
921 outstanding_requests: requests.len(),
922 requests,
923 open_channels,
924 outbound_queue_depth: Some(outbound_queue_depth),
925 outbound_queue_capacity: Some(outbound_queue_capacity),
926 local_control_queue_depth: None,
927 local_control_queue_capacity: None,
928 last_inbound_message_at,
929 last_outbound_message_at,
930 last_progress_at,
931 close_reason: *self.close_reason.lock(),
932 driver_task_status,
933 }],
934 }
935 }
936
937 fn shared_channel_credit_snapshot(&self) -> BTreeMap<ChannelId, Option<u32>> {
938 self.channel_credits
939 .lock()
940 .iter()
941 .map(|(channel_id, semaphore)| {
942 (
943 *channel_id,
944 Some(semaphore.available_permits().min(u32::MAX as usize) as u32),
945 )
946 })
947 .collect()
948 }
949
950 fn set_connection_closed(&self, reason: ConnectionCloseReason) {
951 *self.close_reason.lock() = Some(reason);
952 }
953
954 fn connection_debug_state(&self, closed: bool) -> ConnectionDebugState {
955 if closed {
956 ConnectionDebugState::Closed
957 } else {
958 ConnectionDebugState::Open
959 }
960 }
961}
962
963struct CallerDropGuard {
964 control_tx: mpsc::UnboundedSender<DropControlRequest>,
965 request: DropControlRequest,
966}
967
968impl Drop for CallerDropGuard {
969 fn drop(&mut self) {
970 let _ = self.control_tx.send(self.request);
971 }
972}
973
974#[cfg(test)]
975mod tests {
976 use super::{DriverChannelCreditReplenisher, DriverLocalControl};
977 use vox_types::{ChannelCreditReplenisher, ChannelId};
978
979 #[tokio::test]
980 async fn replenisher_batches_at_half_the_initial_window() {
981 let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher");
982 let replenisher = DriverChannelCreditReplenisher::new(
983 vox_types::ConnectionId::ROOT,
984 ChannelId(7),
985 None,
986 std::sync::Weak::new(),
987 16,
988 tx,
989 None,
990 );
991
992 for _ in 0..7 {
993 replenisher.on_item_consumed();
994 }
995 assert!(
996 vox_types::time::tokio::timeout(std::time::Duration::from_millis(20), rx.recv())
997 .await
998 .is_err(),
999 "should not emit credit before reaching the batch threshold"
1000 );
1001
1002 replenisher.on_item_consumed();
1003 let Some(DriverLocalControl::GrantCredit {
1004 channel_id,
1005 additional,
1006 }) = rx.recv().await
1007 else {
1008 panic!("expected batched credit grant");
1009 };
1010 assert_eq!(channel_id, ChannelId(7));
1011 assert_eq!(additional, 8);
1012 }
1013
1014 #[tokio::test]
1015 async fn replenisher_grants_one_by_one_for_single_credit_windows() {
1016 let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher.single");
1017 let replenisher = DriverChannelCreditReplenisher::new(
1018 vox_types::ConnectionId::ROOT,
1019 ChannelId(9),
1020 None,
1021 std::sync::Weak::new(),
1022 1,
1023 tx,
1024 None,
1025 );
1026
1027 replenisher.on_item_consumed();
1028 let Some(DriverLocalControl::GrantCredit {
1029 channel_id,
1030 additional,
1031 }) = rx.recv().await
1032 else {
1033 panic!("expected immediate credit grant");
1034 };
1035 assert_eq!(channel_id, ChannelId(9));
1036 assert_eq!(additional, 1);
1037 }
1038}
1039
1040pub struct DriverReplySink {
1048 sender: Option<ConnectionSender>,
1049 request_id: RequestId,
1050 method_id: vox_types::MethodId,
1051 retry: vox_types::RetryPolicy,
1052 operation_id: Option<OperationId>,
1053 operations: Option<Arc<dyn OperationStore>>,
1054 live_operations: Option<Arc<SyncMutex<LiveOperationTracker>>>,
1055 binder: DriverChannelBinder,
1056 handler_response_shape: Option<&'static facet_core::Shape>,
1060}
1061
1062async fn replay_sealed_response(
1068 sender: ConnectionSender,
1069 request_id: RequestId,
1070 method_id: vox_types::MethodId,
1071 encoded_response: &[u8],
1072 response_shape: Option<&'static facet_core::Shape>,
1073) -> Result<(), ()> {
1074 let mut response: RequestResponse<'_> =
1075 vox_postcard::from_slice_borrowed(encoded_response).map_err(|_| ())?;
1076 if let Some(shape) = response_shape {
1077 sender.prepare_replay_schemas(request_id, method_id, shape, &mut response);
1078 } else {
1079 response.schemas = Default::default();
1080 }
1081 sender.send_response(request_id, response).await
1082}
1083
1084fn incoming_args_bytes<'a>(call: &'a RequestCall<'a>) -> &'a [u8] {
1085 match &call.args {
1086 Payload::PostcardBytes(bytes) => bytes,
1087 Payload::Value { .. } => {
1088 panic!("incoming request payload should always be decoded as incoming bytes")
1089 }
1090 }
1091}
1092
1093impl ReplySink for DriverReplySink {
1094 async fn send_reply(mut self, response: RequestResponse<'_>) {
1095 let sender = self
1096 .sender
1097 .take()
1098 .expect("unreachable: send_reply takes self by value");
1099
1100 vox_types::dlog!(
1101 "[driver] send_reply: conn={:?} req={:?} method={:?} payload={} operation_id={:?}",
1102 sender.connection_id(),
1103 self.request_id,
1104 self.method_id,
1105 match &response.ret {
1106 Payload::Value { .. } => "Value",
1107 Payload::PostcardBytes(_) => "PostcardBytes",
1108 },
1109 self.operation_id
1110 );
1111 self.binder.shared.mark_outbound_progress();
1112
1113 if let Payload::Value { shape, .. } = &response.ret
1114 && let Ok(extracted) = vox_types::extract_schemas(shape)
1115 {
1116 vox_types::dlog!(
1117 "[schema] driver send_reply: method={:?} root={:?}",
1118 self.method_id,
1119 extracted.root
1120 );
1121 }
1122
1123 if let (Some(operation_id), Some(operations)) = (self.operation_id, self.operations.take())
1124 {
1125 let mut response = response;
1126 sender.prepare_response_for_method(self.request_id, self.method_id, &mut response);
1127
1128 let schemas_for_wire = std::mem::take(&mut response.schemas);
1129 #[cfg(not(target_arch = "wasm32"))]
1130 let encoded_bytes: Vec<u8> =
1131 vox_jit::encode!(&response).expect("JIT encode failed for response store");
1132 #[cfg(target_arch = "wasm32")]
1133 let encoded_bytes: Vec<u8> =
1134 vox_postcard::to_vec(&response).expect("postcard encode failed for response store");
1135 let encoded_for_store: PostcardPayload = encoded_bytes.into();
1136 response.schemas = schemas_for_wire;
1137
1138 vox_types::dlog!(
1140 "[driver] send_reply wire send: conn={:?} req={:?} method={:?} schemas={}",
1141 sender.connection_id(),
1142 self.request_id,
1143 self.method_id,
1144 response.schemas.0.len()
1145 );
1146 if let Err(_e) = sender.send_response(self.request_id, response).await {
1147 sender.mark_failure(self.request_id, FailureDisposition::Cancelled);
1148 }
1149
1150 operations.seal(operation_id, self.method_id, &encoded_for_store);
1153
1154 let waiters = self
1156 .live_operations
1157 .as_ref()
1158 .map(|lo| lo.lock().seal(operation_id))
1159 .unwrap_or_default();
1160 let response_shape = self.handler_response_shape;
1161 for waiter in waiters {
1162 if waiter == self.request_id {
1163 continue;
1164 }
1165 if replay_sealed_response(
1166 sender.clone(),
1167 waiter,
1168 self.method_id,
1169 encoded_for_store.as_bytes(),
1170 response_shape,
1171 )
1172 .await
1173 .is_err()
1174 {
1175 sender.mark_failure(waiter, FailureDisposition::Cancelled);
1176 }
1177 }
1178 } else {
1179 vox_types::dlog!(
1180 "[driver] send_reply direct send: conn={:?} req={:?} method={:?}",
1181 sender.connection_id(),
1182 self.request_id,
1183 self.method_id
1184 );
1185 if let Err(_e) = sender
1186 .send_response_for_method(self.request_id, self.method_id, response)
1187 .await
1188 {
1189 sender.mark_failure(self.request_id, FailureDisposition::Cancelled);
1190 }
1191 }
1192 }
1193
1194 fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
1195 Some(&self.binder)
1196 }
1197
1198 fn request_id(&self) -> Option<RequestId> {
1199 Some(self.request_id)
1200 }
1201
1202 fn connection_id(&self) -> Option<vox_types::ConnectionId> {
1203 self.sender.as_ref().map(|sender| sender.connection_id())
1204 }
1205}
1206
1207impl Drop for DriverReplySink {
1209 fn drop(&mut self) {
1210 if let Some(sender) = self.sender.take() {
1211 let disposition = if self.retry.persist {
1212 FailureDisposition::Indeterminate
1213 } else {
1214 FailureDisposition::Cancelled
1215 };
1216
1217 if let Some(operation_id) = self.operation_id {
1218 if let Some(live_ops) = self.live_operations.take()
1224 && let Some(live) = live_ops.lock().release(operation_id)
1225 {
1226 for waiter in live.waiters {
1227 sender.mark_failure(waiter, disposition);
1228 }
1229 return;
1230 }
1231 }
1232
1233 sender.mark_failure(self.request_id, disposition);
1234 }
1235 }
1236}
1237
1238pub struct DriverChannelSink {
1246 sender: ConnectionSender,
1247 shared: Arc<DriverShared>,
1248 channel_id: ChannelId,
1249 debug_context: Option<ChannelDebugContext>,
1250 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
1251}
1252
1253impl ChannelSink for DriverChannelSink {
1254 fn send_payload<'payload>(
1255 &self,
1256 payload: Payload<'payload>,
1257 ) -> Pin<Box<dyn vox_types::MaybeSendFuture<Output = Result<(), TxError>> + 'payload>> {
1258 let sender = self.sender.clone();
1259 let shared = Arc::clone(&self.shared);
1260 let channel_id = self.channel_id;
1261 Box::pin(async move {
1262 if shared.terminal_channels.lock().contains(&channel_id) {
1263 return Err(TxError::Transport("channel closed".into()));
1264 }
1265
1266 shared.mark_outbound_progress();
1267 sender
1268 .send(ConnectionMessage::Channel(ChannelMessage {
1269 id: channel_id,
1270 body: ChannelBody::Item(ChannelItem { item: payload }),
1271 }))
1272 .await
1273 .map_err(|()| TxError::Transport("connection closed".into()))
1274 })
1275 }
1276
1277 fn channel_id(&self) -> Option<ChannelId> {
1278 Some(self.channel_id)
1279 }
1280
1281 fn connection_id(&self) -> Option<vox_types::ConnectionId> {
1282 Some(self.sender.connection_id())
1283 }
1284
1285 fn debug_context(&self) -> Option<ChannelDebugContext> {
1286 self.debug_context
1287 .and_then(ChannelDebugContext::into_option)
1288 .or_else(|| {
1289 self.shared
1290 .channel_contexts
1291 .lock()
1292 .get(&self.channel_id)
1293 .copied()
1294 })
1295 }
1296
1297 fn observer(&self) -> Option<VoxObserverHandle> {
1298 self.shared.observer.clone()
1299 }
1300
1301 fn note_send_started(&self) {
1302 self.shared.record_send_started(self.channel_id);
1303 }
1304
1305 fn note_send_waiting_for_credit(&self) {
1306 self.shared.record_send_waiting_for_credit(self.channel_id);
1307 }
1308
1309 fn note_send_finished(&self, outcome: ChannelSendOutcome) {
1310 self.shared.record_send_finished(self.channel_id, outcome);
1311 }
1312
1313 fn note_try_send_outcome(&self, outcome: ChannelTrySendOutcome) {
1314 self.shared
1315 .record_try_send_outcome(self.channel_id, outcome);
1316 }
1317
1318 fn try_send_payload_with_outcome<'payload>(
1321 &self,
1322 payload: Payload<'payload>,
1323 ) -> Result<(), ChannelTrySendOutcome> {
1324 if self
1325 .shared
1326 .terminal_channels
1327 .lock()
1328 .contains(&self.channel_id)
1329 {
1330 return Err(ChannelTrySendOutcome::Closed);
1331 }
1332
1333 self.shared.mark_outbound_progress();
1334 self.sender
1335 .try_send(ConnectionMessage::Channel(ChannelMessage {
1336 id: self.channel_id,
1337 body: ChannelBody::Item(ChannelItem { item: payload }),
1338 }))
1339 .map_err(|err| match err {
1340 TrySendError::Closed(()) => ChannelTrySendOutcome::Closed,
1341 TrySendError::Full(()) => ChannelTrySendOutcome::FullRuntimeQueue,
1342 })
1343 }
1344
1345 fn close_channel(
1346 &self,
1347 _metadata: vox_types::Metadata,
1348 ) -> Pin<Box<dyn vox_types::MaybeSendFuture<Output = Result<(), TxError>> + 'static>> {
1349 let sender = self.sender.clone();
1353 let shared = Arc::clone(&self.shared);
1354 let channel_id = self.channel_id;
1355 let debug_context = self.debug_context;
1356 Box::pin(async move {
1357 shared.terminal_channels.lock().insert(channel_id);
1358 shared.observe_channel(channel_id, debug_context, |channel| ChannelEvent::Closed {
1359 channel,
1360 reason: ChannelCloseReason::Local,
1361 });
1362
1363 shared.mark_outbound_progress();
1364 sender
1365 .send(ConnectionMessage::Channel(ChannelMessage {
1366 id: channel_id,
1367 body: ChannelBody::Close(ChannelClose {
1368 metadata: Default::default(),
1369 }),
1370 }))
1371 .await
1372 .map_err(|()| TxError::Transport("connection closed".into()))
1373 })
1374 }
1375
1376 fn close_channel_on_drop(&self) {
1377 self.shared.terminal_channels.lock().insert(self.channel_id);
1378 self.shared
1379 .observe_channel(self.channel_id, self.debug_context, |channel| {
1380 ChannelEvent::Closed {
1381 channel,
1382 reason: ChannelCloseReason::Dropped,
1383 }
1384 });
1385 let _ = self
1386 .local_control_tx
1387 .send(DriverLocalControl::CloseChannel {
1388 channel_id: self.channel_id,
1389 });
1390 }
1391}
1392
1393pub trait ErasedHandler: MaybeSend + MaybeSync + 'static {
1398 fn retry_policy(&self, method_id: vox_types::MethodId) -> vox_types::RetryPolicy {
1399 let _ = method_id;
1400 vox_types::RetryPolicy::VOLATILE
1401 }
1402
1403 fn args_have_channels(&self, method_id: vox_types::MethodId) -> bool {
1404 let _ = method_id;
1405 false
1406 }
1407
1408 fn response_wire_shape(&self, method_id: vox_types::MethodId) -> Option<&'static facet::Shape> {
1409 let _ = method_id;
1410 None
1411 }
1412
1413 fn handle_erased(
1414 &self,
1415 call: SelfRef<RequestCall<'static>>,
1416 reply: DriverReplySink,
1417 schemas: std::sync::Arc<vox_types::SchemaRecvTracker>,
1418 ) -> BoxFut<'_, ()>;
1419}
1420
1421impl<H: Handler<DriverReplySink>> ErasedHandler for H {
1422 fn retry_policy(&self, method_id: vox_types::MethodId) -> vox_types::RetryPolicy {
1423 Handler::retry_policy(self, method_id)
1424 }
1425
1426 fn args_have_channels(&self, method_id: vox_types::MethodId) -> bool {
1427 Handler::args_have_channels(self, method_id)
1428 }
1429
1430 fn response_wire_shape(&self, method_id: vox_types::MethodId) -> Option<&'static facet::Shape> {
1431 Handler::response_wire_shape(self, method_id)
1432 }
1433
1434 fn handle_erased(
1435 &self,
1436 call: SelfRef<RequestCall<'static>>,
1437 reply: DriverReplySink,
1438 schemas: std::sync::Arc<vox_types::SchemaRecvTracker>,
1439 ) -> BoxFut<'_, ()> {
1440 Box::pin(Handler::handle(self, call, reply, schemas))
1441 }
1442}
1443
1444impl Handler<DriverReplySink> for Box<dyn ErasedHandler> {
1445 fn retry_policy(&self, method_id: vox_types::MethodId) -> vox_types::RetryPolicy {
1446 (**self).retry_policy(method_id)
1447 }
1448
1449 fn args_have_channels(&self, method_id: vox_types::MethodId) -> bool {
1450 (**self).args_have_channels(method_id)
1451 }
1452
1453 fn response_wire_shape(&self, method_id: vox_types::MethodId) -> Option<&'static facet::Shape> {
1454 (**self).response_wire_shape(method_id)
1455 }
1456
1457 async fn handle(
1458 &self,
1459 call: SelfRef<RequestCall<'static>>,
1460 reply: DriverReplySink,
1461 schemas: std::sync::Arc<vox_types::SchemaRecvTracker>,
1462 ) {
1463 (**self).handle_erased(call, reply, schemas).await
1464 }
1465}
1466
1467#[must_use = "Dropping this caller may close the connection if it is the last caller."]
1473#[derive(Clone)]
1474pub struct Caller {
1475 inner: Arc<DriverCaller>,
1476 service: Option<&'static vox_types::ServiceDescriptor>,
1477 middlewares: Vec<Arc<dyn vox_types::ClientMiddleware>>,
1478}
1479
1480impl Caller {
1481 pub fn new(driver: DriverCaller) -> Self {
1483 Self {
1484 inner: Arc::new(driver),
1485 service: None,
1486 middlewares: vec![],
1487 }
1488 }
1489
1490 #[cfg(test)]
1492 pub(crate) fn driver(&self) -> &DriverCaller {
1493 &self.inner
1494 }
1495
1496 pub fn with_middleware(
1498 mut self,
1499 service: &'static vox_types::ServiceDescriptor,
1500 middleware: impl vox_types::ClientMiddleware,
1501 ) -> Self {
1502 if let Some(existing_service) = self.service {
1503 assert_eq!(
1504 existing_service.service_name, service.service_name,
1505 "Caller middleware service mismatch"
1506 );
1507 } else {
1508 self.service = Some(service);
1509 }
1510 self.middlewares.push(Arc::new(middleware));
1511 self
1512 }
1513
1514 pub async fn call(&self, mut call: RequestCall<'_>) -> CallResult {
1517 use vox_types::{
1518 ClientCallOutcome, ClientContext, ClientRequest, Extensions, OwnedMetadata,
1519 };
1520
1521 let Some(service) = self.service else {
1522 return self.inner.call_inner(call, None).await;
1523 };
1524
1525 let extensions = Extensions::new();
1526 let method = service.by_id(call.method_id);
1527 let context = ClientContext::new(method, call.method_id, &extensions);
1528 let mut owned_metadata = OwnedMetadata::default();
1529
1530 if !self.middlewares.is_empty() {
1531 for middleware in &self.middlewares {
1532 let mut request = ClientRequest::new(&mut call, &mut owned_metadata);
1533 middleware.pre(&context, &mut request).await;
1534 }
1535 }
1536
1537 let request_debug = method.map(|method| (method.service_name, method.method_name));
1538 let result = self.inner.call_inner(call, request_debug).await;
1539 if !self.middlewares.is_empty() {
1540 let outcome = match &result {
1541 Ok(_) => ClientCallOutcome::Response,
1542 Err(error) => ClientCallOutcome::Error(error),
1543 };
1544 for middleware in self.middlewares.iter().rev() {
1545 middleware.post(&context, outcome).await;
1546 }
1547 }
1548 result
1549 }
1550
1551 pub async fn closed(&self) {
1553 if self.inner.closed_rx.borrow().is_some() {
1554 return;
1555 }
1556 let mut rx = self.inner.closed_rx.clone();
1557 while rx.changed().await.is_ok() {
1558 if rx.borrow().is_some() {
1559 return;
1560 }
1561 }
1562 }
1563
1564 pub fn is_connected(&self) -> bool {
1566 self.inner.closed_rx.borrow().is_none()
1567 }
1568
1569 pub fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
1571 Some(self.inner.as_ref())
1572 }
1573
1574 pub fn debug_snapshot(&self) -> VoxDebugSnapshot {
1576 self.inner.debug_snapshot()
1577 }
1578
1579 pub fn dump_debug_snapshot(&self) -> VoxDebugSnapshot {
1580 let snapshot = self.debug_snapshot();
1581 tracing::info!(?snapshot, "vox debug snapshot");
1582 snapshot
1583 }
1584}
1585
1586pub trait FromVoxSession {
1592 const SERVICE_NAME: &'static str;
1595
1596 fn from_vox_session(
1597 caller: Caller,
1598 session_handle: Option<crate::session::SessionHandle>,
1599 ) -> Self;
1600}
1601
1602#[must_use = "Dropping NoopClient may close the connection if it is the last caller."]
1607#[derive(Clone)]
1608pub struct NoopClient {
1609 pub caller: Caller,
1611 pub session: Option<crate::session::SessionHandle>,
1613}
1614
1615impl FromVoxSession for NoopClient {
1616 const SERVICE_NAME: &'static str = "Noop";
1617
1618 fn from_vox_session(caller: Caller, session: Option<crate::session::SessionHandle>) -> Self {
1619 Self { caller, session }
1620 }
1621}
1622
1623#[derive(Clone)]
1624struct DriverChannelBinder {
1625 sender: ConnectionSender,
1626 shared: Arc<DriverShared>,
1627 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
1628 drop_guard: Option<Arc<CallerDropGuard>>,
1629}
1630
1631fn register_rx_channel_impl(
1632 shared: &Arc<DriverShared>,
1633 channel_id: ChannelId,
1634 initial_channel_credit: u32,
1635 debug_context: Option<ChannelDebugContext>,
1636 liveness: Option<ChannelLivenessHandle>,
1637 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
1638) -> vox_types::BoundChannelReceiver {
1639 observe_channel_opened(
1640 shared,
1641 channel_id,
1642 ChannelDirection::Rx,
1643 initial_channel_credit,
1644 debug_context,
1645 );
1646 let (rx, terminal) = shared.register_inbound_channel_receiver(channel_id);
1647
1648 if terminal {
1649 shared.channel_credits.lock().remove(&channel_id);
1650 return vox_types::BoundChannelReceiver {
1651 receiver: rx,
1652 liveness,
1653 replenisher: None,
1654 };
1655 }
1656
1657 vox_types::BoundChannelReceiver {
1658 receiver: rx,
1659 liveness,
1660 replenisher: Some(Arc::new(DriverChannelCreditReplenisher::new(
1661 shared.connection_id,
1662 channel_id,
1663 debug_context,
1664 Arc::downgrade(shared),
1665 initial_channel_credit,
1666 local_control_tx,
1667 shared.observer.clone(),
1668 )) as ChannelCreditReplenisherHandle),
1669 }
1670}
1671
1672fn observe_channel_opened(
1674 shared: &DriverShared,
1675 channel_id: ChannelId,
1676 direction: ChannelDirection,
1677 initial_credit: u32,
1678 debug_context: Option<ChannelDebugContext>,
1679) {
1680 shared.remember_channel_context(channel_id, debug_context);
1681 shared.observe_channel(channel_id, debug_context, |channel| ChannelEvent::Opened {
1682 channel,
1683 direction,
1684 initial_credit,
1685 });
1686}
1687
1688fn make_tx_channel_sink(
1689 sender: &ConnectionSender,
1690 shared: &Arc<DriverShared>,
1691 local_control_tx: &mpsc::UnboundedSender<DriverLocalControl>,
1692 channel_id: ChannelId,
1693 debug_context: Option<ChannelDebugContext>,
1694) -> Arc<CreditSink<DriverChannelSink>> {
1695 observe_channel_opened(
1696 shared,
1697 channel_id,
1698 ChannelDirection::Tx,
1699 shared.peer_initial_channel_credit,
1700 debug_context,
1701 );
1702 let inner = DriverChannelSink {
1703 sender: sender.clone(),
1704 shared: Arc::clone(shared),
1705 channel_id,
1706 debug_context: debug_context.and_then(ChannelDebugContext::into_option),
1707 local_control_tx: local_control_tx.clone(),
1708 };
1709 let sink = Arc::new(CreditSink::new(inner, shared.peer_initial_channel_credit));
1710 shared
1711 .channel_credits
1712 .lock()
1713 .insert(channel_id, Arc::clone(sink.credit()));
1714 sink
1715}
1716
1717trait DriverChannelEndpoint {
1718 fn endpoint_sender(&self) -> &ConnectionSender;
1719 fn endpoint_shared(&self) -> &Arc<DriverShared>;
1720 fn endpoint_local_control_tx(&self) -> &mpsc::UnboundedSender<DriverLocalControl>;
1721 fn endpoint_liveness(&self) -> Option<ChannelLivenessHandle>;
1722
1723 fn create_tx_credit_sink(
1724 &self,
1725 debug_context: Option<ChannelDebugContext>,
1726 ) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
1727 let shared = self.endpoint_shared();
1728 let channel_id = shared.channel_ids.lock().alloc();
1729 let sink = make_tx_channel_sink(
1730 self.endpoint_sender(),
1731 shared,
1732 self.endpoint_local_control_tx(),
1733 channel_id,
1734 debug_context,
1735 );
1736 (channel_id, sink)
1737 }
1738
1739 fn create_tx_dyn(
1740 &self,
1741 debug_context: Option<ChannelDebugContext>,
1742 ) -> (ChannelId, Arc<dyn ChannelSink>) {
1743 let (id, sink) = self.create_tx_credit_sink(debug_context);
1744 (id, sink as Arc<dyn ChannelSink>)
1745 }
1746
1747 fn create_rx_bound(
1748 &self,
1749 debug_context: Option<ChannelDebugContext>,
1750 ) -> (ChannelId, vox_types::BoundChannelReceiver) {
1751 let channel_id = self.endpoint_shared().channel_ids.lock().alloc();
1752 let rx = self.register_rx_bound(channel_id, debug_context);
1753 (channel_id, rx)
1754 }
1755
1756 fn bind_tx_dyn(
1757 &self,
1758 channel_id: ChannelId,
1759 debug_context: Option<ChannelDebugContext>,
1760 ) -> Arc<dyn ChannelSink> {
1761 make_tx_channel_sink(
1762 self.endpoint_sender(),
1763 self.endpoint_shared(),
1764 self.endpoint_local_control_tx(),
1765 channel_id,
1766 debug_context,
1767 )
1768 }
1769
1770 fn register_rx_bound(
1771 &self,
1772 channel_id: ChannelId,
1773 debug_context: Option<ChannelDebugContext>,
1774 ) -> vox_types::BoundChannelReceiver {
1775 let shared = self.endpoint_shared();
1776 register_rx_channel_impl(
1777 shared,
1778 channel_id,
1779 shared.local_initial_channel_credit,
1780 debug_context,
1781 self.endpoint_liveness(),
1782 self.endpoint_local_control_tx().clone(),
1783 )
1784 }
1785}
1786
1787impl DriverChannelEndpoint for DriverChannelBinder {
1788 fn endpoint_sender(&self) -> &ConnectionSender {
1789 &self.sender
1790 }
1791
1792 fn endpoint_shared(&self) -> &Arc<DriverShared> {
1793 &self.shared
1794 }
1795
1796 fn endpoint_local_control_tx(&self) -> &mpsc::UnboundedSender<DriverLocalControl> {
1797 &self.local_control_tx
1798 }
1799
1800 fn endpoint_liveness(&self) -> Option<ChannelLivenessHandle> {
1801 self.drop_guard
1802 .as_ref()
1803 .map(|guard| guard.clone() as ChannelLivenessHandle)
1804 }
1805}
1806
1807impl ChannelBinder for DriverChannelBinder {
1808 fn create_tx(&self) -> (ChannelId, Arc<dyn ChannelSink>) {
1809 self.create_tx_dyn(None)
1810 }
1811
1812 fn create_tx_with_context(
1813 &self,
1814 debug_context: Option<ChannelDebugContext>,
1815 ) -> (ChannelId, Arc<dyn ChannelSink>) {
1816 self.create_tx_dyn(debug_context)
1817 }
1818
1819 fn create_rx(&self) -> (ChannelId, vox_types::BoundChannelReceiver) {
1820 self.create_rx_bound(None)
1821 }
1822
1823 fn create_rx_with_context(
1824 &self,
1825 debug_context: Option<ChannelDebugContext>,
1826 ) -> (ChannelId, vox_types::BoundChannelReceiver) {
1827 self.create_rx_bound(debug_context)
1828 }
1829
1830 fn bind_tx(&self, channel_id: ChannelId) -> Arc<dyn ChannelSink> {
1831 self.bind_tx_dyn(channel_id, None)
1832 }
1833
1834 fn bind_tx_with_context(
1835 &self,
1836 channel_id: ChannelId,
1837 debug_context: Option<ChannelDebugContext>,
1838 ) -> Arc<dyn ChannelSink> {
1839 self.bind_tx_dyn(channel_id, debug_context)
1840 }
1841
1842 fn register_rx(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
1843 self.register_rx_bound(channel_id, None)
1844 }
1845
1846 fn register_rx_with_context(
1847 &self,
1848 channel_id: ChannelId,
1849 debug_context: Option<ChannelDebugContext>,
1850 ) -> vox_types::BoundChannelReceiver {
1851 self.register_rx_bound(channel_id, debug_context)
1852 }
1853
1854 fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
1855 self.endpoint_liveness()
1856 }
1857}
1858
1859#[derive(Clone)]
1863pub struct DriverCaller {
1864 sender: ConnectionSender,
1865 shared: Arc<DriverShared>,
1866 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
1867 closed_rx: watch::Receiver<Option<ConnectionCloseReason>>,
1868 resumed_rx: watch::Receiver<u64>,
1869 resume_processed_rx: watch::Receiver<u64>,
1870 peer_supports_retry: bool,
1871 _drop_guard: Option<Arc<CallerDropGuard>>,
1872}
1873
1874impl DriverCaller {
1875 pub fn create_tx_channel(&self) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
1880 self.create_tx_credit_sink(None)
1881 }
1882
1883 #[cfg(test)]
1888 pub(crate) fn connection_sender(&self) -> &ConnectionSender {
1889 &self.sender
1890 }
1891
1892 pub fn register_rx_channel(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
1897 self.register_rx_bound(channel_id, None)
1898 }
1899}
1900
1901impl DriverChannelEndpoint for DriverCaller {
1902 fn endpoint_sender(&self) -> &ConnectionSender {
1903 &self.sender
1904 }
1905
1906 fn endpoint_shared(&self) -> &Arc<DriverShared> {
1907 &self.shared
1908 }
1909
1910 fn endpoint_local_control_tx(&self) -> &mpsc::UnboundedSender<DriverLocalControl> {
1911 &self.local_control_tx
1912 }
1913
1914 fn endpoint_liveness(&self) -> Option<ChannelLivenessHandle> {
1915 self._drop_guard
1916 .as_ref()
1917 .map(|guard| guard.clone() as ChannelLivenessHandle)
1918 }
1919}
1920
1921impl ChannelBinder for DriverCaller {
1922 fn create_tx(&self) -> (ChannelId, Arc<dyn ChannelSink>) {
1923 self.create_tx_dyn(None)
1924 }
1925
1926 fn create_tx_with_context(
1927 &self,
1928 debug_context: Option<ChannelDebugContext>,
1929 ) -> (ChannelId, Arc<dyn ChannelSink>) {
1930 self.create_tx_dyn(debug_context)
1931 }
1932
1933 fn create_rx(&self) -> (ChannelId, vox_types::BoundChannelReceiver) {
1934 self.create_rx_bound(None)
1935 }
1936
1937 fn create_rx_with_context(
1938 &self,
1939 debug_context: Option<ChannelDebugContext>,
1940 ) -> (ChannelId, vox_types::BoundChannelReceiver) {
1941 self.create_rx_bound(debug_context)
1942 }
1943
1944 fn bind_tx(&self, channel_id: ChannelId) -> Arc<dyn ChannelSink> {
1945 self.bind_tx_dyn(channel_id, None)
1946 }
1947
1948 fn bind_tx_with_context(
1949 &self,
1950 channel_id: ChannelId,
1951 debug_context: Option<ChannelDebugContext>,
1952 ) -> Arc<dyn ChannelSink> {
1953 self.bind_tx_dyn(channel_id, debug_context)
1954 }
1955
1956 fn register_rx(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
1957 self.register_rx_bound(channel_id, None)
1958 }
1959
1960 fn register_rx_with_context(
1961 &self,
1962 channel_id: ChannelId,
1963 debug_context: Option<ChannelDebugContext>,
1964 ) -> vox_types::BoundChannelReceiver {
1965 self.register_rx_bound(channel_id, debug_context)
1966 }
1967
1968 fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
1969 self.endpoint_liveness()
1970 }
1971}
1972
1973impl DriverCaller {
1974 pub fn debug_snapshot(&self) -> VoxDebugSnapshot {
1976 self.shared.debug_snapshot(
1977 &self.sender,
1978 self.shared
1979 .connection_debug_state(self.closed_rx.borrow().is_some()),
1980 if self.closed_rx.borrow().is_some() {
1981 DriverTaskStatus::Dead
1982 } else {
1983 DriverTaskStatus::Alive
1984 },
1985 )
1986 }
1987
1988 pub fn dump_debug_snapshot(&self) -> VoxDebugSnapshot {
1989 let snapshot = self.debug_snapshot();
1990 tracing::info!(?snapshot, "vox debug snapshot");
1991 snapshot
1992 }
1993
1994 async fn call_inner(
1996 &self,
1997 mut call: RequestCall<'_>,
1998 request_debug: Option<(&'static str, &'static str)>,
1999 ) -> CallResult {
2000 if self.peer_supports_retry {
2001 let operation_id = OperationId(
2002 self.shared
2003 .next_operation_id
2004 .fetch_add(1, Ordering::Relaxed),
2005 );
2006 ensure_operation_id(&mut call.metadata, operation_id);
2007 }
2008
2009 let req_id = self.shared.request_ids.lock().alloc();
2011 let request_started_at = Instant::now();
2012 if let Some(observer) = &self.shared.observer {
2013 observer.driver_event(DriverEvent::RequestStarted {
2014 connection_id: self.sender.connection_id(),
2015 request_id: req_id,
2016 method_id: call.method_id,
2017 });
2018 }
2019 let finish_request = |outcome: RpcOutcome| {
2020 self.shared.finish_request(
2021 req_id,
2022 if outcome == RpcOutcome::Ok {
2023 RequestDebugState::Finished
2024 } else {
2025 RequestDebugState::Failed
2026 },
2027 );
2028 if let Some(observer) = &self.shared.observer {
2029 observer.driver_event(DriverEvent::RequestFinished {
2030 connection_id: self.sender.connection_id(),
2031 request_id: req_id,
2032 outcome,
2033 elapsed: request_started_at.elapsed(),
2034 });
2035 }
2036 };
2037
2038 let (tx, rx) = moire::sync::oneshot::channel("driver.response");
2041 self.shared.pending_responses.lock().insert(req_id, tx);
2042 self.shared.start_request(
2043 req_id,
2044 call.method_id,
2045 request_debug.map(|(service, _)| service),
2046 request_debug.map(|(_, method)| method),
2047 RequestDebugState::WaitingForResponse,
2048 );
2049
2050 self.shared.mark_outbound_progress();
2058 if self
2059 .sender
2060 .send_with_binder(
2061 ConnectionMessage::Request(RequestMessage {
2062 id: req_id,
2063 body: RequestBody::Call(RequestCall {
2064 method_id: call.method_id,
2065 args: call.args.reborrow(),
2066 metadata: call.metadata.clone(),
2067 schemas: Default::default(),
2068 }),
2069 }),
2070 Some(self),
2071 )
2072 .await
2073 .is_err()
2074 {
2075 self.shared.pending_responses.lock().remove(&req_id);
2076 finish_request(RpcOutcome::SendFailed);
2077 return Err(VoxError::SendFailed);
2078 }
2079
2080 let mut resumed_rx = self.resumed_rx.clone();
2081 let mut seen_resume_generation = *resumed_rx.borrow();
2082 let mut resume_processed_rx = self.resume_processed_rx.clone();
2083 let mut closed_rx = self.closed_rx.clone();
2084 let mut response = std::pin::pin!(rx.named("awaiting_response"));
2085
2086 let pending: PendingResponse = loop {
2087 tokio::select! {
2088 result = &mut response => {
2089 match result {
2090 Ok(pending) => break pending,
2091 Err(_) => {
2092 finish_request(RpcOutcome::Closed);
2093 return Err(VoxError::ConnectionClosed);
2094 }
2095 }
2096 }
2097 changed = resumed_rx.changed(), if self.peer_supports_retry => {
2098 vox_types::dlog!("[CALLER] resumed_rx fired");
2099 if changed.is_err() {
2100 self.shared.pending_responses.lock().remove(&req_id);
2101 finish_request(RpcOutcome::Closed);
2102 return Err(VoxError::SessionShutdown);
2103 }
2104 let generation = *resumed_rx.borrow();
2105 if generation == seen_resume_generation {
2106 continue;
2107 }
2108 seen_resume_generation = generation;
2109 while *resume_processed_rx.borrow() < generation {
2110 if resume_processed_rx.changed().await.is_err() {
2111 self.shared.pending_responses.lock().remove(&req_id);
2112 finish_request(RpcOutcome::Closed);
2113 return Err(VoxError::SessionShutdown);
2114 }
2115 }
2116 match metadata_channel_retry_mode(&call.metadata) {
2117 ChannelRetryMode::NonIdem => {
2118 self.shared.pending_responses.lock().remove(&req_id);
2119 finish_request(RpcOutcome::Indeterminate);
2120 return Err(VoxError::Indeterminate);
2121 }
2122 ChannelRetryMode::Idem | ChannelRetryMode::None => {}
2123 }
2124 self.shared.mark_outbound_progress();
2128 let _ = self.sender.send_with_binder(
2129 ConnectionMessage::Request(RequestMessage {
2130 id: req_id,
2131 body: RequestBody::Call(RequestCall {
2132 method_id: call.method_id,
2133 args: call.args.reborrow(),
2134 metadata: call.metadata.clone(),
2135 schemas: Default::default(),
2136 }),
2137 }),
2138 Some(self),
2139 ).await;
2140 }
2141 changed = closed_rx.changed() => {
2142 vox_types::dlog!("[CALLER] closed_rx fired, value={:?}", *closed_rx.borrow());
2143 if changed.is_err() || closed_rx.borrow().is_some() {
2144 self.shared.pending_responses.lock().remove(&req_id);
2145 finish_request(RpcOutcome::Closed);
2146 return Err(VoxError::ConnectionClosed);
2147 }
2148 }
2149 }
2150 };
2151
2152 let PendingResponse {
2154 msg: response_msg,
2155 schemas: response_schemas,
2156 } = pending;
2157 let response = response_msg.map(|m| match m.body {
2158 RequestBody::Response(r) => r,
2159 _ => unreachable!("pending_responses only gets Response variants"),
2160 });
2161
2162 finish_request(RpcOutcome::Ok);
2163 Ok(vox_types::WithTracker {
2164 value: response,
2165 tracker: response_schemas,
2166 })
2167 }
2168}
2169
2170pub struct Driver<H: Handler<DriverReplySink>> {
2177 sender: ConnectionSender,
2178 rx: mpsc::Receiver<crate::session::RecvMessage>,
2179 failures_rx: mpsc::UnboundedReceiver<(RequestId, FailureDisposition)>,
2180 closed_rx: watch::Receiver<Option<ConnectionCloseReason>>,
2181 resumed_rx: watch::Receiver<u64>,
2182 resume_processed_tx: watch::Sender<u64>,
2183 peer_supports_retry: bool,
2184 local_control_rx: mpsc::UnboundedReceiver<DriverLocalControl>,
2185 handler: Arc<H>,
2186 shared: Arc<DriverShared>,
2187 in_flight_handlers: BTreeMap<RequestId, InFlightHandler>,
2191 handler_futs: FuturesUnordered<HandlerFut>,
2197 live_operations: Arc<SyncMutex<LiveOperationTracker>>,
2200 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
2201 drop_control_seed: Option<mpsc::UnboundedSender<DropControlRequest>>,
2202 drop_control_request: DropControlRequest,
2203 drop_guard: SyncMutex<Option<Weak<CallerDropGuard>>>,
2204}
2205
2206enum DriverLocalControl {
2207 CloseChannel {
2208 channel_id: ChannelId,
2209 },
2210 ResetChannel {
2211 channel_id: ChannelId,
2212 },
2213 GrantCredit {
2214 channel_id: ChannelId,
2215 additional: u32,
2216 },
2217}
2218
2219struct DriverChannelCreditReplenisher {
2220 connection_id: ConnectionId,
2221 channel_id: ChannelId,
2222 debug_context: Option<ChannelDebugContext>,
2223 shared: Weak<DriverShared>,
2224 threshold: u32,
2225 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
2226 observer: Option<VoxObserverHandle>,
2227 pending: std::sync::Mutex<u32>,
2228}
2229
2230impl DriverChannelCreditReplenisher {
2231 fn new(
2232 connection_id: ConnectionId,
2233 channel_id: ChannelId,
2234 debug_context: Option<ChannelDebugContext>,
2235 shared: Weak<DriverShared>,
2236 initial_credit: u32,
2237 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
2238 observer: Option<VoxObserverHandle>,
2239 ) -> Self {
2240 Self {
2241 connection_id,
2242 channel_id,
2243 debug_context,
2244 shared,
2245 threshold: (initial_credit / 2).max(1),
2246 local_control_tx,
2247 observer,
2248 pending: std::sync::Mutex::new(0),
2249 }
2250 }
2251}
2252
2253impl ChannelCreditReplenisher for DriverChannelCreditReplenisher {
2254 fn on_item_consumed(&self) {
2255 let mut pending = self.pending.lock().expect("pending credit mutex poisoned");
2256 *pending += 1;
2257 if let Some(shared) = self.shared.upgrade() {
2258 shared.record_item_consumed(self.channel_id);
2259 shared.record_pending_local_grant(self.channel_id, *pending);
2260 }
2261 if *pending < self.threshold {
2262 return;
2263 }
2264
2265 let additional = *pending;
2266 *pending = 0;
2267 if let Some(shared) = self.shared.upgrade() {
2268 shared.record_pending_local_grant(self.channel_id, additional);
2269 }
2270 let _ = self.local_control_tx.send(DriverLocalControl::GrantCredit {
2271 channel_id: self.channel_id,
2272 additional,
2273 });
2274 }
2275
2276 fn on_receiver_dropped(&self) {
2277 if let Some(shared) = self.shared.upgrade() {
2278 shared.record_receiver_dropped(self.channel_id);
2279 }
2280 let _ = self
2281 .local_control_tx
2282 .send(DriverLocalControl::ResetChannel {
2283 channel_id: self.channel_id,
2284 });
2285 }
2286
2287 fn channel_id(&self) -> Option<ChannelId> {
2288 Some(self.channel_id)
2289 }
2290
2291 fn connection_id(&self) -> Option<ConnectionId> {
2292 Some(self.connection_id)
2293 }
2294
2295 fn debug_context(&self) -> Option<ChannelDebugContext> {
2296 self.debug_context
2297 }
2298
2299 fn observer(&self) -> Option<VoxObserverHandle> {
2300 self.observer.clone()
2301 }
2302}
2303
2304impl<H: Handler<DriverReplySink>> Driver<H> {
2305 fn close_all_channel_runtime_state(&self, teardown: ChannelRuntimeTeardown) {
2307 let mut credits = self.shared.channel_credits.lock();
2308 for semaphore in credits.values() {
2309 semaphore.close();
2310 }
2311 let mut stale = self.shared.stale_close_channels.lock();
2314 stale.extend(credits.keys().copied());
2315 credits.clear();
2316 drop(credits);
2317
2318 let channel_senders = {
2319 let mut senders = self.shared.channel_senders.lock();
2320 std::mem::take(&mut *senders)
2321 };
2322 if let ChannelRuntimeTeardown::ConnectionClosed(reason) = teardown {
2323 for (channel_id, sender) in channel_senders {
2324 let _ = sender.force_send(IncomingChannelMessage::ConnectionClosed(reason));
2325 self.shared
2326 .observe_channel(channel_id, None, |channel| ChannelEvent::Closed {
2327 channel,
2328 reason: ChannelCloseReason::ConnectionClosed,
2329 });
2330 }
2331 }
2332 self.shared.channel_receivers.lock().clear();
2333 self.shared.terminal_channels.lock().clear();
2334 }
2335
2336 fn close_outbound_channel(&self, channel_id: ChannelId) {
2337 self.shared.terminal_channels.lock().insert(channel_id);
2338 if let Some(semaphore) = self.shared.channel_credits.lock().remove(&channel_id) {
2339 semaphore.close();
2340 }
2341 }
2342
2343 fn abort_channel_handlers(&mut self) {
2344 for in_flight in self.in_flight_handlers.values() {
2345 if self.handler.args_have_channels(in_flight.method_id) {
2346 if let Some(operation_id) = in_flight.operation_id {
2347 self.shared.operations.remove(operation_id);
2348 self.live_operations.lock().release(operation_id);
2349 }
2350 in_flight.abort.abort();
2351 }
2352 }
2353 }
2354
2355 pub fn new(handle: ConnectionHandle, handler: H) -> Self {
2356 Self::with_operation_store(handle, handler, Arc::new(InMemoryOperationStore::default()))
2357 }
2358
2359 pub fn with_operation_store(
2360 handle: ConnectionHandle,
2361 handler: H,
2362 operation_store: Arc<dyn OperationStore>,
2363 ) -> Self {
2364 let conn_id = handle.connection_id();
2365 let ConnectionHandle {
2366 sender,
2367 rx,
2368 failures_rx,
2369 control_tx,
2370 closed_rx,
2371 resumed_rx,
2372 local_settings,
2373 peer_settings,
2374 parity,
2375 peer_supports_retry,
2376 observer,
2377 } = handle;
2378 let drop_control_request = DropControlRequest::Close(conn_id);
2379 let (local_control_tx, local_control_rx) = mpsc::unbounded_channel("driver.local_control");
2380 let (resume_processed_tx, _resume_processed_rx) = watch::channel(0_u64);
2381 Self {
2382 sender,
2383 rx,
2384 failures_rx,
2385 closed_rx,
2386 resumed_rx,
2387 resume_processed_tx,
2388 peer_supports_retry,
2389 local_control_rx,
2390 handler: Arc::new(handler),
2391 shared: Arc::new(DriverShared {
2392 connection_id: conn_id,
2393 pending_responses: SyncMutex::new("driver.pending_responses", BTreeMap::new()),
2394 request_ids: SyncMutex::new("driver.request_ids", IdAllocator::new(parity)),
2395 next_operation_id: AtomicU64::new(1),
2396 operations: operation_store,
2397 channel_ids: SyncMutex::new("driver.channel_ids", IdAllocator::new(parity)),
2398 channel_senders: SyncMutex::new("driver.channel_senders", BTreeMap::new()),
2399 channel_receivers: SyncMutex::new("driver.channel_receivers", BTreeMap::new()),
2400 channel_credits: SyncMutex::new("driver.channel_credits", BTreeMap::new()),
2401 channel_contexts: SyncMutex::new("driver.channel_contexts", BTreeMap::new()),
2402 request_debug: SyncMutex::new("driver.request_debug", BTreeMap::new()),
2403 channel_debug: SyncMutex::new("driver.channel_debug", BTreeMap::new()),
2404 last_inbound_message_at: SyncMutex::new("driver.last_inbound_message_at", None),
2405 last_outbound_message_at: SyncMutex::new("driver.last_outbound_message_at", None),
2406 close_reason: SyncMutex::new("driver.close_reason", None),
2407 terminal_channels: SyncMutex::new("driver.terminal_channels", HashSet::new()),
2408 stale_close_channels: SyncMutex::new(
2409 "driver.stale_close_channels",
2410 std::collections::HashSet::new(),
2411 ),
2412 local_initial_channel_credit: local_settings.initial_channel_credit,
2413 peer_initial_channel_credit: peer_settings.initial_channel_credit,
2414 observer,
2415 }),
2416 in_flight_handlers: BTreeMap::new(),
2417 handler_futs: FuturesUnordered::new(),
2418 live_operations: Arc::new(SyncMutex::new(
2419 "driver.live_operations",
2420 LiveOperationTracker::new(),
2421 )),
2422 local_control_tx,
2423 drop_control_seed: control_tx,
2424 drop_control_request,
2425 drop_guard: SyncMutex::new("driver.drop_guard", None),
2426 }
2427 }
2428
2429 fn existing_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
2435 self.drop_guard.lock().as_ref().and_then(Weak::upgrade)
2436 }
2437
2438 fn connection_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
2439 if let Some(existing) = self.existing_drop_guard() {
2440 Some(existing)
2441 } else if let Some(seed) = &self.drop_control_seed {
2442 let mut guard = self.drop_guard.lock();
2443 if let Some(existing) = guard.as_ref().and_then(Weak::upgrade) {
2444 Some(existing)
2445 } else {
2446 let arc = Arc::new(CallerDropGuard {
2447 control_tx: seed.clone(),
2448 request: self.drop_control_request,
2449 });
2450 *guard = Some(Arc::downgrade(&arc));
2451 Some(arc)
2452 }
2453 } else {
2454 None
2455 }
2456 }
2457
2458 pub fn caller(&self) -> DriverCaller {
2459 let drop_guard = self.connection_drop_guard();
2460 DriverCaller {
2461 sender: self.sender.clone(),
2462 shared: Arc::clone(&self.shared),
2463 local_control_tx: self.local_control_tx.clone(),
2464 closed_rx: self.closed_rx.clone(),
2465 resumed_rx: self.resumed_rx.clone(),
2466 resume_processed_rx: self.resume_processed_tx.subscribe(),
2467 peer_supports_retry: self.peer_supports_retry,
2468 _drop_guard: drop_guard,
2469 }
2470 }
2471
2472 pub fn debug_snapshot(&self) -> VoxDebugSnapshot {
2474 self.shared.debug_snapshot(
2475 &self.sender,
2476 self.shared
2477 .connection_debug_state(self.closed_rx.borrow().is_some()),
2478 DriverTaskStatus::Alive,
2479 )
2480 }
2481
2482 pub fn dump_debug_snapshot(&self) -> VoxDebugSnapshot {
2483 let snapshot = self.debug_snapshot();
2484 tracing::info!(?snapshot, "vox debug snapshot");
2485 snapshot
2486 }
2487
2488 fn internal_binder(&self) -> DriverChannelBinder {
2489 DriverChannelBinder {
2490 sender: self.sender.clone(),
2491 shared: Arc::clone(&self.shared),
2492 local_control_tx: self.local_control_tx.clone(),
2493 drop_guard: self.existing_drop_guard(),
2494 }
2495 }
2496
2497 pub async fn run(&mut self) {
2502 let mut resumed_rx = self.resumed_rx.clone();
2503 let mut seen_resume_generation = *resumed_rx.borrow();
2504 loop {
2505 tracing::trace!("driver select loop top");
2506 tokio::select! {
2507 biased;
2508 changed = resumed_rx.changed() => {
2509 if changed.is_err() {
2510 tracing::trace!(
2511 conn_id = self.sender.connection_id().0,
2512 "resume notifier closed, exiting driver"
2513 );
2514 break;
2515 }
2516 let generation = *resumed_rx.borrow();
2517 if generation != seen_resume_generation {
2518 seen_resume_generation = generation;
2519 self.close_all_channel_runtime_state(ChannelRuntimeTeardown::DropOnly);
2520 self.abort_channel_handlers();
2521 let _ = self.resume_processed_tx.send(generation);
2522 }
2523 }
2524 Some(ctrl) = self.local_control_rx.recv() => {
2525 self.handle_local_control(ctrl).await;
2526 }
2527 Some((req_id, disposition)) = self.failures_rx.recv() => {
2528 tracing::trace!(%req_id, ?disposition, "failures_rx fired");
2529 let in_flight_found = self.in_flight_handlers.contains_key(&req_id);
2530 let in_flight_method_id =
2531 self.in_flight_handlers.get(&req_id).map(|in_flight| in_flight.method_id);
2532 let reply_disposition = self
2533 .in_flight_handlers
2534 .get(&req_id)
2535 .map(|in_flight| {
2536 let has_channels =
2537 self.handler.args_have_channels(in_flight.method_id);
2538 if has_channels && !in_flight.retry.idem {
2539 Some(FailureDisposition::Indeterminate)
2540 } else if has_channels && in_flight.retry.idem {
2541 None
2542 } else {
2543 Some(disposition)
2544 }
2545 })
2546 .unwrap_or(Some(disposition));
2547 tracing::trace!(%req_id, in_flight_found, ?reply_disposition, "failures_rx computed disposition");
2548 self.in_flight_handlers.remove(&req_id);
2550 self.shared.finish_request(req_id, RequestDebugState::Failed);
2551 tracing::trace!(%req_id, in_flight = self.in_flight_handlers.len(), "handler removed on failure");
2552 let had_pending = self.shared.pending_responses.lock().remove(&req_id).is_some();
2553 tracing::trace!(%req_id, had_pending, "failures_rx checked pending_responses");
2554 if !had_pending {
2555 let Some(reply_disposition) = reply_disposition else {
2556 tracing::trace!(%req_id, "failures_rx: no reply_disposition, skipping");
2557 continue;
2558 };
2559 tracing::trace!(%req_id, ?reply_disposition, "failures_rx: sending error response");
2560 let vox_error = match reply_disposition {
2561 FailureDisposition::Cancelled => VoxError::Cancelled,
2562 FailureDisposition::Indeterminate => VoxError::Indeterminate,
2563 };
2564 if let Some(method_id) = in_flight_method_id
2565 && let Some(response_shape) = self.handler.response_wire_shape(method_id)
2566 && let Ok(extracted) = vox_types::extract_schemas(response_shape)
2567 {
2568 let registry = vox_types::build_registry(&extracted.schemas);
2569 let error: Result<(), VoxError<core::convert::Infallible>> =
2570 Err(vox_error);
2571 let encoded = vox_postcard::to_vec(&error)
2572 .expect("serialize runtime-generated error response");
2573 let mut response = RequestResponse {
2574 ret: Payload::PostcardBytes(Box::leak(encoded.into_boxed_slice())),
2575 metadata: Default::default(),
2576 schemas: Default::default(),
2577 };
2578 self.sender.prepare_response_from_source(
2579 req_id,
2580 method_id,
2581 &extracted.root,
2582 ®istry,
2583 &mut response,
2584 );
2585 let _ = self.sender.send_response(req_id, response).await;
2586 } else {
2587 let error: Result<(), VoxError<core::convert::Infallible>> =
2588 Err(vox_error);
2589 let _ = self.sender.send_response(req_id, RequestResponse {
2590 ret: Payload::outgoing(&error),
2591 metadata: Default::default(),
2592 schemas: Default::default(),
2593 }).await;
2594 }
2595 tracing::trace!(%req_id, "failures_rx: error response sent");
2596 }
2597 }
2598 recv = self.rx.recv() => {
2599 match recv {
2600 Some(recv) => {
2601 self.handle_recv(recv).await;
2602 }
2603 None => {
2604 tracing::trace!("driver rx closed, exiting loop");
2605 break;
2606 }
2607 }
2608 }
2609 Some(item) = self.handler_futs.next(), if !self.handler_futs.is_empty() => {
2615 match item {
2616 Ok(req_id) => {
2617 let removed = self.in_flight_handlers.remove(&req_id).is_some();
2618 self.shared.finish_request(req_id, RequestDebugState::Finished);
2619 tracing::trace!(
2620 %req_id,
2621 removed,
2622 in_flight = self.in_flight_handlers.len(),
2623 "handler completion processed",
2624 );
2625 }
2626 Err(_aborted) => {
2627 }
2630 }
2631 }
2632 }
2633 }
2634
2635 for (_, in_flight) in std::mem::take(&mut self.in_flight_handlers) {
2636 if !in_flight.retry.persist {
2637 in_flight.abort.abort();
2638 }
2639 }
2640 self.shared.pending_responses.lock().clear();
2641 self.shared.request_debug.lock().clear();
2642 let close_reason =
2643 (*self.closed_rx.borrow()).unwrap_or(ConnectionCloseReason::SessionShutdown);
2644 self.shared.set_connection_closed(close_reason);
2645
2646 self.close_all_channel_runtime_state(ChannelRuntimeTeardown::ConnectionClosed(
2650 close_reason,
2651 ));
2652 }
2653
2654 async fn handle_local_control(&mut self, control: DriverLocalControl) {
2655 match control {
2656 DriverLocalControl::CloseChannel { channel_id } => {
2657 if self.shared.stale_close_channels.lock().remove(&channel_id) {
2662 tracing::trace!(%channel_id, "suppressing ChannelClose for stale channel");
2663 return;
2664 }
2665 self.close_outbound_channel(channel_id);
2666 self.shared
2667 .observe_channel(channel_id, None, |channel| ChannelEvent::Closed {
2668 channel,
2669 reason: ChannelCloseReason::Local,
2670 });
2671 self.shared.mark_outbound_progress();
2672 let _ = self
2673 .sender
2674 .send(ConnectionMessage::Channel(ChannelMessage {
2675 id: channel_id,
2676 body: ChannelBody::Close(ChannelClose {
2677 metadata: Default::default(),
2678 }),
2679 }))
2680 .await;
2681 }
2682 DriverLocalControl::ResetChannel { channel_id } => {
2683 self.shared.channel_senders.lock().remove(&channel_id);
2684 self.shared.channel_receivers.lock().remove(&channel_id);
2685 self.close_outbound_channel(channel_id);
2686 self.shared
2687 .observe_channel(channel_id, None, |channel| ChannelEvent::Reset {
2688 channel,
2689 reason: ChannelResetReason::Local,
2690 });
2691 self.shared.mark_outbound_progress();
2692 let _ = self
2693 .sender
2694 .send(ConnectionMessage::Channel(ChannelMessage {
2695 id: channel_id,
2696 body: ChannelBody::Reset(vox_types::ChannelReset {
2697 metadata: Default::default(),
2698 }),
2699 }))
2700 .await;
2701 }
2702 DriverLocalControl::GrantCredit {
2703 channel_id,
2704 additional,
2705 } => {
2706 self.shared.observe_channel(channel_id, None, |channel| {
2707 ChannelEvent::CreditGranted {
2708 channel,
2709 amount: additional,
2710 }
2711 });
2712 self.shared.mark_outbound_progress();
2713 let _ = self
2714 .sender
2715 .send(ConnectionMessage::Channel(ChannelMessage {
2716 id: channel_id,
2717 body: ChannelBody::GrantCredit(vox_types::ChannelGrantCredit {
2718 additional,
2719 }),
2720 }))
2721 .await;
2722 }
2723 }
2724 }
2725
2726 async fn handle_recv(&mut self, recv: crate::session::RecvMessage) {
2727 self.shared.mark_inbound_progress();
2728 let crate::session::RecvMessage { schemas, msg } = recv;
2729 let msg_ref = msg.get();
2730 let is_request = matches!(msg_ref, ConnectionMessage::Request(_));
2731 if is_request {
2732 if let ConnectionMessage::Request(req) = msg_ref {
2733 vox_types::dlog!(
2734 "[driver] handle_recv request: conn={:?} req={:?} body={} method={:?}",
2735 self.sender.connection_id(),
2736 req.id,
2737 match &req.body {
2738 RequestBody::Call(_) => "Call",
2739 RequestBody::Response(_) => "Response",
2740 RequestBody::Cancel(_) => "Cancel",
2741 },
2742 match &req.body {
2743 RequestBody::Call(call) => Some(call.method_id),
2744 RequestBody::Response(_) | RequestBody::Cancel(_) => None,
2745 }
2746 );
2747 match &req.body {
2748 RequestBody::Call(call) => tracing::trace!(
2749 conn_id = self.sender.connection_id().0,
2750 req_id = req.id.0,
2751 method_id = call.method_id.0,
2752 "driver received call"
2753 ),
2754 RequestBody::Response(_) => tracing::trace!(
2755 conn_id = self.sender.connection_id().0,
2756 req_id = req.id.0,
2757 "driver received response message"
2758 ),
2759 RequestBody::Cancel(_) => tracing::trace!(
2760 conn_id = self.sender.connection_id().0,
2761 req_id = req.id.0,
2762 "driver received cancel message"
2763 ),
2764 }
2765 }
2766 let msg = msg.map(|m| match m {
2767 ConnectionMessage::Request(r) => r,
2768 _ => unreachable!(),
2769 });
2770 self.handle_request(msg, schemas);
2771 } else {
2772 let msg = msg.map(|m| match m {
2773 ConnectionMessage::Channel(c) => c,
2774 _ => unreachable!(),
2775 });
2776 self.handle_channel(msg).await;
2777 }
2778 }
2779
2780 fn handle_request(
2781 &mut self,
2782 msg: SelfRef<RequestMessage<'static>>,
2783 schemas: Arc<vox_types::SchemaRecvTracker>,
2784 ) {
2785 let msg_ref = msg.get();
2786 let req_id = msg_ref.id;
2787 let is_call = matches!(&msg_ref.body, RequestBody::Call(_));
2788 let is_response = matches!(&msg_ref.body, RequestBody::Response(_));
2789 let is_cancel = matches!(&msg_ref.body, RequestBody::Cancel(_));
2790
2791 if is_call {
2792 let method_id = match &msg_ref.body {
2793 RequestBody::Call(call) => call.method_id,
2794 _ => unreachable!(),
2795 };
2796 vox_types::dlog!(
2797 "[driver] inbound call: conn={:?} req={:?} method={:?}",
2798 self.sender.connection_id(),
2799 req_id,
2800 method_id
2801 );
2802 let call = msg.map(|m| match m.body {
2805 RequestBody::Call(c) => c,
2806 _ => unreachable!(),
2807 });
2808 let call_ref = call.get();
2809 let handler = Arc::clone(&self.handler);
2810 let retry = handler.retry_policy(call_ref.method_id);
2811 let operation_id = metadata_operation_id(&call_ref.metadata).filter(|_| !retry.idem);
2813 let method_id = call_ref.method_id;
2814
2815 if let Some(operation_id) = operation_id {
2816 let admit = self.live_operations.lock().admit(
2818 operation_id,
2819 call_ref.method_id,
2820 incoming_args_bytes(call_ref),
2821 retry,
2822 req_id,
2823 );
2824 match admit {
2825 AdmitResult::Attached => return,
2826 AdmitResult::Conflict => {
2827 let sender = self.sender.clone();
2828 moire::task::spawn(
2829 async move {
2830 let error: Result<(), VoxError<core::convert::Infallible>> =
2831 Err(VoxError::InvalidPayload("operation ID conflict".into()));
2832 let _ = sender
2833 .send_response(
2834 req_id,
2835 RequestResponse {
2836 ret: Payload::outgoing(&error),
2837 metadata: Default::default(),
2838 schemas: Default::default(),
2839 },
2840 )
2841 .await;
2842 }
2843 .named("operation_reject"),
2844 );
2845 return;
2846 }
2847 AdmitResult::Start => {}
2848 }
2849
2850 match self.shared.operations.lookup(operation_id) {
2852 crate::OperationState::Sealed => {
2853 if let Some(sealed) = self.shared.operations.get_sealed(operation_id) {
2855 let sender = self.sender.clone();
2856 let method_id = call_ref.method_id;
2857 let response_shape = self.handler.response_wire_shape(method_id);
2858 self.live_operations.lock().seal(operation_id);
2860 moire::task::spawn(
2861 async move {
2862 if replay_sealed_response(
2863 sender.clone(),
2864 req_id,
2865 method_id,
2866 sealed.response.as_bytes(),
2867 response_shape,
2868 )
2869 .await
2870 .is_err()
2871 {
2872 sender.mark_failure(req_id, FailureDisposition::Cancelled);
2873 }
2874 }
2875 .named("operation_replay"),
2876 );
2877 return;
2878 }
2879 }
2880 crate::OperationState::Admitted => {
2881 self.live_operations.lock().seal(operation_id);
2883 let sender = self.sender.clone();
2884 moire::task::spawn(
2885 async move {
2886 let error: Result<(), VoxError<core::convert::Infallible>> =
2887 Err(VoxError::Indeterminate);
2888 let _ = sender
2889 .send_response(
2890 req_id,
2891 RequestResponse {
2892 ret: Payload::outgoing(&error),
2893 metadata: Default::default(),
2894 schemas: Default::default(),
2895 },
2896 )
2897 .await;
2898 }
2899 .named("operation_indeterminate"),
2900 );
2901 return;
2902 }
2903 crate::OperationState::Unknown => {
2904 if !retry.idem {
2907 self.shared.operations.admit(operation_id);
2908 }
2909 }
2910 }
2911 }
2912 let reply = DriverReplySink {
2913 sender: Some(self.sender.clone()),
2914 request_id: req_id,
2915 method_id: call_ref.method_id,
2916 retry,
2917 operation_id,
2918 operations: operation_id.map(|_| Arc::clone(&self.shared.operations)),
2919 live_operations: operation_id.map(|_| Arc::clone(&self.live_operations)),
2920 binder: self.internal_binder(),
2921 handler_response_shape: handler.response_wire_shape(call_ref.method_id),
2922 };
2923 self.shared.start_request(
2924 req_id,
2925 method_id,
2926 None,
2927 None,
2928 RequestDebugState::Dispatching,
2929 );
2930 let (abort, abort_reg) = AbortHandle::new_pair();
2931 let handler_fut: Pin<Box<dyn MaybeSendFuture<Output = RequestId> + 'static>> =
2932 Box::pin(async move {
2933 vox_types::dlog!(
2934 "[driver] handler start: req={:?} method={:?}",
2935 req_id,
2936 method_id
2937 );
2938 handler.handle(call, reply, schemas).await;
2939 vox_types::dlog!(
2940 "[driver] handler done: req={:?} method={:?}",
2941 req_id,
2942 method_id
2943 );
2944 req_id
2945 });
2946 self.handler_futs
2947 .push(Abortable::new(handler_fut, abort_reg));
2948 self.in_flight_handlers.insert(
2949 req_id,
2950 InFlightHandler {
2951 abort,
2952 method_id,
2953 retry,
2954 operation_id,
2955 },
2956 );
2957 tracing::trace!(%req_id, in_flight = self.in_flight_handlers.len(), "handler inserted");
2958 } else if is_response {
2959 vox_types::dlog!(
2961 "[driver] inbound response: conn={:?} req={:?}",
2962 self.sender.connection_id(),
2963 req_id
2964 );
2965 tracing::trace!(%req_id, "driver received response");
2966 if let Some(tx) = self.shared.pending_responses.lock().remove(&req_id) {
2967 vox_types::dlog!("[driver] routing response to waiter: req={:?}", req_id);
2968 tracing::trace!(%req_id, "routing response to pending oneshot");
2969 let _: Result<(), _> = tx.send(PendingResponse { msg, schemas });
2970 } else {
2971 vox_types::dlog!("[driver] dropped unmatched response: req={:?}", req_id);
2972 tracing::trace!(%req_id, "no pending response slot for this req_id");
2973 }
2974 } else if is_cancel {
2975 vox_types::dlog!(
2976 "[driver] inbound cancel: conn={:?} req={:?}",
2977 self.sender.connection_id(),
2978 req_id
2979 );
2980 tracing::trace!(%req_id, in_flight = self.in_flight_handlers.contains_key(&req_id), "received cancel");
2983 match self.live_operations.lock().cancel(req_id) {
2984 CancelResult::NotFound => {
2985 let should_abort = self
2986 .in_flight_handlers
2987 .get(&req_id)
2988 .map(|in_flight| !in_flight.retry.persist)
2989 .unwrap_or(false);
2990 tracing::trace!(%req_id, should_abort, "cancel: not in live operations");
2991 if should_abort && let Some(in_flight) = self.in_flight_handlers.remove(&req_id)
2992 {
2993 tracing::trace!(%req_id, "aborting handler");
2994 in_flight.abort.abort();
2995 self.shared
2996 .finish_request(req_id, RequestDebugState::Failed);
2997 tracing::trace!(%req_id, in_flight = self.in_flight_handlers.len(), "handler removed on cancel");
2998 }
2999 }
3000 CancelResult::Detached => {}
3001 CancelResult::Abort {
3002 owner_request_id,
3003 waiters,
3004 } => {
3005 if let Some(in_flight) = self.in_flight_handlers.remove(&owner_request_id) {
3006 if let Some(op_id) = in_flight.operation_id {
3007 self.shared.operations.remove(op_id);
3008 }
3009 in_flight.abort.abort();
3010 self.shared
3011 .finish_request(owner_request_id, RequestDebugState::Failed);
3012 tracing::trace!(%owner_request_id, in_flight = self.in_flight_handlers.len(), "owner handler removed on abort");
3013 }
3014 for waiter in waiters {
3015 self.sender
3016 .mark_failure(waiter, FailureDisposition::Cancelled);
3017 }
3018 }
3019 }
3020 }
3023 }
3024
3025 async fn handle_channel(&mut self, msg: SelfRef<ChannelMessage<'static>>) {
3026 let msg_ref = msg.get();
3027 let chan_id = msg_ref.id;
3028 enum ChannelBodyKind {
3029 Item,
3030 Close,
3031 Reset,
3032 GrantCredit(u32),
3033 }
3034 let body_kind = match &msg_ref.body {
3035 ChannelBody::Item(_) => ChannelBodyKind::Item,
3036 ChannelBody::Close(_) => ChannelBodyKind::Close,
3037 ChannelBody::Reset(_) => ChannelBodyKind::Reset,
3038 ChannelBody::GrantCredit(grant) => ChannelBodyKind::GrantCredit(grant.additional),
3039 };
3040
3041 match body_kind {
3042 ChannelBodyKind::Item => {
3045 if self.shared.terminal_channels.lock().contains(&chan_id) {
3046 self.shared.record_inbound_item_not_enqueued(chan_id);
3047 tracing::trace!(
3048 conn_id = self.sender.connection_id().0,
3049 channel_id = chan_id.0,
3050 "driver dropped item for terminal channel"
3051 );
3052 return;
3053 }
3054
3055 tracing::trace!(
3056 conn_id = self.sender.connection_id().0,
3057 channel_id = chan_id.0,
3058 "driver received channel item"
3059 );
3060 let item = msg.map(|m| match m.body {
3061 ChannelBody::Item(item) => item,
3062 _ => unreachable!(),
3063 });
3064 let sender = self.shared.inbound_channel_sender(chan_id);
3065 if sender
3066 .send(IncomingChannelMessage::Item(item))
3067 .await
3068 .is_err()
3069 {
3070 self.shared.record_inbound_item_not_enqueued(chan_id);
3071 self.shared.channel_senders.lock().remove(&chan_id);
3072 self.shared.channel_receivers.lock().remove(&chan_id);
3073 self.close_outbound_channel(chan_id);
3074 let _ = self
3075 .local_control_tx
3076 .send(DriverLocalControl::ResetChannel {
3077 channel_id: chan_id,
3078 });
3079 return;
3080 }
3081 self.shared
3082 .observe_channel(chan_id, None, |channel| ChannelEvent::ItemReceived {
3083 channel,
3084 });
3085 }
3086 ChannelBodyKind::Close => {
3088 if self.shared.terminal_channels.lock().contains(&chan_id) {
3089 return;
3090 }
3091 let sender = self.shared.inbound_channel_sender(chan_id);
3092 tracing::trace!(
3093 conn_id = self.sender.connection_id().0,
3094 channel_id = chan_id.0,
3095 "driver received channel close"
3096 );
3097 let close = msg.map(|m| match m.body {
3098 ChannelBody::Close(close) => close,
3099 _ => unreachable!(),
3100 });
3101 let delivered = sender
3102 .send(IncomingChannelMessage::Close(close))
3103 .await
3104 .is_ok();
3105 self.shared.channel_senders.lock().remove(&chan_id);
3106 self.shared.terminal_channels.lock().insert(chan_id);
3107 self.close_outbound_channel(chan_id);
3108 if !delivered {
3109 self.shared.channel_receivers.lock().remove(&chan_id);
3110 return;
3111 }
3112 self.shared
3113 .observe_channel(chan_id, None, |channel| ChannelEvent::Closed {
3114 channel,
3115 reason: ChannelCloseReason::Remote,
3116 });
3117 }
3118 ChannelBodyKind::Reset => {
3120 if self.shared.terminal_channels.lock().contains(&chan_id) {
3121 return;
3122 }
3123 let sender = self.shared.inbound_channel_sender(chan_id);
3124 tracing::trace!(
3125 conn_id = self.sender.connection_id().0,
3126 channel_id = chan_id.0,
3127 "driver received channel reset"
3128 );
3129 let reset = msg.map(|m| match m.body {
3130 ChannelBody::Reset(reset) => reset,
3131 _ => unreachable!(),
3132 });
3133 let delivered = sender
3134 .send(IncomingChannelMessage::Reset(reset))
3135 .await
3136 .is_ok();
3137 self.shared.channel_senders.lock().remove(&chan_id);
3138 self.shared.terminal_channels.lock().insert(chan_id);
3139 self.close_outbound_channel(chan_id);
3140 if !delivered {
3141 self.shared.channel_receivers.lock().remove(&chan_id);
3142 return;
3143 }
3144 self.shared
3145 .observe_channel(chan_id, None, |channel| ChannelEvent::Reset {
3146 channel,
3147 reason: ChannelResetReason::Remote,
3148 });
3149 }
3150 ChannelBodyKind::GrantCredit(additional) => {
3153 self.shared.record_credit_received(chan_id, additional);
3154 self.shared.emit_channel_event(chan_id, None, |channel| {
3155 ChannelEvent::CreditGranted {
3156 channel,
3157 amount: additional,
3158 }
3159 });
3160 tracing::trace!(
3161 conn_id = self.sender.connection_id().0,
3162 channel_id = chan_id.0,
3163 additional,
3164 "driver received channel credit"
3165 );
3166 if let Some(semaphore) = self.shared.channel_credits.lock().get(&chan_id) {
3167 semaphore.add_permits(additional as usize);
3168 }
3169 }
3170 }
3171 }
3172}