1use std::{
2 collections::BTreeMap,
3 pin::Pin,
4 sync::{
5 Arc, Weak,
6 atomic::{AtomicU64, Ordering},
7 },
8};
9
10use moire::sync::{Semaphore, SyncMutex};
11use tokio::sync::watch;
12
13use moire::task::FutureExt as _;
14use vox_types::{
15 BoxFut, CallResult, Caller, ChannelBinder, ChannelBody, ChannelClose, ChannelCreditReplenisher,
16 ChannelCreditReplenisherHandle, ChannelId, ChannelItem, ChannelLivenessHandle, ChannelMessage,
17 ChannelRetryMode, ChannelSink, CreditSink, Handler, IdAllocator, IncomingChannelMessage,
18 Payload, ReplySink, RequestBody, RequestCall, RequestId, RequestMessage, RequestResponse,
19 SelfRef, VoxError, TxError, ensure_operation_id, metadata_channel_retry_mode,
20 metadata_operation_id,
21};
22
23use crate::session::{
24 ConnectionHandle, ConnectionMessage, ConnectionSender, DropControlRequest, FailureDisposition,
25};
26use crate::{InMemoryOperationStore, OperationStore};
27use moire::sync::mpsc;
28use vox_types::{OperationId, PostcardPayload, SchemaHash, TypeRef};
29
30struct PendingResponse {
36 msg: SelfRef<RequestMessage<'static>>,
37 schemas: Arc<vox_types::SchemaRecvTracker>,
38}
39
40type ResponseSlot = moire::sync::oneshot::Sender<PendingResponse>;
41
42struct InFlightHandler {
43 handle: moire::task::JoinHandle<()>,
44 method_id: vox_types::MethodId,
45 retry: vox_types::RetryPolicy,
46 has_channels: bool,
47 operation_id: Option<OperationId>,
48}
49
50struct LiveOperationTracker {
60 live: HashMap<OperationId, LiveOperation>,
62 request_to_operation: HashMap<RequestId, OperationId>,
64}
65
66struct LiveOperation {
67 method_id: vox_types::MethodId,
68 args_hash: u64,
69 owner_request_id: RequestId,
70 waiters: Vec<RequestId>,
71 retry: vox_types::RetryPolicy,
72}
73
74enum AdmitResult {
75 Start,
77 Attached,
79 Conflict,
81}
82
83impl LiveOperationTracker {
84 fn new() -> Self {
85 Self {
86 live: HashMap::new(),
87 request_to_operation: HashMap::new(),
88 }
89 }
90
91 fn admit(
92 &mut self,
93 operation_id: OperationId,
94 method_id: vox_types::MethodId,
95 args: &[u8],
96 retry: vox_types::RetryPolicy,
97 request_id: RequestId,
98 ) -> AdmitResult {
99 use std::hash::{Hash, Hasher};
100 let args_hash = {
101 let mut h = std::collections::hash_map::DefaultHasher::new();
102 method_id.hash(&mut h);
103 args.hash(&mut h);
104 h.finish()
105 };
106
107 if let Some(live) = self.live.get_mut(&operation_id) {
108 if live.method_id != method_id || live.args_hash != args_hash {
109 return AdmitResult::Conflict;
110 }
111 live.waiters.push(request_id);
112 self.request_to_operation.insert(request_id, operation_id);
113 return AdmitResult::Attached;
114 }
115
116 self.live.insert(
117 operation_id,
118 LiveOperation {
119 method_id,
120 args_hash,
121 owner_request_id: request_id,
122 waiters: vec![request_id],
123 retry,
124 },
125 );
126 self.request_to_operation.insert(request_id, operation_id);
127 AdmitResult::Start
128 }
129
130 fn seal(&mut self, operation_id: OperationId) -> Vec<RequestId> {
132 if let Some(live) = self.live.remove(&operation_id) {
133 for waiter in &live.waiters {
134 self.request_to_operation.remove(waiter);
135 }
136 live.waiters
137 } else {
138 vec![]
139 }
140 }
141
142 fn release(&mut self, operation_id: OperationId) -> Option<LiveOperation> {
144 if let Some(live) = self.live.remove(&operation_id) {
145 for waiter in &live.waiters {
146 self.request_to_operation.remove(waiter);
147 }
148 Some(live)
149 } else {
150 None
151 }
152 }
153
154 fn cancel(&mut self, request_id: RequestId) -> CancelResult {
156 let Some(&operation_id) = self.request_to_operation.get(&request_id) else {
157 return CancelResult::NotFound;
158 };
159 let Some(live) = self.live.get_mut(&operation_id) else {
160 self.request_to_operation.remove(&request_id);
161 return CancelResult::NotFound;
162 };
163
164 if live.retry.persist {
165 if live.owner_request_id == request_id {
167 return CancelResult::NotFound; }
169 live.waiters.retain(|w| *w != request_id);
170 self.request_to_operation.remove(&request_id);
171 CancelResult::Detached
172 } else {
173 let live = self.live.remove(&operation_id).unwrap();
175 for waiter in &live.waiters {
176 self.request_to_operation.remove(waiter);
177 }
178 CancelResult::Abort {
179 owner_request_id: live.owner_request_id,
180 waiters: live.waiters,
181 }
182 }
183 }
184}
185
186enum CancelResult {
187 NotFound,
188 Detached,
189 Abort {
190 owner_request_id: RequestId,
191 waiters: Vec<RequestId>,
192 },
193}
194
195use std::collections::HashMap;
196
197struct DriverShared {
202 pending_responses: SyncMutex<BTreeMap<RequestId, ResponseSlot>>,
203 request_ids: SyncMutex<IdAllocator<RequestId>>,
204 next_operation_id: AtomicU64,
205 operations: Arc<dyn OperationStore>,
206 channel_ids: SyncMutex<IdAllocator<ChannelId>>,
207 channel_senders: SyncMutex<BTreeMap<ChannelId, mpsc::Sender<IncomingChannelMessage>>>,
209 channel_buffers: SyncMutex<BTreeMap<ChannelId, Vec<IncomingChannelMessage>>>,
216 channel_credits: SyncMutex<BTreeMap<ChannelId, Arc<Semaphore>>>,
219 stale_close_channels: SyncMutex<std::collections::HashSet<ChannelId>>,
224}
225
226struct CallerDropGuard {
227 control_tx: mpsc::UnboundedSender<DropControlRequest>,
228 request: DropControlRequest,
229}
230
231impl Drop for CallerDropGuard {
232 fn drop(&mut self) {
233 let _ = self.control_tx.send(self.request);
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::{DriverChannelCreditReplenisher, DriverLocalControl};
240 use vox_types::{ChannelCreditReplenisher, ChannelId};
241 use tokio::sync::mpsc::error::TryRecvError;
242
243 #[test]
244 fn replenisher_batches_at_half_the_initial_window() {
245 let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher");
246 let replenisher = DriverChannelCreditReplenisher::new(ChannelId(7), 16, tx);
247
248 for _ in 0..7 {
249 replenisher.on_item_consumed();
250 }
251 assert!(
252 matches!(rx.try_recv(), Err(TryRecvError::Empty)),
253 "should not emit credit before reaching the batch threshold"
254 );
255
256 replenisher.on_item_consumed();
257 let Ok(DriverLocalControl::GrantCredit {
258 channel_id,
259 additional,
260 }) = rx.try_recv()
261 else {
262 panic!("expected batched credit grant");
263 };
264 assert_eq!(channel_id, ChannelId(7));
265 assert_eq!(additional, 8);
266 }
267
268 #[test]
269 fn replenisher_grants_one_by_one_for_single_credit_windows() {
270 let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher.single");
271 let replenisher = DriverChannelCreditReplenisher::new(ChannelId(9), 1, tx);
272
273 replenisher.on_item_consumed();
274 let Ok(DriverLocalControl::GrantCredit {
275 channel_id,
276 additional,
277 }) = rx.try_recv()
278 else {
279 panic!("expected immediate credit grant");
280 };
281 assert_eq!(channel_id, ChannelId(9));
282 assert_eq!(additional, 1);
283 }
284}
285
286pub struct DriverReplySink {
294 sender: Option<ConnectionSender>,
295 request_id: RequestId,
296 method_id: vox_types::MethodId,
297 retry: vox_types::RetryPolicy,
298 operation_id: Option<OperationId>,
299 operations: Option<Arc<dyn OperationStore>>,
300 live_operations: Option<Arc<SyncMutex<LiveOperationTracker>>>,
301 binder: DriverChannelBinder,
302}
303
304async fn replay_sealed_response(
310 sender: ConnectionSender,
311 request_id: RequestId,
312 method_id: vox_types::MethodId,
313 encoded_response: &[u8],
314 root_type: TypeRef,
315 operations: &dyn OperationStore,
316) -> Result<(), ()> {
317 let mut response: RequestResponse<'_> =
318 vox_postcard::from_slice_borrowed(encoded_response).map_err(|_| ())?;
319 sender.prepare_replay_schemas(request_id, method_id, &root_type, operations, &mut response);
320 sender.send_response(request_id, response).await
321}
322
323fn extract_root_type_ref(schemas_cbor: &vox_types::CborPayload) -> TypeRef {
325 if schemas_cbor.is_empty() {
326 return TypeRef::concrete(SchemaHash(0));
327 }
328 let payload =
329 vox_types::SchemaPayload::from_cbor(&schemas_cbor.0).expect("schema CBOR must be valid");
330 payload.root
331}
332
333fn incoming_args_bytes<'a>(call: &'a RequestCall<'a>) -> &'a [u8] {
334 match &call.args {
335 Payload::PostcardBytes(bytes) => bytes,
336 Payload::Value { .. } => {
337 panic!("incoming request payload should always be decoded as incoming bytes")
338 }
339 }
340}
341
342impl ReplySink for DriverReplySink {
343 async fn send_reply(mut self, response: RequestResponse<'_>) {
344 let sender = self
345 .sender
346 .take()
347 .expect("unreachable: send_reply takes self by value");
348
349 vox_types::dlog!(
350 "[driver] send_reply: conn={:?} req={:?} method={:?} payload={} operation_id={:?}",
351 sender.connection_id(),
352 self.request_id,
353 self.method_id,
354 match &response.ret {
355 Payload::Value { .. } => "Value",
356 Payload::PostcardBytes(_) => "PostcardBytes",
357 },
358 self.operation_id
359 );
360
361 if let Payload::Value { shape, .. } = &response.ret
362 && let Ok(extracted) = vox_types::extract_schemas(shape)
363 {
364 vox_types::dlog!(
365 "[schema] driver send_reply: method={:?} root={:?}",
366 self.method_id,
367 extracted.root
368 );
369 }
370
371 if let (Some(operation_id), Some(operations)) = (self.operation_id, self.operations.take())
372 {
373 let mut response = response;
374 sender.prepare_response_for_method(self.request_id, self.method_id, &mut response);
375
376 let root_type = extract_root_type_ref(&response.schemas);
378
379 let schemas_for_wire = std::mem::take(&mut response.schemas);
381 let encoded_for_store = PostcardPayload(
382 vox_postcard::to_vec(&response).expect("serialize operation response for store"),
383 );
384 response.schemas = schemas_for_wire;
385
386 vox_types::dlog!(
388 "[driver] send_reply wire send: conn={:?} req={:?} method={:?} schemas={}",
389 sender.connection_id(),
390 self.request_id,
391 self.method_id,
392 response.schemas.0.len()
393 );
394 if let Err(_e) = sender.send_response(self.request_id, response).await {
395 sender.mark_failure(self.request_id, FailureDisposition::Cancelled);
396 }
397
398 let registry = sender.schema_registry();
400 operations.seal(operation_id, &encoded_for_store, &root_type, ®istry);
401
402 let waiters = self
404 .live_operations
405 .as_ref()
406 .map(|lo| lo.lock().seal(operation_id))
407 .unwrap_or_default();
408 for waiter in waiters {
409 if waiter == self.request_id {
410 continue;
411 }
412 if replay_sealed_response(
413 sender.clone(),
414 waiter,
415 self.method_id,
416 encoded_for_store.as_bytes(),
417 root_type.clone(),
418 operations.as_ref(),
419 )
420 .await
421 .is_err()
422 {
423 sender.mark_failure(waiter, FailureDisposition::Cancelled);
424 }
425 }
426 } else {
427 vox_types::dlog!(
428 "[driver] send_reply direct send: conn={:?} req={:?} method={:?}",
429 sender.connection_id(),
430 self.request_id,
431 self.method_id
432 );
433 if let Err(_e) = sender
434 .send_response_for_method(self.request_id, self.method_id, response)
435 .await
436 {
437 sender.mark_failure(self.request_id, FailureDisposition::Cancelled);
438 }
439 }
440 }
441
442 fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
443 Some(&self.binder)
444 }
445}
446
447impl Drop for DriverReplySink {
449 fn drop(&mut self) {
450 if let Some(sender) = self.sender.take() {
451 let disposition = if self.retry.persist {
452 FailureDisposition::Indeterminate
453 } else {
454 FailureDisposition::Cancelled
455 };
456
457 if let Some(operation_id) = self.operation_id {
458 if let Some(live_ops) = self.live_operations.take()
464 && let Some(live) = live_ops.lock().release(operation_id)
465 {
466 for waiter in live.waiters {
467 sender.mark_failure(waiter, disposition);
468 }
469 return;
470 }
471 }
472
473 sender.mark_failure(self.request_id, disposition);
474 }
475 }
476}
477
478pub struct DriverChannelSink {
486 sender: ConnectionSender,
487 channel_id: ChannelId,
488 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
489}
490
491impl ChannelSink for DriverChannelSink {
492 fn send_payload<'payload>(
493 &self,
494 payload: Payload<'payload>,
495 ) -> Pin<Box<dyn vox_types::MaybeSendFuture<Output = Result<(), TxError>> + 'payload>> {
496 let sender = self.sender.clone();
497 let channel_id = self.channel_id;
498 Box::pin(async move {
499 sender
500 .send(ConnectionMessage::Channel(ChannelMessage {
501 id: channel_id,
502 body: ChannelBody::Item(ChannelItem { item: payload }),
503 }))
504 .await
505 .map_err(|()| TxError::Transport("connection closed".into()))
506 })
507 }
508
509 fn close_channel(
510 &self,
511 _metadata: vox_types::Metadata,
512 ) -> Pin<Box<dyn vox_types::MaybeSendFuture<Output = Result<(), TxError>> + 'static>> {
513 let sender = self.sender.clone();
517 let channel_id = self.channel_id;
518 Box::pin(async move {
519 sender
520 .send(ConnectionMessage::Channel(ChannelMessage {
521 id: channel_id,
522 body: ChannelBody::Close(ChannelClose {
523 metadata: Default::default(),
524 }),
525 }))
526 .await
527 .map_err(|()| TxError::Transport("connection closed".into()))
528 })
529 }
530
531 fn close_channel_on_drop(&self) {
532 let _ = self
533 .local_control_tx
534 .send(DriverLocalControl::CloseChannel {
535 channel_id: self.channel_id,
536 });
537 }
538}
539
540#[must_use = "Dropping NoopCaller may close the connection if it is the last caller."]
544#[derive(Clone)]
545pub struct NoopCaller(#[allow(dead_code)] DriverCaller);
546
547impl From<DriverCaller> for NoopCaller {
548 fn from(caller: DriverCaller) -> Self {
549 Self(caller)
550 }
551}
552
553#[derive(Clone)]
554struct DriverChannelBinder {
555 sender: ConnectionSender,
556 shared: Arc<DriverShared>,
557 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
558 drop_guard: Option<Arc<CallerDropGuard>>,
559}
560
561const DEFAULT_CHANNEL_CREDIT: u32 = 16;
563
564fn register_rx_channel_impl(
565 shared: &Arc<DriverShared>,
566 channel_id: ChannelId,
567 queue_name: &'static str,
568 liveness: Option<ChannelLivenessHandle>,
569 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
570) -> vox_types::BoundChannelReceiver {
571 let (tx, rx) = mpsc::channel(queue_name, 64);
572
573 let mut terminal_buffered = false;
574 {
575 let mut senders = shared.channel_senders.lock();
576
577 senders.insert(channel_id, tx.clone());
586
587 let buffered = shared.channel_buffers.lock().remove(&channel_id);
588 if let Some(buffered) = buffered {
589 for msg in buffered {
590 let is_terminal = matches!(
591 msg,
592 IncomingChannelMessage::Close(_) | IncomingChannelMessage::Reset(_)
593 );
594 let _ = tx.try_send(msg);
595 if is_terminal {
596 terminal_buffered = true;
597 break;
598 }
599 }
600 }
601
602 if terminal_buffered {
603 senders.remove(&channel_id);
604 }
605 }
606
607 if terminal_buffered {
608 shared.channel_credits.lock().remove(&channel_id);
609 return vox_types::BoundChannelReceiver {
610 receiver: rx,
611 liveness,
612 replenisher: None,
613 };
614 }
615
616 vox_types::BoundChannelReceiver {
617 receiver: rx,
618 liveness,
619 replenisher: Some(Arc::new(DriverChannelCreditReplenisher::new(
620 channel_id,
621 DEFAULT_CHANNEL_CREDIT,
622 local_control_tx,
623 )) as ChannelCreditReplenisherHandle),
624 }
625}
626
627impl DriverChannelBinder {
628 fn create_tx_channel(&self) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
629 let channel_id = self.shared.channel_ids.lock().alloc();
630 let inner = DriverChannelSink {
631 sender: self.sender.clone(),
632 channel_id,
633 local_control_tx: self.local_control_tx.clone(),
634 };
635 let sink = Arc::new(CreditSink::new(inner, DEFAULT_CHANNEL_CREDIT));
636 self.shared
637 .channel_credits
638 .lock()
639 .insert(channel_id, Arc::clone(sink.credit()));
640 (channel_id, sink)
641 }
642
643 fn register_rx_channel(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
644 register_rx_channel_impl(
645 &self.shared,
646 channel_id,
647 "driver.register_rx_channel",
648 self.channel_liveness(),
649 self.local_control_tx.clone(),
650 )
651 }
652}
653
654impl ChannelBinder for DriverChannelBinder {
655 fn create_tx(&self) -> (ChannelId, Arc<dyn ChannelSink>) {
656 let (id, sink) = self.create_tx_channel();
657 (id, sink as Arc<dyn ChannelSink>)
658 }
659
660 fn create_rx(&self) -> (ChannelId, vox_types::BoundChannelReceiver) {
661 let channel_id = self.shared.channel_ids.lock().alloc();
662 let rx = self.register_rx_channel(channel_id);
663 (channel_id, rx)
664 }
665
666 fn bind_tx(&self, channel_id: ChannelId) -> Arc<dyn ChannelSink> {
667 let inner = DriverChannelSink {
668 sender: self.sender.clone(),
669 channel_id,
670 local_control_tx: self.local_control_tx.clone(),
671 };
672 let sink = Arc::new(CreditSink::new(inner, DEFAULT_CHANNEL_CREDIT));
673 self.shared
674 .channel_credits
675 .lock()
676 .insert(channel_id, Arc::clone(sink.credit()));
677 sink
678 }
679
680 fn register_rx(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
681 self.register_rx_channel(channel_id)
682 }
683
684 fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
685 self.drop_guard
686 .as_ref()
687 .map(|guard| guard.clone() as ChannelLivenessHandle)
688 }
689}
690
691#[derive(Clone)]
695pub struct DriverCaller {
696 sender: ConnectionSender,
697 shared: Arc<DriverShared>,
698 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
699 closed_rx: watch::Receiver<bool>,
700 resumed_rx: watch::Receiver<u64>,
701 resume_processed_rx: watch::Receiver<u64>,
702 peer_supports_retry: bool,
703 _drop_guard: Option<Arc<CallerDropGuard>>,
704}
705
706impl DriverCaller {
707 pub fn create_tx_channel(&self) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
712 let channel_id = self.shared.channel_ids.lock().alloc();
713 let inner = DriverChannelSink {
714 sender: self.sender.clone(),
715 channel_id,
716 local_control_tx: self.local_control_tx.clone(),
717 };
718 let sink = Arc::new(CreditSink::new(inner, DEFAULT_CHANNEL_CREDIT));
719 self.shared
720 .channel_credits
721 .lock()
722 .insert(channel_id, Arc::clone(sink.credit()));
723 (channel_id, sink)
724 }
725
726 #[cfg(test)]
731 pub(crate) fn connection_sender(&self) -> &ConnectionSender {
732 &self.sender
733 }
734
735 pub fn register_rx_channel(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
740 register_rx_channel_impl(
741 &self.shared,
742 channel_id,
743 "driver.caller.register_rx_channel",
744 self.channel_liveness(),
745 self.local_control_tx.clone(),
746 )
747 }
748}
749
750impl ChannelBinder for DriverCaller {
751 fn create_tx(&self) -> (ChannelId, Arc<dyn ChannelSink>) {
752 let (id, sink) = self.create_tx_channel();
753 (id, sink as Arc<dyn ChannelSink>)
754 }
755
756 fn create_rx(&self) -> (ChannelId, vox_types::BoundChannelReceiver) {
757 let channel_id = self.shared.channel_ids.lock().alloc();
758 let rx = self.register_rx_channel(channel_id);
759 (channel_id, rx)
760 }
761
762 fn bind_tx(&self, channel_id: ChannelId) -> Arc<dyn ChannelSink> {
763 let inner = DriverChannelSink {
764 sender: self.sender.clone(),
765 channel_id,
766 local_control_tx: self.local_control_tx.clone(),
767 };
768 let sink = Arc::new(CreditSink::new(inner, DEFAULT_CHANNEL_CREDIT));
769 self.shared
770 .channel_credits
771 .lock()
772 .insert(channel_id, Arc::clone(sink.credit()));
773 sink
774 }
775
776 fn register_rx(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
777 self.register_rx_channel(channel_id)
778 }
779
780 fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
781 self._drop_guard
782 .as_ref()
783 .map(|guard| guard.clone() as ChannelLivenessHandle)
784 }
785}
786
787impl Caller for DriverCaller {
788 async fn call<'a>(&'a self, mut call: RequestCall<'a>) -> CallResult {
789 if self.peer_supports_retry {
790 let operation_id = OperationId(
791 self.shared
792 .next_operation_id
793 .fetch_add(1, Ordering::Relaxed),
794 );
795 ensure_operation_id(&mut call.metadata, operation_id);
796 }
797
798 let req_id = self.shared.request_ids.lock().alloc();
800
801 let (tx, rx) = moire::sync::oneshot::channel("driver.response");
804 self.shared.pending_responses.lock().insert(req_id, tx);
805
806 if self
814 .sender
815 .send_with_binder(
816 ConnectionMessage::Request(RequestMessage {
817 id: req_id,
818 body: RequestBody::Call(RequestCall {
819 method_id: call.method_id,
820 args: call.args.reborrow(),
821 metadata: call.metadata.clone(),
822 schemas: Default::default(),
823 }),
824 }),
825 Some(self),
826 )
827 .await
828 .is_err()
829 {
830 self.shared.pending_responses.lock().remove(&req_id);
831 return Err(VoxError::SendFailed);
832 }
833
834 let mut resumed_rx = self.resumed_rx.clone();
835 let mut seen_resume_generation = *resumed_rx.borrow();
836 let mut resume_processed_rx = self.resume_processed_rx.clone();
837 let mut closed_rx = self.closed_rx.clone();
838 let mut response = std::pin::pin!(rx.named("awaiting_response"));
839
840 let pending: PendingResponse = loop {
841 tokio::select! {
842 result = &mut response => {
843 match result {
844 Ok(pending) => break pending,
845 Err(_) => {
846 return Err(VoxError::ConnectionClosed);
847 }
848 }
849 }
850 changed = resumed_rx.changed(), if self.peer_supports_retry => {
851 vox_types::dlog!("[CALLER] resumed_rx fired");
852 if changed.is_err() {
853 self.shared.pending_responses.lock().remove(&req_id);
854 return Err(VoxError::SessionShutdown);
855 }
856 let generation = *resumed_rx.borrow();
857 if generation == seen_resume_generation {
858 continue;
859 }
860 seen_resume_generation = generation;
861 while *resume_processed_rx.borrow() < generation {
862 if resume_processed_rx.changed().await.is_err() {
863 self.shared.pending_responses.lock().remove(&req_id);
864 return Err(VoxError::SessionShutdown);
865 }
866 }
867 match metadata_channel_retry_mode(&call.metadata) {
868 ChannelRetryMode::NonIdem => {
869 self.shared.pending_responses.lock().remove(&req_id);
870 return Err(VoxError::Indeterminate);
871 }
872 ChannelRetryMode::Idem | ChannelRetryMode::None => {}
873 }
874 let _ = self.sender.send_with_binder(
878 ConnectionMessage::Request(RequestMessage {
879 id: req_id,
880 body: RequestBody::Call(RequestCall {
881 method_id: call.method_id,
882 args: call.args.reborrow(),
883 metadata: call.metadata.clone(),
884 schemas: Default::default(),
885 }),
886 }),
887 Some(self),
888 ).await;
889 }
890 changed = closed_rx.changed() => {
891 vox_types::dlog!("[CALLER] closed_rx fired, value={}", *closed_rx.borrow());
892 if changed.is_err() || *closed_rx.borrow() {
893 self.shared.pending_responses.lock().remove(&req_id);
894 return Err(VoxError::ConnectionClosed);
895 }
896 }
897 }
898 };
899
900 let PendingResponse {
902 msg: response_msg,
903 schemas: response_schemas,
904 } = pending;
905 let response = response_msg.map(|m| match m.body {
906 RequestBody::Response(r) => r,
907 _ => unreachable!("pending_responses only gets Response variants"),
908 });
909
910 Ok(vox_types::WithTracker {
911 value: response,
912 tracker: response_schemas,
913 })
914 }
915
916 fn closed(&self) -> BoxFut<'_, ()> {
917 Box::pin(async move {
918 if *self.closed_rx.borrow() {
919 return;
920 }
921 let mut rx = self.closed_rx.clone();
922 while rx.changed().await.is_ok() {
923 if *rx.borrow() {
924 return;
925 }
926 }
927 })
928 }
929
930 fn is_connected(&self) -> bool {
931 !*self.closed_rx.borrow()
932 }
933
934 fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
935 Some(self)
936 }
937}
938
939pub struct Driver<H: Handler<DriverReplySink>> {
946 sender: ConnectionSender,
947 rx: mpsc::Receiver<crate::session::RecvMessage>,
948 failures_rx: mpsc::UnboundedReceiver<(RequestId, FailureDisposition)>,
949 closed_rx: watch::Receiver<bool>,
950 resumed_rx: watch::Receiver<u64>,
951 resume_processed_tx: watch::Sender<u64>,
952 peer_supports_retry: bool,
953 local_control_rx: mpsc::UnboundedReceiver<DriverLocalControl>,
954 handler: Arc<H>,
955 shared: Arc<DriverShared>,
956 in_flight_handlers: BTreeMap<RequestId, InFlightHandler>,
959 live_operations: Arc<SyncMutex<LiveOperationTracker>>,
962 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
963 drop_control_seed: Option<mpsc::UnboundedSender<DropControlRequest>>,
964 drop_control_request: DropControlRequest,
965 drop_guard: SyncMutex<Option<Weak<CallerDropGuard>>>,
966}
967
968enum DriverLocalControl {
969 CloseChannel {
970 channel_id: ChannelId,
971 },
972 GrantCredit {
973 channel_id: ChannelId,
974 additional: u32,
975 },
976}
977
978struct DriverChannelCreditReplenisher {
979 channel_id: ChannelId,
980 threshold: u32,
981 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
982 pending: std::sync::Mutex<u32>,
983}
984
985impl DriverChannelCreditReplenisher {
986 fn new(
987 channel_id: ChannelId,
988 initial_credit: u32,
989 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
990 ) -> Self {
991 Self {
992 channel_id,
993 threshold: (initial_credit / 2).max(1),
994 local_control_tx,
995 pending: std::sync::Mutex::new(0),
996 }
997 }
998}
999
1000impl ChannelCreditReplenisher for DriverChannelCreditReplenisher {
1001 fn on_item_consumed(&self) {
1002 let mut pending = self.pending.lock().expect("pending credit mutex poisoned");
1003 *pending += 1;
1004 if *pending < self.threshold {
1005 return;
1006 }
1007
1008 let additional = *pending;
1009 *pending = 0;
1010 let _ = self.local_control_tx.send(DriverLocalControl::GrantCredit {
1011 channel_id: self.channel_id,
1012 additional,
1013 });
1014 }
1015}
1016
1017impl<H: Handler<DriverReplySink>> Driver<H> {
1018 fn close_all_channel_runtime_state(&self) {
1019 let mut credits = self.shared.channel_credits.lock();
1020 for semaphore in credits.values() {
1021 semaphore.close();
1022 }
1023 let mut stale = self.shared.stale_close_channels.lock();
1026 stale.extend(credits.keys().copied());
1027 credits.clear();
1028 drop(credits);
1029
1030 self.shared.channel_senders.lock().clear();
1031 self.shared.channel_buffers.lock().clear();
1032 }
1033
1034 fn close_outbound_channel(&self, channel_id: ChannelId) {
1035 if let Some(semaphore) = self.shared.channel_credits.lock().remove(&channel_id) {
1036 semaphore.close();
1037 }
1038 }
1039
1040 fn abort_channel_handlers(&mut self) {
1041 for (_req_id, in_flight) in &self.in_flight_handlers {
1042 if in_flight.has_channels {
1043 if let Some(operation_id) = in_flight.operation_id {
1044 self.shared.operations.remove(operation_id);
1045 self.live_operations.lock().release(operation_id);
1046 }
1047 in_flight.handle.abort();
1048 }
1049 }
1050 }
1051
1052 pub fn new(handle: ConnectionHandle, handler: H) -> Self {
1053 Self::with_operation_store(handle, handler, Arc::new(InMemoryOperationStore::default()))
1054 }
1055
1056 pub fn with_operation_store(
1057 handle: ConnectionHandle,
1058 handler: H,
1059 operation_store: Arc<dyn OperationStore>,
1060 ) -> Self {
1061 let conn_id = handle.connection_id();
1062 let ConnectionHandle {
1063 sender,
1064 rx,
1065 failures_rx,
1066 control_tx,
1067 closed_rx,
1068 resumed_rx,
1069 parity,
1070 peer_supports_retry,
1071 } = handle;
1072 let drop_control_request = DropControlRequest::Close(conn_id);
1073 let (local_control_tx, local_control_rx) = mpsc::unbounded_channel("driver.local_control");
1074 let (resume_processed_tx, _resume_processed_rx) = watch::channel(0_u64);
1075 Self {
1076 sender,
1077 rx,
1078 failures_rx,
1079 closed_rx,
1080 resumed_rx,
1081 resume_processed_tx,
1082 peer_supports_retry,
1083 local_control_rx,
1084 handler: Arc::new(handler),
1085 shared: Arc::new(DriverShared {
1086 pending_responses: SyncMutex::new("driver.pending_responses", BTreeMap::new()),
1087 request_ids: SyncMutex::new("driver.request_ids", IdAllocator::new(parity)),
1088 next_operation_id: AtomicU64::new(1),
1089 operations: operation_store,
1090 channel_ids: SyncMutex::new("driver.channel_ids", IdAllocator::new(parity)),
1091 channel_senders: SyncMutex::new("driver.channel_senders", BTreeMap::new()),
1092 channel_buffers: SyncMutex::new("driver.channel_buffers", BTreeMap::new()),
1093 channel_credits: SyncMutex::new("driver.channel_credits", BTreeMap::new()),
1094 stale_close_channels: SyncMutex::new(
1095 "driver.stale_close_channels",
1096 std::collections::HashSet::new(),
1097 ),
1098 }),
1099 in_flight_handlers: BTreeMap::new(),
1100 live_operations: Arc::new(SyncMutex::new(
1101 "driver.live_operations",
1102 LiveOperationTracker::new(),
1103 )),
1104 local_control_tx,
1105 drop_control_seed: control_tx,
1106 drop_control_request,
1107 drop_guard: SyncMutex::new("driver.drop_guard", None),
1108 }
1109 }
1110
1111 fn existing_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
1117 self.drop_guard.lock().as_ref().and_then(Weak::upgrade)
1118 }
1119
1120 fn connection_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
1121 if let Some(existing) = self.existing_drop_guard() {
1122 Some(existing)
1123 } else if let Some(seed) = &self.drop_control_seed {
1124 let mut guard = self.drop_guard.lock();
1125 if let Some(existing) = guard.as_ref().and_then(Weak::upgrade) {
1126 Some(existing)
1127 } else {
1128 let arc = Arc::new(CallerDropGuard {
1129 control_tx: seed.clone(),
1130 request: self.drop_control_request,
1131 });
1132 *guard = Some(Arc::downgrade(&arc));
1133 Some(arc)
1134 }
1135 } else {
1136 None
1137 }
1138 }
1139
1140 pub fn caller(&self) -> DriverCaller {
1141 let drop_guard = self.connection_drop_guard();
1142 DriverCaller {
1143 sender: self.sender.clone(),
1144 shared: Arc::clone(&self.shared),
1145 local_control_tx: self.local_control_tx.clone(),
1146 closed_rx: self.closed_rx.clone(),
1147 resumed_rx: self.resumed_rx.clone(),
1148 resume_processed_rx: self.resume_processed_tx.subscribe(),
1149 peer_supports_retry: self.peer_supports_retry,
1150 _drop_guard: drop_guard,
1151 }
1152 }
1153
1154 fn internal_binder(&self) -> DriverChannelBinder {
1155 DriverChannelBinder {
1156 sender: self.sender.clone(),
1157 shared: Arc::clone(&self.shared),
1158 local_control_tx: self.local_control_tx.clone(),
1159 drop_guard: self.existing_drop_guard(),
1160 }
1161 }
1162
1163 pub async fn run(&mut self) {
1168 let mut resumed_rx = self.resumed_rx.clone();
1169 let mut seen_resume_generation = *resumed_rx.borrow();
1170 loop {
1171 tracing::trace!("driver select loop top");
1172 tokio::select! {
1173 biased;
1174 changed = resumed_rx.changed() => {
1175 if changed.is_err() {
1176 break;
1177 }
1178 let generation = *resumed_rx.borrow();
1179 if generation != seen_resume_generation {
1180 seen_resume_generation = generation;
1181 self.close_all_channel_runtime_state();
1182 self.abort_channel_handlers();
1183 let _ = self.resume_processed_tx.send(generation);
1184 }
1185 }
1186 recv = self.rx.recv() => {
1187 match recv {
1188 Some(recv) => {
1189 tracing::debug!("driver rx received message");
1190 self.handle_recv(recv);
1191 }
1192 None => {
1193 tracing::debug!("driver rx closed, exiting loop");
1194 break;
1195 }
1196 }
1197 }
1198 Some((req_id, disposition)) = self.failures_rx.recv() => {
1199 tracing::debug!(%req_id, ?disposition, "failures_rx fired");
1200 let in_flight_found = self.in_flight_handlers.contains_key(&req_id);
1201 let in_flight_method_id =
1202 self.in_flight_handlers.get(&req_id).map(|in_flight| in_flight.method_id);
1203 let reply_disposition = self
1204 .in_flight_handlers
1205 .get(&req_id)
1206 .map(|in_flight| {
1207 if in_flight.has_channels && !in_flight.retry.idem {
1208 Some(FailureDisposition::Indeterminate)
1209 } else if in_flight.has_channels && in_flight.retry.idem {
1210 None
1211 } else {
1212 Some(disposition)
1213 }
1214 })
1215 .unwrap_or(Some(disposition));
1216 tracing::debug!(%req_id, in_flight_found, ?reply_disposition, "failures_rx computed disposition");
1217 self.in_flight_handlers.remove(&req_id);
1219 let had_pending = self.shared.pending_responses.lock().remove(&req_id).is_some();
1220 tracing::debug!(%req_id, had_pending, "failures_rx checked pending_responses");
1221 if !had_pending {
1222 let Some(reply_disposition) = reply_disposition else {
1223 tracing::debug!(%req_id, "failures_rx: no reply_disposition, skipping");
1224 continue;
1225 };
1226 tracing::debug!(%req_id, ?reply_disposition, "failures_rx: sending error response");
1227 let vox_error = match reply_disposition {
1228 FailureDisposition::Cancelled => VoxError::Cancelled,
1229 FailureDisposition::Indeterminate => VoxError::Indeterminate,
1230 };
1231 if let Some(method_id) = in_flight_method_id
1232 && let Some(response_shape) = self.handler.response_wire_shape(method_id)
1233 && let Ok(extracted) = vox_types::extract_schemas(response_shape)
1234 {
1235 let registry = vox_types::build_registry(&extracted.schemas);
1236 let error: Result<(), VoxError<core::convert::Infallible>> =
1237 Err(vox_error);
1238 let encoded = vox_postcard::to_vec(&error)
1239 .expect("serialize runtime-generated error response");
1240 let mut response = RequestResponse {
1241 ret: Payload::PostcardBytes(Box::leak(encoded.into_boxed_slice())),
1242 metadata: Default::default(),
1243 schemas: Default::default(),
1244 };
1245 self.sender.prepare_response_from_source(
1246 req_id,
1247 method_id,
1248 &extracted.root,
1249 ®istry,
1250 &mut response,
1251 );
1252 let _ = self.sender.send_response(req_id, response).await;
1253 } else {
1254 let error: Result<(), VoxError<core::convert::Infallible>> =
1255 Err(vox_error);
1256 let _ = self.sender.send_response(req_id, RequestResponse {
1257 ret: Payload::outgoing(&error),
1258 metadata: Default::default(),
1259 schemas: Default::default(),
1260 }).await;
1261 }
1262 tracing::debug!(%req_id, "failures_rx: error response sent");
1263 }
1264 }
1265 Some(ctrl) = self.local_control_rx.recv() => {
1266 self.handle_local_control(ctrl).await;
1267 }
1268 }
1269 }
1270
1271 for (_, in_flight) in std::mem::take(&mut self.in_flight_handlers) {
1272 if !in_flight.retry.persist {
1273 in_flight.handle.abort();
1274 }
1275 }
1276 self.shared.pending_responses.lock().clear();
1277
1278 self.close_all_channel_runtime_state();
1282 }
1283
1284 async fn handle_local_control(&mut self, control: DriverLocalControl) {
1285 match control {
1286 DriverLocalControl::CloseChannel { channel_id } => {
1287 if self.shared.stale_close_channels.lock().remove(&channel_id) {
1292 tracing::debug!(%channel_id, "suppressing ChannelClose for stale channel");
1293 return;
1294 }
1295 let _ = self
1296 .sender
1297 .send(ConnectionMessage::Channel(ChannelMessage {
1298 id: channel_id,
1299 body: ChannelBody::Close(ChannelClose {
1300 metadata: Default::default(),
1301 }),
1302 }))
1303 .await;
1304 }
1305 DriverLocalControl::GrantCredit {
1306 channel_id,
1307 additional,
1308 } => {
1309 let _ = self
1310 .sender
1311 .send(ConnectionMessage::Channel(ChannelMessage {
1312 id: channel_id,
1313 body: ChannelBody::GrantCredit(vox_types::ChannelGrantCredit {
1314 additional,
1315 }),
1316 }))
1317 .await;
1318 }
1319 }
1320 }
1321
1322 fn handle_recv(&mut self, recv: crate::session::RecvMessage) {
1323 let crate::session::RecvMessage { schemas, msg } = recv;
1324 let is_request = matches!(&*msg, ConnectionMessage::Request(_));
1325 if is_request {
1326 if let ConnectionMessage::Request(req) = &*msg {
1327 vox_types::dlog!(
1328 "[driver] handle_recv request: conn={:?} req={:?} body={} method={:?}",
1329 self.sender.connection_id(),
1330 req.id,
1331 match &req.body {
1332 RequestBody::Call(_) => "Call",
1333 RequestBody::Response(_) => "Response",
1334 RequestBody::Cancel(_) => "Cancel",
1335 },
1336 match &req.body {
1337 RequestBody::Call(call) => Some(call.method_id),
1338 RequestBody::Response(_) | RequestBody::Cancel(_) => None,
1339 }
1340 );
1341 }
1342 let msg = msg.map(|m| match m {
1343 ConnectionMessage::Request(r) => r,
1344 _ => unreachable!(),
1345 });
1346 self.handle_request(msg, schemas);
1347 } else {
1348 let msg = msg.map(|m| match m {
1349 ConnectionMessage::Channel(c) => c,
1350 _ => unreachable!(),
1351 });
1352 self.handle_channel(msg);
1353 }
1354 }
1355
1356 fn handle_request(
1357 &mut self,
1358 msg: SelfRef<RequestMessage<'static>>,
1359 schemas: Arc<vox_types::SchemaRecvTracker>,
1360 ) {
1361 let req_id = msg.id;
1362 let is_call = matches!(&msg.body, RequestBody::Call(_));
1363 let is_response = matches!(&msg.body, RequestBody::Response(_));
1364 let is_cancel = matches!(&msg.body, RequestBody::Cancel(_));
1365
1366 if is_call {
1367 let method_id = match &msg.body {
1368 RequestBody::Call(call) => call.method_id,
1369 _ => unreachable!(),
1370 };
1371 vox_types::dlog!(
1372 "[driver] inbound call: conn={:?} req={:?} method={:?}",
1373 self.sender.connection_id(),
1374 req_id,
1375 method_id
1376 );
1377 let call = msg.map(|m| match m.body {
1380 RequestBody::Call(c) => c,
1381 _ => unreachable!(),
1382 });
1383 let handler = Arc::clone(&self.handler);
1384 let retry = handler.retry_policy(call.method_id);
1385 let operation_id = metadata_operation_id(&call.metadata);
1386 let method_id = call.method_id;
1387
1388 if let Some(operation_id) = operation_id {
1389 let admit = self.live_operations.lock().admit(
1391 operation_id,
1392 call.method_id,
1393 incoming_args_bytes(&call),
1394 retry,
1395 req_id,
1396 );
1397 match admit {
1398 AdmitResult::Attached => return,
1399 AdmitResult::Conflict => {
1400 let sender = self.sender.clone();
1401 moire::task::spawn(
1402 async move {
1403 let error: Result<(), VoxError<core::convert::Infallible>> =
1404 Err(VoxError::InvalidPayload("operation ID conflict".into()));
1405 let _ = sender
1406 .send_response(
1407 req_id,
1408 RequestResponse {
1409 ret: Payload::outgoing(&error),
1410 metadata: Default::default(),
1411 schemas: Default::default(),
1412 },
1413 )
1414 .await;
1415 }
1416 .named("operation_reject"),
1417 );
1418 return;
1419 }
1420 AdmitResult::Start => {}
1421 }
1422
1423 match self.shared.operations.lookup(operation_id) {
1425 crate::OperationState::Sealed => {
1426 if let Some(sealed) = self.shared.operations.get_sealed(operation_id) {
1428 let sender = self.sender.clone();
1429 let method_id = call.method_id;
1430 let operations = Arc::clone(&self.shared.operations);
1431 self.live_operations.lock().seal(operation_id);
1433 moire::task::spawn(
1434 async move {
1435 if replay_sealed_response(
1436 sender.clone(),
1437 req_id,
1438 method_id,
1439 sealed.response.as_bytes(),
1440 sealed.root_type,
1441 operations.as_ref(),
1442 )
1443 .await
1444 .is_err()
1445 {
1446 sender.mark_failure(req_id, FailureDisposition::Cancelled);
1447 }
1448 }
1449 .named("operation_replay"),
1450 );
1451 return;
1452 }
1453 }
1454 crate::OperationState::Admitted => {
1455 self.live_operations.lock().seal(operation_id);
1457 let sender = self.sender.clone();
1458 moire::task::spawn(
1459 async move {
1460 let error: Result<(), VoxError<core::convert::Infallible>> =
1461 Err(VoxError::Indeterminate);
1462 let _ = sender
1463 .send_response(
1464 req_id,
1465 RequestResponse {
1466 ret: Payload::outgoing(&error),
1467 metadata: Default::default(),
1468 schemas: Default::default(),
1469 },
1470 )
1471 .await;
1472 }
1473 .named("operation_indeterminate"),
1474 );
1475 return;
1476 }
1477 crate::OperationState::Unknown => {
1478 if !retry.idem {
1481 self.shared.operations.admit(operation_id);
1482 }
1483 }
1484 }
1485 }
1486 let reply = DriverReplySink {
1487 sender: Some(self.sender.clone()),
1488 request_id: req_id,
1489 method_id: call.method_id,
1490 retry,
1491 operation_id,
1492 operations: operation_id.map(|_| Arc::clone(&self.shared.operations)),
1493 live_operations: operation_id.map(|_| Arc::clone(&self.live_operations)),
1494 binder: self.internal_binder(),
1495 };
1496 let has_channels = handler.args_have_channels(call.method_id);
1497 let join_handle = moire::task::spawn(
1498 async move {
1499 vox_types::dlog!(
1500 "[driver] handler start: req={:?} method={:?}",
1501 req_id,
1502 method_id
1503 );
1504 handler.handle(call, reply, schemas).await;
1505 vox_types::dlog!(
1506 "[driver] handler done: req={:?} method={:?}",
1507 req_id,
1508 method_id
1509 );
1510 }
1511 .named("handler"),
1512 );
1513 self.in_flight_handlers.insert(
1514 req_id,
1515 InFlightHandler {
1516 handle: join_handle,
1517 method_id,
1518 retry,
1519 has_channels,
1520 operation_id,
1521 },
1522 );
1523 } else if is_response {
1524 vox_types::dlog!(
1526 "[driver] inbound response: conn={:?} req={:?}",
1527 self.sender.connection_id(),
1528 req_id
1529 );
1530 tracing::debug!(%req_id, "driver received response");
1531 if let Some(tx) = self.shared.pending_responses.lock().remove(&req_id) {
1532 vox_types::dlog!("[driver] routing response to waiter: req={:?}", req_id);
1533 tracing::debug!(%req_id, "routing response to pending oneshot");
1534 let _: Result<(), _> = tx.send(PendingResponse { msg, schemas });
1535 } else {
1536 vox_types::dlog!("[driver] dropped unmatched response: req={:?}", req_id);
1537 tracing::debug!(%req_id, "no pending response slot for this req_id");
1538 }
1539 } else if is_cancel {
1540 vox_types::dlog!(
1541 "[driver] inbound cancel: conn={:?} req={:?}",
1542 self.sender.connection_id(),
1543 req_id
1544 );
1545 tracing::debug!(%req_id, in_flight = self.in_flight_handlers.contains_key(&req_id), "received cancel");
1548 match self.live_operations.lock().cancel(req_id) {
1549 CancelResult::NotFound => {
1550 let should_abort = self
1551 .in_flight_handlers
1552 .get(&req_id)
1553 .map(|in_flight| !in_flight.retry.persist)
1554 .unwrap_or(false);
1555 tracing::debug!(%req_id, should_abort, "cancel: not in live operations");
1556 if should_abort && let Some(in_flight) = self.in_flight_handlers.remove(&req_id)
1557 {
1558 tracing::debug!(%req_id, "aborting handler");
1559 in_flight.handle.abort();
1560 }
1561 }
1562 CancelResult::Detached => {}
1563 CancelResult::Abort {
1564 owner_request_id,
1565 waiters,
1566 } => {
1567 if let Some(in_flight) = self.in_flight_handlers.remove(&owner_request_id) {
1568 if let Some(op_id) = in_flight.operation_id {
1569 self.shared.operations.remove(op_id);
1570 }
1571 in_flight.handle.abort();
1572 }
1573 for waiter in waiters {
1574 self.sender
1575 .mark_failure(waiter, FailureDisposition::Cancelled);
1576 }
1577 }
1578 }
1579 }
1582 }
1583
1584 fn handle_channel(&mut self, msg: SelfRef<ChannelMessage<'static>>) {
1585 let chan_id = msg.id;
1586
1587 let sender = self.shared.channel_senders.lock().get(&chan_id).cloned();
1590
1591 match &msg.body {
1592 ChannelBody::Item(_item) => {
1594 if let Some(tx) = &sender {
1595 let item = msg.map(|m| match m.body {
1596 ChannelBody::Item(item) => item,
1597 _ => unreachable!(),
1598 });
1599 let _ = tx.try_send(IncomingChannelMessage::Item(item));
1601 } else {
1602 let item = msg.map(|m| match m.body {
1604 ChannelBody::Item(item) => item,
1605 _ => unreachable!(),
1606 });
1607 self.shared
1608 .channel_buffers
1609 .lock()
1610 .entry(chan_id)
1611 .or_default()
1612 .push(IncomingChannelMessage::Item(item));
1613 }
1614 }
1615 ChannelBody::Close(_close) => {
1617 if let Some(tx) = &sender {
1618 let close = msg.map(|m| match m.body {
1619 ChannelBody::Close(close) => close,
1620 _ => unreachable!(),
1621 });
1622 let _ = tx.try_send(IncomingChannelMessage::Close(close));
1623 } else {
1624 let close = msg.map(|m| match m.body {
1626 ChannelBody::Close(close) => close,
1627 _ => unreachable!(),
1628 });
1629 self.shared
1630 .channel_buffers
1631 .lock()
1632 .entry(chan_id)
1633 .or_default()
1634 .push(IncomingChannelMessage::Close(close));
1635 }
1636 self.shared.channel_senders.lock().remove(&chan_id);
1637 self.close_outbound_channel(chan_id);
1638 }
1639 ChannelBody::Reset(_reset) => {
1641 if let Some(tx) = &sender {
1642 let reset = msg.map(|m| match m.body {
1643 ChannelBody::Reset(reset) => reset,
1644 _ => unreachable!(),
1645 });
1646 let _ = tx.try_send(IncomingChannelMessage::Reset(reset));
1647 } else {
1648 let reset = msg.map(|m| match m.body {
1650 ChannelBody::Reset(reset) => reset,
1651 _ => unreachable!(),
1652 });
1653 self.shared
1654 .channel_buffers
1655 .lock()
1656 .entry(chan_id)
1657 .or_default()
1658 .push(IncomingChannelMessage::Reset(reset));
1659 }
1660 self.shared.channel_senders.lock().remove(&chan_id);
1661 self.close_outbound_channel(chan_id);
1662 }
1663 ChannelBody::GrantCredit(grant) => {
1666 if let Some(semaphore) = self.shared.channel_credits.lock().get(&chan_id) {
1667 semaphore.add_permits(grant.additional as usize);
1668 }
1669 }
1670 }
1671 }
1672}