1use std::{
2 collections::{BTreeMap, HashMap},
3 future::Future,
4 pin::Pin,
5 sync::Arc,
6 time::Duration,
7};
8
9use facet_core::Shape;
10use moire::sync::mpsc;
11use tokio::sync::{mpsc as tokio_mpsc, oneshot as tokio_oneshot, watch};
12use tracing::{trace, warn};
13use vox_types::{
14 BoxFut, ChannelMessage, ConduitRx, ConduitTx, ConnectionAccept, ConnectionClose, ConnectionId,
15 ConnectionOpen, ConnectionReject, ConnectionSettings, Handler, HandshakeResult, IdAllocator,
16 MaybeSend, MaybeSync, Message, MessageFamily, MessagePayload, Metadata, Parity, RequestBody,
17 RequestId, RequestMessage, RequestResponse, SchemaMessage, SelfRef, SessionResumeKey,
18 SessionRole, TrySendError, VoxDebugSnapshot, VoxObserverHandle,
19};
20use vox_types::{
21 ConnectionCloseReason, ConnectionDebugSnapshot, ConnectionDebugState, DecodeErrorKind,
22 DriverTaskStatus,
23};
24
25mod builders;
26pub use builders::*;
27
28#[derive(Debug, Clone, Copy)]
30pub struct SessionKeepaliveConfig {
31 pub ping_interval: Duration,
32 pub pong_timeout: Duration,
33}
34
35pub struct ConnectionRequest<'a> {
43 metadata: &'a [vox_types::MetadataEntry<'a>],
44 service: &'a str,
45}
46
47impl<'a> ConnectionRequest<'a> {
48 pub fn new(metadata: &'a [vox_types::MetadataEntry<'a>]) -> Result<Self, SessionError> {
52 let service = vox_types::metadata_get_str(metadata, "vox-service").ok_or_else(|| {
53 SessionError::Protocol("missing required vox-service metadata".into())
54 })?;
55 Ok(Self { metadata, service })
56 }
57
58 pub fn service(&self) -> &str {
60 self.service
61 }
62
63 pub fn transport(&self) -> Option<&str> {
65 vox_types::metadata_get_str(self.metadata, "vox-transport")
66 }
67
68 pub fn peer_addr(&self) -> Option<&str> {
70 vox_types::metadata_get_str(self.metadata, "vox-peer-addr")
71 }
72
73 pub fn is_root(&self) -> bool {
75 !self.is_virtual()
76 }
77
78 pub fn is_virtual(&self) -> bool {
80 vox_types::metadata_get_str(self.metadata, "vox-connection-kind") == Some("virtual")
81 }
82
83 pub fn get_str(&self, key: &str) -> Option<&str> {
85 vox_types::metadata_get_str(self.metadata, key)
86 }
87
88 pub fn get_u64(&self, key: &str) -> Option<u64> {
90 vox_types::metadata_get_u64(self.metadata, key)
91 }
92
93 pub fn metadata(&self) -> &[vox_types::MetadataEntry<'a>] {
95 self.metadata
96 }
97}
98
99pub struct PendingConnection {
106 handle: Option<ConnectionHandle>,
107 caller_slot: Option<Arc<std::sync::Mutex<Option<crate::Caller>>>>,
108 operation_store: Option<Arc<dyn crate::OperationStore>>,
109}
110
111impl PendingConnection {
112 fn new(handle: ConnectionHandle) -> Self {
113 Self {
114 handle: Some(handle),
115 caller_slot: None,
116 operation_store: None,
117 }
118 }
119
120 fn with_caller_slot(
122 handle: ConnectionHandle,
123 caller_slot: Arc<std::sync::Mutex<Option<crate::Caller>>>,
124 operation_store: Option<Arc<dyn crate::OperationStore>>,
125 ) -> Self {
126 Self {
127 handle: Some(handle),
128 caller_slot: Some(caller_slot),
129 operation_store,
130 }
131 }
132
133 pub fn handle_with(mut self, handler: impl Handler<crate::DriverReplySink> + 'static) {
135 let handle = self
136 .handle
137 .take()
138 .expect("PendingConnection already consumed");
139 let conn_id = handle.connection_id();
140 trace!(%conn_id, "PendingConnection::handle_with: creating driver");
141 let mut driver = match self.operation_store.take() {
142 Some(store) => crate::Driver::with_operation_store(handle, handler, store),
143 None => crate::Driver::new(handle, handler),
144 };
145 if let Some(slot) = &self.caller_slot {
146 let caller = crate::Caller::new(driver.caller());
147 *slot.lock().unwrap() = Some(caller);
148 }
149 #[cfg(not(target_arch = "wasm32"))]
150 tokio::spawn(async move {
151 trace!(%conn_id, "PendingConnection driver starting");
152 driver.run().await;
153 trace!(%conn_id, "PendingConnection driver exited");
154 });
155 #[cfg(target_arch = "wasm32")]
156 wasm_bindgen_futures::spawn_local(async move { driver.run().await });
157 }
158
159 pub fn handle_with_client<C: crate::FromVoxSession>(
161 mut self,
162 handler: impl Handler<crate::DriverReplySink> + 'static,
163 ) -> C {
164 let handle = self
165 .handle
166 .take()
167 .expect("PendingConnection already consumed");
168 let conn_id = handle.connection_id();
169 trace!(%conn_id, "PendingConnection::handle_with_client: creating driver");
170 let mut driver = match self.operation_store.take() {
171 Some(store) => crate::Driver::with_operation_store(handle, handler, store),
172 None => crate::Driver::new(handle, handler),
173 };
174 let caller = crate::Caller::new(driver.caller());
175 if let Some(slot) = &self.caller_slot {
176 *slot.lock().unwrap() = Some(caller.clone());
177 }
178 #[cfg(not(target_arch = "wasm32"))]
179 tokio::spawn(async move {
180 trace!(%conn_id, "PendingConnection driver starting");
181 driver.run().await;
182 trace!(%conn_id, "PendingConnection driver exited");
183 });
184 #[cfg(target_arch = "wasm32")]
185 wasm_bindgen_futures::spawn_local(async move { driver.run().await });
186 C::from_vox_session(caller, None)
187 }
188
189 pub fn proxy_to(mut self, other: ConnectionHandle) {
191 let handle = self
192 .handle
193 .take()
194 .expect("PendingConnection already consumed");
195 #[cfg(not(target_arch = "wasm32"))]
196 tokio::spawn(async move {
197 let _ = proxy_connections(handle, other).await;
198 });
199 #[cfg(target_arch = "wasm32")]
200 wasm_bindgen_futures::spawn_local(async move {
201 let _ = proxy_connections(handle, other).await;
202 });
203 }
204
205 pub fn into_handle(mut self) -> ConnectionHandle {
207 self.handle
208 .take()
209 .expect("PendingConnection already consumed")
210 }
211}
212
213impl Drop for PendingConnection {
214 fn drop(&mut self) {
215 if let Some(handle) = self.handle.take() {
216 let conn_id = handle.connection_id();
217 warn!(%conn_id, "PendingConnection dropped without being consumed — closing connection");
218 if let Some(tx) = handle.control_tx.as_ref() {
219 let _ = send_drop_control(tx, DropControlRequest::Close(conn_id));
220 }
221 }
222 }
223}
224
225pub trait ConnectionAcceptor: MaybeSend + MaybeSync + 'static {
227 fn accept(
228 &self,
229 request: &ConnectionRequest,
230 connection: PendingConnection,
231 ) -> Result<(), Metadata<'static>>;
232}
233
234impl<H> ConnectionAcceptor for H
236where
237 H: Handler<crate::DriverReplySink> + Clone + MaybeSend + MaybeSync + 'static,
238{
239 fn accept(
240 &self,
241 _request: &ConnectionRequest,
242 connection: PendingConnection,
243 ) -> Result<(), Metadata<'static>> {
244 connection.handle_with(self.clone());
245 Ok(())
246 }
247}
248
249pub struct AcceptorFn<F>(pub F);
251
252impl<F> ConnectionAcceptor for AcceptorFn<F>
253where
254 F: Fn(&ConnectionRequest, PendingConnection) -> Result<(), Metadata<'static>>
255 + MaybeSend
256 + MaybeSync
257 + 'static,
258{
259 fn accept(
260 &self,
261 request: &ConnectionRequest,
262 connection: PendingConnection,
263 ) -> Result<(), Metadata<'static>> {
264 (self.0)(request, connection)
265 }
266}
267
268pub fn acceptor_fn<F>(f: F) -> AcceptorFn<F>
270where
271 F: Fn(&ConnectionRequest, PendingConnection) -> Result<(), Metadata<'static>>
272 + MaybeSend
273 + MaybeSync
274 + 'static,
275{
276 AcceptorFn(f)
277}
278
279struct OpenRequest {
284 settings: ConnectionSettings,
285 metadata: Metadata<'static>,
286 result_tx: moire::sync::oneshot::Sender<Result<ConnectionHandle, SessionError>>,
287}
288
289struct CloseRequest {
290 conn_id: ConnectionId,
291 metadata: Metadata<'static>,
292 result_tx: moire::sync::oneshot::Sender<Result<(), SessionError>>,
293}
294
295struct ResumeRequest {
296 tx: Arc<dyn DynConduitTx>,
297 rx: Box<dyn DynConduitRx>,
298 handshake_result: HandshakeResult,
299 result_tx: moire::sync::oneshot::Sender<Result<(), SessionError>>,
300}
301
302#[derive(Debug, Clone, Copy)]
303pub(crate) enum DropControlRequest {
304 Shutdown,
305 Close(ConnectionId),
306}
307
308#[derive(Clone, Copy, Debug)]
309pub(crate) enum FailureDisposition {
310 Cancelled,
311 Indeterminate,
312}
313
314#[cfg(not(target_arch = "wasm32"))]
315fn send_drop_control(
316 tx: &mpsc::UnboundedSender<DropControlRequest>,
317 req: DropControlRequest,
318) -> Result<(), ()> {
319 tx.send(req).map_err(|_| ())
320}
321
322#[cfg(target_arch = "wasm32")]
323fn send_drop_control(
324 tx: &mpsc::UnboundedSender<DropControlRequest>,
325 req: DropControlRequest,
326) -> Result<(), ()> {
327 tx.try_send(req).map_err(|_| ())
328}
329
330#[derive(Clone)]
341pub struct SessionHandle {
342 open_tx: mpsc::Sender<OpenRequest>,
343 close_tx: mpsc::Sender<CloseRequest>,
344 resume_tx: mpsc::Sender<ResumeRequest>,
345 control_tx: mpsc::UnboundedSender<DropControlRequest>,
346 resume_key: Option<SessionResumeKey>,
347}
348
349impl SessionHandle {
350 pub async fn open<Client: crate::FromVoxSession>(
356 &self,
357 settings: ConnectionSettings,
358 ) -> Result<Client, SessionError> {
359 use crate::{Caller, Driver};
360 use vox_types::{MetadataEntry, MetadataFlags, MetadataValue};
361
362 let metadata: Metadata<'static> = vec![MetadataEntry {
363 key: crate::session::builders::VOX_SERVICE_METADATA_KEY.into(),
364 value: MetadataValue::String(Client::SERVICE_NAME.into()),
365 flags: MetadataFlags::NONE,
366 }];
367 let handle = self.open_connection(settings, metadata).await?;
368 let mut driver = Driver::new(handle, ());
369 let caller = Caller::new(driver.caller());
370 #[cfg(not(target_arch = "wasm32"))]
371 tokio::spawn(async move { driver.run().await });
372 #[cfg(target_arch = "wasm32")]
373 wasm_bindgen_futures::spawn_local(async move { driver.run().await });
374 Ok(Client::from_vox_session(caller, None))
375 }
376
377 pub async fn open_connection(
384 &self,
385 settings: ConnectionSettings,
386 metadata: Metadata<'static>,
387 ) -> Result<ConnectionHandle, SessionError> {
388 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.open_result");
389 self.open_tx
390 .send(OpenRequest {
391 settings,
392 metadata,
393 result_tx,
394 })
395 .await
396 .map_err(|_| SessionError::Protocol("session closed".into()))?;
397 result_rx
398 .await
399 .map_err(|_| SessionError::Protocol("session closed".into()))?
400 }
401
402 pub async fn close_connection(
409 &self,
410 conn_id: ConnectionId,
411 metadata: Metadata<'static>,
412 ) -> Result<(), SessionError> {
413 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.close_result");
414 self.close_tx
415 .send(CloseRequest {
416 conn_id,
417 metadata,
418 result_tx,
419 })
420 .await
421 .map_err(|_| SessionError::Protocol("session closed".into()))?;
422 result_rx
423 .await
424 .map_err(|_| SessionError::Protocol("session closed".into()))?
425 }
426
427 pub(crate) async fn resume_parts(
428 &self,
429 tx: Arc<dyn DynConduitTx>,
430 rx: Box<dyn DynConduitRx>,
431 handshake_result: HandshakeResult,
432 ) -> Result<(), SessionError> {
433 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.resume_result");
434 self.resume_tx
435 .send(ResumeRequest {
436 tx,
437 rx,
438 handshake_result,
439 result_tx,
440 })
441 .await
442 .map_err(|_| SessionError::Protocol("session closed".into()))?;
443 result_rx
444 .await
445 .map_err(|_| SessionError::Protocol("session closed".into()))?
446 }
447
448 pub fn resume_key(&self) -> Option<&SessionResumeKey> {
450 self.resume_key.as_ref()
451 }
452
453 pub fn shutdown(&self) -> Result<(), SessionError> {
455 send_drop_control(&self.control_tx, DropControlRequest::Shutdown)
456 .map_err(|_| SessionError::Protocol("session closed".into()))
457 }
458}
459
460pub struct Session {
468 rx: Box<dyn DynConduitRx>,
470
471 role: SessionRole,
473
474 parity: Parity,
477
478 sess_core: Arc<SessionCore>,
480 peer_supports_retry: bool,
481 local_root_settings: ConnectionSettings,
482 peer_root_settings: Option<ConnectionSettings>,
483 resumable: bool,
484 session_resume_key: Option<SessionResumeKey>,
485
486 conns: BTreeMap<ConnectionId, ConnectionSlot>,
488 root_closed_internal: bool,
490
491 conn_ids: IdAllocator<ConnectionId>,
493
494 on_connection: Option<Arc<dyn ConnectionAcceptor>>,
496
497 open_rx: mpsc::Receiver<OpenRequest>,
499
500 close_rx: mpsc::Receiver<CloseRequest>,
502
503 resume_rx: mpsc::Receiver<ResumeRequest>,
505
506 control_tx: mpsc::UnboundedSender<DropControlRequest>,
508 control_rx: mpsc::UnboundedReceiver<DropControlRequest>,
509
510 keepalive: Option<SessionKeepaliveConfig>,
512 resume_notifier: watch::Sender<u64>,
513 recoverer: Option<Box<dyn ConduitRecoverer>>,
514 recovery_timeout: Option<Duration>,
515 registered_in_registry: bool,
518
519 observer: Option<VoxObserverHandle>,
520}
521
522#[derive(Debug)]
523struct KeepaliveRuntime {
524 ping_interval: Duration,
525 pong_timeout: Duration,
526 next_ping_at: vox_types::time::tokio::Instant,
527 waiting_pong_nonce: Option<u64>,
528 pong_deadline: vox_types::time::tokio::Instant,
529 next_ping_nonce: u64,
530}
531
532#[derive(Debug)]
535pub struct ConnectionState {
536 pub id: ConnectionId,
538
539 pub local_settings: ConnectionSettings,
541
542 pub peer_settings: ConnectionSettings,
544
545 conn_tx: mpsc::Sender<RecvMessage>,
547 closed_tx: watch::Sender<Option<ConnectionCloseReason>>,
548
549 schema_recv_tracker: Arc<vox_types::SchemaRecvTracker>,
551}
552
553#[derive(Debug)]
554enum ConnectionSlot {
555 Active(ConnectionState),
556 PendingOutbound(PendingOutboundData),
557}
558
559struct PendingOutboundData {
561 local_settings: ConnectionSettings,
562 result_tx: Option<moire::sync::oneshot::Sender<Result<ConnectionHandle, SessionError>>>,
563}
564
565impl std::fmt::Debug for PendingOutboundData {
566 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
567 f.debug_struct("PendingOutbound")
568 .field("local_settings", &self.local_settings)
569 .finish()
570 }
571}
572
573#[derive(Clone)]
574pub(crate) struct ConnectionSender {
575 connection_id: ConnectionId,
576 pub(crate) sess_core: Arc<SessionCore>,
577 failures: Arc<mpsc::UnboundedSender<(RequestId, FailureDisposition)>>,
578}
579
580fn forwarded_payload<'a>(payload: &'a vox_types::Payload<'a>) -> vox_types::Payload<'a> {
581 let vox_types::Payload::PostcardBytes(bytes) = payload else {
582 unreachable!("proxy forwarding expects decoded incoming payload bytes")
583 };
584 vox_types::Payload::PostcardBytes(bytes)
585}
586
587fn forwarded_request_body<'a>(body: &'a RequestBody<'a>) -> RequestBody<'a> {
588 match body {
589 RequestBody::Call(call) => RequestBody::Call(vox_types::RequestCall {
590 method_id: call.method_id,
591 metadata: call.metadata.clone(),
592 args: forwarded_payload(&call.args),
593 schemas: call.schemas.clone(),
594 }),
595 RequestBody::Response(response) => RequestBody::Response(RequestResponse {
596 metadata: response.metadata.clone(),
597 ret: forwarded_payload(&response.ret),
598 schemas: response.schemas.clone(),
599 }),
600 RequestBody::Cancel(cancel) => RequestBody::Cancel(vox_types::RequestCancel {
601 metadata: cancel.metadata.clone(),
602 }),
603 }
604}
605
606fn forwarded_channel_body<'a>(body: &'a vox_types::ChannelBody<'a>) -> vox_types::ChannelBody<'a> {
607 match body {
608 vox_types::ChannelBody::Item(item) => {
609 vox_types::ChannelBody::Item(vox_types::ChannelItem {
610 item: forwarded_payload(&item.item),
611 })
612 }
613 vox_types::ChannelBody::Close(close) => {
614 vox_types::ChannelBody::Close(vox_types::ChannelClose {
615 metadata: close.metadata.clone(),
616 })
617 }
618 vox_types::ChannelBody::Reset(reset) => {
619 vox_types::ChannelBody::Reset(vox_types::ChannelReset {
620 metadata: reset.metadata.clone(),
621 })
622 }
623 vox_types::ChannelBody::GrantCredit(credit) => {
624 vox_types::ChannelBody::GrantCredit(vox_types::ChannelGrantCredit {
625 additional: credit.additional,
626 })
627 }
628 }
629}
630
631impl ConnectionSender {
632 pub(crate) fn connection_id(&self) -> ConnectionId {
633 self.connection_id
634 }
635
636 pub(crate) async fn send_with_binder<'a>(
637 &self,
638 msg: ConnectionMessage<'a>,
639 binder: Option<&'a dyn vox_types::ChannelBinder>,
640 ) -> Result<(), ()> {
641 let payload = match msg {
642 ConnectionMessage::Request(r) => MessagePayload::RequestMessage(r),
643 ConnectionMessage::Channel(c) => MessagePayload::ChannelMessage(c),
644 };
645 let message = Message {
646 connection_id: self.connection_id,
647 payload,
648 };
649 self.sess_core
650 .send(message, binder, None)
651 .await
652 .map_err(|_| ())
653 }
654
655 pub async fn send<'a>(&self, msg: ConnectionMessage<'a>) -> Result<(), ()> {
657 self.send_with_binder(msg, None).await
658 }
659
660 pub(crate) fn try_send<'a>(&self, msg: ConnectionMessage<'a>) -> Result<(), TrySendError<()>> {
662 let payload = match msg {
663 ConnectionMessage::Request(r) => MessagePayload::RequestMessage(r),
664 ConnectionMessage::Channel(c) => MessagePayload::ChannelMessage(c),
665 };
666 self.sess_core.try_send(
667 Message {
668 connection_id: self.connection_id,
669 payload,
670 },
671 None,
672 None,
673 )
674 }
675
676 pub(crate) async fn send_owned(
678 &self,
679 schemas: Arc<vox_types::SchemaRecvTracker>,
680 msg: SelfRef<ConnectionMessage<'static>>,
681 ) -> Result<(), ()> {
682 let msg_ref = msg.get();
683 let payload = match msg_ref {
684 ConnectionMessage::Request(request) => MessagePayload::RequestMessage(RequestMessage {
685 id: request.id,
686 body: forwarded_request_body(&request.body),
687 }),
688 ConnectionMessage::Channel(channel) => MessagePayload::ChannelMessage(ChannelMessage {
689 id: channel.id,
690 body: forwarded_channel_body(&channel.body),
691 }),
692 };
693
694 self.sess_core
695 .send(
696 Message {
697 connection_id: self.connection_id,
698 payload,
699 },
700 None,
701 Some(&*schemas),
702 )
703 .await
704 .map_err(|_| ())
705 }
706
707 pub async fn send_response<'a>(
709 &self,
710 request_id: RequestId,
711 response: RequestResponse<'a>,
712 ) -> Result<(), ()> {
713 self.send(ConnectionMessage::Request(RequestMessage {
714 id: request_id,
715 body: RequestBody::Response(response),
716 }))
717 .await
718 }
719
720 pub async fn send_response_for_method<'a>(
722 &self,
723 request_id: RequestId,
724 method_id: vox_types::MethodId,
725 mut response: RequestResponse<'a>,
726 ) -> Result<(), ()> {
727 self.prepare_response_for_method(request_id, method_id, &mut response);
728 self.send(ConnectionMessage::Request(RequestMessage {
729 id: request_id,
730 body: RequestBody::Response(response),
731 }))
732 .await
733 }
734
735 pub(crate) fn prepare_response_for_method(
737 &self,
738 request_id: RequestId,
739 method_id: vox_types::MethodId,
740 response: &mut RequestResponse<'_>,
741 ) {
742 self.sess_core.prepare_response_for_method(
743 self.connection_id,
744 request_id,
745 method_id,
746 response,
747 );
748 }
749
750 pub(crate) fn prepare_response_from_source(
752 &self,
753 request_id: RequestId,
754 method_id: vox_types::MethodId,
755 root_type: &vox_types::TypeRef,
756 source: &dyn vox_types::SchemaSource,
757 response: &mut RequestResponse<'_>,
758 ) {
759 self.sess_core.prepare_response_from_source(
760 self.connection_id,
761 request_id,
762 method_id,
763 root_type,
764 source,
765 response,
766 );
767 }
768
769 pub fn mark_failure(&self, request_id: RequestId, disposition: FailureDisposition) {
772 let _ = self.failures.send((request_id, disposition));
773 }
774
775 pub fn prepare_replay_schemas(
779 &self,
780 request_id: RequestId,
781 method_id: vox_types::MethodId,
782 response_shape: &'static Shape,
783 response: &mut RequestResponse<'_>,
784 ) {
785 self.sess_core.prepare_response_from_shape(
786 self.connection_id,
787 request_id,
788 method_id,
789 response_shape,
790 response,
791 );
792 }
793}
794
795pub struct ConnectionHandle {
796 pub(crate) sender: ConnectionSender,
797 pub(crate) rx: mpsc::Receiver<RecvMessage>,
798 pub(crate) failures_rx: mpsc::UnboundedReceiver<(RequestId, FailureDisposition)>,
799 pub(crate) control_tx: Option<mpsc::UnboundedSender<DropControlRequest>>,
800 pub(crate) closed_rx: watch::Receiver<Option<ConnectionCloseReason>>,
801 pub(crate) resumed_rx: watch::Receiver<u64>,
802 pub(crate) local_settings: ConnectionSettings,
803 pub(crate) peer_settings: ConnectionSettings,
804 pub parity: Parity,
806 pub(crate) peer_supports_retry: bool,
807 pub(crate) observer: Option<VoxObserverHandle>,
808}
809
810impl std::fmt::Debug for ConnectionHandle {
811 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
812 f.debug_struct("ConnectionHandle")
813 .field("connection_id", &self.sender.connection_id)
814 .finish()
815 }
816}
817
818pub(crate) enum ConnectionMessage<'payload> {
819 Request(RequestMessage<'payload>),
820 Channel(ChannelMessage<'payload>),
821}
822
823vox_types::impl_reborrow!(ConnectionMessage);
824
825pub(crate) struct RecvMessage {
829 pub schemas: Arc<vox_types::SchemaRecvTracker>,
830 pub msg: SelfRef<ConnectionMessage<'static>>,
831 pub fds: vox_types::FrameFds,
834}
835
836impl ConnectionHandle {
837 pub fn connection_id(&self) -> ConnectionId {
839 self.sender.connection_id
840 }
841
842 pub async fn closed(&self) {
844 if self.closed_rx.borrow().is_some() {
845 return;
846 }
847 let mut rx = self.closed_rx.clone();
848 while rx.changed().await.is_ok() {
849 if rx.borrow().is_some() {
850 return;
851 }
852 }
853 }
854
855 pub fn is_connected(&self) -> bool {
857 self.closed_rx.borrow().is_none()
858 }
859
860 pub fn close_reason(&self) -> Option<ConnectionCloseReason> {
861 *self.closed_rx.borrow()
862 }
863
864 pub fn peer_supports_retry(&self) -> bool {
865 self.peer_supports_retry
866 }
867
868 pub fn debug_snapshot(&self) -> VoxDebugSnapshot {
870 let (outbound_queue_depth, outbound_queue_capacity) =
871 self.sender.sess_core.outbound_queue_stats();
872 VoxDebugSnapshot {
873 connections: vec![ConnectionDebugSnapshot {
874 connection_id: self.connection_id(),
875 endpoint: None,
876 surface: None,
877 component: None,
878 state: if self.closed_rx.borrow().is_some() {
879 ConnectionDebugState::Closed
880 } else {
881 ConnectionDebugState::Open
882 },
883 outstanding_requests: 0,
884 requests: Vec::new(),
885 open_channels: Vec::new(),
886 outbound_queue_depth: Some(outbound_queue_depth),
887 outbound_queue_capacity: Some(outbound_queue_capacity),
888 local_control_queue_depth: None,
889 local_control_queue_capacity: None,
890 last_inbound_message_at: None,
891 last_outbound_message_at: None,
892 last_progress_at: None,
893 close_reason: *self.closed_rx.borrow(),
894 driver_task_status: DriverTaskStatus::Unknown,
895 }],
896 }
897 }
898
899 pub fn dump_debug_snapshot(&self) -> VoxDebugSnapshot {
900 let snapshot = self.debug_snapshot();
901 tracing::info!(?snapshot, "vox debug snapshot");
902 snapshot
903 }
904}
905
906pub async fn proxy_connections(
912 left: ConnectionHandle,
913 right: ConnectionHandle,
914) -> Result<(), SessionError> {
915 if left.parity == right.parity {
916 return Err(SessionError::Protocol(
917 "proxy_connections requires opposite parities".into(),
918 ));
919 }
920 let left_conn_id = left.connection_id();
921 let right_conn_id = right.connection_id();
922 let ConnectionHandle {
923 sender: left_sender,
924 rx: mut left_rx,
925 failures_rx: _left_failures_rx,
926 control_tx: left_control_tx,
927 closed_rx: _left_closed_rx,
928 resumed_rx: _left_resumed_rx,
929 local_settings: _left_local_settings,
930 peer_settings: _left_peer_settings,
931 parity: _left_parity,
932 peer_supports_retry: _left_peer_supports_retry,
933 observer: _left_observer,
934 } = left;
935 let ConnectionHandle {
936 sender: right_sender,
937 rx: mut right_rx,
938 failures_rx: _right_failures_rx,
939 control_tx: right_control_tx,
940 closed_rx: _right_closed_rx,
941 resumed_rx: _right_resumed_rx,
942 local_settings: _right_local_settings,
943 peer_settings: _right_peer_settings,
944 parity: _right_parity,
945 peer_supports_retry: _right_peer_supports_retry,
946 observer: _right_observer,
947 } = right;
948
949 loop {
950 tokio::select! {
951 recv = left_rx.recv() => {
952 let Some(recv) = recv else {
953 break;
954 };
955 if right_sender.send_owned(recv.schemas, recv.msg).await.is_err() {
956 break;
957 }
958 }
959 recv = right_rx.recv() => {
960 let Some(recv) = recv else {
961 break;
962 };
963 if left_sender.send_owned(recv.schemas, recv.msg).await.is_err() {
964 break;
965 }
966 }
967 }
968 }
969
970 if let Some(tx) = left_control_tx.as_ref() {
971 let _ = send_drop_control(tx, DropControlRequest::Close(left_conn_id));
972 }
973 if let Some(tx) = right_control_tx.as_ref() {
974 let _ = send_drop_control(tx, DropControlRequest::Close(right_conn_id));
975 }
976 Ok(())
977}
978
979#[derive(Debug)]
981pub enum SessionError {
982 Io(std::io::Error),
983 Protocol(String),
984 Rejected(Metadata<'static>),
985 NotResumable,
986 ConnectTimeout,
987}
988
989impl SessionError {
990 pub fn is_retryable(&self) -> bool {
996 matches!(
997 self,
998 Self::Io(_) | Self::ConnectTimeout | Self::NotResumable
999 )
1000 }
1001}
1002
1003impl std::fmt::Display for SessionError {
1004 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1005 match self {
1006 Self::Io(e) => write!(f, "io error: {e}"),
1007 Self::Protocol(msg) => write!(f, "protocol error: {msg}"),
1008 Self::Rejected(_) => write!(f, "connection rejected"),
1009 Self::NotResumable => write!(f, "session is not resumable"),
1010 Self::ConnectTimeout => write!(f, "connect timeout"),
1011 }
1012 }
1013}
1014
1015impl std::error::Error for SessionError {}
1016
1017fn classify_session_recv_error(error: &std::io::Error) -> ConnectionCloseReason {
1018 let message = error.to_string();
1019 if message.contains("decode error") || message.contains("protocol") {
1020 ConnectionCloseReason::Protocol
1021 } else {
1022 ConnectionCloseReason::Transport
1023 }
1024}
1025
1026fn classify_decode_error(error: &std::io::Error) -> Option<DecodeErrorKind> {
1027 let message = error.to_string();
1028 if message.contains("decode error") {
1029 Some(DecodeErrorKind::Payload)
1030 } else {
1031 None
1032 }
1033}
1034
1035impl Session {
1036 fn observe_session_recv_error(&self, error: &std::io::Error) {
1039 let Some(observer) = &self.observer else {
1040 return;
1041 };
1042
1043 if let Some(kind) = classify_decode_error(error) {
1044 for conn_id in self.conns.iter().filter_map(|(conn_id, slot)| {
1045 matches!(slot, ConnectionSlot::Active(_)).then_some(*conn_id)
1046 }) {
1047 observer.driver_event(vox_types::DriverEvent::DecodeError {
1048 connection_id: conn_id,
1049 kind,
1050 });
1051 }
1052 return;
1053 }
1054
1055 observer.transport_event(vox_types::TransportEvent::Closed {
1056 connection_id: None,
1057 reason: classify_session_recv_error(error),
1058 });
1059 }
1060
1061 fn close_connection_for_protocol_error(
1062 &mut self,
1063 conn_id: ConnectionId,
1064 detail: impl std::fmt::Display,
1065 ) {
1066 warn!(%conn_id, "closing connection after protocol error: {detail}");
1067 self.remove_connection_with_reason(&conn_id, ConnectionCloseReason::Protocol);
1068 self.maybe_request_shutdown_after_root_closed();
1069 }
1070
1071 fn record_received_schema_cbor(
1072 &mut self,
1073 conn_id: ConnectionId,
1074 schema_recv_tracker: Arc<vox_types::SchemaRecvTracker>,
1075 method_id: vox_types::MethodId,
1076 direction: vox_types::BindingDirection,
1077 schemas_cbor: &vox_types::CborPayload,
1078 context: &str,
1079 ) -> bool {
1080 let payload = match vox_types::SchemaPayload::from_cbor(&schemas_cbor.0) {
1081 Ok(payload) => payload,
1082 Err(error) => {
1083 self.close_connection_for_protocol_error(
1084 conn_id,
1085 format!("{context}: invalid schema CBOR: {error}"),
1086 );
1087 return false;
1088 }
1089 };
1090
1091 if let Err(error) = schema_recv_tracker.record_received(method_id, direction, payload) {
1092 self.close_connection_for_protocol_error(conn_id, format!("{context}: {error}"));
1093 return false;
1094 }
1095
1096 true
1097 }
1098
1099 #[allow(clippy::too_many_arguments)]
1100 fn pre_handshake<Tx, Rx>(
1101 tx: Tx,
1102 rx: Rx,
1103 on_connection: Option<Arc<dyn ConnectionAcceptor>>,
1104 open_rx: mpsc::Receiver<OpenRequest>,
1105 close_rx: mpsc::Receiver<CloseRequest>,
1106 resume_rx: mpsc::Receiver<ResumeRequest>,
1107 control_tx: mpsc::UnboundedSender<DropControlRequest>,
1108 control_rx: mpsc::UnboundedReceiver<DropControlRequest>,
1109 keepalive: Option<SessionKeepaliveConfig>,
1110 resumable: bool,
1111 recoverer: Option<Box<dyn ConduitRecoverer>>,
1112 recovery_timeout: Option<Duration>,
1113 observer: Option<VoxObserverHandle>,
1114 ) -> Self
1115 where
1116 Tx: ConduitTx<Msg = MessageFamily> + MaybeSend + MaybeSync + 'static,
1117 Rx: ConduitRx<Msg = MessageFamily> + MaybeSend + 'static,
1118 {
1119 let (outbound_tx, outbound_rx) = tokio_mpsc::channel(256);
1120 let sess_core = Arc::new(SessionCore {
1121 inner: std::sync::Mutex::new(SessionCoreInner {
1122 tx: Arc::new(tx) as Arc<dyn DynConduitTx>,
1123 conns: HashMap::new(),
1124 }),
1125 outbound_tx,
1126 observer: observer.clone(),
1127 });
1128 spawn_outbound_worker(outbound_rx);
1129 let (resume_notifier, _resume_rx) = watch::channel(0_u64);
1130 Session {
1131 rx: Box::new(rx),
1132 role: SessionRole::Initiator, parity: Parity::Odd, sess_core,
1135 peer_supports_retry: false,
1136 local_root_settings: ConnectionSettings {
1137 parity: Parity::Odd,
1138 max_concurrent_requests: 64,
1139 initial_channel_credit: 16,
1140 },
1141 peer_root_settings: None,
1142 resumable,
1143 session_resume_key: None,
1144 conns: BTreeMap::new(),
1145 root_closed_internal: false,
1146 conn_ids: IdAllocator::new(Parity::Odd), on_connection,
1148 open_rx,
1149 close_rx,
1150 resume_rx,
1151 control_tx,
1152 control_rx,
1153 keepalive,
1154 resume_notifier,
1155 recoverer,
1156 recovery_timeout,
1157 registered_in_registry: false,
1158 observer,
1159 }
1160 }
1161
1162 pub(crate) fn resume_key(&self) -> Option<SessionResumeKey> {
1163 self.session_resume_key
1164 }
1165
1166 fn establish_from_handshake(
1168 &mut self,
1169 result: HandshakeResult,
1170 ) -> Result<ConnectionHandle, SessionError> {
1171 self.role = result.role;
1172 self.parity = result.our_settings.parity;
1173 self.conn_ids = IdAllocator::new(result.our_settings.parity);
1174 self.local_root_settings = result.our_settings.clone();
1175 self.peer_root_settings = Some(result.peer_settings.clone());
1176 self.peer_supports_retry = result.peer_supports_retry;
1177 self.session_resume_key = result.session_resume_key;
1178
1179 if self.resumable && self.session_resume_key.is_none() {
1180 return Err(SessionError::NotResumable);
1181 }
1182
1183 Ok(self.make_root_handle(result.our_settings, result.peer_settings))
1184 }
1185
1186 fn make_root_handle(
1187 &mut self,
1188 local_settings: ConnectionSettings,
1189 peer_settings: ConnectionSettings,
1190 ) -> ConnectionHandle {
1191 self.make_connection_handle(ConnectionId::ROOT, local_settings, peer_settings)
1192 }
1193
1194 fn make_connection_handle(
1195 &mut self,
1196 conn_id: ConnectionId,
1197 local_settings: ConnectionSettings,
1198 peer_settings: ConnectionSettings,
1199 ) -> ConnectionHandle {
1200 let label = format!("session.conn{}", conn_id.0);
1201 let (conn_tx, conn_rx) = mpsc::channel::<RecvMessage>(&label, 64);
1202 let (failures_tx, failures_rx) = mpsc::unbounded_channel(format!("{label}.failures"));
1203 let (closed_tx, closed_rx) = watch::channel(None);
1204 let resumed_rx = self.resume_notifier.subscribe();
1205
1206 let sender = ConnectionSender {
1207 connection_id: conn_id,
1208 sess_core: Arc::clone(&self.sess_core),
1209 failures: Arc::new(failures_tx),
1210 };
1211
1212 let parity = local_settings.parity;
1213 let handle_local_settings = local_settings.clone();
1214 let handle_peer_settings = peer_settings.clone();
1215 trace!(%conn_id, "make_connection_handle: inserting slot into conns");
1216 if let Some(observer) = &self.observer {
1217 observer.driver_event(vox_types::DriverEvent::ConnectionOpened {
1218 connection_id: conn_id,
1219 });
1220 }
1221 self.conns.insert(
1222 conn_id,
1223 ConnectionSlot::Active(ConnectionState {
1224 id: conn_id,
1225 local_settings,
1226 peer_settings,
1227 conn_tx,
1228 closed_tx,
1229 schema_recv_tracker: Arc::new(vox_types::SchemaRecvTracker::new()),
1230 }),
1231 );
1232
1233 ConnectionHandle {
1234 sender,
1235 rx: conn_rx,
1236 failures_rx,
1237 control_tx: Some(self.control_tx.clone()),
1238 closed_rx,
1239 resumed_rx,
1240 local_settings: handle_local_settings,
1241 peer_settings: handle_peer_settings,
1242 parity,
1243 peer_supports_retry: self.peer_supports_retry,
1244 observer: self.observer.clone(),
1245 }
1246 }
1247
1248 pub async fn run(&mut self) {
1253 let mut keepalive_runtime = self.make_keepalive_runtime();
1254 let mut keepalive_tick = keepalive_runtime.as_ref().map(|_| {
1255 let mut interval = vox_types::time::tokio::interval(Duration::from_millis(10));
1256 interval.set_missed_tick_behavior(vox_types::time::tokio::MissedTickBehavior::Delay);
1257 interval
1258 });
1259
1260 loop {
1261 tokio::select! {
1262 biased;
1268
1269 msg = self.rx.recv_msg() => {
1270 vox_types::dlog!("[session {:?}] recv_msg returned", self.role);
1271 match msg {
1272 Ok(Some(msg)) => {
1273 let fds = self.rx.take_frame_fds();
1276 self.handle_message(msg, fds, &mut keepalive_runtime).await;
1277 }
1278 Ok(None) => {
1279 vox_types::dlog!("[session {:?}] recv loop: conduit returned EOF", self.role);
1280 if !self.handle_conduit_break(&mut keepalive_runtime).await {
1281 vox_types::dlog!("[session {:?}] recv loop: breaking (not resumable)", self.role);
1282 self.close_all_connections(ConnectionCloseReason::Remote);
1283 break;
1284 }
1285 }
1286 Err(error) => {
1287 let close_reason = classify_session_recv_error(&error);
1288 self.observe_session_recv_error(&error);
1289 warn!(
1290 role = ?self.role,
1291 %error,
1292 ?close_reason,
1293 "session receive failed; closing connections if recovery is unavailable"
1294 );
1295 vox_types::dlog!("[session {:?}] recv loop: conduit recv error: {}", self.role, error);
1296 if !self.handle_conduit_break(&mut keepalive_runtime).await {
1297 vox_types::dlog!("[session {:?}] recv loop: breaking (not resumable)", self.role);
1298 self.close_all_connections(close_reason);
1299 break;
1300 }
1301 }
1302 }
1303 }
1304 Some(req) = self.open_rx.recv() => {
1305 self.handle_open_request(req).await;
1306 }
1307 Some(req) = self.close_rx.recv() => {
1308 self.handle_close_request(req).await;
1309 }
1310 Some(req) = self.resume_rx.recv() => {
1311 let _ = req.result_tx.send(Err(SessionError::Protocol(
1312 "resume is only valid while the session is disconnected".into(),
1313 )));
1314 }
1315 Some(req) = self.control_rx.recv() => {
1316 if !self.handle_drop_control_request(req).await {
1317 self.close_all_connections(ConnectionCloseReason::Local);
1318 break;
1319 }
1320 }
1321 _ = async {
1322 if let Some(interval) = keepalive_tick.as_mut() {
1323 interval.tick().await;
1324 }
1325 }, if keepalive_tick.is_some() => {
1326 if !self.handle_keepalive_tick(&mut keepalive_runtime).await {
1327 self.close_all_connections(ConnectionCloseReason::Protocol);
1328 break;
1329 }
1330 }
1331 }
1332 }
1333
1334 self.close_all_connections(ConnectionCloseReason::SessionShutdown);
1336 trace!("session recv loop exited");
1337 }
1338
1339 async fn handle_conduit_break(
1340 &mut self,
1341 keepalive_runtime: &mut Option<KeepaliveRuntime>,
1342 ) -> bool {
1343 if let Some(recoverer) = self.recoverer.as_mut() {
1350 let recovery_fut = recoverer.next_conduit(self.session_resume_key.as_ref());
1351 let recovery_result = match self.recovery_timeout {
1352 Some(timeout) => match vox_types::time::tokio::timeout(timeout, recovery_fut).await
1353 {
1354 Ok(r) => r,
1355 Err(_) => return false,
1356 },
1357 None => recovery_fut.await,
1358 };
1359 match recovery_result {
1360 Ok(recovered) => {
1361 let result =
1362 self.resume_from_handshake(recovered.tx, recovered.rx, recovered.handshake);
1363 match result {
1364 Ok(()) => {
1365 let next_generation = self.resume_notifier.borrow().wrapping_add(1);
1366 let _ = self.resume_notifier.send(next_generation);
1367 *keepalive_runtime = self.make_keepalive_runtime();
1368 return true;
1369 }
1370 Err(_) => return false,
1371 }
1372 }
1373 Err(_) => return false,
1374 }
1375 }
1376
1377 if !self.registered_in_registry {
1378 return false;
1379 }
1380
1381 loop {
1382 tokio::select! {
1383 Some(req) = self.resume_rx.recv() => {
1384 let result =
1385 self.resume_from_handshake(req.tx, req.rx, req.handshake_result);
1386 let ok = result.is_ok();
1387 let _ = req.result_tx.send(result);
1388 if ok {
1389 let next_generation = self.resume_notifier.borrow().wrapping_add(1);
1390 let _ = self.resume_notifier.send(next_generation);
1391 *keepalive_runtime = self.make_keepalive_runtime();
1392 return true;
1393 }
1394 }
1395 Some(req) = self.control_rx.recv() => {
1396 if !self.handle_drop_control_request(req).await {
1397 return false;
1398 }
1399 }
1400 Some(req) = self.open_rx.recv() => {
1401 let _ = req.result_tx.send(Err(SessionError::Protocol(
1402 "session is disconnected; resume before opening connections".into(),
1403 )));
1404 }
1405 Some(req) = self.close_rx.recv() => {
1406 let _ = req.result_tx.send(Err(SessionError::Protocol(
1407 "session is disconnected; resume before closing connections".into(),
1408 )));
1409 }
1410 else => return false,
1411 }
1412 }
1413 }
1414
1415 fn resume_from_handshake(
1417 &mut self,
1418 tx: Arc<dyn DynConduitTx>,
1419 rx: Box<dyn DynConduitRx>,
1420 result: HandshakeResult,
1421 ) -> Result<(), SessionError> {
1422 let Some(peer_settings) = self.peer_root_settings.clone() else {
1423 return Err(SessionError::Protocol("missing peer root settings".into()));
1424 };
1425
1426 if result.our_settings != self.local_root_settings {
1427 return Err(SessionError::Protocol(
1428 "local root settings changed across session resume".into(),
1429 ));
1430 }
1431
1432 if result.peer_settings != peer_settings {
1433 return Err(SessionError::Protocol(
1434 "peer root settings changed across session resume".into(),
1435 ));
1436 }
1437
1438 self.peer_supports_retry = result.peer_supports_retry;
1439 self.session_resume_key = result.session_resume_key.or(self.session_resume_key);
1440
1441 self.sess_core.replace_tx_and_reset_schemas(tx);
1442 self.rx = rx;
1443 if let Some(ConnectionSlot::Active(state)) = self.conns.get_mut(&ConnectionId::ROOT) {
1446 state.schema_recv_tracker = Arc::new(vox_types::SchemaRecvTracker::new());
1447 }
1448 Ok(())
1449 }
1450
1451 async fn handle_message(
1452 &mut self,
1453 msg: SelfRef<Message<'static>>,
1454 fds: vox_types::FrameFds,
1455 keepalive_runtime: &mut Option<KeepaliveRuntime>,
1456 ) {
1457 let msg_ref = msg.get();
1458 let conn_id = msg_ref.connection_id;
1459 match &msg_ref.payload {
1460 MessagePayload::Ping(ping) => {
1461 let _ = self
1462 .sess_core
1463 .send(
1464 Message {
1465 connection_id: conn_id,
1466 payload: MessagePayload::Pong(vox_types::Pong { nonce: ping.nonce }),
1467 },
1468 None,
1469 None,
1470 )
1471 .await;
1472 return;
1473 }
1474 MessagePayload::Pong(pong) => {
1475 if conn_id.is_root() {
1476 self.handle_keepalive_pong(pong.nonce, keepalive_runtime);
1477 }
1478 return;
1479 }
1480 MessagePayload::SchemaMessage(schema_msg) => {
1481 let schema_recv_tracker = match self.conns.get(&conn_id) {
1482 Some(ConnectionSlot::Active(state)) => Arc::clone(&state.schema_recv_tracker),
1483 _ => return,
1484 };
1485 let _ = self.record_received_schema_cbor(
1486 conn_id,
1487 schema_recv_tracker,
1488 schema_msg.method_id,
1489 schema_msg.direction,
1490 &schema_msg.schemas,
1491 "standalone schema message",
1492 );
1493 return;
1494 }
1495 _ => {}
1496 }
1497 vox_types::selfref_match!(msg, payload {
1498 MessagePayload::ConnectionClose(_) => {
1500 if conn_id.is_root() {
1501 warn!("received ConnectionClose for root connection");
1502 } else {
1503 trace!(conn_id = conn_id.0, "received ConnectionClose for virtual connection");
1504 }
1505 self.remove_connection_with_reason(&conn_id, ConnectionCloseReason::Remote);
1509 self.maybe_request_shutdown_after_root_closed();
1510 }
1511 MessagePayload::ConnectionOpen(open) => {
1512 self.handle_inbound_open(conn_id, open).await;
1513 }
1514 MessagePayload::ConnectionAccept(accept) => {
1515 self.handle_inbound_accept(conn_id, accept);
1516 }
1517 MessagePayload::ConnectionReject(reject) => {
1518 self.handle_inbound_reject(conn_id, reject);
1519 }
1520 MessagePayload::RequestMessage(r) => {
1521 let r_ref = r.get();
1522 vox_types::dlog!(
1523 "[session {:?}] recv request: conn={:?} req={:?} body={} method={:?}",
1524 self.role,
1525 conn_id,
1526 r_ref.id,
1527 match &r_ref.body {
1528 RequestBody::Call(_) => "Call",
1529 RequestBody::Response(_) => "Response",
1530 RequestBody::Cancel(_) => "Cancel",
1531 },
1532 match &r_ref.body {
1533 RequestBody::Call(call) => Some(call.method_id),
1534 RequestBody::Response(_) | RequestBody::Cancel(_) => None,
1535 }
1536 );
1537 let response_had_schema_payload = matches!(&r_ref.body, RequestBody::Response(resp) if !resp.schemas.is_empty());
1539 {
1540 let schemas_cbor = match &r_ref.body {
1541 RequestBody::Call(call) => Some(&call.schemas),
1542 RequestBody::Response(resp) => Some(&resp.schemas),
1543 _ => None,
1544 };
1545 vox_types::dlog!(
1546 "[schema] recv ({:?}): req={:?} body={} schemas_len={:?}",
1547 self.role,
1548 r_ref.id,
1549 match &r_ref.body {
1550 RequestBody::Call(_) => "Call",
1551 RequestBody::Response(_) => "Response",
1552 RequestBody::Cancel(_) => "Cancel",
1553 },
1554 schemas_cbor.map(|s| s.0.len())
1555 );
1556 let schema_recv_tracker = match self.conns.get(&conn_id) {
1557 Some(ConnectionSlot::Active(state)) => {
1558 Arc::clone(&state.schema_recv_tracker)
1559 }
1560 _ => return,
1561 };
1562 if let Some(schemas_cbor) = schemas_cbor
1563 && !schemas_cbor.is_empty()
1564 {
1565 let (method_id, direction) = match &r_ref.body {
1566 RequestBody::Call(call) => {
1567 (call.method_id, vox_types::BindingDirection::Args)
1568 }
1569 RequestBody::Response(_) => {
1570 let Some(method_id) =
1571 self.sess_core.take_outgoing_call_method(conn_id, r_ref.id)
1572 else {
1573 self.close_connection_for_protocol_error(
1574 conn_id,
1575 format!(
1576 "response schemas for unknown inflight request {:?}",
1577 r_ref.id
1578 ),
1579 );
1580 return;
1581 };
1582 (method_id, vox_types::BindingDirection::Response)
1583 }
1584 RequestBody::Cancel(_) => unreachable!(),
1585 };
1586 if !self.record_received_schema_cbor(
1587 conn_id,
1588 schema_recv_tracker,
1589 method_id,
1590 direction,
1591 schemas_cbor,
1592 "inlined request schemas",
1593 ) {
1594 return;
1595 }
1596 }
1597 }
1598 if matches!(&r_ref.body, RequestBody::Response(_)) && !response_had_schema_payload {
1599 let _ = self.sess_core.take_outgoing_call_method(conn_id, r_ref.id);
1600 }
1601 if let RequestBody::Call(call) = &r_ref.body {
1604 self.sess_core.record_incoming_call(conn_id, r_ref.id, call.method_id);
1605 }
1606 let state = match self.conns.get(&conn_id) {
1607 Some(ConnectionSlot::Active(state)) => state,
1608 _ => return,
1609 };
1610 let conn_tx = state.conn_tx.clone();
1611 let request_id = r_ref.id;
1612 let body_kind = match &r_ref.body {
1613 RequestBody::Call(_) => "Call",
1614 RequestBody::Response(_) => "Response",
1615 RequestBody::Cancel(_) => "Cancel",
1616 };
1617 let recv_msg = RecvMessage {
1618 schemas: Arc::clone(&state.schema_recv_tracker),
1619 msg: r.map(ConnectionMessage::Request),
1620 fds,
1621 };
1622 vox_types::dlog!(
1623 "[session {:?}] dispatch request: conn={:?} req={:?} body={}",
1624 self.role,
1625 conn_id,
1626 request_id,
1627 body_kind
1628 );
1629 if conn_tx.send(recv_msg).await.is_err() {
1630 self.remove_connection_with_reason(&conn_id, ConnectionCloseReason::Unknown);
1631 self.maybe_request_shutdown_after_root_closed();
1632 }
1633 }
1634 MessagePayload::ChannelMessage(c) => {
1635 let state = match self.conns.get(&conn_id) {
1636 Some(ConnectionSlot::Active(state)) => state,
1637 _ => return,
1638 };
1639 let conn_tx = state.conn_tx.clone();
1640 let recv_msg = RecvMessage {
1641 schemas: Arc::clone(&state.schema_recv_tracker),
1642 msg: c.map(ConnectionMessage::Channel),
1643 fds,
1644 };
1645 if conn_tx.send(recv_msg).await.is_err() {
1646 self.remove_connection_with_reason(&conn_id, ConnectionCloseReason::Unknown);
1647 self.maybe_request_shutdown_after_root_closed();
1648 }
1649 }
1650 })
1652 }
1653
1654 fn make_keepalive_runtime(&self) -> Option<KeepaliveRuntime> {
1655 let config = self.keepalive?;
1656 if config.ping_interval.is_zero() || config.pong_timeout.is_zero() {
1657 warn!("keepalive disabled due to non-positive interval/timeout");
1658 return None;
1659 }
1660 let now = vox_types::time::tokio::Instant::now();
1661 Some(KeepaliveRuntime {
1662 ping_interval: config.ping_interval,
1663 pong_timeout: config.pong_timeout,
1664 next_ping_at: now + config.ping_interval,
1665 waiting_pong_nonce: None,
1666 pong_deadline: now,
1667 next_ping_nonce: 1,
1668 })
1669 }
1670
1671 fn handle_keepalive_pong(&self, nonce: u64, keepalive_runtime: &mut Option<KeepaliveRuntime>) {
1672 let Some(runtime) = keepalive_runtime.as_mut() else {
1673 return;
1674 };
1675 if runtime.waiting_pong_nonce != Some(nonce) {
1676 return;
1677 }
1678 runtime.waiting_pong_nonce = None;
1679 runtime.next_ping_at = vox_types::time::tokio::Instant::now() + runtime.ping_interval;
1680 }
1681
1682 async fn handle_keepalive_tick(
1683 &mut self,
1684 keepalive_runtime: &mut Option<KeepaliveRuntime>,
1685 ) -> bool {
1686 let Some(runtime) = keepalive_runtime.as_mut() else {
1687 return true;
1688 };
1689 let now = vox_types::time::tokio::Instant::now();
1690
1691 if let Some(waiting_nonce) = runtime.waiting_pong_nonce {
1692 if now >= runtime.pong_deadline {
1693 warn!(
1694 nonce = waiting_nonce,
1695 timeout_ms = runtime.pong_timeout.as_millis(),
1696 "keepalive timeout waiting for pong"
1697 );
1698 return false;
1699 }
1700 return true;
1701 }
1702
1703 if now < runtime.next_ping_at {
1704 return true;
1705 }
1706
1707 let nonce = runtime.next_ping_nonce;
1708 if self
1709 .sess_core
1710 .send(
1711 Message {
1712 connection_id: ConnectionId::ROOT,
1713 payload: MessagePayload::Ping(vox_types::Ping { nonce }),
1714 },
1715 None,
1716 None,
1717 )
1718 .await
1719 .is_err()
1720 {
1721 warn!("failed to send keepalive ping");
1722 return false;
1723 }
1724
1725 runtime.waiting_pong_nonce = Some(nonce);
1726 runtime.pong_deadline = now + runtime.pong_timeout;
1727 runtime.next_ping_at = now + runtime.ping_interval;
1728 runtime.next_ping_nonce = runtime.next_ping_nonce.wrapping_add(1);
1729 true
1730 }
1731
1732 async fn handle_inbound_open(
1733 &mut self,
1734 conn_id: ConnectionId,
1735 open: SelfRef<ConnectionOpen<'static>>,
1736 ) {
1737 let peer_parity = self.parity.other();
1739 if !conn_id.has_parity(peer_parity) {
1740 let _ = self
1742 .sess_core
1743 .send(
1744 Message {
1745 connection_id: conn_id,
1746 payload: MessagePayload::ConnectionReject(vox_types::ConnectionReject {
1747 metadata: vec![],
1748 }),
1749 },
1750 None,
1751 None,
1752 )
1753 .await;
1754 return;
1755 }
1756
1757 if self.conns.contains_key(&conn_id) {
1759 let _ = self
1761 .sess_core
1762 .send(
1763 Message {
1764 connection_id: conn_id,
1765 payload: MessagePayload::ConnectionReject(vox_types::ConnectionReject {
1766 metadata: vec![],
1767 }),
1768 },
1769 None,
1770 None,
1771 )
1772 .await;
1773 return;
1774 }
1775
1776 if self.on_connection.is_none() {
1779 let _ = self
1780 .sess_core
1781 .send(
1782 Message {
1783 connection_id: conn_id,
1784 payload: MessagePayload::ConnectionReject(vox_types::ConnectionReject {
1785 metadata: vec![],
1786 }),
1787 },
1788 None,
1789 None,
1790 )
1791 .await;
1792 return;
1793 }
1794
1795 let open = open.get();
1797 if open.connection_settings.initial_channel_credit == 0 {
1798 let _ = self
1799 .sess_core
1800 .send(
1801 Message {
1802 connection_id: conn_id,
1803 payload: MessagePayload::ConnectionReject(vox_types::ConnectionReject {
1804 metadata: vec![vox_types::MetadataEntry::str(
1805 "error",
1806 "initial_channel_credit must be greater than zero",
1807 )],
1808 }),
1809 },
1810 None,
1811 None,
1812 )
1813 .await;
1814 return;
1815 }
1816
1817 let our_settings = ConnectionSettings {
1818 parity: open.connection_settings.parity.other(),
1819 max_concurrent_requests: open.connection_settings.max_concurrent_requests,
1820 initial_channel_credit: open.connection_settings.initial_channel_credit,
1821 };
1822
1823 let handle = self.make_connection_handle(
1825 conn_id,
1826 our_settings.clone(),
1827 open.connection_settings.clone(),
1828 );
1829
1830 let mut metadata: Vec<vox_types::MetadataEntry<'_>> = open.metadata.to_vec();
1832 metadata.push(vox_types::MetadataEntry::str(
1833 "vox-connection-kind",
1834 "virtual",
1835 ));
1836 let request = match ConnectionRequest::new(&metadata) {
1837 Ok(r) => r,
1838 Err(e) => {
1839 trace!(%conn_id, %e, "rejecting virtual connection");
1840 self.conns.remove(&conn_id);
1841 let _ = self
1842 .sess_core
1843 .send(
1844 Message {
1845 connection_id: conn_id,
1846 payload: MessagePayload::ConnectionReject(
1847 vox_types::ConnectionReject {
1848 metadata: vec![vox_types::MetadataEntry::str(
1849 "error",
1850 e.to_string(),
1851 )],
1852 },
1853 ),
1854 },
1855 None,
1856 None,
1857 )
1858 .await;
1859 return;
1860 }
1861 };
1862 let pending = PendingConnection::new(handle);
1863 let acceptor = self.on_connection.as_ref().unwrap();
1864 trace!(%conn_id, "calling acceptor for virtual connection");
1865 match acceptor.accept(&request, pending) {
1866 Ok(()) => {
1867 trace!(%conn_id, "acceptor accepted virtual connection, sending ConnectionAccept");
1868 let _ = self
1869 .sess_core
1870 .send(
1871 Message {
1872 connection_id: conn_id,
1873 payload: MessagePayload::ConnectionAccept(
1874 vox_types::ConnectionAccept {
1875 connection_settings: our_settings,
1876 metadata: vec![],
1877 },
1878 ),
1879 },
1880 None,
1881 None,
1882 )
1883 .await;
1884 }
1885 Err(reject_metadata) => {
1886 trace!(%conn_id, "acceptor rejected, removing conn slot");
1888 self.conns.remove(&conn_id);
1889 let _ = self
1890 .sess_core
1891 .send(
1892 Message {
1893 connection_id: conn_id,
1894 payload: MessagePayload::ConnectionReject(
1895 vox_types::ConnectionReject {
1896 metadata: reject_metadata,
1897 },
1898 ),
1899 },
1900 None,
1901 None,
1902 )
1903 .await;
1904 }
1905 }
1906 }
1907
1908 fn handle_inbound_accept(
1909 &mut self,
1910 conn_id: ConnectionId,
1911 accept: SelfRef<ConnectionAccept<'static>>,
1912 ) {
1913 let accept = accept.get();
1914 let slot = self.remove_connection(&conn_id);
1915 match slot {
1916 Some(ConnectionSlot::PendingOutbound(mut pending))
1917 if accept.connection_settings.initial_channel_credit == 0 =>
1918 {
1919 if let Some(tx) = pending.result_tx.take() {
1920 let _ = tx.send(Err(SessionError::Protocol(
1921 "initial_channel_credit must be greater than zero".into(),
1922 )));
1923 }
1924 }
1925 Some(ConnectionSlot::PendingOutbound(mut pending)) => {
1926 let handle = self.make_connection_handle(
1927 conn_id,
1928 pending.local_settings.clone(),
1929 accept.connection_settings.clone(),
1930 );
1931
1932 if let Some(tx) = pending.result_tx.take() {
1933 let _ = tx.send(Ok(handle));
1934 }
1935 }
1936 Some(other) => {
1937 self.conns.insert(conn_id, other);
1939 }
1940 None => {
1941 }
1943 }
1944 }
1945
1946 fn handle_inbound_reject(
1947 &mut self,
1948 conn_id: ConnectionId,
1949 reject: SelfRef<ConnectionReject<'static>>,
1950 ) {
1951 let reject = reject.get();
1952 let slot = self.remove_connection(&conn_id);
1953 match slot {
1954 Some(ConnectionSlot::PendingOutbound(mut pending)) => {
1955 if let Some(tx) = pending.result_tx.take() {
1956 let _ = tx.send(Err(SessionError::Rejected(vox_types::metadata_into_owned(
1957 reject.metadata.to_vec(),
1958 ))));
1959 }
1960 }
1961 Some(other) => {
1962 self.conns.insert(conn_id, other);
1963 }
1964 None => {}
1965 }
1966 }
1967
1968 async fn handle_open_request(&mut self, req: OpenRequest) {
1970 if req.settings.initial_channel_credit == 0 {
1971 let _ = req.result_tx.send(Err(SessionError::Protocol(
1972 "initial_channel_credit must be greater than zero".into(),
1973 )));
1974 return;
1975 }
1976
1977 let conn_id = self.conn_ids.alloc();
1978
1979 let send_result = self
1981 .sess_core
1982 .send(
1983 Message {
1984 connection_id: conn_id,
1985 payload: MessagePayload::ConnectionOpen(ConnectionOpen {
1986 connection_settings: req.settings.clone(),
1987 metadata: req.metadata,
1988 }),
1989 },
1990 None,
1991 None,
1992 )
1993 .await;
1994
1995 if send_result.is_err() {
1996 let _ = req.result_tx.send(Err(SessionError::Protocol(
1997 "failed to send ConnectionOpen".into(),
1998 )));
1999 return;
2000 }
2001
2002 self.conns.insert(
2005 conn_id,
2006 ConnectionSlot::PendingOutbound(PendingOutboundData {
2007 local_settings: req.settings,
2008 result_tx: Some(req.result_tx),
2009 }),
2010 );
2011 }
2012
2013 async fn handle_close_request(&mut self, req: CloseRequest) {
2015 if req.conn_id.is_root() {
2016 let _ = req.result_tx.send(Err(SessionError::Protocol(
2017 "cannot close root connection".into(),
2018 )));
2019 return;
2020 }
2021
2022 if self
2025 .remove_connection_with_reason(&req.conn_id, ConnectionCloseReason::Local)
2026 .is_none()
2027 {
2028 let _ = req
2029 .result_tx
2030 .send(Err(SessionError::Protocol("connection not found".into())));
2031 return;
2032 }
2033
2034 let send_result = self
2036 .sess_core
2037 .send(
2038 Message {
2039 connection_id: req.conn_id,
2040 payload: MessagePayload::ConnectionClose(ConnectionClose {
2041 metadata: req.metadata,
2042 }),
2043 },
2044 None,
2045 None,
2046 )
2047 .await;
2048
2049 if send_result.is_err() {
2050 let _ = req.result_tx.send(Err(SessionError::Protocol(
2051 "failed to send ConnectionClose".into(),
2052 )));
2053 return;
2054 }
2055
2056 let _ = req.result_tx.send(Ok(()));
2057 self.maybe_request_shutdown_after_root_closed();
2058 }
2059
2060 async fn handle_drop_control_request(&mut self, req: DropControlRequest) -> bool {
2061 match req {
2062 DropControlRequest::Shutdown => {
2063 trace!("session shutdown requested");
2064 false
2065 }
2066 DropControlRequest::Close(conn_id) => {
2067 if conn_id.is_root() {
2069 trace!("root callers dropped; internally closing root connection");
2071 self.root_closed_internal = true;
2072 return self.has_virtual_connections();
2074 }
2075
2076 if self
2077 .remove_connection_with_reason(&conn_id, ConnectionCloseReason::Local)
2078 .is_some()
2079 {
2080 let _ = self
2081 .sess_core
2082 .send(
2083 Message {
2084 connection_id: conn_id,
2085 payload: MessagePayload::ConnectionClose(ConnectionClose {
2086 metadata: vec![],
2087 }),
2088 },
2089 None,
2090 None,
2091 )
2092 .await;
2093 }
2094
2095 !self.root_closed_internal || self.has_virtual_connections()
2096 }
2097 }
2098 }
2099
2100 fn has_virtual_connections(&self) -> bool {
2101 self.conns.keys().any(|id| !id.is_root())
2102 }
2103
2104 fn remove_connection(&mut self, conn_id: &ConnectionId) -> Option<ConnectionSlot> {
2105 self.remove_connection_with_reason(conn_id, ConnectionCloseReason::Unknown)
2106 }
2107
2108 fn remove_connection_with_reason(
2109 &mut self,
2110 conn_id: &ConnectionId,
2111 reason: ConnectionCloseReason,
2112 ) -> Option<ConnectionSlot> {
2113 trace!(%conn_id, "remove_connection called");
2114 let slot = self.conns.remove(conn_id);
2115 if let Some(ConnectionSlot::Active(state)) = &slot {
2116 let _ = state.closed_tx.send(Some(reason));
2117 if let Some(observer) = &self.observer {
2118 observer.driver_event(vox_types::DriverEvent::ConnectionClosed {
2119 connection_id: *conn_id,
2120 reason,
2121 });
2122 }
2123 }
2124 slot
2125 }
2126
2127 fn close_all_connections(&mut self, reason: ConnectionCloseReason) {
2129 trace!(role = ?self.role, count = self.conns.len(), "close_all_connections");
2130 vox_types::dlog!(
2131 "[session {:?}] close_all_connections: {} slots",
2132 self.role,
2133 self.conns.len()
2134 );
2135 for (conn_id, slot) in self.conns.iter() {
2136 if let ConnectionSlot::Active(state) = slot {
2137 vox_types::dlog!("[session {:?}] closing connection {:?}", self.role, conn_id);
2138 let _ = state.closed_tx.send(Some(reason));
2139 if let Some(observer) = &self.observer {
2140 observer.driver_event(vox_types::DriverEvent::ConnectionClosed {
2141 connection_id: *conn_id,
2142 reason,
2143 });
2144 }
2145 }
2146 }
2147 self.conns.clear();
2148 }
2149
2150 fn maybe_request_shutdown_after_root_closed(&self) {
2151 if self.root_closed_internal && !self.has_virtual_connections() {
2152 let _ = send_drop_control(&self.control_tx, DropControlRequest::Shutdown);
2153 }
2154 }
2155}
2156
2157pub(crate) struct SessionCore {
2158 inner: std::sync::Mutex<SessionCoreInner>,
2159 outbound_tx: tokio_mpsc::Sender<OutboundBatch>,
2160 observer: Option<VoxObserverHandle>,
2161}
2162
2163pub trait OutboundSendFuture: Future<Output = std::io::Result<()>> + MaybeSend + 'static {}
2164impl<T> OutboundSendFuture for T where T: Future<Output = std::io::Result<()>> + MaybeSend + 'static {}
2165
2166type OutboundSend = Pin<Box<dyn OutboundSendFuture>>;
2167
2168#[derive(Clone)]
2169struct PendingSchemaSend {
2170 method_id: vox_types::MethodId,
2171 direction: vox_types::BindingDirection,
2172 prepared: vox_types::PreparedSchemaPlan,
2173}
2174
2175struct OutboundBatch {
2176 conn_id: ConnectionId,
2177 request_id: Option<RequestId>,
2178 payload_kind: &'static str,
2179 conn_state: Arc<std::sync::Mutex<SendConnState>>,
2180 tx: Arc<dyn DynConduitTx>,
2181 schema_sends: Vec<PendingSchemaSend>,
2182 payload_send: OutboundSend,
2183 result_tx: tokio_oneshot::Sender<std::io::Result<()>>,
2184}
2185
2186async fn run_outbound_worker(mut rx: tokio_mpsc::Receiver<OutboundBatch>) {
2187 while let Some(batch) = rx.recv().await {
2188 trace!(
2189 conn_id = %batch.conn_id,
2190 request_id = ?batch.request_id,
2191 payload_kind = batch.payload_kind,
2192 schema_count = batch.schema_sends.len(),
2193 "session outbound worker received batch"
2194 );
2195 let mut result = Ok(());
2196 for schema_send in batch.schema_sends {
2197 trace!(
2198 conn_id = %batch.conn_id,
2199 request_id = ?batch.request_id,
2200 method_id = ?schema_send.method_id,
2201 direction = ?schema_send.direction,
2202 "session outbound worker sending schema batch"
2203 );
2204 let schemas = {
2205 let mut conn_state = batch
2206 .conn_state
2207 .lock()
2208 .expect("send conn state mutex poisoned");
2209 conn_state.send_tracker.preview_prepared_plan(
2210 schema_send.method_id,
2211 schema_send.direction,
2212 &schema_send.prepared,
2213 )
2214 };
2215 if schemas.is_empty() {
2216 continue;
2217 }
2218
2219 let schema_msg = Message {
2220 connection_id: batch.conn_id,
2221 payload: MessagePayload::SchemaMessage(SchemaMessage {
2222 method_id: schema_send.method_id,
2223 direction: schema_send.direction,
2224 schemas,
2225 }),
2226 };
2227 let send = match batch.tx.clone().prepare_msg(schema_msg, None) {
2228 Ok(send) => send,
2229 Err(error) => {
2230 result = Err(error);
2231 break;
2232 }
2233 };
2234 if let Err(error) = send.await {
2235 result = Err(error);
2236 break;
2237 }
2238 let mut conn_state = batch
2239 .conn_state
2240 .lock()
2241 .expect("send conn state mutex poisoned");
2242 conn_state.send_tracker.mark_prepared_plan_sent(
2243 schema_send.method_id,
2244 schema_send.direction,
2245 &schema_send.prepared,
2246 );
2247 conn_state
2248 .planned_bindings
2249 .remove(&(schema_send.direction, schema_send.method_id));
2250 }
2251 if result.is_ok()
2252 && let Err(error) = batch.payload_send.await
2253 {
2254 trace!(
2255 conn_id = %batch.conn_id,
2256 request_id = ?batch.request_id,
2257 payload_kind = batch.payload_kind,
2258 ?error,
2259 "session outbound worker payload send failed"
2260 );
2261 result = Err(error);
2262 }
2263 trace!(
2264 conn_id = %batch.conn_id,
2265 request_id = ?batch.request_id,
2266 payload_kind = batch.payload_kind,
2267 ok = result.is_ok(),
2268 "session outbound worker finished batch"
2269 );
2270 let _ = batch.result_tx.send(result);
2271 }
2272}
2273
2274#[cfg(not(target_arch = "wasm32"))]
2275fn spawn_outbound_worker(rx: tokio_mpsc::Receiver<OutboundBatch>) {
2276 if tokio::runtime::Handle::try_current().is_ok() {
2277 tokio::spawn(run_outbound_worker(rx));
2278 return;
2279 }
2280
2281 std::thread::spawn(move || {
2282 let runtime = tokio::runtime::Builder::new_current_thread()
2283 .enable_all()
2284 .build()
2285 .expect("build outbound worker runtime");
2286 runtime.block_on(run_outbound_worker(rx));
2287 });
2288}
2289
2290#[cfg(target_arch = "wasm32")]
2291fn spawn_outbound_worker(rx: tokio_mpsc::Receiver<OutboundBatch>) {
2292 wasm_bindgen_futures::spawn_local(run_outbound_worker(rx));
2293}
2294
2295struct SendConnState {
2296 send_tracker: vox_types::SchemaSendTracker,
2298
2299 inflight_incoming: HashMap<RequestId, vox_types::MethodId>,
2302
2303 inflight_outgoing: HashMap<RequestId, vox_types::MethodId>,
2306
2307 planned_bindings:
2309 HashMap<(vox_types::BindingDirection, vox_types::MethodId), vox_types::PreparedSchemaPlan>,
2310}
2311
2312impl SendConnState {
2313 fn new() -> Self {
2314 SendConnState {
2315 send_tracker: vox_types::SchemaSendTracker::new(),
2316 inflight_incoming: HashMap::new(),
2317 inflight_outgoing: HashMap::new(),
2318 planned_bindings: HashMap::new(),
2319 }
2320 }
2321}
2322
2323struct SessionCoreInner {
2324 tx: Arc<dyn DynConduitTx>,
2326
2327 conns: HashMap<ConnectionId, Arc<std::sync::Mutex<SendConnState>>>,
2329}
2330
2331fn get_or_create_send_conn_state(
2332 inner: &mut SessionCoreInner,
2333 conn_id: ConnectionId,
2334) -> Arc<std::sync::Mutex<SendConnState>> {
2335 inner
2336 .conns
2337 .entry(conn_id)
2338 .or_insert_with(|| Arc::new(std::sync::Mutex::new(SendConnState::new())))
2339 .clone()
2340}
2341
2342impl SessionCore {
2343 pub(crate) fn outbound_queue_stats(&self) -> (usize, usize) {
2344 let capacity = self.outbound_tx.max_capacity();
2345 let available = self.outbound_tx.capacity();
2346 (capacity.saturating_sub(available), capacity)
2347 }
2348
2349 fn prepare_outbound_batch<'a>(
2350 &self,
2351 mut msg: Message<'a>,
2352 binder: Option<&'a dyn vox_types::ChannelBinder>,
2353 forwarded_schemas: Option<&vox_types::SchemaRecvTracker>,
2354 ) -> Result<(OutboundBatch, tokio_oneshot::Receiver<std::io::Result<()>>), ()> {
2355 let conn_id = msg.connection_id;
2356 let (request_id, payload_kind) = match &msg.payload {
2357 MessagePayload::RequestMessage(req) => {
2358 let kind = match &req.body {
2359 RequestBody::Call(_) => "request.call",
2360 RequestBody::Response(_) => "request.response",
2361 RequestBody::Cancel(_) => "request.cancel",
2362 };
2363 (Some(req.id), kind)
2364 }
2365 MessagePayload::SchemaMessage(_) => (None, "schema"),
2366 MessagePayload::ChannelMessage(_) => (None, "channel"),
2367 MessagePayload::ConnectionOpen(_) => (None, "connection.open"),
2368 MessagePayload::ConnectionAccept(_) => (None, "connection.accept"),
2369 MessagePayload::ConnectionReject(_) => (None, "connection.reject"),
2370 MessagePayload::ConnectionClose(_) => (None, "connection.close"),
2371 MessagePayload::ProtocolError(_) => (None, "protocol.error"),
2372 MessagePayload::Ping(_) => (None, "ping"),
2373 MessagePayload::Pong(_) => (None, "pong"),
2374 };
2375 trace!(
2376 conn_id = %conn_id,
2377 ?request_id,
2378 payload_kind,
2379 "session preparing outbound message"
2380 );
2381 let (tx, conn_state, schema_sends) = {
2382 let mut inner = self.inner.lock().expect("session core mutex poisoned");
2383 let tx = inner.tx.clone();
2384 let conn_state = get_or_create_send_conn_state(&mut inner, conn_id);
2385 drop(inner);
2386
2387 if let MessagePayload::RequestMessage(req) = &mut msg.payload {
2388 vox_types::dlog!(
2389 "[session-core] send request: conn={:?} req={:?} body={} forwarded={}",
2390 conn_id,
2391 req.id,
2392 match &req.body {
2393 RequestBody::Call(_) => "Call",
2394 RequestBody::Response(_) => "Response",
2395 RequestBody::Cancel(_) => "Cancel",
2396 },
2397 forwarded_schemas.is_some()
2398 );
2399 let schema_sends = {
2400 let mut conn_state_guard =
2401 conn_state.lock().expect("send conn state mutex poisoned");
2402 let mut schema_sends = Vec::new();
2403 match &mut req.body {
2404 RequestBody::Call(call) => {
2405 if let Some(schema_send) = Self::plan_call_schema_send(
2406 &mut conn_state_guard,
2407 req.id,
2408 call.method_id,
2409 call,
2410 forwarded_schemas,
2411 ) {
2412 schema_sends.push(schema_send);
2413 }
2414 call.schemas = Default::default();
2415 }
2416 RequestBody::Response(resp) => {
2417 if let Some(method_id) =
2418 conn_state_guard.inflight_incoming.remove(&req.id)
2419 && let Some(schema_send) = Self::plan_response_schema_send(
2420 &mut conn_state_guard,
2421 req.id,
2422 method_id,
2423 resp,
2424 forwarded_schemas,
2425 )
2426 {
2427 schema_sends.push(schema_send);
2428 }
2429 resp.schemas = Default::default();
2430 }
2431 RequestBody::Cancel(_) => {}
2432 }
2433 schema_sends
2434 };
2435 (tx, conn_state, schema_sends)
2436 } else {
2437 (tx, conn_state, Vec::new())
2438 }
2439 };
2440 trace!(
2441 conn_id = %conn_id,
2442 ?request_id,
2443 payload_kind,
2444 schema_count = schema_sends.len(),
2445 "session preparing outbound payload"
2446 );
2447 let payload_send = tx.clone().prepare_msg(msg, binder).map_err(|_| ())?;
2448 trace!(
2449 conn_id = %conn_id,
2450 ?request_id,
2451 payload_kind,
2452 "session prepared outbound payload"
2453 );
2454
2455 let (result_tx, result_rx) = tokio_oneshot::channel();
2456 Ok((
2457 OutboundBatch {
2458 conn_id,
2459 request_id,
2460 payload_kind,
2461 conn_state,
2462 tx,
2463 schema_sends,
2464 payload_send,
2465 result_tx,
2466 },
2467 result_rx,
2468 ))
2469 }
2470
2471 pub(crate) async fn send<'a>(
2473 &self,
2474 msg: Message<'a>,
2475 binder: Option<&'a dyn vox_types::ChannelBinder>,
2476 forwarded_schemas: Option<&vox_types::SchemaRecvTracker>,
2477 ) -> Result<(), ()> {
2478 let connection_id = msg.connection_id;
2479 let (batch, result_rx) = self.prepare_outbound_batch(msg, binder, forwarded_schemas)?;
2480 if self.outbound_tx.send(batch).await.is_err() {
2481 if let Some(observer) = &self.observer {
2482 observer
2483 .driver_event(vox_types::DriverEvent::OutboundQueueClosed { connection_id });
2484 }
2485 return Err(());
2486 }
2487 trace!(conn_id = %connection_id, "session queued outbound batch");
2488 let result = result_rx.await.map_err(|_| ());
2489 trace!(
2490 conn_id = %connection_id,
2491 ok = result.as_ref().map(|inner| inner.is_ok()).unwrap_or(false),
2492 "session outbound batch completed"
2493 );
2494 match result? {
2495 Ok(()) => Ok(()),
2496 Err(_) => {
2497 if let Some(observer) = &self.observer {
2498 observer.driver_event(vox_types::DriverEvent::EncodeError {
2499 connection_id,
2500 kind: vox_types::EncodeErrorKind::Transport,
2501 });
2502 }
2503 Err(())
2504 }
2505 }
2506 }
2507
2508 pub(crate) fn try_send<'a>(
2510 &self,
2511 msg: Message<'a>,
2512 binder: Option<&'a dyn vox_types::ChannelBinder>,
2513 forwarded_schemas: Option<&vox_types::SchemaRecvTracker>,
2514 ) -> Result<(), TrySendError<()>> {
2515 let connection_id = msg.connection_id;
2516 let (batch, _result_rx) = self
2517 .prepare_outbound_batch(msg, binder, forwarded_schemas)
2518 .map_err(|_| TrySendError::Closed(()))?;
2519 self.outbound_tx.try_send(batch).map_err(|err| match err {
2520 tokio_mpsc::error::TrySendError::Full(_) => {
2521 if let Some(observer) = &self.observer {
2522 observer
2523 .driver_event(vox_types::DriverEvent::OutboundQueueFull { connection_id });
2524 }
2525 TrySendError::Full(())
2526 }
2527 tokio_mpsc::error::TrySendError::Closed(_) => {
2528 if let Some(observer) = &self.observer {
2529 observer.driver_event(vox_types::DriverEvent::OutboundQueueClosed {
2530 connection_id,
2531 });
2532 }
2533 TrySendError::Closed(())
2534 }
2535 })
2536 }
2537
2538 pub(crate) fn record_incoming_call(
2541 &self,
2542 conn_id: ConnectionId,
2543 request_id: RequestId,
2544 method_id: vox_types::MethodId,
2545 ) {
2546 let mut inner = self.inner.lock().expect("session core mutex poisoned");
2547 let conn_state = get_or_create_send_conn_state(&mut inner, conn_id);
2548 vox_types::dlog!(
2549 "[schema] record_incoming_call: conn={:?} req={:?} method={:?}",
2550 conn_id,
2551 request_id,
2552 method_id
2553 );
2554 conn_state
2555 .lock()
2556 .expect("send conn state mutex poisoned")
2557 .inflight_incoming
2558 .insert(request_id, method_id);
2559 }
2560
2561 pub(crate) fn take_outgoing_call_method(
2562 &self,
2563 conn_id: ConnectionId,
2564 request_id: RequestId,
2565 ) -> Option<vox_types::MethodId> {
2566 let inner = self.inner.lock().expect("session core mutex poisoned");
2567 inner.conns.get(&conn_id).and_then(|conn_state| {
2568 conn_state
2569 .lock()
2570 .expect("send conn state mutex poisoned")
2571 .inflight_outgoing
2572 .remove(&request_id)
2573 })
2574 }
2575
2576 pub(crate) fn prepare_response_for_method(
2577 &self,
2578 conn_id: ConnectionId,
2579 request_id: RequestId,
2580 method_id: vox_types::MethodId,
2581 response: &mut RequestResponse<'_>,
2582 ) {
2583 let mut inner = self.inner.lock().expect("session core mutex poisoned");
2584 let conn_state = get_or_create_send_conn_state(&mut inner, conn_id);
2585 let mut conn_state = conn_state.lock().expect("send conn state mutex poisoned");
2586 let key = (vox_types::BindingDirection::Response, method_id);
2587 if conn_state
2588 .send_tracker
2589 .has_sent_binding(method_id, vox_types::BindingDirection::Response)
2590 {
2591 response.schemas = Default::default();
2592 return;
2593 }
2594
2595 let prepared = match &response.ret {
2596 vox_types::Payload::Value { shape, .. } => {
2597 match Self::get_or_plan_binding_for_shape(
2598 &mut conn_state,
2599 key,
2600 request_id,
2601 "response",
2602 shape,
2603 ) {
2604 Some(prepared) => prepared,
2605 None => return,
2606 }
2607 }
2608 vox_types::Payload::PostcardBytes(_) => {
2609 tracing::error!(
2610 "schema attachment failed: missing forwarded response schemas for method {:?}",
2611 method_id
2612 );
2613 return;
2614 }
2615 };
2616 response.schemas = prepared.to_cbor();
2617 }
2618
2619 pub(crate) fn prepare_response_from_source(
2621 &self,
2622 conn_id: ConnectionId,
2623 _request_id: RequestId,
2624 method_id: vox_types::MethodId,
2625 root_type: &vox_types::TypeRef,
2626 source: &dyn vox_types::SchemaSource,
2627 response: &mut RequestResponse<'_>,
2628 ) {
2629 let mut inner = self.inner.lock().expect("session core mutex poisoned");
2630 let conn_state = get_or_create_send_conn_state(&mut inner, conn_id);
2631 let mut conn_state = conn_state.lock().expect("send conn state mutex poisoned");
2632 let key = (vox_types::BindingDirection::Response, method_id);
2633 if conn_state
2634 .send_tracker
2635 .has_sent_binding(method_id, vox_types::BindingDirection::Response)
2636 {
2637 response.schemas = Default::default();
2638 return;
2639 }
2640 let prepared =
2641 Self::get_or_plan_binding_from_source(&mut conn_state, key, root_type, source);
2642 response.schemas = prepared.to_cbor();
2643 }
2644
2645 pub(crate) fn prepare_response_from_shape(
2649 &self,
2650 conn_id: ConnectionId,
2651 request_id: RequestId,
2652 method_id: vox_types::MethodId,
2653 response_shape: &'static Shape,
2654 response: &mut RequestResponse<'_>,
2655 ) {
2656 let mut inner = self.inner.lock().expect("session core mutex poisoned");
2657 let conn_state = get_or_create_send_conn_state(&mut inner, conn_id);
2658 let mut conn_state = conn_state.lock().expect("send conn state mutex poisoned");
2659 let key = (vox_types::BindingDirection::Response, method_id);
2660 if conn_state
2661 .send_tracker
2662 .has_sent_binding(method_id, vox_types::BindingDirection::Response)
2663 {
2664 response.schemas = Default::default();
2665 return;
2666 }
2667 let prepared = match Self::get_or_plan_binding_for_shape(
2668 &mut conn_state,
2669 key,
2670 request_id,
2671 "response",
2672 response_shape,
2673 ) {
2674 Some(prepared) => prepared,
2675 None => return,
2676 };
2677 response.schemas = prepared.to_cbor();
2678 }
2679
2680 fn get_or_plan_binding_for_shape(
2681 conn_state: &mut SendConnState,
2682 key: (vox_types::BindingDirection, vox_types::MethodId),
2683 request_id: RequestId,
2684 kind: &str,
2685 shape: &'static Shape,
2686 ) -> Option<vox_types::PreparedSchemaPlan> {
2687 if let Some(prepared) = conn_state.planned_bindings.get(&key) {
2688 return Some(prepared.clone());
2689 }
2690 match vox_types::SchemaSendTracker::plan_for_shape(shape) {
2691 Ok(prepared) => {
2692 vox_types::dlog!(
2693 "[schema] planned {} {} schemas for method {:?} (req {:?})",
2694 prepared.schemas.len(),
2695 kind,
2696 key.1,
2697 request_id
2698 );
2699 conn_state.planned_bindings.insert(key, prepared.clone());
2700 Some(prepared)
2701 }
2702 Err(e) => {
2703 tracing::error!("schema extraction failed: {e}");
2704 None
2705 }
2706 }
2707 }
2708
2709 fn get_or_plan_binding_from_source(
2710 conn_state: &mut SendConnState,
2711 key: (vox_types::BindingDirection, vox_types::MethodId),
2712 root_type: &vox_types::TypeRef,
2713 source: &dyn vox_types::SchemaSource,
2714 ) -> vox_types::PreparedSchemaPlan {
2715 if let Some(prepared) = conn_state.planned_bindings.get(&key) {
2716 return prepared.clone();
2717 }
2718 let prepared = vox_types::SchemaSendTracker::plan_from_source(root_type, source);
2719 conn_state.planned_bindings.insert(key, prepared.clone());
2720 prepared
2721 }
2722
2723 fn plan_response_schema_send(
2724 conn_state: &mut SendConnState,
2725 request_id: RequestId,
2726 method_id: vox_types::MethodId,
2727 response: &mut RequestResponse<'_>,
2728 forwarded_schemas: Option<&vox_types::SchemaRecvTracker>,
2729 ) -> Option<PendingSchemaSend> {
2730 if conn_state
2731 .send_tracker
2732 .has_sent_binding(method_id, vox_types::BindingDirection::Response)
2733 {
2734 response.schemas = Default::default();
2735 return None;
2736 }
2737
2738 let key = (vox_types::BindingDirection::Response, method_id);
2739 let prepared = if !response.schemas.is_empty() {
2740 conn_state
2741 .planned_bindings
2742 .get(&key)
2743 .cloned()
2744 .unwrap_or_else(|| {
2745 let prepared_payload = vox_types::SchemaPayload::from_cbor(&response.schemas.0)
2746 .expect("prepared schema payloads must be valid CBOR");
2747 vox_types::PreparedSchemaPlan {
2748 schemas: prepared_payload.schemas,
2749 root: prepared_payload.root,
2750 }
2751 })
2752 } else {
2753 match &response.ret {
2754 vox_types::Payload::Value { shape, .. } => Self::get_or_plan_binding_for_shape(
2755 conn_state, key, request_id, "response", shape,
2756 )?,
2757 vox_types::Payload::PostcardBytes(_) => {
2758 let Some(source) = forwarded_schemas else {
2759 tracing::error!(
2760 "schema attachment failed: missing forwarded response schemas for method {:?}",
2761 method_id
2762 );
2763 return None;
2764 };
2765 let Some(root) = source.get_remote_response_root(method_id) else {
2766 tracing::error!(
2767 "schema attachment failed: missing forwarded response root for method {:?}",
2768 method_id
2769 );
2770 return None;
2771 };
2772 Self::get_or_plan_binding_from_source(conn_state, key, &root, source)
2773 }
2774 }
2775 };
2776
2777 Some(PendingSchemaSend {
2778 method_id,
2779 direction: vox_types::BindingDirection::Response,
2780 prepared,
2781 })
2782 }
2783
2784 fn plan_call_schema_send(
2785 conn_state: &mut SendConnState,
2786 request_id: RequestId,
2787 method_id: vox_types::MethodId,
2788 call: &mut vox_types::RequestCall<'_>,
2789 forwarded_schemas: Option<&vox_types::SchemaRecvTracker>,
2790 ) -> Option<PendingSchemaSend> {
2791 conn_state.inflight_outgoing.insert(request_id, method_id);
2792 if conn_state
2793 .send_tracker
2794 .has_sent_binding(method_id, vox_types::BindingDirection::Args)
2795 {
2796 call.schemas = Default::default();
2797 return None;
2798 }
2799
2800 let key = (vox_types::BindingDirection::Args, method_id);
2801 let prepared = match &call.args {
2802 vox_types::Payload::Value { shape, .. } => {
2803 Self::get_or_plan_binding_for_shape(conn_state, key, request_id, "args", shape)?
2804 }
2805 vox_types::Payload::PostcardBytes(_) => {
2806 let Some(source) = forwarded_schemas else {
2807 tracing::error!(
2808 "schema attachment failed: missing forwarded args schemas for method {:?}",
2809 method_id
2810 );
2811 return None;
2812 };
2813 let Some(root) = source.get_remote_args_root(method_id) else {
2814 tracing::error!(
2815 "schema attachment failed: missing forwarded args root for method {:?}",
2816 method_id
2817 );
2818 return None;
2819 };
2820 Self::get_or_plan_binding_from_source(conn_state, key, &root, source)
2821 }
2822 };
2823
2824 Some(PendingSchemaSend {
2825 method_id,
2826 direction: vox_types::BindingDirection::Args,
2827 prepared,
2828 })
2829 }
2830
2831 fn replace_tx_and_reset_schemas(&self, tx: Arc<dyn DynConduitTx>) {
2832 let mut inner = self.inner.lock().expect("session core mutex poisoned");
2833 inner.tx = tx;
2834 inner.conns.clear();
2835 }
2836}
2837
2838pub(crate) struct RecoveredConduit {
2839 pub tx: Arc<dyn DynConduitTx>,
2840 pub rx: Box<dyn DynConduitRx>,
2841 pub handshake: HandshakeResult,
2842}
2843
2844pub(crate) trait ConduitRecoverer: MaybeSend {
2845 fn next_conduit<'a>(
2846 &'a mut self,
2847 resume_key: Option<&'a SessionResumeKey>,
2848 ) -> BoxFut<'a, Result<RecoveredConduit, SessionError>>;
2849}
2850
2851pub trait DynConduitTx: MaybeSend + MaybeSync {
2852 fn prepare_msg<'a>(
2853 self: Arc<Self>,
2854 msg: Message<'a>,
2855 binder: Option<&'a dyn vox_types::ChannelBinder>,
2856 ) -> std::io::Result<OutboundSend>;
2857}
2858pub trait DynConduitRx: MaybeSend {
2859 fn recv_msg<'a>(&'a mut self)
2860 -> BoxFut<'a, std::io::Result<Option<SelfRef<Message<'static>>>>>;
2861
2862 fn take_frame_fds(&mut self) -> vox_types::FrameFds;
2865}
2866
2867impl<T> DynConduitTx for T
2870where
2871 T: ConduitTx<Msg = MessageFamily> + MaybeSend + MaybeSync + 'static,
2872{
2873 fn prepare_msg<'a>(
2874 self: Arc<Self>,
2875 msg: Message<'a>,
2876 binder: Option<&'a dyn vox_types::ChannelBinder>,
2877 ) -> std::io::Result<OutboundSend> {
2878 let prepared = if let Some(binder) = binder {
2879 vox_types::with_channel_binder(binder, || self.prepare_send(msg))
2880 } else {
2881 self.prepare_send(msg)
2882 };
2883 let prepared = prepared.map_err(|e| std::io::Error::other(e.to_string()))?;
2884 Ok(Box::pin(async move {
2885 self.send_prepared(prepared)
2886 .await
2887 .map_err(|e| std::io::Error::other(e.to_string()))
2888 }))
2889 }
2890}
2891
2892impl<T> DynConduitRx for T
2893where
2894 T: ConduitRx<Msg = MessageFamily> + MaybeSend,
2895{
2896 fn recv_msg<'a>(
2897 &'a mut self,
2898 ) -> BoxFut<'a, std::io::Result<Option<SelfRef<Message<'static>>>>> {
2899 Box::pin(async move {
2900 self.recv()
2901 .await
2902 .map_err(|error| std::io::Error::other(error.to_string()))
2903 })
2904 }
2905
2906 fn take_frame_fds(&mut self) -> vox_types::FrameFds {
2907 ConduitRx::take_frame_fds(self)
2908 }
2909}
2910
2911#[cfg(test)]
2912mod tests {
2913 use moire::sync::mpsc;
2914 use vox_types::{
2915 Backing, Conduit, ConnectionAccept, ConnectionReject, HandshakeResult, SelfRef,
2916 };
2917
2918 use super::*;
2919
2920 fn make_session() -> Session {
2921 let (a, b) = crate::memory_link_pair(32);
2922 std::mem::forget(b);
2924 let conduit = crate::BareConduit::new(a);
2925 let (tx, rx) = conduit.split();
2926 let (_open_tx, open_rx) = mpsc::channel::<OpenRequest>("session.open.test", 4);
2927 let (_close_tx, close_rx) = mpsc::channel::<CloseRequest>("session.close.test", 4);
2928 let (_resume_tx, resume_rx) = mpsc::channel::<ResumeRequest>("session.resume.test", 1);
2929 let (control_tx, control_rx) = mpsc::unbounded_channel("session.control.test");
2930 Session::pre_handshake(
2931 tx, rx, None, open_rx, close_rx, resume_rx, control_tx, control_rx, None, false, None,
2932 None, None,
2933 )
2934 }
2935
2936 fn resumed_handshake(
2937 our_settings: ConnectionSettings,
2938 peer_settings: ConnectionSettings,
2939 ) -> HandshakeResult {
2940 HandshakeResult {
2941 role: SessionRole::Initiator,
2942 our_settings,
2943 peer_settings,
2944 peer_supports_retry: true,
2945 session_resume_key: Some(SessionResumeKey([7; 16])),
2946 peer_resume_key: None,
2947 our_schema: vec![],
2948 peer_schema: vec![],
2949 peer_metadata: vec![],
2950 }
2951 }
2952
2953 fn accept_ref() -> SelfRef<ConnectionAccept<'static>> {
2954 SelfRef::owning(
2955 Backing::Boxed(Box::<[u8]>::default()),
2956 ConnectionAccept {
2957 connection_settings: ConnectionSettings {
2958 parity: Parity::Even,
2959 max_concurrent_requests: 64,
2960 initial_channel_credit: 16,
2961 },
2962 metadata: vec![],
2963 },
2964 )
2965 }
2966
2967 fn zero_credit_accept_ref() -> SelfRef<ConnectionAccept<'static>> {
2968 SelfRef::owning(
2969 Backing::Boxed(Box::<[u8]>::default()),
2970 ConnectionAccept {
2971 connection_settings: ConnectionSettings {
2972 parity: Parity::Even,
2973 max_concurrent_requests: 64,
2974 initial_channel_credit: 0,
2975 },
2976 metadata: vec![],
2977 },
2978 )
2979 }
2980
2981 fn reject_ref() -> SelfRef<ConnectionReject<'static>> {
2982 SelfRef::owning(
2983 Backing::Boxed(Box::<[u8]>::default()),
2984 ConnectionReject { metadata: vec![] },
2985 )
2986 }
2987
2988 #[tokio::test]
2989 async fn duplicate_connection_accept_is_ignored_after_first() {
2990 let mut session = make_session();
2991 let conn_id = ConnectionId(1);
2992 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.test.open_result");
2993
2994 session.conns.insert(
2995 conn_id,
2996 ConnectionSlot::PendingOutbound(PendingOutboundData {
2997 local_settings: ConnectionSettings {
2998 parity: Parity::Odd,
2999 max_concurrent_requests: 64,
3000 initial_channel_credit: 16,
3001 },
3002 result_tx: Some(result_tx),
3003 }),
3004 );
3005
3006 session.handle_inbound_accept(conn_id, accept_ref());
3007 let handle = result_rx
3008 .await
3009 .expect("pending outbound result should resolve")
3010 .expect("accept should resolve as Ok");
3011 assert_eq!(handle.connection_id(), conn_id);
3012
3013 session.handle_inbound_accept(conn_id, accept_ref());
3014 assert!(
3015 matches!(
3016 session.conns.get(&conn_id),
3017 Some(ConnectionSlot::Active(ConnectionState { id, .. })) if *id == conn_id
3018 ),
3019 "duplicate accept should keep existing active connection state"
3020 );
3021 }
3022
3023 #[tokio::test]
3024 async fn duplicate_connection_reject_is_ignored_after_first() {
3025 let mut session = make_session();
3026 let conn_id = ConnectionId(1);
3027 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.test.open_result");
3028
3029 session.conns.insert(
3030 conn_id,
3031 ConnectionSlot::PendingOutbound(PendingOutboundData {
3032 local_settings: ConnectionSettings {
3033 parity: Parity::Odd,
3034 max_concurrent_requests: 64,
3035 initial_channel_credit: 16,
3036 },
3037 result_tx: Some(result_tx),
3038 }),
3039 );
3040
3041 session.handle_inbound_reject(conn_id, reject_ref());
3042 let result = result_rx
3043 .await
3044 .expect("pending outbound result should resolve");
3045 assert!(
3046 matches!(result, Err(SessionError::Rejected(_))),
3047 "expected rejection, got: {result:?}"
3048 );
3049
3050 session.handle_inbound_reject(conn_id, reject_ref());
3051 assert!(
3052 !session.conns.contains_key(&conn_id),
3053 "duplicate reject should not recreate connection state"
3054 );
3055 }
3056
3057 #[tokio::test]
3059 async fn inbound_accept_with_zero_initial_credit_rejects_pending_open() {
3060 let mut session = make_session();
3061 let conn_id = ConnectionId(1);
3062 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.test.open_result");
3063
3064 session.conns.insert(
3065 conn_id,
3066 ConnectionSlot::PendingOutbound(PendingOutboundData {
3067 local_settings: ConnectionSettings {
3068 parity: Parity::Odd,
3069 max_concurrent_requests: 64,
3070 initial_channel_credit: 16,
3071 },
3072 result_tx: Some(result_tx),
3073 }),
3074 );
3075
3076 session.handle_inbound_accept(conn_id, zero_credit_accept_ref());
3077 let result = result_rx
3078 .await
3079 .expect("pending outbound result should resolve");
3080 assert!(
3081 matches!(
3082 result,
3083 Err(SessionError::Protocol(ref message))
3084 if message == "initial_channel_credit must be greater than zero"
3085 ),
3086 "expected zero-credit protocol error, got: {result:?}"
3087 );
3088 assert!(
3089 !session.conns.contains_key(&conn_id),
3090 "zero-credit accept should not create an active connection"
3091 );
3092 }
3093
3094 #[test]
3095 fn out_of_order_accept_or_reject_without_pending_is_ignored() {
3096 let mut session = make_session();
3097 let conn_id = ConnectionId(99);
3098
3099 session.handle_inbound_accept(conn_id, accept_ref());
3100 session.handle_inbound_reject(conn_id, reject_ref());
3101
3102 assert!(
3103 session.conns.is_empty(),
3104 "out-of-order accept/reject should not mutate empty connection table"
3105 );
3106 }
3107
3108 #[tokio::test]
3109 async fn close_request_clears_pending_outbound_open() {
3110 let mut session = make_session();
3111 let (open_result_tx, open_result_rx) = moire::sync::oneshot::channel("session.open.result");
3112 let (close_result_tx, close_result_rx) =
3113 moire::sync::oneshot::channel("session.close.result");
3114
3115 session.conns.insert(
3116 ConnectionId(1),
3117 ConnectionSlot::PendingOutbound(PendingOutboundData {
3118 local_settings: ConnectionSettings {
3119 parity: Parity::Odd,
3120 max_concurrent_requests: 64,
3121 initial_channel_credit: 16,
3122 },
3123 result_tx: Some(open_result_tx),
3124 }),
3125 );
3126
3127 session
3128 .handle_close_request(CloseRequest {
3129 conn_id: ConnectionId(1),
3130 metadata: vec![],
3131 result_tx: close_result_tx,
3132 })
3133 .await;
3134
3135 let close_result = close_result_rx
3136 .await
3137 .expect("close result should be delivered");
3138 assert!(
3139 close_result.is_ok(),
3140 "close should succeed for pending outbound connection"
3141 );
3142
3143 assert!(
3144 open_result_rx.await.is_err(),
3145 "pending open result channel should be closed once the pending slot is removed"
3146 );
3147 }
3148
3149 #[test]
3150 fn resume_rejects_changed_local_root_settings() {
3151 let mut session = make_session();
3152 let local_settings = ConnectionSettings {
3153 parity: Parity::Odd,
3154 max_concurrent_requests: 64,
3155 initial_channel_credit: 16,
3156 };
3157 let peer_settings = ConnectionSettings {
3158 parity: Parity::Even,
3159 max_concurrent_requests: 64,
3160 initial_channel_credit: 16,
3161 };
3162 let _root = session
3163 .establish_from_handshake(resumed_handshake(
3164 local_settings.clone(),
3165 peer_settings.clone(),
3166 ))
3167 .expect("initial handshake should establish session");
3168
3169 let (link_a, _link_b) = crate::memory_link_pair(32);
3170 let conduit = crate::BareConduit::new(link_a);
3171 let (tx, rx) = conduit.split();
3172
3173 let result = session.resume_from_handshake(
3174 Arc::new(tx),
3175 Box::new(rx),
3176 resumed_handshake(
3177 ConnectionSettings {
3178 parity: Parity::Odd,
3179 max_concurrent_requests: 65,
3180 initial_channel_credit: 16,
3181 },
3182 peer_settings,
3183 ),
3184 );
3185
3186 assert!(
3187 matches!(
3188 &result,
3189 Err(SessionError::Protocol(message))
3190 if message == "local root settings changed across session resume"
3191 ),
3192 "expected local-root-settings mismatch, got: {result:?}"
3193 );
3194 }
3195
3196 #[test]
3197 fn resume_rejects_changed_peer_root_settings() {
3198 let mut session = make_session();
3199 let local_settings = ConnectionSettings {
3200 parity: Parity::Odd,
3201 max_concurrent_requests: 64,
3202 initial_channel_credit: 16,
3203 };
3204 let peer_settings = ConnectionSettings {
3205 parity: Parity::Even,
3206 max_concurrent_requests: 64,
3207 initial_channel_credit: 16,
3208 };
3209 let _root = session
3210 .establish_from_handshake(resumed_handshake(
3211 local_settings.clone(),
3212 peer_settings.clone(),
3213 ))
3214 .expect("initial handshake should establish session");
3215
3216 let (link_a, _link_b) = crate::memory_link_pair(32);
3217 let conduit = crate::BareConduit::new(link_a);
3218 let (tx, rx) = conduit.split();
3219
3220 let result = session.resume_from_handshake(
3221 Arc::new(tx),
3222 Box::new(rx),
3223 resumed_handshake(
3224 local_settings,
3225 ConnectionSettings {
3226 parity: Parity::Even,
3227 max_concurrent_requests: 65,
3228 initial_channel_credit: 16,
3229 },
3230 ),
3231 );
3232
3233 assert!(
3234 matches!(
3235 &result,
3236 Err(SessionError::Protocol(message))
3237 if message == "peer root settings changed across session resume"
3238 ),
3239 "expected peer-root-settings mismatch, got: {result:?}"
3240 );
3241 }
3242}