1use std::{
2 collections::{BTreeMap, HashMap, HashSet},
3 pin::Pin,
4 sync::{
5 Arc, Weak,
6 atomic::{AtomicU64, Ordering},
7 },
8};
9
10use vox_types::time::Instant;
11
12use futures_util::future::{AbortHandle, Abortable};
13use futures_util::stream::{FuturesUnordered, StreamExt as _};
14use moire::sync::{Semaphore, SyncMutex};
15use tokio::sync::watch;
16
17use moire::task::FutureExt as _;
18use vox_types::{
19 BoxFut, CallResult, ChannelBinder, ChannelBody, ChannelClose, ChannelCreditReplenisher,
20 ChannelCreditReplenisherHandle, ChannelEventContext, ChannelId, ChannelItem,
21 ChannelLivenessHandle, ChannelMailboxReceiver, ChannelMailboxSender, ChannelMessage,
22 ChannelRetryMode, ChannelSink, ConnectionId, CreditSink, Handler, IdAllocator,
23 IncomingChannelMessage, MaybeSend, MaybeSendFuture, MaybeSync, Payload, ReplySink, RequestBody,
24 RequestCall, RequestId, RequestMessage, RequestResponse, SelfRef, TrySendError, TxError,
25 VoxError, channel_mailbox, ensure_operation_id, metadata_channel_retry_mode,
26 metadata_operation_id,
27};
28use vox_types::{
29 ChannelCloseReason, ChannelDebugContext, ChannelDirection, ChannelEvent, ChannelResetReason,
30 ChannelSendOutcome, ChannelTrySendOutcome, DriverEvent, RpcOutcome, VoxObserverHandle,
31};
32use vox_types::{
33 ChannelDebugSnapshot, ChannelReceiverState, ConnectionCloseReason, ConnectionDebugSnapshot,
34 ConnectionDebugState, DriverTaskStatus, RequestDebugSnapshot, RequestDebugState,
35 VoxDebugSnapshot,
36};
37
38use crate::session::{
39 ConnectionHandle, ConnectionMessage, ConnectionSender, DropControlRequest, FailureDisposition,
40};
41use crate::{InMemoryOperationStore, OperationStore};
42use moire::sync::mpsc;
43use vox_types::{OperationId, PostcardPayload};
44
45struct PendingResponse {
51 msg: SelfRef<RequestMessage<'static>>,
52 schemas: Arc<vox_types::SchemaRecvTracker>,
53 fds: vox_types::FrameFds,
57}
58
59type ResponseSlot = moire::sync::oneshot::Sender<PendingResponse>;
60
61struct InFlightHandler {
62 abort: AbortHandle,
67 method_id: vox_types::MethodId,
68 retry: vox_types::RetryPolicy,
69 operation_id: Option<OperationId>,
70}
71
72type HandlerFut = Abortable<Pin<Box<dyn MaybeSendFuture<Output = RequestId> + 'static>>>;
85
86#[derive(Clone, Copy, Debug)]
87enum ChannelRuntimeTeardown {
88 DropOnly,
89 ConnectionClosed(ConnectionCloseReason),
90}
91
92struct LiveOperationTracker {
102 live: HashMap<OperationId, LiveOperation>,
104 request_to_operation: HashMap<RequestId, OperationId>,
106}
107
108struct LiveOperation {
109 method_id: vox_types::MethodId,
110 args_hash: u64,
111 owner_request_id: RequestId,
112 waiters: Vec<RequestId>,
113 retry: vox_types::RetryPolicy,
114}
115
116enum AdmitResult {
117 Start,
119 Attached,
121 Conflict,
123}
124
125impl LiveOperationTracker {
126 fn new() -> Self {
127 Self {
128 live: HashMap::new(),
129 request_to_operation: HashMap::new(),
130 }
131 }
132
133 fn admit(
134 &mut self,
135 operation_id: OperationId,
136 method_id: vox_types::MethodId,
137 args: &[u8],
138 retry: vox_types::RetryPolicy,
139 request_id: RequestId,
140 ) -> AdmitResult {
141 use std::hash::{Hash, Hasher};
142 let args_hash = {
143 let mut h = std::collections::hash_map::DefaultHasher::new();
144 method_id.hash(&mut h);
145 args.hash(&mut h);
146 h.finish()
147 };
148 let live_operations = self.live.len();
149
150 if let Some(live) = self.live.get_mut(&operation_id) {
151 if live.method_id != method_id || live.args_hash != args_hash {
152 let request_bindings = self.request_to_operation.len();
153 tracing::trace!(
154 %operation_id,
155 %request_id,
156 ?method_id,
157 live_operations,
158 request_bindings,
159 "live operation conflict"
160 );
161 return AdmitResult::Conflict;
162 }
163 live.waiters.push(request_id);
164 self.request_to_operation.insert(request_id, operation_id);
165 let waiters = live.waiters.len();
166 let request_bindings = self.request_to_operation.len();
167 tracing::trace!(
168 %operation_id,
169 %request_id,
170 ?method_id,
171 waiters,
172 live_operations,
173 request_bindings,
174 "live operation attached"
175 );
176 return AdmitResult::Attached;
177 }
178
179 self.live.insert(
180 operation_id,
181 LiveOperation {
182 method_id,
183 args_hash,
184 owner_request_id: request_id,
185 waiters: vec![request_id],
186 retry,
187 },
188 );
189 self.request_to_operation.insert(request_id, operation_id);
190 let live_operations = self.live.len();
191 let request_bindings = self.request_to_operation.len();
192 tracing::trace!(
193 %operation_id,
194 %request_id,
195 ?method_id,
196 live_operations,
197 request_bindings,
198 "live operation admitted"
199 );
200 AdmitResult::Start
201 }
202
203 fn seal(&mut self, operation_id: OperationId) -> Vec<RequestId> {
205 if let Some(live) = self.live.remove(&operation_id) {
206 for waiter in &live.waiters {
207 self.request_to_operation.remove(waiter);
208 }
209 let waiters = live.waiters.len();
210 let live_operations = self.live.len();
211 let request_bindings = self.request_to_operation.len();
212 tracing::trace!(
213 %operation_id,
214 waiters,
215 live_operations,
216 request_bindings,
217 "live operation sealed"
218 );
219 live.waiters
220 } else {
221 vec![]
222 }
223 }
224
225 fn release(&mut self, operation_id: OperationId) -> Option<LiveOperation> {
227 if let Some(live) = self.live.remove(&operation_id) {
228 for waiter in &live.waiters {
229 self.request_to_operation.remove(waiter);
230 }
231 let waiters = live.waiters.len();
232 let live_operations = self.live.len();
233 let request_bindings = self.request_to_operation.len();
234 tracing::trace!(
235 %operation_id,
236 waiters,
237 live_operations,
238 request_bindings,
239 "live operation released"
240 );
241 Some(live)
242 } else {
243 None
244 }
245 }
246
247 fn cancel(&mut self, request_id: RequestId) -> CancelResult {
249 let Some(&operation_id) = self.request_to_operation.get(&request_id) else {
250 return CancelResult::NotFound;
251 };
252 let live_operations = self.live.len();
253 let Some(live) = self.live.get_mut(&operation_id) else {
254 self.request_to_operation.remove(&request_id);
255 return CancelResult::NotFound;
256 };
257
258 if live.retry.persist {
259 if live.owner_request_id == request_id {
261 return CancelResult::NotFound; }
263 live.waiters.retain(|w| *w != request_id);
264 self.request_to_operation.remove(&request_id);
265 let waiters = live.waiters.len();
266 let request_bindings = self.request_to_operation.len();
267 tracing::trace!(
268 %operation_id,
269 %request_id,
270 waiters,
271 live_operations,
272 request_bindings,
273 "live operation detached waiter"
274 );
275 CancelResult::Detached
276 } else {
277 let live = self.live.remove(&operation_id).unwrap();
279 for waiter in &live.waiters {
280 self.request_to_operation.remove(waiter);
281 }
282 let waiters = live.waiters.len();
283 let live_operations = self.live.len();
284 let request_bindings = self.request_to_operation.len();
285 tracing::trace!(
286 %operation_id,
287 %request_id,
288 waiters,
289 live_operations,
290 request_bindings,
291 "live operation aborted"
292 );
293 CancelResult::Abort {
294 owner_request_id: live.owner_request_id,
295 waiters: live.waiters,
296 }
297 }
298 }
299}
300
301enum CancelResult {
302 NotFound,
303 Detached,
304 Abort {
305 owner_request_id: RequestId,
306 waiters: Vec<RequestId>,
307 },
308}
309
310#[derive(Clone)]
311struct RequestRuntimeDebug {
312 method_id: vox_types::MethodId,
313 service: Option<&'static str>,
314 method: Option<&'static str>,
315 started_at: Instant,
316 state: RequestDebugState,
317 response_sender_blocked: Option<bool>,
318 associated_channels: Vec<ChannelId>,
319}
320
321impl RequestRuntimeDebug {
322 fn snapshot(&self, request_id: RequestId, now: Instant) -> RequestDebugSnapshot {
323 RequestDebugSnapshot {
324 request_id,
325 service: self.service,
326 method: self.method,
327 method_id: self.method_id,
328 age: now.saturating_duration_since(self.started_at),
329 state: self.state,
330 response_sender_blocked: self.response_sender_blocked,
331 associated_channels: self.associated_channels.clone(),
332 }
333 }
334}
335
336#[derive(Clone)]
337struct ChannelRuntimeDebug {
338 direction: ChannelDirection,
339 debug: Option<ChannelDebugContext>,
340 initial_credit: u32,
341 inbound_queue_len: usize,
342 inbound_queue_capacity: Option<usize>,
343 receiver_state: ChannelReceiverState,
344 last_item_sent_at: Option<Instant>,
345 last_item_received_at: Option<Instant>,
346 last_item_consumed_at: Option<Instant>,
347 last_credit_granted_at: Option<Instant>,
348 last_credit_received_at: Option<Instant>,
349 last_credit_granted_amount: Option<u32>,
350 last_credit_received_amount: Option<u32>,
351 pending_local_grant_credit: u32,
352 total_credit_granted: u64,
353 total_credit_received: u64,
354 sent: u64,
355 sends_started: u64,
356 sends_completed: u64,
357 sends_waited_for_credit: u64,
358 try_send_full_credit: u64,
359 try_send_full_runtime_queue: u64,
360 closed: u64,
361 reset: u64,
362 dropped: u64,
363 items_received: u64,
364 items_consumed: u64,
365 credit_granted: u64,
366 credit_received: u64,
367 close_reason: Option<ChannelCloseReason>,
368 reset_reason: Option<ChannelResetReason>,
369}
370
371impl ChannelRuntimeDebug {
372 fn new(
373 direction: ChannelDirection,
374 initial_credit: u32,
375 debug: Option<ChannelDebugContext>,
376 ) -> Self {
377 Self {
378 direction,
379 debug,
380 initial_credit,
381 inbound_queue_len: 0,
382 inbound_queue_capacity: match direction {
383 ChannelDirection::Rx => Some(initial_credit as usize),
384 ChannelDirection::Tx => None,
385 },
386 receiver_state: ChannelReceiverState::Present,
387 last_item_sent_at: None,
388 last_item_received_at: None,
389 last_item_consumed_at: None,
390 last_credit_granted_at: None,
391 last_credit_received_at: None,
392 last_credit_granted_amount: None,
393 last_credit_received_amount: None,
394 pending_local_grant_credit: 0,
395 total_credit_granted: 0,
396 total_credit_received: 0,
397 sent: 0,
398 sends_started: 0,
399 sends_completed: 0,
400 sends_waited_for_credit: 0,
401 try_send_full_credit: 0,
402 try_send_full_runtime_queue: 0,
403 closed: 0,
404 reset: 0,
405 dropped: 0,
406 items_received: 0,
407 items_consumed: 0,
408 credit_granted: 0,
409 credit_received: 0,
410 close_reason: None,
411 reset_reason: None,
412 }
413 }
414
415 fn merge_debug(&mut self, debug: Option<ChannelDebugContext>) {
416 if self.debug.is_none() {
417 self.debug = debug;
418 }
419 }
420
421 fn mark_item_received(&mut self, now: Instant) {
422 self.items_received = self.items_received.saturating_add(1);
423 self.inbound_queue_len = self.inbound_queue_len.saturating_add(1);
424 self.last_item_received_at = Some(now);
425 }
426
427 fn mark_closed(&mut self, reason: ChannelCloseReason) {
428 self.closed = self.closed.saturating_add(1);
429 self.close_reason = Some(reason);
430 self.receiver_state = ChannelReceiverState::Closed;
431 if reason == ChannelCloseReason::Dropped {
432 self.dropped = self.dropped.saturating_add(1);
433 self.receiver_state = ChannelReceiverState::Dropped;
434 }
435 }
436
437 fn mark_reset(&mut self, reason: ChannelResetReason) {
438 self.reset = self.reset.saturating_add(1);
439 self.reset_reason = Some(reason);
440 self.receiver_state = ChannelReceiverState::Reset;
441 }
442
443 fn mark_send_started(&mut self) {
444 self.sends_started = self.sends_started.saturating_add(1);
445 }
446
447 fn mark_send_waiting_for_credit(&mut self) {
448 self.sends_waited_for_credit = self.sends_waited_for_credit.saturating_add(1);
449 }
450
451 fn mark_send_finished(&mut self, outcome: ChannelSendOutcome, now: Instant) {
452 self.sends_completed = self.sends_completed.saturating_add(1);
453 if outcome == ChannelSendOutcome::Sent {
454 self.sent = self.sent.saturating_add(1);
455 self.last_item_sent_at = Some(now);
456 }
457 }
458
459 fn mark_try_send_outcome(&mut self, outcome: ChannelTrySendOutcome, now: Instant) {
460 match outcome {
461 ChannelTrySendOutcome::Sent => {
462 self.sent = self.sent.saturating_add(1);
463 self.last_item_sent_at = Some(now);
464 }
465 ChannelTrySendOutcome::FullCredit => {
466 self.try_send_full_credit = self.try_send_full_credit.saturating_add(1);
467 }
468 ChannelTrySendOutcome::FullRuntimeQueue => {
469 self.try_send_full_runtime_queue =
470 self.try_send_full_runtime_queue.saturating_add(1);
471 }
472 ChannelTrySendOutcome::Unbound | ChannelTrySendOutcome::Closed => {}
473 }
474 }
475
476 fn mark_item_consumed(&mut self, now: Instant) {
477 self.items_consumed = self.items_consumed.saturating_add(1);
478 self.inbound_queue_len = self.inbound_queue_len.saturating_sub(1);
479 self.last_item_consumed_at = Some(now);
480 }
481
482 fn mark_inbound_item_not_enqueued(&mut self) {
483 self.inbound_queue_len = self.inbound_queue_len.saturating_sub(1);
484 }
485
486 fn mark_credit_granted(&mut self, amount: u32, now: Instant) {
487 self.credit_granted = self.credit_granted.saturating_add(1);
488 self.total_credit_granted = self.total_credit_granted.saturating_add(amount as u64);
489 self.last_credit_granted_at = Some(now);
490 self.last_credit_granted_amount = Some(amount);
491 self.pending_local_grant_credit = 0;
492 }
493
494 fn mark_credit_received(&mut self, amount: u32, now: Instant) {
495 self.credit_received = self.credit_received.saturating_add(1);
496 self.total_credit_received = self.total_credit_received.saturating_add(amount as u64);
497 self.last_credit_received_at = Some(now);
498 self.last_credit_received_amount = Some(amount);
499 }
500
501 fn mark_receiver_dropped(&mut self) {
502 self.reset = self.reset.saturating_add(1);
503 self.reset_reason = Some(ChannelResetReason::ReceiverDropped);
504 self.receiver_state = ChannelReceiverState::Dropped;
505 self.dropped = self.dropped.saturating_add(1);
506 }
507
508 fn snapshot(
509 &self,
510 connection_id: ConnectionId,
511 channel_id: ChannelId,
512 available_send_credit: Option<u32>,
513 ) -> ChannelDebugSnapshot {
514 ChannelDebugSnapshot {
515 connection_id,
516 channel_id,
517 direction: self.direction,
518 debug: self.debug,
519 initial_credit: self.initial_credit,
520 available_send_credit,
521 inbound_queue_len: Some(self.inbound_queue_len),
522 inbound_queue_capacity: self.inbound_queue_capacity,
523 outbound_runtime_queue_len: None,
524 outbound_runtime_queue_capacity: None,
525 send_waiters_count: None,
526 receiver_state: self.receiver_state,
527 last_item_sent_at: self.last_item_sent_at,
528 last_item_received_at: self.last_item_received_at,
529 last_item_consumed_at: self.last_item_consumed_at,
530 last_credit_granted_at: self.last_credit_granted_at,
531 last_credit_received_at: self.last_credit_received_at,
532 last_credit_granted_amount: self.last_credit_granted_amount,
533 last_credit_received_amount: self.last_credit_received_amount,
534 pending_local_grant_credit: self.pending_local_grant_credit,
535 total_credit_granted: self.total_credit_granted,
536 total_credit_received: self.total_credit_received,
537 current_permit_count: available_send_credit,
538 zero_credit_with_blocked_senders: available_send_credit == Some(0)
539 && self.sends_waited_for_credit > 0,
540 sent: self.sent,
541 sends_started: self.sends_started,
542 sends_completed: self.sends_completed,
543 sends_waited_for_credit: self.sends_waited_for_credit,
544 try_send_full_credit: self.try_send_full_credit,
545 try_send_full_runtime_queue: self.try_send_full_runtime_queue,
546 closed: self.closed,
547 reset: self.reset,
548 dropped: self.dropped,
549 items_received: self.items_received,
550 items_consumed: self.items_consumed,
551 credit_granted: self.credit_granted,
552 credit_received: self.credit_received,
553 close_reason: self.close_reason,
554 reset_reason: self.reset_reason,
555 }
556 }
557}
558
559struct DriverShared {
564 connection_id: ConnectionId,
565 pending_responses: SyncMutex<BTreeMap<RequestId, ResponseSlot>>,
566 request_ids: SyncMutex<IdAllocator<RequestId>>,
567 next_operation_id: AtomicU64,
568 operations: Arc<dyn OperationStore>,
569 channel_ids: SyncMutex<IdAllocator<ChannelId>>,
570 channel_senders: SyncMutex<BTreeMap<ChannelId, ChannelMailboxSender<IncomingChannelMessage>>>,
572 channel_receivers:
575 SyncMutex<BTreeMap<ChannelId, ChannelMailboxReceiver<IncomingChannelMessage>>>,
576 channel_credits: SyncMutex<BTreeMap<ChannelId, Arc<Semaphore>>>,
579 channel_contexts: SyncMutex<BTreeMap<ChannelId, ChannelDebugContext>>,
581 request_debug: SyncMutex<BTreeMap<RequestId, RequestRuntimeDebug>>,
583 channel_debug: SyncMutex<BTreeMap<ChannelId, ChannelRuntimeDebug>>,
585 last_inbound_message_at: SyncMutex<Option<Instant>>,
586 last_outbound_message_at: SyncMutex<Option<Instant>>,
587 close_reason: SyncMutex<Option<ConnectionCloseReason>>,
588 terminal_channels: SyncMutex<HashSet<ChannelId>>,
592 stale_close_channels: SyncMutex<std::collections::HashSet<ChannelId>>,
597 local_initial_channel_credit: u32,
599 peer_initial_channel_credit: u32,
601 observer: Option<VoxObserverHandle>,
602}
603
604impl DriverShared {
605 fn remember_channel_context(
606 &self,
607 channel_id: ChannelId,
608 debug_context: Option<ChannelDebugContext>,
609 ) {
610 if let Some(debug_context) = debug_context.and_then(ChannelDebugContext::into_option) {
611 self.channel_contexts
612 .lock()
613 .insert(channel_id, debug_context);
614 if let Some(channel) = self.channel_debug.lock().get_mut(&channel_id) {
615 channel.debug = Some(debug_context);
616 }
617 }
618 }
619
620 fn channel_event_context(
621 &self,
622 channel_id: ChannelId,
623 debug_context: Option<ChannelDebugContext>,
624 ) -> ChannelEventContext {
625 let debug = debug_context
626 .and_then(ChannelDebugContext::into_option)
627 .or_else(|| self.channel_contexts.lock().get(&channel_id).copied());
628 ChannelEventContext {
629 connection_id: Some(self.connection_id),
630 channel_id,
631 debug,
632 }
633 }
634
635 fn emit_channel_event(
636 &self,
637 channel_id: ChannelId,
638 debug_context: Option<ChannelDebugContext>,
639 event: impl FnOnce(ChannelEventContext) -> ChannelEvent,
640 ) {
641 if let Some(observer) = &self.observer {
642 observer.channel_event(event(self.channel_event_context(channel_id, debug_context)));
643 }
644 }
645
646 fn observe_channel(
647 &self,
648 channel_id: ChannelId,
649 debug_context: Option<ChannelDebugContext>,
650 event: impl FnOnce(ChannelEventContext) -> ChannelEvent,
651 ) {
652 let event = event(self.channel_event_context(channel_id, debug_context));
653 self.record_channel_event(event);
654 if let Some(observer) = &self.observer {
655 observer.channel_event(event);
656 }
657 }
658
659 fn update_channel_debug(
660 &self,
661 channel: ChannelEventContext,
662 default_direction: ChannelDirection,
663 default_initial_credit: u32,
664 update: impl FnOnce(&mut ChannelRuntimeDebug),
665 ) {
666 let mut channels = self.channel_debug.lock();
667 let entry = channels.entry(channel.channel_id).or_insert_with(|| {
668 ChannelRuntimeDebug::new(default_direction, default_initial_credit, channel.debug)
669 });
670 entry.merge_debug(channel.debug);
671 update(entry);
672 }
673
674 fn update_existing_channel_debug(
675 &self,
676 channel_id: ChannelId,
677 update: impl FnOnce(&mut ChannelRuntimeDebug),
678 ) {
679 if let Some(channel) = self.channel_debug.lock().get_mut(&channel_id) {
680 update(channel);
681 }
682 }
683
684 fn record_channel_event(&self, event: ChannelEvent) {
685 let now = Instant::now();
686 match event {
687 ChannelEvent::Opened {
688 channel,
689 direction,
690 initial_credit,
691 } => {
692 self.channel_debug.lock().insert(
693 channel.channel_id,
694 ChannelRuntimeDebug::new(direction, initial_credit, channel.debug),
695 );
696 }
697 ChannelEvent::ItemReceived { channel } => {
698 self.update_channel_debug(channel, ChannelDirection::Rx, 0, |entry| {
699 entry.mark_item_received(now);
700 });
701 }
702 ChannelEvent::Closed { channel, reason } => {
703 self.update_channel_debug(channel, ChannelDirection::Rx, 0, |entry| {
704 entry.mark_closed(reason);
705 });
706 }
707 ChannelEvent::Reset { channel, reason } => {
708 self.update_channel_debug(channel, ChannelDirection::Rx, 0, |entry| {
709 entry.mark_reset(reason);
710 });
711 }
712 ChannelEvent::CreditGranted { channel, amount } => {
713 self.record_credit_granted_at(channel.channel_id, amount, now);
714 }
715 ChannelEvent::SendStarted { channel } => {
716 self.record_send_started(channel.channel_id);
717 }
718 ChannelEvent::SendWaitingForCredit { channel } => {
719 self.record_send_waiting_for_credit(channel.channel_id);
720 }
721 ChannelEvent::SendFinished {
722 channel, outcome, ..
723 } => {
724 self.record_send_finished(channel.channel_id, outcome);
725 }
726 ChannelEvent::TrySend { channel, outcome } => {
727 self.record_try_send_outcome(channel.channel_id, outcome);
728 }
729 ChannelEvent::ItemConsumed { channel } => {
730 self.record_item_consumed(channel.channel_id);
731 }
732 }
733 }
734
735 fn mark_inbound_progress(&self) {
736 *self.last_inbound_message_at.lock() = Some(Instant::now());
737 }
738
739 fn mark_outbound_progress(&self) {
740 *self.last_outbound_message_at.lock() = Some(Instant::now());
741 }
742
743 fn start_request(
744 &self,
745 request_id: RequestId,
746 method_id: vox_types::MethodId,
747 service: Option<&'static str>,
748 method: Option<&'static str>,
749 state: RequestDebugState,
750 ) {
751 self.request_debug.lock().insert(
752 request_id,
753 RequestRuntimeDebug {
754 method_id,
755 service,
756 method,
757 started_at: Instant::now(),
758 state,
759 response_sender_blocked: Some(false),
760 associated_channels: Vec::new(),
761 },
762 );
763 }
764
765 fn finish_request(&self, request_id: RequestId, state: RequestDebugState) {
766 if let Some(request) = self.request_debug.lock().get_mut(&request_id) {
767 request.state = state;
768 }
769 self.request_debug.lock().remove(&request_id);
770 }
771
772 fn record_send_started(&self, channel_id: ChannelId) {
773 self.update_existing_channel_debug(channel_id, ChannelRuntimeDebug::mark_send_started);
774 }
775
776 fn record_send_waiting_for_credit(&self, channel_id: ChannelId) {
777 self.update_existing_channel_debug(
778 channel_id,
779 ChannelRuntimeDebug::mark_send_waiting_for_credit,
780 );
781 }
782
783 fn record_send_finished(&self, channel_id: ChannelId, outcome: ChannelSendOutcome) {
784 let now = Instant::now();
785 self.update_existing_channel_debug(channel_id, |channel| {
786 channel.mark_send_finished(outcome, now);
787 });
788 }
789
790 fn record_try_send_outcome(&self, channel_id: ChannelId, outcome: ChannelTrySendOutcome) {
791 let now = Instant::now();
792 self.update_existing_channel_debug(channel_id, |channel| {
793 channel.mark_try_send_outcome(outcome, now);
794 });
795 }
796
797 fn record_item_consumed(&self, channel_id: ChannelId) {
798 let now = Instant::now();
799 self.update_existing_channel_debug(channel_id, |channel| {
800 channel.mark_item_consumed(now);
801 });
802 }
803
804 fn record_inbound_item_not_enqueued(&self, channel_id: ChannelId) {
805 self.update_existing_channel_debug(
806 channel_id,
807 ChannelRuntimeDebug::mark_inbound_item_not_enqueued,
808 );
809 }
810
811 fn record_pending_local_grant(&self, channel_id: ChannelId, pending: u32) {
812 self.update_existing_channel_debug(channel_id, |channel| {
813 channel.pending_local_grant_credit = pending;
814 });
815 }
816
817 fn record_credit_granted_at(&self, channel_id: ChannelId, amount: u32, now: Instant) {
818 self.update_existing_channel_debug(channel_id, |channel| {
819 channel.mark_credit_granted(amount, now);
820 });
821 }
822
823 fn record_credit_received(&self, channel_id: ChannelId, amount: u32) {
824 let now = Instant::now();
825 self.update_existing_channel_debug(channel_id, |channel| {
826 channel.mark_credit_received(amount, now);
827 });
828 }
829
830 fn record_receiver_dropped(&self, channel_id: ChannelId) {
831 self.update_existing_channel_debug(channel_id, ChannelRuntimeDebug::mark_receiver_dropped);
832 }
833
834 fn new_channel_mailbox(
835 &self,
836 ) -> (
837 ChannelMailboxSender<IncomingChannelMessage>,
838 ChannelMailboxReceiver<IncomingChannelMessage>,
839 ) {
840 channel_mailbox(
841 "driver.channel_mailbox",
842 self.local_initial_channel_credit as usize,
843 )
844 }
845
846 fn inbound_channel_sender(
847 &self,
848 channel_id: ChannelId,
849 ) -> ChannelMailboxSender<IncomingChannelMessage> {
850 let mut senders = self.channel_senders.lock();
851 if let Some(sender) = senders.get(&channel_id) {
852 return sender.clone();
853 }
854
855 let (sender, receiver) = self.new_channel_mailbox();
856 senders.insert(channel_id, sender.clone());
857 self.channel_receivers.lock().insert(channel_id, receiver);
858 sender
859 }
860
861 fn register_inbound_channel_receiver(
862 &self,
863 channel_id: ChannelId,
864 ) -> (ChannelMailboxReceiver<IncomingChannelMessage>, bool) {
865 let terminal = self.terminal_channels.lock().contains(&channel_id);
866 let mut senders = self.channel_senders.lock();
867 let mut receivers = self.channel_receivers.lock();
868
869 if let Some(receiver) = receivers.remove(&channel_id) {
870 return (receiver, terminal);
871 }
872
873 let (sender, receiver) = self.new_channel_mailbox();
874 if terminal {
875 drop(sender);
876 } else {
877 senders.insert(channel_id, sender);
878 }
879 (receiver, terminal)
880 }
881
882 fn debug_snapshot(
883 &self,
884 sender: &ConnectionSender,
885 state: ConnectionDebugState,
886 driver_task_status: DriverTaskStatus,
887 ) -> VoxDebugSnapshot {
888 let now = Instant::now();
889 let requests: Vec<_> = self
890 .request_debug
891 .lock()
892 .iter()
893 .map(|(request_id, request)| request.snapshot(*request_id, now))
894 .collect();
895 let credits = self.shared_channel_credit_snapshot();
896 let open_channels: Vec<_> = self
897 .channel_debug
898 .lock()
899 .iter()
900 .map(|(channel_id, channel)| {
901 channel.snapshot(
902 self.connection_id,
903 *channel_id,
904 credits.get(channel_id).copied().flatten(),
905 )
906 })
907 .collect();
908 let last_inbound_message_at = *self.last_inbound_message_at.lock();
909 let last_outbound_message_at = *self.last_outbound_message_at.lock();
910 let last_progress_at = match (last_inbound_message_at, last_outbound_message_at) {
911 (Some(inbound), Some(outbound)) => Some(inbound.max(outbound)),
912 (Some(inbound), None) => Some(inbound),
913 (None, Some(outbound)) => Some(outbound),
914 (None, None) => None,
915 };
916 let (outbound_queue_depth, outbound_queue_capacity) =
917 sender.sess_core.outbound_queue_stats();
918 VoxDebugSnapshot {
919 connections: vec![ConnectionDebugSnapshot {
920 connection_id: self.connection_id,
921 endpoint: None,
922 surface: None,
923 component: None,
924 state,
925 outstanding_requests: requests.len(),
926 requests,
927 open_channels,
928 outbound_queue_depth: Some(outbound_queue_depth),
929 outbound_queue_capacity: Some(outbound_queue_capacity),
930 local_control_queue_depth: None,
931 local_control_queue_capacity: None,
932 last_inbound_message_at,
933 last_outbound_message_at,
934 last_progress_at,
935 close_reason: *self.close_reason.lock(),
936 driver_task_status,
937 }],
938 }
939 }
940
941 fn shared_channel_credit_snapshot(&self) -> BTreeMap<ChannelId, Option<u32>> {
942 self.channel_credits
943 .lock()
944 .iter()
945 .map(|(channel_id, semaphore)| {
946 (
947 *channel_id,
948 Some(semaphore.available_permits().min(u32::MAX as usize) as u32),
949 )
950 })
951 .collect()
952 }
953
954 fn set_connection_closed(&self, reason: ConnectionCloseReason) {
955 *self.close_reason.lock() = Some(reason);
956 }
957
958 fn connection_debug_state(&self, closed: bool) -> ConnectionDebugState {
959 if closed {
960 ConnectionDebugState::Closed
961 } else {
962 ConnectionDebugState::Open
963 }
964 }
965}
966
967struct CallerDropGuard {
968 control_tx: mpsc::UnboundedSender<DropControlRequest>,
969 request: DropControlRequest,
970}
971
972impl Drop for CallerDropGuard {
973 fn drop(&mut self) {
974 let _ = self.control_tx.send(self.request);
975 }
976}
977
978#[cfg(test)]
979mod tests {
980 use super::{DriverChannelCreditReplenisher, DriverLocalControl};
981 use vox_types::{ChannelCreditReplenisher, ChannelId};
982
983 #[tokio::test]
984 async fn replenisher_batches_at_half_the_initial_window() {
985 let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher");
986 let replenisher = DriverChannelCreditReplenisher::new(
987 vox_types::ConnectionId::ROOT,
988 ChannelId(7),
989 None,
990 std::sync::Weak::new(),
991 16,
992 tx,
993 None,
994 );
995
996 for _ in 0..7 {
997 replenisher.on_item_consumed();
998 }
999 assert!(
1000 vox_types::time::tokio::timeout(std::time::Duration::from_millis(20), rx.recv())
1001 .await
1002 .is_err(),
1003 "should not emit credit before reaching the batch threshold"
1004 );
1005
1006 replenisher.on_item_consumed();
1007 let Some(DriverLocalControl::GrantCredit {
1008 channel_id,
1009 additional,
1010 }) = rx.recv().await
1011 else {
1012 panic!("expected batched credit grant");
1013 };
1014 assert_eq!(channel_id, ChannelId(7));
1015 assert_eq!(additional, 8);
1016 }
1017
1018 #[tokio::test]
1019 async fn replenisher_grants_one_by_one_for_single_credit_windows() {
1020 let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher.single");
1021 let replenisher = DriverChannelCreditReplenisher::new(
1022 vox_types::ConnectionId::ROOT,
1023 ChannelId(9),
1024 None,
1025 std::sync::Weak::new(),
1026 1,
1027 tx,
1028 None,
1029 );
1030
1031 replenisher.on_item_consumed();
1032 let Some(DriverLocalControl::GrantCredit {
1033 channel_id,
1034 additional,
1035 }) = rx.recv().await
1036 else {
1037 panic!("expected immediate credit grant");
1038 };
1039 assert_eq!(channel_id, ChannelId(9));
1040 assert_eq!(additional, 1);
1041 }
1042}
1043
1044pub struct DriverReplySink {
1052 sender: Option<ConnectionSender>,
1053 request_id: RequestId,
1054 method_id: vox_types::MethodId,
1055 retry: vox_types::RetryPolicy,
1056 operation_id: Option<OperationId>,
1057 operations: Option<Arc<dyn OperationStore>>,
1058 live_operations: Option<Arc<SyncMutex<LiveOperationTracker>>>,
1059 binder: DriverChannelBinder,
1060 handler_response_shape: Option<&'static facet_core::Shape>,
1064}
1065
1066async fn replay_sealed_response(
1072 sender: ConnectionSender,
1073 request_id: RequestId,
1074 method_id: vox_types::MethodId,
1075 encoded_response: &[u8],
1076 response_shape: Option<&'static facet_core::Shape>,
1077) -> Result<(), ()> {
1078 let mut response: RequestResponse<'_> =
1079 vox_postcard::from_slice_borrowed(encoded_response).map_err(|_| ())?;
1080 if let Some(shape) = response_shape {
1081 sender.prepare_replay_schemas(request_id, method_id, shape, &mut response);
1082 } else {
1083 response.schemas = Default::default();
1084 }
1085 sender.send_response(request_id, response).await
1086}
1087
1088fn incoming_args_bytes<'a>(call: &'a RequestCall<'a>) -> &'a [u8] {
1089 match &call.args {
1090 Payload::PostcardBytes(bytes) => bytes,
1091 Payload::Value { .. } => {
1092 panic!("incoming request payload should always be decoded as incoming bytes")
1093 }
1094 }
1095}
1096
1097impl ReplySink for DriverReplySink {
1098 async fn send_reply(mut self, response: RequestResponse<'_>) {
1099 let sender = self
1100 .sender
1101 .take()
1102 .expect("unreachable: send_reply takes self by value");
1103
1104 vox_types::dlog!(
1105 "[driver] send_reply: conn={:?} req={:?} method={:?} payload={} operation_id={:?}",
1106 sender.connection_id(),
1107 self.request_id,
1108 self.method_id,
1109 match &response.ret {
1110 Payload::Value { .. } => "Value",
1111 Payload::PostcardBytes(_) => "PostcardBytes",
1112 },
1113 self.operation_id
1114 );
1115 self.binder.shared.mark_outbound_progress();
1116
1117 if let Payload::Value { shape, .. } = &response.ret
1118 && let Ok(extracted) = vox_types::extract_schemas(shape)
1119 {
1120 vox_types::dlog!(
1121 "[schema] driver send_reply: method={:?} root={:?}",
1122 self.method_id,
1123 extracted.root
1124 );
1125 }
1126
1127 if let (Some(operation_id), Some(operations)) = (self.operation_id, self.operations.take())
1128 {
1129 let mut response = response;
1130 sender.prepare_response_for_method(self.request_id, self.method_id, &mut response);
1131
1132 let schemas_for_wire = std::mem::take(&mut response.schemas);
1133 #[cfg(not(target_arch = "wasm32"))]
1134 let encoded_bytes: Vec<u8> =
1135 vox_jit::encode!(&response).expect("JIT encode failed for response store");
1136 #[cfg(target_arch = "wasm32")]
1137 let encoded_bytes: Vec<u8> =
1138 vox_postcard::to_vec(&response).expect("postcard encode failed for response store");
1139 let encoded_for_store: PostcardPayload = encoded_bytes.into();
1140 response.schemas = schemas_for_wire;
1141
1142 vox_types::dlog!(
1144 "[driver] send_reply wire send: conn={:?} req={:?} method={:?} schemas={}",
1145 sender.connection_id(),
1146 self.request_id,
1147 self.method_id,
1148 response.schemas.0.len()
1149 );
1150 if let Err(_e) = sender.send_response(self.request_id, response).await {
1151 sender.mark_failure(self.request_id, FailureDisposition::Cancelled);
1152 }
1153
1154 operations.seal(operation_id, self.method_id, &encoded_for_store);
1157
1158 let waiters = self
1160 .live_operations
1161 .as_ref()
1162 .map(|lo| lo.lock().seal(operation_id))
1163 .unwrap_or_default();
1164 let response_shape = self.handler_response_shape;
1165 for waiter in waiters {
1166 if waiter == self.request_id {
1167 continue;
1168 }
1169 if replay_sealed_response(
1170 sender.clone(),
1171 waiter,
1172 self.method_id,
1173 encoded_for_store.as_bytes(),
1174 response_shape,
1175 )
1176 .await
1177 .is_err()
1178 {
1179 sender.mark_failure(waiter, FailureDisposition::Cancelled);
1180 }
1181 }
1182 } else {
1183 vox_types::dlog!(
1184 "[driver] send_reply direct send: conn={:?} req={:?} method={:?}",
1185 sender.connection_id(),
1186 self.request_id,
1187 self.method_id
1188 );
1189 if let Err(_e) = sender
1190 .send_response_for_method(self.request_id, self.method_id, response)
1191 .await
1192 {
1193 sender.mark_failure(self.request_id, FailureDisposition::Cancelled);
1194 }
1195 }
1196 }
1197
1198 fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
1199 Some(&self.binder)
1200 }
1201
1202 fn request_id(&self) -> Option<RequestId> {
1203 Some(self.request_id)
1204 }
1205
1206 fn connection_id(&self) -> Option<vox_types::ConnectionId> {
1207 self.sender.as_ref().map(|sender| sender.connection_id())
1208 }
1209}
1210
1211impl Drop for DriverReplySink {
1213 fn drop(&mut self) {
1214 if let Some(sender) = self.sender.take() {
1215 let disposition = if self.retry.persist {
1216 FailureDisposition::Indeterminate
1217 } else {
1218 FailureDisposition::Cancelled
1219 };
1220
1221 if let Some(operation_id) = self.operation_id {
1222 if let Some(live_ops) = self.live_operations.take()
1228 && let Some(live) = live_ops.lock().release(operation_id)
1229 {
1230 for waiter in live.waiters {
1231 sender.mark_failure(waiter, disposition);
1232 }
1233 return;
1234 }
1235 }
1236
1237 sender.mark_failure(self.request_id, disposition);
1238 }
1239 }
1240}
1241
1242pub struct DriverChannelSink {
1250 sender: ConnectionSender,
1251 shared: Arc<DriverShared>,
1252 channel_id: ChannelId,
1253 debug_context: Option<ChannelDebugContext>,
1254 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
1255}
1256
1257impl ChannelSink for DriverChannelSink {
1258 fn send_payload<'payload>(
1259 &self,
1260 payload: Payload<'payload>,
1261 ) -> Pin<Box<dyn vox_types::MaybeSendFuture<Output = Result<(), TxError>> + 'payload>> {
1262 let sender = self.sender.clone();
1263 let shared = Arc::clone(&self.shared);
1264 let channel_id = self.channel_id;
1265 Box::pin(async move {
1266 if shared.terminal_channels.lock().contains(&channel_id) {
1267 return Err(TxError::Transport("channel closed".into()));
1268 }
1269
1270 shared.mark_outbound_progress();
1271 sender
1272 .send(ConnectionMessage::Channel(ChannelMessage {
1273 id: channel_id,
1274 body: ChannelBody::Item(ChannelItem { item: payload }),
1275 }))
1276 .await
1277 .map_err(|()| TxError::Transport("connection closed".into()))
1278 })
1279 }
1280
1281 fn channel_id(&self) -> Option<ChannelId> {
1282 Some(self.channel_id)
1283 }
1284
1285 fn connection_id(&self) -> Option<vox_types::ConnectionId> {
1286 Some(self.sender.connection_id())
1287 }
1288
1289 fn debug_context(&self) -> Option<ChannelDebugContext> {
1290 self.debug_context
1291 .and_then(ChannelDebugContext::into_option)
1292 .or_else(|| {
1293 self.shared
1294 .channel_contexts
1295 .lock()
1296 .get(&self.channel_id)
1297 .copied()
1298 })
1299 }
1300
1301 fn observer(&self) -> Option<VoxObserverHandle> {
1302 self.shared.observer.clone()
1303 }
1304
1305 fn note_send_started(&self) {
1306 self.shared.record_send_started(self.channel_id);
1307 }
1308
1309 fn note_send_waiting_for_credit(&self) {
1310 self.shared.record_send_waiting_for_credit(self.channel_id);
1311 }
1312
1313 fn note_send_finished(&self, outcome: ChannelSendOutcome) {
1314 self.shared.record_send_finished(self.channel_id, outcome);
1315 }
1316
1317 fn note_try_send_outcome(&self, outcome: ChannelTrySendOutcome) {
1318 self.shared
1319 .record_try_send_outcome(self.channel_id, outcome);
1320 }
1321
1322 fn try_send_payload_with_outcome<'payload>(
1325 &self,
1326 payload: Payload<'payload>,
1327 ) -> Result<(), ChannelTrySendOutcome> {
1328 if self
1329 .shared
1330 .terminal_channels
1331 .lock()
1332 .contains(&self.channel_id)
1333 {
1334 return Err(ChannelTrySendOutcome::Closed);
1335 }
1336
1337 self.shared.mark_outbound_progress();
1338 self.sender
1339 .try_send(ConnectionMessage::Channel(ChannelMessage {
1340 id: self.channel_id,
1341 body: ChannelBody::Item(ChannelItem { item: payload }),
1342 }))
1343 .map_err(|err| match err {
1344 TrySendError::Closed(()) => ChannelTrySendOutcome::Closed,
1345 TrySendError::Full(()) => ChannelTrySendOutcome::FullRuntimeQueue,
1346 })
1347 }
1348
1349 fn close_channel(
1350 &self,
1351 _metadata: vox_types::Metadata,
1352 ) -> Pin<Box<dyn vox_types::MaybeSendFuture<Output = Result<(), TxError>> + 'static>> {
1353 let sender = self.sender.clone();
1357 let shared = Arc::clone(&self.shared);
1358 let channel_id = self.channel_id;
1359 let debug_context = self.debug_context;
1360 Box::pin(async move {
1361 shared.terminal_channels.lock().insert(channel_id);
1362 shared.observe_channel(channel_id, debug_context, |channel| ChannelEvent::Closed {
1363 channel,
1364 reason: ChannelCloseReason::Local,
1365 });
1366
1367 shared.mark_outbound_progress();
1368 sender
1369 .send(ConnectionMessage::Channel(ChannelMessage {
1370 id: channel_id,
1371 body: ChannelBody::Close(ChannelClose {
1372 metadata: Default::default(),
1373 }),
1374 }))
1375 .await
1376 .map_err(|()| TxError::Transport("connection closed".into()))
1377 })
1378 }
1379
1380 fn close_channel_on_drop(&self) {
1381 self.shared.terminal_channels.lock().insert(self.channel_id);
1382 self.shared
1383 .observe_channel(self.channel_id, self.debug_context, |channel| {
1384 ChannelEvent::Closed {
1385 channel,
1386 reason: ChannelCloseReason::Dropped,
1387 }
1388 });
1389 let _ = self
1390 .local_control_tx
1391 .send(DriverLocalControl::CloseChannel {
1392 channel_id: self.channel_id,
1393 });
1394 }
1395}
1396
1397pub trait ErasedHandler: MaybeSend + MaybeSync + 'static {
1402 fn retry_policy(&self, method_id: vox_types::MethodId) -> vox_types::RetryPolicy {
1403 let _ = method_id;
1404 vox_types::RetryPolicy::VOLATILE
1405 }
1406
1407 fn args_have_channels(&self, method_id: vox_types::MethodId) -> bool {
1408 let _ = method_id;
1409 false
1410 }
1411
1412 fn response_wire_shape(&self, method_id: vox_types::MethodId) -> Option<&'static facet::Shape> {
1413 let _ = method_id;
1414 None
1415 }
1416
1417 fn handle_erased(
1418 &self,
1419 call: SelfRef<RequestCall<'static>>,
1420 reply: DriverReplySink,
1421 schemas: std::sync::Arc<vox_types::SchemaRecvTracker>,
1422 ) -> BoxFut<'_, ()>;
1423}
1424
1425impl<H: Handler<DriverReplySink>> ErasedHandler for H {
1426 fn retry_policy(&self, method_id: vox_types::MethodId) -> vox_types::RetryPolicy {
1427 Handler::retry_policy(self, method_id)
1428 }
1429
1430 fn args_have_channels(&self, method_id: vox_types::MethodId) -> bool {
1431 Handler::args_have_channels(self, method_id)
1432 }
1433
1434 fn response_wire_shape(&self, method_id: vox_types::MethodId) -> Option<&'static facet::Shape> {
1435 Handler::response_wire_shape(self, method_id)
1436 }
1437
1438 fn handle_erased(
1439 &self,
1440 call: SelfRef<RequestCall<'static>>,
1441 reply: DriverReplySink,
1442 schemas: std::sync::Arc<vox_types::SchemaRecvTracker>,
1443 ) -> BoxFut<'_, ()> {
1444 Box::pin(Handler::handle(self, call, reply, schemas))
1445 }
1446}
1447
1448impl Handler<DriverReplySink> for Box<dyn ErasedHandler> {
1449 fn retry_policy(&self, method_id: vox_types::MethodId) -> vox_types::RetryPolicy {
1450 (**self).retry_policy(method_id)
1451 }
1452
1453 fn args_have_channels(&self, method_id: vox_types::MethodId) -> bool {
1454 (**self).args_have_channels(method_id)
1455 }
1456
1457 fn response_wire_shape(&self, method_id: vox_types::MethodId) -> Option<&'static facet::Shape> {
1458 (**self).response_wire_shape(method_id)
1459 }
1460
1461 async fn handle(
1462 &self,
1463 call: SelfRef<RequestCall<'static>>,
1464 reply: DriverReplySink,
1465 schemas: std::sync::Arc<vox_types::SchemaRecvTracker>,
1466 ) {
1467 (**self).handle_erased(call, reply, schemas).await
1468 }
1469}
1470
1471#[must_use = "Dropping this caller may close the connection if it is the last caller."]
1477#[derive(Clone)]
1478pub struct Caller {
1479 inner: Arc<DriverCaller>,
1480 service: Option<&'static vox_types::ServiceDescriptor>,
1481 middlewares: Vec<Arc<dyn vox_types::ClientMiddleware>>,
1482}
1483
1484impl Caller {
1485 pub fn new(driver: DriverCaller) -> Self {
1487 Self {
1488 inner: Arc::new(driver),
1489 service: None,
1490 middlewares: vec![],
1491 }
1492 }
1493
1494 #[cfg(test)]
1496 pub(crate) fn driver(&self) -> &DriverCaller {
1497 &self.inner
1498 }
1499
1500 pub fn with_middleware(
1502 mut self,
1503 service: &'static vox_types::ServiceDescriptor,
1504 middleware: impl vox_types::ClientMiddleware,
1505 ) -> Self {
1506 if let Some(existing_service) = self.service {
1507 assert_eq!(
1508 existing_service.service_name, service.service_name,
1509 "Caller middleware service mismatch"
1510 );
1511 } else {
1512 self.service = Some(service);
1513 }
1514 self.middlewares.push(Arc::new(middleware));
1515 self
1516 }
1517
1518 pub async fn call(&self, mut call: RequestCall<'_>) -> CallResult {
1521 use vox_types::{
1522 ClientCallOutcome, ClientContext, ClientRequest, Extensions, OwnedMetadata,
1523 };
1524
1525 let Some(service) = self.service else {
1526 return self.inner.call_inner(call, None).await;
1527 };
1528
1529 let extensions = Extensions::new();
1530 let method = service.by_id(call.method_id);
1531 let context = ClientContext::new(method, call.method_id, &extensions);
1532 let mut owned_metadata = OwnedMetadata::default();
1533
1534 if !self.middlewares.is_empty() {
1535 for middleware in &self.middlewares {
1536 let mut request = ClientRequest::new(&mut call, &mut owned_metadata);
1537 middleware.pre(&context, &mut request).await;
1538 }
1539 }
1540
1541 let request_debug = method.map(|method| (method.service_name, method.method_name));
1542 let result = self.inner.call_inner(call, request_debug).await;
1543 if !self.middlewares.is_empty() {
1544 let outcome = match &result {
1545 Ok(_) => ClientCallOutcome::Response,
1546 Err(error) => ClientCallOutcome::Error(error),
1547 };
1548 for middleware in self.middlewares.iter().rev() {
1549 middleware.post(&context, outcome).await;
1550 }
1551 }
1552 result
1553 }
1554
1555 pub async fn closed(&self) {
1557 if self.inner.closed_rx.borrow().is_some() {
1558 return;
1559 }
1560 let mut rx = self.inner.closed_rx.clone();
1561 while rx.changed().await.is_ok() {
1562 if rx.borrow().is_some() {
1563 return;
1564 }
1565 }
1566 }
1567
1568 pub fn is_connected(&self) -> bool {
1570 self.inner.closed_rx.borrow().is_none()
1571 }
1572
1573 pub fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
1575 Some(self.inner.as_ref())
1576 }
1577
1578 pub fn debug_snapshot(&self) -> VoxDebugSnapshot {
1580 self.inner.debug_snapshot()
1581 }
1582
1583 pub fn dump_debug_snapshot(&self) -> VoxDebugSnapshot {
1584 let snapshot = self.debug_snapshot();
1585 tracing::info!(?snapshot, "vox debug snapshot");
1586 snapshot
1587 }
1588}
1589
1590pub trait FromVoxSession {
1596 const SERVICE_NAME: &'static str;
1599
1600 fn from_vox_session(
1601 caller: Caller,
1602 session_handle: Option<crate::session::SessionHandle>,
1603 ) -> Self;
1604}
1605
1606#[must_use = "Dropping NoopClient may close the connection if it is the last caller."]
1611#[derive(Clone)]
1612pub struct NoopClient {
1613 pub caller: Caller,
1615 pub session: Option<crate::session::SessionHandle>,
1617}
1618
1619impl FromVoxSession for NoopClient {
1620 const SERVICE_NAME: &'static str = "Noop";
1621
1622 fn from_vox_session(caller: Caller, session: Option<crate::session::SessionHandle>) -> Self {
1623 Self { caller, session }
1624 }
1625}
1626
1627#[derive(Clone)]
1628struct DriverChannelBinder {
1629 sender: ConnectionSender,
1630 shared: Arc<DriverShared>,
1631 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
1632 drop_guard: Option<Arc<CallerDropGuard>>,
1633}
1634
1635fn register_rx_channel_impl(
1636 shared: &Arc<DriverShared>,
1637 channel_id: ChannelId,
1638 initial_channel_credit: u32,
1639 debug_context: Option<ChannelDebugContext>,
1640 liveness: Option<ChannelLivenessHandle>,
1641 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
1642) -> vox_types::BoundChannelReceiver {
1643 observe_channel_opened(
1644 shared,
1645 channel_id,
1646 ChannelDirection::Rx,
1647 initial_channel_credit,
1648 debug_context,
1649 );
1650 let (rx, terminal) = shared.register_inbound_channel_receiver(channel_id);
1651
1652 if terminal {
1653 shared.channel_credits.lock().remove(&channel_id);
1654 return vox_types::BoundChannelReceiver {
1655 receiver: rx,
1656 liveness,
1657 replenisher: None,
1658 };
1659 }
1660
1661 vox_types::BoundChannelReceiver {
1662 receiver: rx,
1663 liveness,
1664 replenisher: Some(Arc::new(DriverChannelCreditReplenisher::new(
1665 shared.connection_id,
1666 channel_id,
1667 debug_context,
1668 Arc::downgrade(shared),
1669 initial_channel_credit,
1670 local_control_tx,
1671 shared.observer.clone(),
1672 )) as ChannelCreditReplenisherHandle),
1673 }
1674}
1675
1676fn observe_channel_opened(
1678 shared: &DriverShared,
1679 channel_id: ChannelId,
1680 direction: ChannelDirection,
1681 initial_credit: u32,
1682 debug_context: Option<ChannelDebugContext>,
1683) {
1684 shared.remember_channel_context(channel_id, debug_context);
1685 shared.observe_channel(channel_id, debug_context, |channel| ChannelEvent::Opened {
1686 channel,
1687 direction,
1688 initial_credit,
1689 });
1690}
1691
1692fn make_tx_channel_sink(
1693 sender: &ConnectionSender,
1694 shared: &Arc<DriverShared>,
1695 local_control_tx: &mpsc::UnboundedSender<DriverLocalControl>,
1696 channel_id: ChannelId,
1697 debug_context: Option<ChannelDebugContext>,
1698) -> Arc<CreditSink<DriverChannelSink>> {
1699 observe_channel_opened(
1700 shared,
1701 channel_id,
1702 ChannelDirection::Tx,
1703 shared.peer_initial_channel_credit,
1704 debug_context,
1705 );
1706 let inner = DriverChannelSink {
1707 sender: sender.clone(),
1708 shared: Arc::clone(shared),
1709 channel_id,
1710 debug_context: debug_context.and_then(ChannelDebugContext::into_option),
1711 local_control_tx: local_control_tx.clone(),
1712 };
1713 let sink = Arc::new(CreditSink::new(inner, shared.peer_initial_channel_credit));
1714 shared
1715 .channel_credits
1716 .lock()
1717 .insert(channel_id, Arc::clone(sink.credit()));
1718 sink
1719}
1720
1721trait DriverChannelEndpoint {
1722 fn endpoint_sender(&self) -> &ConnectionSender;
1723 fn endpoint_shared(&self) -> &Arc<DriverShared>;
1724 fn endpoint_local_control_tx(&self) -> &mpsc::UnboundedSender<DriverLocalControl>;
1725 fn endpoint_liveness(&self) -> Option<ChannelLivenessHandle>;
1726
1727 fn create_tx_credit_sink(
1728 &self,
1729 debug_context: Option<ChannelDebugContext>,
1730 ) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
1731 let shared = self.endpoint_shared();
1732 let channel_id = shared.channel_ids.lock().alloc();
1733 let sink = make_tx_channel_sink(
1734 self.endpoint_sender(),
1735 shared,
1736 self.endpoint_local_control_tx(),
1737 channel_id,
1738 debug_context,
1739 );
1740 (channel_id, sink)
1741 }
1742
1743 fn create_tx_dyn(
1744 &self,
1745 debug_context: Option<ChannelDebugContext>,
1746 ) -> (ChannelId, Arc<dyn ChannelSink>) {
1747 let (id, sink) = self.create_tx_credit_sink(debug_context);
1748 (id, sink as Arc<dyn ChannelSink>)
1749 }
1750
1751 fn create_rx_bound(
1752 &self,
1753 debug_context: Option<ChannelDebugContext>,
1754 ) -> (ChannelId, vox_types::BoundChannelReceiver) {
1755 let channel_id = self.endpoint_shared().channel_ids.lock().alloc();
1756 let rx = self.register_rx_bound(channel_id, debug_context);
1757 (channel_id, rx)
1758 }
1759
1760 fn bind_tx_dyn(
1761 &self,
1762 channel_id: ChannelId,
1763 debug_context: Option<ChannelDebugContext>,
1764 ) -> Arc<dyn ChannelSink> {
1765 make_tx_channel_sink(
1766 self.endpoint_sender(),
1767 self.endpoint_shared(),
1768 self.endpoint_local_control_tx(),
1769 channel_id,
1770 debug_context,
1771 )
1772 }
1773
1774 fn register_rx_bound(
1775 &self,
1776 channel_id: ChannelId,
1777 debug_context: Option<ChannelDebugContext>,
1778 ) -> vox_types::BoundChannelReceiver {
1779 let shared = self.endpoint_shared();
1780 register_rx_channel_impl(
1781 shared,
1782 channel_id,
1783 shared.local_initial_channel_credit,
1784 debug_context,
1785 self.endpoint_liveness(),
1786 self.endpoint_local_control_tx().clone(),
1787 )
1788 }
1789}
1790
1791impl DriverChannelEndpoint for DriverChannelBinder {
1792 fn endpoint_sender(&self) -> &ConnectionSender {
1793 &self.sender
1794 }
1795
1796 fn endpoint_shared(&self) -> &Arc<DriverShared> {
1797 &self.shared
1798 }
1799
1800 fn endpoint_local_control_tx(&self) -> &mpsc::UnboundedSender<DriverLocalControl> {
1801 &self.local_control_tx
1802 }
1803
1804 fn endpoint_liveness(&self) -> Option<ChannelLivenessHandle> {
1805 self.drop_guard
1806 .as_ref()
1807 .map(|guard| guard.clone() as ChannelLivenessHandle)
1808 }
1809}
1810
1811impl ChannelBinder for DriverChannelBinder {
1812 fn create_tx(&self) -> (ChannelId, Arc<dyn ChannelSink>) {
1813 self.create_tx_dyn(None)
1814 }
1815
1816 fn create_tx_with_context(
1817 &self,
1818 debug_context: Option<ChannelDebugContext>,
1819 ) -> (ChannelId, Arc<dyn ChannelSink>) {
1820 self.create_tx_dyn(debug_context)
1821 }
1822
1823 fn create_rx(&self) -> (ChannelId, vox_types::BoundChannelReceiver) {
1824 self.create_rx_bound(None)
1825 }
1826
1827 fn create_rx_with_context(
1828 &self,
1829 debug_context: Option<ChannelDebugContext>,
1830 ) -> (ChannelId, vox_types::BoundChannelReceiver) {
1831 self.create_rx_bound(debug_context)
1832 }
1833
1834 fn bind_tx(&self, channel_id: ChannelId) -> Arc<dyn ChannelSink> {
1835 self.bind_tx_dyn(channel_id, None)
1836 }
1837
1838 fn bind_tx_with_context(
1839 &self,
1840 channel_id: ChannelId,
1841 debug_context: Option<ChannelDebugContext>,
1842 ) -> Arc<dyn ChannelSink> {
1843 self.bind_tx_dyn(channel_id, debug_context)
1844 }
1845
1846 fn register_rx(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
1847 self.register_rx_bound(channel_id, None)
1848 }
1849
1850 fn register_rx_with_context(
1851 &self,
1852 channel_id: ChannelId,
1853 debug_context: Option<ChannelDebugContext>,
1854 ) -> vox_types::BoundChannelReceiver {
1855 self.register_rx_bound(channel_id, debug_context)
1856 }
1857
1858 fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
1859 self.endpoint_liveness()
1860 }
1861}
1862
1863#[derive(Clone)]
1867pub struct DriverCaller {
1868 sender: ConnectionSender,
1869 shared: Arc<DriverShared>,
1870 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
1871 closed_rx: watch::Receiver<Option<ConnectionCloseReason>>,
1872 resumed_rx: watch::Receiver<u64>,
1873 resume_processed_rx: watch::Receiver<u64>,
1874 peer_supports_retry: bool,
1875 _drop_guard: Option<Arc<CallerDropGuard>>,
1876}
1877
1878impl DriverCaller {
1879 pub fn create_tx_channel(&self) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
1884 self.create_tx_credit_sink(None)
1885 }
1886
1887 #[cfg(test)]
1892 pub(crate) fn connection_sender(&self) -> &ConnectionSender {
1893 &self.sender
1894 }
1895
1896 pub fn register_rx_channel(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
1901 self.register_rx_bound(channel_id, None)
1902 }
1903}
1904
1905impl DriverChannelEndpoint for DriverCaller {
1906 fn endpoint_sender(&self) -> &ConnectionSender {
1907 &self.sender
1908 }
1909
1910 fn endpoint_shared(&self) -> &Arc<DriverShared> {
1911 &self.shared
1912 }
1913
1914 fn endpoint_local_control_tx(&self) -> &mpsc::UnboundedSender<DriverLocalControl> {
1915 &self.local_control_tx
1916 }
1917
1918 fn endpoint_liveness(&self) -> Option<ChannelLivenessHandle> {
1919 self._drop_guard
1920 .as_ref()
1921 .map(|guard| guard.clone() as ChannelLivenessHandle)
1922 }
1923}
1924
1925impl ChannelBinder for DriverCaller {
1926 fn create_tx(&self) -> (ChannelId, Arc<dyn ChannelSink>) {
1927 self.create_tx_dyn(None)
1928 }
1929
1930 fn create_tx_with_context(
1931 &self,
1932 debug_context: Option<ChannelDebugContext>,
1933 ) -> (ChannelId, Arc<dyn ChannelSink>) {
1934 self.create_tx_dyn(debug_context)
1935 }
1936
1937 fn create_rx(&self) -> (ChannelId, vox_types::BoundChannelReceiver) {
1938 self.create_rx_bound(None)
1939 }
1940
1941 fn create_rx_with_context(
1942 &self,
1943 debug_context: Option<ChannelDebugContext>,
1944 ) -> (ChannelId, vox_types::BoundChannelReceiver) {
1945 self.create_rx_bound(debug_context)
1946 }
1947
1948 fn bind_tx(&self, channel_id: ChannelId) -> Arc<dyn ChannelSink> {
1949 self.bind_tx_dyn(channel_id, None)
1950 }
1951
1952 fn bind_tx_with_context(
1953 &self,
1954 channel_id: ChannelId,
1955 debug_context: Option<ChannelDebugContext>,
1956 ) -> Arc<dyn ChannelSink> {
1957 self.bind_tx_dyn(channel_id, debug_context)
1958 }
1959
1960 fn register_rx(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
1961 self.register_rx_bound(channel_id, None)
1962 }
1963
1964 fn register_rx_with_context(
1965 &self,
1966 channel_id: ChannelId,
1967 debug_context: Option<ChannelDebugContext>,
1968 ) -> vox_types::BoundChannelReceiver {
1969 self.register_rx_bound(channel_id, debug_context)
1970 }
1971
1972 fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
1973 self.endpoint_liveness()
1974 }
1975}
1976
1977impl DriverCaller {
1978 pub fn debug_snapshot(&self) -> VoxDebugSnapshot {
1980 self.shared.debug_snapshot(
1981 &self.sender,
1982 self.shared
1983 .connection_debug_state(self.closed_rx.borrow().is_some()),
1984 if self.closed_rx.borrow().is_some() {
1985 DriverTaskStatus::Dead
1986 } else {
1987 DriverTaskStatus::Alive
1988 },
1989 )
1990 }
1991
1992 pub fn dump_debug_snapshot(&self) -> VoxDebugSnapshot {
1993 let snapshot = self.debug_snapshot();
1994 tracing::info!(?snapshot, "vox debug snapshot");
1995 snapshot
1996 }
1997
1998 async fn call_inner(
2000 &self,
2001 mut call: RequestCall<'_>,
2002 request_debug: Option<(&'static str, &'static str)>,
2003 ) -> CallResult {
2004 if self.peer_supports_retry {
2005 let operation_id = OperationId(
2006 self.shared
2007 .next_operation_id
2008 .fetch_add(1, Ordering::Relaxed),
2009 );
2010 ensure_operation_id(&mut call.metadata, operation_id);
2011 }
2012
2013 let req_id = self.shared.request_ids.lock().alloc();
2015 let request_started_at = Instant::now();
2016 if let Some(observer) = &self.shared.observer {
2017 observer.driver_event(DriverEvent::RequestStarted {
2018 connection_id: self.sender.connection_id(),
2019 request_id: req_id,
2020 method_id: call.method_id,
2021 });
2022 }
2023 let finish_request = |outcome: RpcOutcome| {
2024 self.shared.finish_request(
2025 req_id,
2026 if outcome == RpcOutcome::Ok {
2027 RequestDebugState::Finished
2028 } else {
2029 RequestDebugState::Failed
2030 },
2031 );
2032 if let Some(observer) = &self.shared.observer {
2033 observer.driver_event(DriverEvent::RequestFinished {
2034 connection_id: self.sender.connection_id(),
2035 request_id: req_id,
2036 outcome,
2037 elapsed: request_started_at.elapsed(),
2038 });
2039 }
2040 };
2041
2042 let (tx, rx) = moire::sync::oneshot::channel("driver.response");
2045 self.shared.pending_responses.lock().insert(req_id, tx);
2046 self.shared.start_request(
2047 req_id,
2048 call.method_id,
2049 request_debug.map(|(service, _)| service),
2050 request_debug.map(|(_, method)| method),
2051 RequestDebugState::WaitingForResponse,
2052 );
2053
2054 self.shared.mark_outbound_progress();
2062 if self
2063 .sender
2064 .send_with_binder(
2065 ConnectionMessage::Request(RequestMessage {
2066 id: req_id,
2067 body: RequestBody::Call(RequestCall {
2068 method_id: call.method_id,
2069 args: call.args.reborrow(),
2070 metadata: call.metadata.clone(),
2071 schemas: Default::default(),
2072 }),
2073 }),
2074 Some(self),
2075 )
2076 .await
2077 .is_err()
2078 {
2079 self.shared.pending_responses.lock().remove(&req_id);
2080 finish_request(RpcOutcome::SendFailed);
2081 return Err(VoxError::SendFailed);
2082 }
2083
2084 let mut resumed_rx = self.resumed_rx.clone();
2085 let mut seen_resume_generation = *resumed_rx.borrow();
2086 let mut resume_processed_rx = self.resume_processed_rx.clone();
2087 let mut closed_rx = self.closed_rx.clone();
2088 let mut response = std::pin::pin!(rx.named("awaiting_response"));
2089
2090 let pending: PendingResponse = loop {
2091 tokio::select! {
2092 result = &mut response => {
2093 match result {
2094 Ok(pending) => break pending,
2095 Err(_) => {
2096 finish_request(RpcOutcome::Closed);
2097 return Err(VoxError::ConnectionClosed);
2098 }
2099 }
2100 }
2101 changed = resumed_rx.changed(), if self.peer_supports_retry => {
2102 vox_types::dlog!("[CALLER] resumed_rx fired");
2103 if changed.is_err() {
2104 self.shared.pending_responses.lock().remove(&req_id);
2105 finish_request(RpcOutcome::Closed);
2106 return Err(VoxError::SessionShutdown);
2107 }
2108 let generation = *resumed_rx.borrow();
2109 if generation == seen_resume_generation {
2110 continue;
2111 }
2112 seen_resume_generation = generation;
2113 while *resume_processed_rx.borrow() < generation {
2114 if resume_processed_rx.changed().await.is_err() {
2115 self.shared.pending_responses.lock().remove(&req_id);
2116 finish_request(RpcOutcome::Closed);
2117 return Err(VoxError::SessionShutdown);
2118 }
2119 }
2120 match metadata_channel_retry_mode(&call.metadata) {
2121 ChannelRetryMode::NonIdem => {
2122 self.shared.pending_responses.lock().remove(&req_id);
2123 finish_request(RpcOutcome::Indeterminate);
2124 return Err(VoxError::Indeterminate);
2125 }
2126 ChannelRetryMode::Idem | ChannelRetryMode::None => {}
2127 }
2128 self.shared.mark_outbound_progress();
2132 let _ = self.sender.send_with_binder(
2133 ConnectionMessage::Request(RequestMessage {
2134 id: req_id,
2135 body: RequestBody::Call(RequestCall {
2136 method_id: call.method_id,
2137 args: call.args.reborrow(),
2138 metadata: call.metadata.clone(),
2139 schemas: Default::default(),
2140 }),
2141 }),
2142 Some(self),
2143 ).await;
2144 }
2145 changed = closed_rx.changed() => {
2146 vox_types::dlog!("[CALLER] closed_rx fired, value={:?}", *closed_rx.borrow());
2147 if changed.is_err() || closed_rx.borrow().is_some() {
2148 self.shared.pending_responses.lock().remove(&req_id);
2149 finish_request(RpcOutcome::Closed);
2150 return Err(VoxError::ConnectionClosed);
2151 }
2152 }
2153 }
2154 };
2155
2156 let PendingResponse {
2158 msg: response_msg,
2159 schemas: response_schemas,
2160 fds: response_fds,
2161 } = pending;
2162 let response = response_msg.map(|m| match m.body {
2163 RequestBody::Response(r) => r,
2164 _ => unreachable!("pending_responses only gets Response variants"),
2165 });
2166
2167 finish_request(RpcOutcome::Ok);
2168 Ok(vox_types::WithTracker {
2169 value: response,
2170 tracker: response_schemas,
2171 fds: response_fds,
2172 })
2173 }
2174}
2175
2176pub struct Driver<H: Handler<DriverReplySink>> {
2183 sender: ConnectionSender,
2184 rx: mpsc::Receiver<crate::session::RecvMessage>,
2185 failures_rx: mpsc::UnboundedReceiver<(RequestId, FailureDisposition)>,
2186 closed_rx: watch::Receiver<Option<ConnectionCloseReason>>,
2187 resumed_rx: watch::Receiver<u64>,
2188 resume_processed_tx: watch::Sender<u64>,
2189 peer_supports_retry: bool,
2190 local_control_rx: mpsc::UnboundedReceiver<DriverLocalControl>,
2191 handler: Arc<H>,
2192 shared: Arc<DriverShared>,
2193 in_flight_handlers: BTreeMap<RequestId, InFlightHandler>,
2197 handler_futs: FuturesUnordered<HandlerFut>,
2203 live_operations: Arc<SyncMutex<LiveOperationTracker>>,
2206 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
2207 drop_control_seed: Option<mpsc::UnboundedSender<DropControlRequest>>,
2208 drop_control_request: DropControlRequest,
2209 drop_guard: SyncMutex<Option<Weak<CallerDropGuard>>>,
2210}
2211
2212enum DriverLocalControl {
2213 CloseChannel {
2214 channel_id: ChannelId,
2215 },
2216 ResetChannel {
2217 channel_id: ChannelId,
2218 },
2219 GrantCredit {
2220 channel_id: ChannelId,
2221 additional: u32,
2222 },
2223}
2224
2225struct DriverChannelCreditReplenisher {
2226 connection_id: ConnectionId,
2227 channel_id: ChannelId,
2228 debug_context: Option<ChannelDebugContext>,
2229 shared: Weak<DriverShared>,
2230 threshold: u32,
2231 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
2232 observer: Option<VoxObserverHandle>,
2233 pending: std::sync::Mutex<u32>,
2234}
2235
2236impl DriverChannelCreditReplenisher {
2237 fn new(
2238 connection_id: ConnectionId,
2239 channel_id: ChannelId,
2240 debug_context: Option<ChannelDebugContext>,
2241 shared: Weak<DriverShared>,
2242 initial_credit: u32,
2243 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
2244 observer: Option<VoxObserverHandle>,
2245 ) -> Self {
2246 Self {
2247 connection_id,
2248 channel_id,
2249 debug_context,
2250 shared,
2251 threshold: (initial_credit / 2).max(1),
2252 local_control_tx,
2253 observer,
2254 pending: std::sync::Mutex::new(0),
2255 }
2256 }
2257}
2258
2259impl ChannelCreditReplenisher for DriverChannelCreditReplenisher {
2260 fn on_item_consumed(&self) {
2261 let mut pending = self.pending.lock().expect("pending credit mutex poisoned");
2262 *pending += 1;
2263 if let Some(shared) = self.shared.upgrade() {
2264 shared.record_item_consumed(self.channel_id);
2265 shared.record_pending_local_grant(self.channel_id, *pending);
2266 }
2267 if *pending < self.threshold {
2268 return;
2269 }
2270
2271 let additional = *pending;
2272 *pending = 0;
2273 if let Some(shared) = self.shared.upgrade() {
2274 shared.record_pending_local_grant(self.channel_id, additional);
2275 }
2276 let _ = self.local_control_tx.send(DriverLocalControl::GrantCredit {
2277 channel_id: self.channel_id,
2278 additional,
2279 });
2280 }
2281
2282 fn on_receiver_dropped(&self) {
2283 if let Some(shared) = self.shared.upgrade() {
2284 shared.record_receiver_dropped(self.channel_id);
2285 }
2286 let _ = self
2287 .local_control_tx
2288 .send(DriverLocalControl::ResetChannel {
2289 channel_id: self.channel_id,
2290 });
2291 }
2292
2293 fn channel_id(&self) -> Option<ChannelId> {
2294 Some(self.channel_id)
2295 }
2296
2297 fn connection_id(&self) -> Option<ConnectionId> {
2298 Some(self.connection_id)
2299 }
2300
2301 fn debug_context(&self) -> Option<ChannelDebugContext> {
2302 self.debug_context
2303 }
2304
2305 fn observer(&self) -> Option<VoxObserverHandle> {
2306 self.observer.clone()
2307 }
2308}
2309
2310impl<H: Handler<DriverReplySink>> Driver<H> {
2311 fn close_all_channel_runtime_state(&self, teardown: ChannelRuntimeTeardown) {
2313 let mut credits = self.shared.channel_credits.lock();
2314 for semaphore in credits.values() {
2315 semaphore.close();
2316 }
2317 let mut stale = self.shared.stale_close_channels.lock();
2320 stale.extend(credits.keys().copied());
2321 credits.clear();
2322 drop(credits);
2323
2324 let channel_senders = {
2325 let mut senders = self.shared.channel_senders.lock();
2326 std::mem::take(&mut *senders)
2327 };
2328 if let ChannelRuntimeTeardown::ConnectionClosed(reason) = teardown {
2329 for (channel_id, sender) in channel_senders {
2330 let _ = sender.force_send(IncomingChannelMessage::ConnectionClosed(reason));
2331 self.shared
2332 .observe_channel(channel_id, None, |channel| ChannelEvent::Closed {
2333 channel,
2334 reason: ChannelCloseReason::ConnectionClosed,
2335 });
2336 }
2337 }
2338 self.shared.channel_receivers.lock().clear();
2339 self.shared.terminal_channels.lock().clear();
2340 }
2341
2342 fn close_outbound_channel(&self, channel_id: ChannelId) {
2343 self.shared.terminal_channels.lock().insert(channel_id);
2344 if let Some(semaphore) = self.shared.channel_credits.lock().remove(&channel_id) {
2345 semaphore.close();
2346 }
2347 }
2348
2349 fn abort_channel_handlers(&mut self) {
2350 for in_flight in self.in_flight_handlers.values() {
2351 if self.handler.args_have_channels(in_flight.method_id) {
2352 if let Some(operation_id) = in_flight.operation_id {
2353 self.shared.operations.remove(operation_id);
2354 self.live_operations.lock().release(operation_id);
2355 }
2356 in_flight.abort.abort();
2357 }
2358 }
2359 }
2360
2361 pub fn new(handle: ConnectionHandle, handler: H) -> Self {
2362 Self::with_operation_store(handle, handler, Arc::new(InMemoryOperationStore::default()))
2363 }
2364
2365 pub fn with_operation_store(
2366 handle: ConnectionHandle,
2367 handler: H,
2368 operation_store: Arc<dyn OperationStore>,
2369 ) -> Self {
2370 let conn_id = handle.connection_id();
2371 let ConnectionHandle {
2372 sender,
2373 rx,
2374 failures_rx,
2375 control_tx,
2376 closed_rx,
2377 resumed_rx,
2378 local_settings,
2379 peer_settings,
2380 parity,
2381 peer_supports_retry,
2382 observer,
2383 } = handle;
2384 let drop_control_request = DropControlRequest::Close(conn_id);
2385 let (local_control_tx, local_control_rx) = mpsc::unbounded_channel("driver.local_control");
2386 let (resume_processed_tx, _resume_processed_rx) = watch::channel(0_u64);
2387 Self {
2388 sender,
2389 rx,
2390 failures_rx,
2391 closed_rx,
2392 resumed_rx,
2393 resume_processed_tx,
2394 peer_supports_retry,
2395 local_control_rx,
2396 handler: Arc::new(handler),
2397 shared: Arc::new(DriverShared {
2398 connection_id: conn_id,
2399 pending_responses: SyncMutex::new("driver.pending_responses", BTreeMap::new()),
2400 request_ids: SyncMutex::new("driver.request_ids", IdAllocator::new(parity)),
2401 next_operation_id: AtomicU64::new(1),
2402 operations: operation_store,
2403 channel_ids: SyncMutex::new("driver.channel_ids", IdAllocator::new(parity)),
2404 channel_senders: SyncMutex::new("driver.channel_senders", BTreeMap::new()),
2405 channel_receivers: SyncMutex::new("driver.channel_receivers", BTreeMap::new()),
2406 channel_credits: SyncMutex::new("driver.channel_credits", BTreeMap::new()),
2407 channel_contexts: SyncMutex::new("driver.channel_contexts", BTreeMap::new()),
2408 request_debug: SyncMutex::new("driver.request_debug", BTreeMap::new()),
2409 channel_debug: SyncMutex::new("driver.channel_debug", BTreeMap::new()),
2410 last_inbound_message_at: SyncMutex::new("driver.last_inbound_message_at", None),
2411 last_outbound_message_at: SyncMutex::new("driver.last_outbound_message_at", None),
2412 close_reason: SyncMutex::new("driver.close_reason", None),
2413 terminal_channels: SyncMutex::new("driver.terminal_channels", HashSet::new()),
2414 stale_close_channels: SyncMutex::new(
2415 "driver.stale_close_channels",
2416 std::collections::HashSet::new(),
2417 ),
2418 local_initial_channel_credit: local_settings.initial_channel_credit,
2419 peer_initial_channel_credit: peer_settings.initial_channel_credit,
2420 observer,
2421 }),
2422 in_flight_handlers: BTreeMap::new(),
2423 handler_futs: FuturesUnordered::new(),
2424 live_operations: Arc::new(SyncMutex::new(
2425 "driver.live_operations",
2426 LiveOperationTracker::new(),
2427 )),
2428 local_control_tx,
2429 drop_control_seed: control_tx,
2430 drop_control_request,
2431 drop_guard: SyncMutex::new("driver.drop_guard", None),
2432 }
2433 }
2434
2435 fn existing_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
2441 self.drop_guard.lock().as_ref().and_then(Weak::upgrade)
2442 }
2443
2444 fn connection_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
2445 if let Some(existing) = self.existing_drop_guard() {
2446 Some(existing)
2447 } else if let Some(seed) = &self.drop_control_seed {
2448 let mut guard = self.drop_guard.lock();
2449 if let Some(existing) = guard.as_ref().and_then(Weak::upgrade) {
2450 Some(existing)
2451 } else {
2452 let arc = Arc::new(CallerDropGuard {
2453 control_tx: seed.clone(),
2454 request: self.drop_control_request,
2455 });
2456 *guard = Some(Arc::downgrade(&arc));
2457 Some(arc)
2458 }
2459 } else {
2460 None
2461 }
2462 }
2463
2464 pub fn caller(&self) -> DriverCaller {
2465 let drop_guard = self.connection_drop_guard();
2466 DriverCaller {
2467 sender: self.sender.clone(),
2468 shared: Arc::clone(&self.shared),
2469 local_control_tx: self.local_control_tx.clone(),
2470 closed_rx: self.closed_rx.clone(),
2471 resumed_rx: self.resumed_rx.clone(),
2472 resume_processed_rx: self.resume_processed_tx.subscribe(),
2473 peer_supports_retry: self.peer_supports_retry,
2474 _drop_guard: drop_guard,
2475 }
2476 }
2477
2478 pub fn debug_snapshot(&self) -> VoxDebugSnapshot {
2480 self.shared.debug_snapshot(
2481 &self.sender,
2482 self.shared
2483 .connection_debug_state(self.closed_rx.borrow().is_some()),
2484 DriverTaskStatus::Alive,
2485 )
2486 }
2487
2488 pub fn dump_debug_snapshot(&self) -> VoxDebugSnapshot {
2489 let snapshot = self.debug_snapshot();
2490 tracing::info!(?snapshot, "vox debug snapshot");
2491 snapshot
2492 }
2493
2494 fn internal_binder(&self) -> DriverChannelBinder {
2495 DriverChannelBinder {
2496 sender: self.sender.clone(),
2497 shared: Arc::clone(&self.shared),
2498 local_control_tx: self.local_control_tx.clone(),
2499 drop_guard: self.existing_drop_guard(),
2500 }
2501 }
2502
2503 pub async fn run(&mut self) {
2508 let mut resumed_rx = self.resumed_rx.clone();
2509 let mut seen_resume_generation = *resumed_rx.borrow();
2510 loop {
2511 tracing::trace!("driver select loop top");
2512 tokio::select! {
2513 biased;
2514 changed = resumed_rx.changed() => {
2515 if changed.is_err() {
2516 tracing::trace!(
2517 conn_id = self.sender.connection_id().0,
2518 "resume notifier closed, exiting driver"
2519 );
2520 break;
2521 }
2522 let generation = *resumed_rx.borrow();
2523 if generation != seen_resume_generation {
2524 seen_resume_generation = generation;
2525 self.close_all_channel_runtime_state(ChannelRuntimeTeardown::DropOnly);
2526 self.abort_channel_handlers();
2527 let _ = self.resume_processed_tx.send(generation);
2528 }
2529 }
2530 Some(ctrl) = self.local_control_rx.recv() => {
2531 self.handle_local_control(ctrl).await;
2532 }
2533 Some((req_id, disposition)) = self.failures_rx.recv() => {
2534 tracing::trace!(%req_id, ?disposition, "failures_rx fired");
2535 let in_flight_found = self.in_flight_handlers.contains_key(&req_id);
2536 let in_flight_method_id =
2537 self.in_flight_handlers.get(&req_id).map(|in_flight| in_flight.method_id);
2538 let reply_disposition = self
2539 .in_flight_handlers
2540 .get(&req_id)
2541 .map(|in_flight| {
2542 let has_channels =
2543 self.handler.args_have_channels(in_flight.method_id);
2544 if has_channels && !in_flight.retry.idem {
2545 Some(FailureDisposition::Indeterminate)
2546 } else if has_channels && in_flight.retry.idem {
2547 None
2548 } else {
2549 Some(disposition)
2550 }
2551 })
2552 .unwrap_or(Some(disposition));
2553 tracing::trace!(%req_id, in_flight_found, ?reply_disposition, "failures_rx computed disposition");
2554 self.in_flight_handlers.remove(&req_id);
2556 self.shared.finish_request(req_id, RequestDebugState::Failed);
2557 tracing::trace!(%req_id, in_flight = self.in_flight_handlers.len(), "handler removed on failure");
2558 let had_pending = self.shared.pending_responses.lock().remove(&req_id).is_some();
2559 tracing::trace!(%req_id, had_pending, "failures_rx checked pending_responses");
2560 if !had_pending {
2561 let Some(reply_disposition) = reply_disposition else {
2562 tracing::trace!(%req_id, "failures_rx: no reply_disposition, skipping");
2563 continue;
2564 };
2565 tracing::trace!(%req_id, ?reply_disposition, "failures_rx: sending error response");
2566 let vox_error = match reply_disposition {
2567 FailureDisposition::Cancelled => VoxError::Cancelled,
2568 FailureDisposition::Indeterminate => VoxError::Indeterminate,
2569 };
2570 if let Some(method_id) = in_flight_method_id
2571 && let Some(response_shape) = self.handler.response_wire_shape(method_id)
2572 && let Ok(extracted) = vox_types::extract_schemas(response_shape)
2573 {
2574 let registry = vox_types::build_registry(&extracted.schemas);
2575 let error: Result<(), VoxError<core::convert::Infallible>> =
2576 Err(vox_error);
2577 let encoded = vox_postcard::to_vec(&error)
2578 .expect("serialize runtime-generated error response");
2579 let mut response = RequestResponse {
2580 ret: Payload::PostcardBytes(Box::leak(encoded.into_boxed_slice())),
2581 metadata: Default::default(),
2582 schemas: Default::default(),
2583 };
2584 self.sender.prepare_response_from_source(
2585 req_id,
2586 method_id,
2587 &extracted.root,
2588 ®istry,
2589 &mut response,
2590 );
2591 let _ = self.sender.send_response(req_id, response).await;
2592 } else {
2593 let error: Result<(), VoxError<core::convert::Infallible>> =
2594 Err(vox_error);
2595 let _ = self.sender.send_response(req_id, RequestResponse {
2596 ret: Payload::outgoing(&error),
2597 metadata: Default::default(),
2598 schemas: Default::default(),
2599 }).await;
2600 }
2601 tracing::trace!(%req_id, "failures_rx: error response sent");
2602 }
2603 }
2604 recv = self.rx.recv() => {
2605 match recv {
2606 Some(recv) => {
2607 self.handle_recv(recv).await;
2608 }
2609 None => {
2610 tracing::trace!("driver rx closed, exiting loop");
2611 break;
2612 }
2613 }
2614 }
2615 Some(item) = self.handler_futs.next(), if !self.handler_futs.is_empty() => {
2621 match item {
2622 Ok(req_id) => {
2623 let removed = self.in_flight_handlers.remove(&req_id).is_some();
2624 self.shared.finish_request(req_id, RequestDebugState::Finished);
2625 tracing::trace!(
2626 %req_id,
2627 removed,
2628 in_flight = self.in_flight_handlers.len(),
2629 "handler completion processed",
2630 );
2631 }
2632 Err(_aborted) => {
2633 }
2636 }
2637 }
2638 }
2639 }
2640
2641 for (_, in_flight) in std::mem::take(&mut self.in_flight_handlers) {
2642 if !in_flight.retry.persist {
2643 in_flight.abort.abort();
2644 }
2645 }
2646 self.shared.pending_responses.lock().clear();
2647 self.shared.request_debug.lock().clear();
2648 let close_reason =
2649 (*self.closed_rx.borrow()).unwrap_or(ConnectionCloseReason::SessionShutdown);
2650 self.shared.set_connection_closed(close_reason);
2651
2652 self.close_all_channel_runtime_state(ChannelRuntimeTeardown::ConnectionClosed(
2656 close_reason,
2657 ));
2658 }
2659
2660 async fn handle_local_control(&mut self, control: DriverLocalControl) {
2661 match control {
2662 DriverLocalControl::CloseChannel { channel_id } => {
2663 if self.shared.stale_close_channels.lock().remove(&channel_id) {
2668 tracing::trace!(%channel_id, "suppressing ChannelClose for stale channel");
2669 return;
2670 }
2671 self.close_outbound_channel(channel_id);
2672 self.shared
2673 .observe_channel(channel_id, None, |channel| ChannelEvent::Closed {
2674 channel,
2675 reason: ChannelCloseReason::Local,
2676 });
2677 self.shared.mark_outbound_progress();
2678 let _ = self
2679 .sender
2680 .send(ConnectionMessage::Channel(ChannelMessage {
2681 id: channel_id,
2682 body: ChannelBody::Close(ChannelClose {
2683 metadata: Default::default(),
2684 }),
2685 }))
2686 .await;
2687 }
2688 DriverLocalControl::ResetChannel { channel_id } => {
2689 self.shared.channel_senders.lock().remove(&channel_id);
2690 self.shared.channel_receivers.lock().remove(&channel_id);
2691 self.close_outbound_channel(channel_id);
2692 self.shared
2693 .observe_channel(channel_id, None, |channel| ChannelEvent::Reset {
2694 channel,
2695 reason: ChannelResetReason::Local,
2696 });
2697 self.shared.mark_outbound_progress();
2698 let _ = self
2699 .sender
2700 .send(ConnectionMessage::Channel(ChannelMessage {
2701 id: channel_id,
2702 body: ChannelBody::Reset(vox_types::ChannelReset {
2703 metadata: Default::default(),
2704 }),
2705 }))
2706 .await;
2707 }
2708 DriverLocalControl::GrantCredit {
2709 channel_id,
2710 additional,
2711 } => {
2712 self.shared.observe_channel(channel_id, None, |channel| {
2713 ChannelEvent::CreditGranted {
2714 channel,
2715 amount: additional,
2716 }
2717 });
2718 self.shared.mark_outbound_progress();
2719 let _ = self
2720 .sender
2721 .send(ConnectionMessage::Channel(ChannelMessage {
2722 id: channel_id,
2723 body: ChannelBody::GrantCredit(vox_types::ChannelGrantCredit {
2724 additional,
2725 }),
2726 }))
2727 .await;
2728 }
2729 }
2730 }
2731
2732 async fn handle_recv(&mut self, recv: crate::session::RecvMessage) {
2733 self.shared.mark_inbound_progress();
2734 let crate::session::RecvMessage { schemas, msg, fds } = recv;
2735 let msg_ref = msg.get();
2736 let is_request = matches!(msg_ref, ConnectionMessage::Request(_));
2737 if is_request {
2738 if let ConnectionMessage::Request(req) = msg_ref {
2739 vox_types::dlog!(
2740 "[driver] handle_recv request: conn={:?} req={:?} body={} method={:?}",
2741 self.sender.connection_id(),
2742 req.id,
2743 match &req.body {
2744 RequestBody::Call(_) => "Call",
2745 RequestBody::Response(_) => "Response",
2746 RequestBody::Cancel(_) => "Cancel",
2747 },
2748 match &req.body {
2749 RequestBody::Call(call) => Some(call.method_id),
2750 RequestBody::Response(_) | RequestBody::Cancel(_) => None,
2751 }
2752 );
2753 match &req.body {
2754 RequestBody::Call(call) => tracing::trace!(
2755 conn_id = self.sender.connection_id().0,
2756 req_id = req.id.0,
2757 method_id = call.method_id.0,
2758 "driver received call"
2759 ),
2760 RequestBody::Response(_) => tracing::trace!(
2761 conn_id = self.sender.connection_id().0,
2762 req_id = req.id.0,
2763 "driver received response message"
2764 ),
2765 RequestBody::Cancel(_) => tracing::trace!(
2766 conn_id = self.sender.connection_id().0,
2767 req_id = req.id.0,
2768 "driver received cancel message"
2769 ),
2770 }
2771 }
2772 let msg = msg.map(|m| match m {
2773 ConnectionMessage::Request(r) => r,
2774 _ => unreachable!(),
2775 });
2776 self.handle_request(msg, schemas, fds);
2777 } else {
2778 let msg = msg.map(|m| match m {
2779 ConnectionMessage::Channel(c) => c,
2780 _ => unreachable!(),
2781 });
2782 self.handle_channel(msg).await;
2783 }
2784 }
2785
2786 fn handle_request(
2787 &mut self,
2788 msg: SelfRef<RequestMessage<'static>>,
2789 schemas: Arc<vox_types::SchemaRecvTracker>,
2790 fds: vox_types::FrameFds,
2791 ) {
2792 let msg_ref = msg.get();
2793 let req_id = msg_ref.id;
2794 let is_call = matches!(&msg_ref.body, RequestBody::Call(_));
2795 let is_response = matches!(&msg_ref.body, RequestBody::Response(_));
2796 let is_cancel = matches!(&msg_ref.body, RequestBody::Cancel(_));
2797
2798 if is_call {
2799 let method_id = match &msg_ref.body {
2800 RequestBody::Call(call) => call.method_id,
2801 _ => unreachable!(),
2802 };
2803 vox_types::dlog!(
2804 "[driver] inbound call: conn={:?} req={:?} method={:?}",
2805 self.sender.connection_id(),
2806 req_id,
2807 method_id
2808 );
2809 let call = msg.map(|m| match m.body {
2812 RequestBody::Call(c) => c,
2813 _ => unreachable!(),
2814 });
2815 let call_ref = call.get();
2816 let handler = Arc::clone(&self.handler);
2817 let retry = handler.retry_policy(call_ref.method_id);
2818 let operation_id = metadata_operation_id(&call_ref.metadata).filter(|_| !retry.idem);
2820 let method_id = call_ref.method_id;
2821
2822 if let Some(operation_id) = operation_id {
2823 let admit = self.live_operations.lock().admit(
2825 operation_id,
2826 call_ref.method_id,
2827 incoming_args_bytes(call_ref),
2828 retry,
2829 req_id,
2830 );
2831 match admit {
2832 AdmitResult::Attached => return,
2833 AdmitResult::Conflict => {
2834 let sender = self.sender.clone();
2835 moire::task::spawn(
2836 async move {
2837 let error: Result<(), VoxError<core::convert::Infallible>> =
2838 Err(VoxError::InvalidPayload("operation ID conflict".into()));
2839 let _ = sender
2840 .send_response(
2841 req_id,
2842 RequestResponse {
2843 ret: Payload::outgoing(&error),
2844 metadata: Default::default(),
2845 schemas: Default::default(),
2846 },
2847 )
2848 .await;
2849 }
2850 .named("operation_reject"),
2851 );
2852 return;
2853 }
2854 AdmitResult::Start => {}
2855 }
2856
2857 match self.shared.operations.lookup(operation_id) {
2859 crate::OperationState::Sealed => {
2860 if let Some(sealed) = self.shared.operations.get_sealed(operation_id) {
2862 let sender = self.sender.clone();
2863 let method_id = call_ref.method_id;
2864 let response_shape = self.handler.response_wire_shape(method_id);
2865 self.live_operations.lock().seal(operation_id);
2867 moire::task::spawn(
2868 async move {
2869 if replay_sealed_response(
2870 sender.clone(),
2871 req_id,
2872 method_id,
2873 sealed.response.as_bytes(),
2874 response_shape,
2875 )
2876 .await
2877 .is_err()
2878 {
2879 sender.mark_failure(req_id, FailureDisposition::Cancelled);
2880 }
2881 }
2882 .named("operation_replay"),
2883 );
2884 return;
2885 }
2886 }
2887 crate::OperationState::Admitted => {
2888 self.live_operations.lock().seal(operation_id);
2890 let sender = self.sender.clone();
2891 moire::task::spawn(
2892 async move {
2893 let error: Result<(), VoxError<core::convert::Infallible>> =
2894 Err(VoxError::Indeterminate);
2895 let _ = sender
2896 .send_response(
2897 req_id,
2898 RequestResponse {
2899 ret: Payload::outgoing(&error),
2900 metadata: Default::default(),
2901 schemas: Default::default(),
2902 },
2903 )
2904 .await;
2905 }
2906 .named("operation_indeterminate"),
2907 );
2908 return;
2909 }
2910 crate::OperationState::Unknown => {
2911 if !retry.idem {
2914 self.shared.operations.admit(operation_id);
2915 }
2916 }
2917 }
2918 }
2919 let reply = DriverReplySink {
2920 sender: Some(self.sender.clone()),
2921 request_id: req_id,
2922 method_id: call_ref.method_id,
2923 retry,
2924 operation_id,
2925 operations: operation_id.map(|_| Arc::clone(&self.shared.operations)),
2926 live_operations: operation_id.map(|_| Arc::clone(&self.live_operations)),
2927 binder: self.internal_binder(),
2928 handler_response_shape: handler.response_wire_shape(call_ref.method_id),
2929 };
2930 self.shared.start_request(
2931 req_id,
2932 method_id,
2933 None,
2934 None,
2935 RequestDebugState::Dispatching,
2936 );
2937 let (abort, abort_reg) = AbortHandle::new_pair();
2938 let handler_fut: Pin<Box<dyn MaybeSendFuture<Output = RequestId> + 'static>> =
2939 Box::pin(async move {
2940 vox_types::dlog!(
2941 "[driver] handler start: req={:?} method={:?}",
2942 req_id,
2943 method_id
2944 );
2945 handler.handle(call, reply, schemas).await;
2946 vox_types::dlog!(
2947 "[driver] handler done: req={:?} method={:?}",
2948 req_id,
2949 method_id
2950 );
2951 req_id
2952 });
2953 self.handler_futs
2954 .push(Abortable::new(handler_fut, abort_reg));
2955 self.in_flight_handlers.insert(
2956 req_id,
2957 InFlightHandler {
2958 abort,
2959 method_id,
2960 retry,
2961 operation_id,
2962 },
2963 );
2964 tracing::trace!(%req_id, in_flight = self.in_flight_handlers.len(), "handler inserted");
2965 } else if is_response {
2966 vox_types::dlog!(
2968 "[driver] inbound response: conn={:?} req={:?}",
2969 self.sender.connection_id(),
2970 req_id
2971 );
2972 tracing::trace!(%req_id, "driver received response");
2973 if let Some(tx) = self.shared.pending_responses.lock().remove(&req_id) {
2974 vox_types::dlog!("[driver] routing response to waiter: req={:?}", req_id);
2975 tracing::trace!(%req_id, "routing response to pending oneshot");
2976 let _: Result<(), _> = tx.send(PendingResponse { msg, schemas, fds });
2977 } else {
2978 vox_types::dlog!("[driver] dropped unmatched response: req={:?}", req_id);
2979 tracing::trace!(%req_id, "no pending response slot for this req_id");
2980 }
2981 } else if is_cancel {
2982 vox_types::dlog!(
2983 "[driver] inbound cancel: conn={:?} req={:?}",
2984 self.sender.connection_id(),
2985 req_id
2986 );
2987 tracing::trace!(%req_id, in_flight = self.in_flight_handlers.contains_key(&req_id), "received cancel");
2990 match self.live_operations.lock().cancel(req_id) {
2991 CancelResult::NotFound => {
2992 let should_abort = self
2993 .in_flight_handlers
2994 .get(&req_id)
2995 .map(|in_flight| !in_flight.retry.persist)
2996 .unwrap_or(false);
2997 tracing::trace!(%req_id, should_abort, "cancel: not in live operations");
2998 if should_abort && let Some(in_flight) = self.in_flight_handlers.remove(&req_id)
2999 {
3000 tracing::trace!(%req_id, "aborting handler");
3001 in_flight.abort.abort();
3002 self.shared
3003 .finish_request(req_id, RequestDebugState::Failed);
3004 tracing::trace!(%req_id, in_flight = self.in_flight_handlers.len(), "handler removed on cancel");
3005 }
3006 }
3007 CancelResult::Detached => {}
3008 CancelResult::Abort {
3009 owner_request_id,
3010 waiters,
3011 } => {
3012 if let Some(in_flight) = self.in_flight_handlers.remove(&owner_request_id) {
3013 if let Some(op_id) = in_flight.operation_id {
3014 self.shared.operations.remove(op_id);
3015 }
3016 in_flight.abort.abort();
3017 self.shared
3018 .finish_request(owner_request_id, RequestDebugState::Failed);
3019 tracing::trace!(%owner_request_id, in_flight = self.in_flight_handlers.len(), "owner handler removed on abort");
3020 }
3021 for waiter in waiters {
3022 self.sender
3023 .mark_failure(waiter, FailureDisposition::Cancelled);
3024 }
3025 }
3026 }
3027 }
3030 }
3031
3032 async fn handle_channel(&mut self, msg: SelfRef<ChannelMessage<'static>>) {
3033 let msg_ref = msg.get();
3034 let chan_id = msg_ref.id;
3035 enum ChannelBodyKind {
3036 Item,
3037 Close,
3038 Reset,
3039 GrantCredit(u32),
3040 }
3041 let body_kind = match &msg_ref.body {
3042 ChannelBody::Item(_) => ChannelBodyKind::Item,
3043 ChannelBody::Close(_) => ChannelBodyKind::Close,
3044 ChannelBody::Reset(_) => ChannelBodyKind::Reset,
3045 ChannelBody::GrantCredit(grant) => ChannelBodyKind::GrantCredit(grant.additional),
3046 };
3047
3048 match body_kind {
3049 ChannelBodyKind::Item => {
3052 if self.shared.terminal_channels.lock().contains(&chan_id) {
3053 self.shared.record_inbound_item_not_enqueued(chan_id);
3054 tracing::trace!(
3055 conn_id = self.sender.connection_id().0,
3056 channel_id = chan_id.0,
3057 "driver dropped item for terminal channel"
3058 );
3059 return;
3060 }
3061
3062 tracing::trace!(
3063 conn_id = self.sender.connection_id().0,
3064 channel_id = chan_id.0,
3065 "driver received channel item"
3066 );
3067 let item = msg.map(|m| match m.body {
3068 ChannelBody::Item(item) => item,
3069 _ => unreachable!(),
3070 });
3071 let sender = self.shared.inbound_channel_sender(chan_id);
3072 if sender
3073 .send(IncomingChannelMessage::Item(item))
3074 .await
3075 .is_err()
3076 {
3077 self.shared.record_inbound_item_not_enqueued(chan_id);
3078 self.shared.channel_senders.lock().remove(&chan_id);
3079 self.shared.channel_receivers.lock().remove(&chan_id);
3080 self.close_outbound_channel(chan_id);
3081 let _ = self
3082 .local_control_tx
3083 .send(DriverLocalControl::ResetChannel {
3084 channel_id: chan_id,
3085 });
3086 return;
3087 }
3088 self.shared
3089 .observe_channel(chan_id, None, |channel| ChannelEvent::ItemReceived {
3090 channel,
3091 });
3092 }
3093 ChannelBodyKind::Close => {
3095 if self.shared.terminal_channels.lock().contains(&chan_id) {
3096 return;
3097 }
3098 let sender = self.shared.inbound_channel_sender(chan_id);
3099 tracing::trace!(
3100 conn_id = self.sender.connection_id().0,
3101 channel_id = chan_id.0,
3102 "driver received channel close"
3103 );
3104 let close = msg.map(|m| match m.body {
3105 ChannelBody::Close(close) => close,
3106 _ => unreachable!(),
3107 });
3108 let delivered = sender
3109 .send(IncomingChannelMessage::Close(close))
3110 .await
3111 .is_ok();
3112 self.shared.channel_senders.lock().remove(&chan_id);
3113 self.shared.terminal_channels.lock().insert(chan_id);
3114 self.close_outbound_channel(chan_id);
3115 if !delivered {
3116 self.shared.channel_receivers.lock().remove(&chan_id);
3117 return;
3118 }
3119 self.shared
3120 .observe_channel(chan_id, None, |channel| ChannelEvent::Closed {
3121 channel,
3122 reason: ChannelCloseReason::Remote,
3123 });
3124 }
3125 ChannelBodyKind::Reset => {
3127 if self.shared.terminal_channels.lock().contains(&chan_id) {
3128 return;
3129 }
3130 let sender = self.shared.inbound_channel_sender(chan_id);
3131 tracing::trace!(
3132 conn_id = self.sender.connection_id().0,
3133 channel_id = chan_id.0,
3134 "driver received channel reset"
3135 );
3136 let reset = msg.map(|m| match m.body {
3137 ChannelBody::Reset(reset) => reset,
3138 _ => unreachable!(),
3139 });
3140 let delivered = sender
3141 .send(IncomingChannelMessage::Reset(reset))
3142 .await
3143 .is_ok();
3144 self.shared.channel_senders.lock().remove(&chan_id);
3145 self.shared.terminal_channels.lock().insert(chan_id);
3146 self.close_outbound_channel(chan_id);
3147 if !delivered {
3148 self.shared.channel_receivers.lock().remove(&chan_id);
3149 return;
3150 }
3151 self.shared
3152 .observe_channel(chan_id, None, |channel| ChannelEvent::Reset {
3153 channel,
3154 reason: ChannelResetReason::Remote,
3155 });
3156 }
3157 ChannelBodyKind::GrantCredit(additional) => {
3160 self.shared.record_credit_received(chan_id, additional);
3161 self.shared.emit_channel_event(chan_id, None, |channel| {
3162 ChannelEvent::CreditGranted {
3163 channel,
3164 amount: additional,
3165 }
3166 });
3167 tracing::trace!(
3168 conn_id = self.sender.connection_id().0,
3169 channel_id = chan_id.0,
3170 additional,
3171 "driver received channel credit"
3172 );
3173 if let Some(semaphore) = self.shared.channel_credits.lock().get(&chan_id) {
3174 semaphore.add_permits(additional as usize);
3175 }
3176 }
3177 }
3178 }
3179}