1use std::{
2 collections::{BTreeMap, HashMap},
3 future::Future,
4 pin::Pin,
5 sync::Arc,
6 time::Duration,
7};
8
9use facet_core::Shape;
10use moire::sync::mpsc;
11use tokio::sync::{mpsc as tokio_mpsc, oneshot as tokio_oneshot, watch};
12use tracing::{trace, warn};
13use vox_types::{
14 BoxFut, ChannelMessage, ConduitRx, ConduitTx, ConnectionAccept, ConnectionClose, ConnectionId,
15 ConnectionOpen, ConnectionReject, ConnectionSettings, Handler, HandshakeResult, IdAllocator,
16 MaybeSend, MaybeSync, Message, MessageFamily, MessagePayload, Metadata, Parity, RequestBody,
17 RequestId, RequestMessage, RequestResponse, SchemaMessage, SelfRef, SessionResumeKey,
18 SessionRole,
19};
20
21mod builders;
22pub use builders::*;
23
24#[derive(Debug, Clone, Copy)]
26pub struct SessionKeepaliveConfig {
27 pub ping_interval: Duration,
28 pub pong_timeout: Duration,
29}
30
31pub struct ConnectionRequest<'a> {
39 metadata: &'a [vox_types::MetadataEntry<'a>],
40 service: &'a str,
41}
42
43impl<'a> ConnectionRequest<'a> {
44 pub fn new(metadata: &'a [vox_types::MetadataEntry<'a>]) -> Result<Self, SessionError> {
48 let service = vox_types::metadata_get_str(metadata, "vox-service").ok_or_else(|| {
49 SessionError::Protocol("missing required vox-service metadata".into())
50 })?;
51 Ok(Self { metadata, service })
52 }
53
54 pub fn service(&self) -> &str {
56 self.service
57 }
58
59 pub fn transport(&self) -> Option<&str> {
61 vox_types::metadata_get_str(self.metadata, "vox-transport")
62 }
63
64 pub fn peer_addr(&self) -> Option<&str> {
66 vox_types::metadata_get_str(self.metadata, "vox-peer-addr")
67 }
68
69 pub fn is_root(&self) -> bool {
71 !self.is_virtual()
72 }
73
74 pub fn is_virtual(&self) -> bool {
76 vox_types::metadata_get_str(self.metadata, "vox-connection-kind") == Some("virtual")
77 }
78
79 pub fn get_str(&self, key: &str) -> Option<&str> {
81 vox_types::metadata_get_str(self.metadata, key)
82 }
83
84 pub fn get_u64(&self, key: &str) -> Option<u64> {
86 vox_types::metadata_get_u64(self.metadata, key)
87 }
88
89 pub fn metadata(&self) -> &[vox_types::MetadataEntry<'a>] {
91 self.metadata
92 }
93}
94
95pub struct PendingConnection {
102 handle: Option<ConnectionHandle>,
103 caller_slot: Option<Arc<std::sync::Mutex<Option<crate::Caller>>>>,
104 operation_store: Option<Arc<dyn crate::OperationStore>>,
105}
106
107impl PendingConnection {
108 fn new(handle: ConnectionHandle) -> Self {
109 Self {
110 handle: Some(handle),
111 caller_slot: None,
112 operation_store: None,
113 }
114 }
115
116 fn with_caller_slot(
118 handle: ConnectionHandle,
119 caller_slot: Arc<std::sync::Mutex<Option<crate::Caller>>>,
120 operation_store: Option<Arc<dyn crate::OperationStore>>,
121 ) -> Self {
122 Self {
123 handle: Some(handle),
124 caller_slot: Some(caller_slot),
125 operation_store,
126 }
127 }
128
129 pub fn handle_with(mut self, handler: impl Handler<crate::DriverReplySink> + 'static) {
131 let handle = self
132 .handle
133 .take()
134 .expect("PendingConnection already consumed");
135 let conn_id = handle.connection_id();
136 trace!(%conn_id, "PendingConnection::handle_with: creating driver");
137 let mut driver = match self.operation_store.take() {
138 Some(store) => crate::Driver::with_operation_store(handle, handler, store),
139 None => crate::Driver::new(handle, handler),
140 };
141 if let Some(slot) = &self.caller_slot {
142 let caller = crate::Caller::new(driver.caller());
143 *slot.lock().unwrap() = Some(caller);
144 }
145 #[cfg(not(target_arch = "wasm32"))]
146 tokio::spawn(async move {
147 trace!(%conn_id, "PendingConnection driver starting");
148 driver.run().await;
149 trace!(%conn_id, "PendingConnection driver exited");
150 });
151 #[cfg(target_arch = "wasm32")]
152 wasm_bindgen_futures::spawn_local(async move { driver.run().await });
153 }
154
155 pub fn handle_with_client<C: crate::FromVoxSession>(
157 mut self,
158 handler: impl Handler<crate::DriverReplySink> + 'static,
159 ) -> C {
160 let handle = self
161 .handle
162 .take()
163 .expect("PendingConnection already consumed");
164 let conn_id = handle.connection_id();
165 trace!(%conn_id, "PendingConnection::handle_with_client: creating driver");
166 let mut driver = match self.operation_store.take() {
167 Some(store) => crate::Driver::with_operation_store(handle, handler, store),
168 None => crate::Driver::new(handle, handler),
169 };
170 let caller = crate::Caller::new(driver.caller());
171 if let Some(slot) = &self.caller_slot {
172 *slot.lock().unwrap() = Some(caller.clone());
173 }
174 #[cfg(not(target_arch = "wasm32"))]
175 tokio::spawn(async move {
176 trace!(%conn_id, "PendingConnection driver starting");
177 driver.run().await;
178 trace!(%conn_id, "PendingConnection driver exited");
179 });
180 #[cfg(target_arch = "wasm32")]
181 wasm_bindgen_futures::spawn_local(async move { driver.run().await });
182 C::from_vox_session(caller, None)
183 }
184
185 pub fn proxy_to(mut self, other: ConnectionHandle) {
187 let handle = self
188 .handle
189 .take()
190 .expect("PendingConnection already consumed");
191 #[cfg(not(target_arch = "wasm32"))]
192 tokio::spawn(async move {
193 let _ = proxy_connections(handle, other).await;
194 });
195 #[cfg(target_arch = "wasm32")]
196 wasm_bindgen_futures::spawn_local(async move {
197 let _ = proxy_connections(handle, other).await;
198 });
199 }
200
201 pub fn into_handle(mut self) -> ConnectionHandle {
203 self.handle
204 .take()
205 .expect("PendingConnection already consumed")
206 }
207}
208
209impl Drop for PendingConnection {
210 fn drop(&mut self) {
211 if let Some(handle) = self.handle.take() {
212 let conn_id = handle.connection_id();
213 warn!(%conn_id, "PendingConnection dropped without being consumed — closing connection");
214 if let Some(tx) = handle.control_tx.as_ref() {
215 let _ = send_drop_control(tx, DropControlRequest::Close(conn_id));
216 }
217 }
218 }
219}
220
221pub trait ConnectionAcceptor: MaybeSend + MaybeSync + 'static {
223 fn accept(
224 &self,
225 request: &ConnectionRequest,
226 connection: PendingConnection,
227 ) -> Result<(), Metadata<'static>>;
228}
229
230impl<H> ConnectionAcceptor for H
232where
233 H: Handler<crate::DriverReplySink> + Clone + MaybeSend + MaybeSync + 'static,
234{
235 fn accept(
236 &self,
237 _request: &ConnectionRequest,
238 connection: PendingConnection,
239 ) -> Result<(), Metadata<'static>> {
240 connection.handle_with(self.clone());
241 Ok(())
242 }
243}
244
245pub struct AcceptorFn<F>(pub F);
247
248impl<F> ConnectionAcceptor for AcceptorFn<F>
249where
250 F: Fn(&ConnectionRequest, PendingConnection) -> Result<(), Metadata<'static>>
251 + MaybeSend
252 + MaybeSync
253 + 'static,
254{
255 fn accept(
256 &self,
257 request: &ConnectionRequest,
258 connection: PendingConnection,
259 ) -> Result<(), Metadata<'static>> {
260 (self.0)(request, connection)
261 }
262}
263
264pub fn acceptor_fn<F>(f: F) -> AcceptorFn<F>
266where
267 F: Fn(&ConnectionRequest, PendingConnection) -> Result<(), Metadata<'static>>
268 + MaybeSend
269 + MaybeSync
270 + 'static,
271{
272 AcceptorFn(f)
273}
274
275struct OpenRequest {
280 settings: ConnectionSettings,
281 metadata: Metadata<'static>,
282 result_tx: moire::sync::oneshot::Sender<Result<ConnectionHandle, SessionError>>,
283}
284
285struct CloseRequest {
286 conn_id: ConnectionId,
287 metadata: Metadata<'static>,
288 result_tx: moire::sync::oneshot::Sender<Result<(), SessionError>>,
289}
290
291struct ResumeRequest {
292 tx: Arc<dyn DynConduitTx>,
293 rx: Box<dyn DynConduitRx>,
294 handshake_result: HandshakeResult,
295 result_tx: moire::sync::oneshot::Sender<Result<(), SessionError>>,
296}
297
298#[derive(Debug, Clone, Copy)]
299pub(crate) enum DropControlRequest {
300 Shutdown,
301 Close(ConnectionId),
302}
303
304#[derive(Clone, Copy, Debug)]
305pub(crate) enum FailureDisposition {
306 Cancelled,
307 Indeterminate,
308}
309
310#[cfg(not(target_arch = "wasm32"))]
311fn send_drop_control(
312 tx: &mpsc::UnboundedSender<DropControlRequest>,
313 req: DropControlRequest,
314) -> Result<(), ()> {
315 tx.send(req).map_err(|_| ())
316}
317
318#[cfg(target_arch = "wasm32")]
319fn send_drop_control(
320 tx: &mpsc::UnboundedSender<DropControlRequest>,
321 req: DropControlRequest,
322) -> Result<(), ()> {
323 tx.try_send(req).map_err(|_| ())
324}
325
326#[derive(Clone)]
337pub struct SessionHandle {
338 open_tx: mpsc::Sender<OpenRequest>,
339 close_tx: mpsc::Sender<CloseRequest>,
340 resume_tx: mpsc::Sender<ResumeRequest>,
341 control_tx: mpsc::UnboundedSender<DropControlRequest>,
342 resume_key: Option<SessionResumeKey>,
343}
344
345impl SessionHandle {
346 pub async fn open<Client: crate::FromVoxSession>(
352 &self,
353 settings: ConnectionSettings,
354 ) -> Result<Client, SessionError> {
355 use crate::{Caller, Driver};
356 use vox_types::{MetadataEntry, MetadataFlags, MetadataValue};
357
358 let metadata: Metadata<'static> = vec![MetadataEntry {
359 key: crate::session::builders::VOX_SERVICE_METADATA_KEY.into(),
360 value: MetadataValue::String(Client::SERVICE_NAME.into()),
361 flags: MetadataFlags::NONE,
362 }];
363 let handle = self.open_connection(settings, metadata).await?;
364 let mut driver = Driver::new(handle, ());
365 let caller = Caller::new(driver.caller());
366 #[cfg(not(target_arch = "wasm32"))]
367 tokio::spawn(async move { driver.run().await });
368 #[cfg(target_arch = "wasm32")]
369 wasm_bindgen_futures::spawn_local(async move { driver.run().await });
370 Ok(Client::from_vox_session(caller, None))
371 }
372
373 pub async fn open_connection(
380 &self,
381 settings: ConnectionSettings,
382 metadata: Metadata<'static>,
383 ) -> Result<ConnectionHandle, SessionError> {
384 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.open_result");
385 self.open_tx
386 .send(OpenRequest {
387 settings,
388 metadata,
389 result_tx,
390 })
391 .await
392 .map_err(|_| SessionError::Protocol("session closed".into()))?;
393 result_rx
394 .await
395 .map_err(|_| SessionError::Protocol("session closed".into()))?
396 }
397
398 pub async fn close_connection(
405 &self,
406 conn_id: ConnectionId,
407 metadata: Metadata<'static>,
408 ) -> Result<(), SessionError> {
409 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.close_result");
410 self.close_tx
411 .send(CloseRequest {
412 conn_id,
413 metadata,
414 result_tx,
415 })
416 .await
417 .map_err(|_| SessionError::Protocol("session closed".into()))?;
418 result_rx
419 .await
420 .map_err(|_| SessionError::Protocol("session closed".into()))?
421 }
422
423 pub(crate) async fn resume_parts(
424 &self,
425 tx: Arc<dyn DynConduitTx>,
426 rx: Box<dyn DynConduitRx>,
427 handshake_result: HandshakeResult,
428 ) -> Result<(), SessionError> {
429 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.resume_result");
430 self.resume_tx
431 .send(ResumeRequest {
432 tx,
433 rx,
434 handshake_result,
435 result_tx,
436 })
437 .await
438 .map_err(|_| SessionError::Protocol("session closed".into()))?;
439 result_rx
440 .await
441 .map_err(|_| SessionError::Protocol("session closed".into()))?
442 }
443
444 pub fn resume_key(&self) -> Option<&SessionResumeKey> {
446 self.resume_key.as_ref()
447 }
448
449 pub fn shutdown(&self) -> Result<(), SessionError> {
451 send_drop_control(&self.control_tx, DropControlRequest::Shutdown)
452 .map_err(|_| SessionError::Protocol("session closed".into()))
453 }
454}
455
456pub struct Session {
464 rx: Box<dyn DynConduitRx>,
466
467 role: SessionRole,
469
470 parity: Parity,
473
474 sess_core: Arc<SessionCore>,
476 peer_supports_retry: bool,
477 local_root_settings: ConnectionSettings,
478 peer_root_settings: Option<ConnectionSettings>,
479 resumable: bool,
480 session_resume_key: Option<SessionResumeKey>,
481
482 conns: BTreeMap<ConnectionId, ConnectionSlot>,
484 root_closed_internal: bool,
486
487 conn_ids: IdAllocator<ConnectionId>,
489
490 on_connection: Option<Arc<dyn ConnectionAcceptor>>,
492
493 open_rx: mpsc::Receiver<OpenRequest>,
495
496 close_rx: mpsc::Receiver<CloseRequest>,
498
499 resume_rx: mpsc::Receiver<ResumeRequest>,
501
502 control_tx: mpsc::UnboundedSender<DropControlRequest>,
504 control_rx: mpsc::UnboundedReceiver<DropControlRequest>,
505
506 keepalive: Option<SessionKeepaliveConfig>,
508 resume_notifier: watch::Sender<u64>,
509 recoverer: Option<Box<dyn ConduitRecoverer>>,
510 recovery_timeout: Option<Duration>,
511 registered_in_registry: bool,
514}
515
516#[derive(Debug)]
517struct KeepaliveRuntime {
518 ping_interval: Duration,
519 pong_timeout: Duration,
520 next_ping_at: tokio::time::Instant,
521 waiting_pong_nonce: Option<u64>,
522 pong_deadline: tokio::time::Instant,
523 next_ping_nonce: u64,
524}
525
526#[derive(Debug)]
529pub struct ConnectionState {
530 pub id: ConnectionId,
532
533 pub local_settings: ConnectionSettings,
535
536 pub peer_settings: ConnectionSettings,
538
539 conn_tx: mpsc::Sender<RecvMessage>,
541 closed_tx: watch::Sender<bool>,
542
543 schema_recv_tracker: Arc<vox_types::SchemaRecvTracker>,
545}
546
547#[derive(Debug)]
548enum ConnectionSlot {
549 Active(ConnectionState),
550 PendingOutbound(PendingOutboundData),
551}
552
553struct PendingOutboundData {
555 local_settings: ConnectionSettings,
556 result_tx: Option<moire::sync::oneshot::Sender<Result<ConnectionHandle, SessionError>>>,
557}
558
559impl std::fmt::Debug for PendingOutboundData {
560 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
561 f.debug_struct("PendingOutbound")
562 .field("local_settings", &self.local_settings)
563 .finish()
564 }
565}
566
567#[derive(Clone)]
568pub(crate) struct ConnectionSender {
569 connection_id: ConnectionId,
570 pub(crate) sess_core: Arc<SessionCore>,
571 failures: Arc<mpsc::UnboundedSender<(RequestId, FailureDisposition)>>,
572}
573
574fn forwarded_payload<'a>(payload: &'a vox_types::Payload<'a>) -> vox_types::Payload<'a> {
575 let vox_types::Payload::PostcardBytes(bytes) = payload else {
576 unreachable!("proxy forwarding expects decoded incoming payload bytes")
577 };
578 vox_types::Payload::PostcardBytes(bytes)
579}
580
581fn forwarded_request_body<'a>(body: &'a RequestBody<'a>) -> RequestBody<'a> {
582 match body {
583 RequestBody::Call(call) => RequestBody::Call(vox_types::RequestCall {
584 method_id: call.method_id,
585 metadata: call.metadata.clone(),
586 args: forwarded_payload(&call.args),
587 schemas: call.schemas.clone(),
588 }),
589 RequestBody::Response(response) => RequestBody::Response(RequestResponse {
590 metadata: response.metadata.clone(),
591 ret: forwarded_payload(&response.ret),
592 schemas: response.schemas.clone(),
593 }),
594 RequestBody::Cancel(cancel) => RequestBody::Cancel(vox_types::RequestCancel {
595 metadata: cancel.metadata.clone(),
596 }),
597 }
598}
599
600fn forwarded_channel_body<'a>(body: &'a vox_types::ChannelBody<'a>) -> vox_types::ChannelBody<'a> {
601 match body {
602 vox_types::ChannelBody::Item(item) => {
603 vox_types::ChannelBody::Item(vox_types::ChannelItem {
604 item: forwarded_payload(&item.item),
605 })
606 }
607 vox_types::ChannelBody::Close(close) => {
608 vox_types::ChannelBody::Close(vox_types::ChannelClose {
609 metadata: close.metadata.clone(),
610 })
611 }
612 vox_types::ChannelBody::Reset(reset) => {
613 vox_types::ChannelBody::Reset(vox_types::ChannelReset {
614 metadata: reset.metadata.clone(),
615 })
616 }
617 vox_types::ChannelBody::GrantCredit(credit) => {
618 vox_types::ChannelBody::GrantCredit(vox_types::ChannelGrantCredit {
619 additional: credit.additional,
620 })
621 }
622 }
623}
624
625impl ConnectionSender {
626 pub(crate) fn connection_id(&self) -> ConnectionId {
627 self.connection_id
628 }
629
630 pub(crate) async fn send_with_binder<'a>(
631 &self,
632 msg: ConnectionMessage<'a>,
633 binder: Option<&'a dyn vox_types::ChannelBinder>,
634 ) -> Result<(), ()> {
635 let payload = match msg {
636 ConnectionMessage::Request(r) => MessagePayload::RequestMessage(r),
637 ConnectionMessage::Channel(c) => MessagePayload::ChannelMessage(c),
638 };
639 let message = Message {
640 connection_id: self.connection_id,
641 payload,
642 };
643 self.sess_core
644 .send(message, binder, None)
645 .await
646 .map_err(|_| ())
647 }
648
649 pub async fn send<'a>(&self, msg: ConnectionMessage<'a>) -> Result<(), ()> {
651 self.send_with_binder(msg, None).await
652 }
653
654 pub(crate) async fn send_owned(
656 &self,
657 schemas: Arc<vox_types::SchemaRecvTracker>,
658 msg: SelfRef<ConnectionMessage<'static>>,
659 ) -> Result<(), ()> {
660 let msg_ref = msg.get();
661 let payload = match msg_ref {
662 ConnectionMessage::Request(request) => MessagePayload::RequestMessage(RequestMessage {
663 id: request.id,
664 body: forwarded_request_body(&request.body),
665 }),
666 ConnectionMessage::Channel(channel) => MessagePayload::ChannelMessage(ChannelMessage {
667 id: channel.id,
668 body: forwarded_channel_body(&channel.body),
669 }),
670 };
671
672 self.sess_core
673 .send(
674 Message {
675 connection_id: self.connection_id,
676 payload,
677 },
678 None,
679 Some(&*schemas),
680 )
681 .await
682 .map_err(|_| ())
683 }
684
685 pub async fn send_response<'a>(
687 &self,
688 request_id: RequestId,
689 response: RequestResponse<'a>,
690 ) -> Result<(), ()> {
691 self.send(ConnectionMessage::Request(RequestMessage {
692 id: request_id,
693 body: RequestBody::Response(response),
694 }))
695 .await
696 }
697
698 pub async fn send_response_for_method<'a>(
700 &self,
701 request_id: RequestId,
702 method_id: vox_types::MethodId,
703 mut response: RequestResponse<'a>,
704 ) -> Result<(), ()> {
705 self.prepare_response_for_method(request_id, method_id, &mut response);
706 self.send(ConnectionMessage::Request(RequestMessage {
707 id: request_id,
708 body: RequestBody::Response(response),
709 }))
710 .await
711 }
712
713 pub(crate) fn prepare_response_for_method(
715 &self,
716 request_id: RequestId,
717 method_id: vox_types::MethodId,
718 response: &mut RequestResponse<'_>,
719 ) {
720 self.sess_core.prepare_response_for_method(
721 self.connection_id,
722 request_id,
723 method_id,
724 response,
725 );
726 }
727
728 pub(crate) fn prepare_response_from_source(
730 &self,
731 request_id: RequestId,
732 method_id: vox_types::MethodId,
733 root_type: &vox_types::TypeRef,
734 source: &dyn vox_types::SchemaSource,
735 response: &mut RequestResponse<'_>,
736 ) {
737 self.sess_core.prepare_response_from_source(
738 self.connection_id,
739 request_id,
740 method_id,
741 root_type,
742 source,
743 response,
744 );
745 }
746
747 pub fn mark_failure(&self, request_id: RequestId, disposition: FailureDisposition) {
750 let _ = self.failures.send((request_id, disposition));
751 }
752
753 pub fn schema_registry(&self) -> vox_types::SchemaRegistry {
755 self.sess_core.schema_registry(self.connection_id)
756 }
757
758 pub fn prepare_replay_schemas(
760 &self,
761 request_id: RequestId,
762 method_id: vox_types::MethodId,
763 root_type: &vox_types::TypeRef,
764 store: &dyn crate::OperationStore,
765 response: &mut RequestResponse<'_>,
766 ) {
767 self.prepare_response_from_source(
768 request_id,
769 method_id,
770 root_type,
771 store.schema_source(),
772 response,
773 );
774 }
775}
776
777pub struct ConnectionHandle {
778 pub(crate) sender: ConnectionSender,
779 pub(crate) rx: mpsc::Receiver<RecvMessage>,
780 pub(crate) failures_rx: mpsc::UnboundedReceiver<(RequestId, FailureDisposition)>,
781 pub(crate) control_tx: Option<mpsc::UnboundedSender<DropControlRequest>>,
782 pub(crate) closed_rx: watch::Receiver<bool>,
783 pub(crate) resumed_rx: watch::Receiver<u64>,
784 pub parity: Parity,
786 pub(crate) peer_supports_retry: bool,
787}
788
789impl std::fmt::Debug for ConnectionHandle {
790 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
791 f.debug_struct("ConnectionHandle")
792 .field("connection_id", &self.sender.connection_id)
793 .finish()
794 }
795}
796
797pub(crate) enum ConnectionMessage<'payload> {
798 Request(RequestMessage<'payload>),
799 Channel(ChannelMessage<'payload>),
800}
801
802vox_types::impl_reborrow!(ConnectionMessage);
803
804pub(crate) struct RecvMessage {
808 pub schemas: Arc<vox_types::SchemaRecvTracker>,
809 pub msg: SelfRef<ConnectionMessage<'static>>,
810}
811
812impl ConnectionHandle {
813 pub fn connection_id(&self) -> ConnectionId {
815 self.sender.connection_id
816 }
817
818 pub async fn closed(&self) {
820 if *self.closed_rx.borrow() {
821 return;
822 }
823 let mut rx = self.closed_rx.clone();
824 while rx.changed().await.is_ok() {
825 if *rx.borrow() {
826 return;
827 }
828 }
829 }
830
831 pub fn is_connected(&self) -> bool {
833 !*self.closed_rx.borrow()
834 }
835
836 pub fn peer_supports_retry(&self) -> bool {
837 self.peer_supports_retry
838 }
839}
840
841pub async fn proxy_connections(
847 left: ConnectionHandle,
848 right: ConnectionHandle,
849) -> Result<(), SessionError> {
850 if left.parity == right.parity {
851 return Err(SessionError::Protocol(
852 "proxy_connections requires opposite parities".into(),
853 ));
854 }
855 let left_conn_id = left.connection_id();
856 let right_conn_id = right.connection_id();
857 let ConnectionHandle {
858 sender: left_sender,
859 rx: mut left_rx,
860 failures_rx: _left_failures_rx,
861 control_tx: left_control_tx,
862 closed_rx: _left_closed_rx,
863 resumed_rx: _left_resumed_rx,
864 parity: _left_parity,
865 peer_supports_retry: _left_peer_supports_retry,
866 } = left;
867 let ConnectionHandle {
868 sender: right_sender,
869 rx: mut right_rx,
870 failures_rx: _right_failures_rx,
871 control_tx: right_control_tx,
872 closed_rx: _right_closed_rx,
873 resumed_rx: _right_resumed_rx,
874 parity: _right_parity,
875 peer_supports_retry: _right_peer_supports_retry,
876 } = right;
877
878 loop {
879 tokio::select! {
880 recv = left_rx.recv() => {
881 let Some(recv) = recv else {
882 break;
883 };
884 if right_sender.send_owned(recv.schemas, recv.msg).await.is_err() {
885 break;
886 }
887 }
888 recv = right_rx.recv() => {
889 let Some(recv) = recv else {
890 break;
891 };
892 if left_sender.send_owned(recv.schemas, recv.msg).await.is_err() {
893 break;
894 }
895 }
896 }
897 }
898
899 if let Some(tx) = left_control_tx.as_ref() {
900 let _ = send_drop_control(tx, DropControlRequest::Close(left_conn_id));
901 }
902 if let Some(tx) = right_control_tx.as_ref() {
903 let _ = send_drop_control(tx, DropControlRequest::Close(right_conn_id));
904 }
905 Ok(())
906}
907
908#[derive(Debug)]
910pub enum SessionError {
911 Io(std::io::Error),
912 Protocol(String),
913 Rejected(Metadata<'static>),
914 NotResumable,
915 ConnectTimeout,
916}
917
918impl SessionError {
919 pub fn is_retryable(&self) -> bool {
925 matches!(
926 self,
927 Self::Io(_) | Self::ConnectTimeout | Self::NotResumable
928 )
929 }
930}
931
932impl std::fmt::Display for SessionError {
933 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
934 match self {
935 Self::Io(e) => write!(f, "io error: {e}"),
936 Self::Protocol(msg) => write!(f, "protocol error: {msg}"),
937 Self::Rejected(_) => write!(f, "connection rejected"),
938 Self::NotResumable => write!(f, "session is not resumable"),
939 Self::ConnectTimeout => write!(f, "connect timeout"),
940 }
941 }
942}
943
944impl std::error::Error for SessionError {}
945
946impl Session {
947 fn close_connection_for_protocol_error(
948 &mut self,
949 conn_id: ConnectionId,
950 detail: impl std::fmt::Display,
951 ) {
952 warn!(%conn_id, "closing connection after protocol error: {detail}");
953 self.remove_connection(&conn_id);
954 self.maybe_request_shutdown_after_root_closed();
955 }
956
957 fn record_received_schema_cbor(
958 &mut self,
959 conn_id: ConnectionId,
960 schema_recv_tracker: Arc<vox_types::SchemaRecvTracker>,
961 method_id: vox_types::MethodId,
962 direction: vox_types::BindingDirection,
963 schemas_cbor: &vox_types::CborPayload,
964 context: &str,
965 ) -> bool {
966 let payload = match vox_types::SchemaPayload::from_cbor(&schemas_cbor.0) {
967 Ok(payload) => payload,
968 Err(error) => {
969 self.close_connection_for_protocol_error(
970 conn_id,
971 format!("{context}: invalid schema CBOR: {error}"),
972 );
973 return false;
974 }
975 };
976
977 if let Err(error) = schema_recv_tracker.record_received(method_id, direction, payload) {
978 self.close_connection_for_protocol_error(conn_id, format!("{context}: {error}"));
979 return false;
980 }
981
982 true
983 }
984
985 #[allow(clippy::too_many_arguments)]
986 fn pre_handshake<Tx, Rx>(
987 tx: Tx,
988 rx: Rx,
989 on_connection: Option<Arc<dyn ConnectionAcceptor>>,
990 open_rx: mpsc::Receiver<OpenRequest>,
991 close_rx: mpsc::Receiver<CloseRequest>,
992 resume_rx: mpsc::Receiver<ResumeRequest>,
993 control_tx: mpsc::UnboundedSender<DropControlRequest>,
994 control_rx: mpsc::UnboundedReceiver<DropControlRequest>,
995 keepalive: Option<SessionKeepaliveConfig>,
996 resumable: bool,
997 recoverer: Option<Box<dyn ConduitRecoverer>>,
998 recovery_timeout: Option<Duration>,
999 ) -> Self
1000 where
1001 Tx: ConduitTx<Msg = MessageFamily> + MaybeSend + MaybeSync + 'static,
1002 Rx: ConduitRx<Msg = MessageFamily> + MaybeSend + 'static,
1003 {
1004 let (outbound_tx, outbound_rx) = tokio_mpsc::channel(256);
1005 let sess_core = Arc::new(SessionCore {
1006 inner: std::sync::Mutex::new(SessionCoreInner {
1007 tx: Arc::new(tx) as Arc<dyn DynConduitTx>,
1008 conns: HashMap::new(),
1009 }),
1010 outbound_tx,
1011 });
1012 spawn_outbound_worker(outbound_rx);
1013 let (resume_notifier, _resume_rx) = watch::channel(0_u64);
1014 Session {
1015 rx: Box::new(rx),
1016 role: SessionRole::Initiator, parity: Parity::Odd, sess_core,
1019 peer_supports_retry: false,
1020 local_root_settings: ConnectionSettings {
1021 parity: Parity::Odd,
1022 max_concurrent_requests: 64,
1023 },
1024 peer_root_settings: None,
1025 resumable,
1026 session_resume_key: None,
1027 conns: BTreeMap::new(),
1028 root_closed_internal: false,
1029 conn_ids: IdAllocator::new(Parity::Odd), on_connection,
1031 open_rx,
1032 close_rx,
1033 resume_rx,
1034 control_tx,
1035 control_rx,
1036 keepalive,
1037 resume_notifier,
1038 recoverer,
1039 recovery_timeout,
1040 registered_in_registry: false,
1041 }
1042 }
1043
1044 pub(crate) fn resume_key(&self) -> Option<SessionResumeKey> {
1045 self.session_resume_key
1046 }
1047
1048 fn establish_from_handshake(
1050 &mut self,
1051 result: HandshakeResult,
1052 ) -> Result<ConnectionHandle, SessionError> {
1053 self.role = result.role;
1054 self.parity = result.our_settings.parity;
1055 self.conn_ids = IdAllocator::new(result.our_settings.parity);
1056 self.local_root_settings = result.our_settings.clone();
1057 self.peer_root_settings = Some(result.peer_settings.clone());
1058 self.peer_supports_retry = result.peer_supports_retry;
1059 self.session_resume_key = result.session_resume_key;
1060
1061 if self.resumable && self.session_resume_key.is_none() {
1062 return Err(SessionError::NotResumable);
1063 }
1064
1065 Ok(self.make_root_handle(result.our_settings, result.peer_settings))
1066 }
1067
1068 fn make_root_handle(
1069 &mut self,
1070 local_settings: ConnectionSettings,
1071 peer_settings: ConnectionSettings,
1072 ) -> ConnectionHandle {
1073 self.make_connection_handle(ConnectionId::ROOT, local_settings, peer_settings)
1074 }
1075
1076 fn make_connection_handle(
1077 &mut self,
1078 conn_id: ConnectionId,
1079 local_settings: ConnectionSettings,
1080 peer_settings: ConnectionSettings,
1081 ) -> ConnectionHandle {
1082 let label = format!("session.conn{}", conn_id.0);
1083 let (conn_tx, conn_rx) = mpsc::channel::<RecvMessage>(&label, 64);
1084 let (failures_tx, failures_rx) = mpsc::unbounded_channel(format!("{label}.failures"));
1085 let (closed_tx, closed_rx) = watch::channel(false);
1086 let resumed_rx = self.resume_notifier.subscribe();
1087
1088 let sender = ConnectionSender {
1089 connection_id: conn_id,
1090 sess_core: Arc::clone(&self.sess_core),
1091 failures: Arc::new(failures_tx),
1092 };
1093
1094 let parity = local_settings.parity;
1095 trace!(%conn_id, "make_connection_handle: inserting slot into conns");
1096 self.conns.insert(
1097 conn_id,
1098 ConnectionSlot::Active(ConnectionState {
1099 id: conn_id,
1100 local_settings,
1101 peer_settings,
1102 conn_tx,
1103 closed_tx,
1104 schema_recv_tracker: Arc::new(vox_types::SchemaRecvTracker::new()),
1105 }),
1106 );
1107
1108 ConnectionHandle {
1109 sender,
1110 rx: conn_rx,
1111 failures_rx,
1112 control_tx: Some(self.control_tx.clone()),
1113 closed_rx,
1114 resumed_rx,
1115 parity,
1116 peer_supports_retry: self.peer_supports_retry,
1117 }
1118 }
1119
1120 pub async fn run(&mut self) {
1125 let mut keepalive_runtime = self.make_keepalive_runtime();
1126 let mut keepalive_tick = keepalive_runtime.as_ref().map(|_| {
1127 let mut interval = tokio::time::interval(Duration::from_millis(10));
1128 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
1129 interval
1130 });
1131
1132 loop {
1133 tokio::select! {
1134 biased;
1140
1141 msg = self.rx.recv_msg() => {
1142 vox_types::dlog!("[session {:?}] recv_msg returned", self.role);
1143 match msg {
1144 Ok(Some(msg)) => {
1145 self.handle_message(msg, &mut keepalive_runtime).await;
1146 }
1147 Ok(None) => {
1148 vox_types::dlog!("[session {:?}] recv loop: conduit returned EOF", self.role);
1149 if !self.handle_conduit_break(&mut keepalive_runtime).await {
1150 vox_types::dlog!("[session {:?}] recv loop: breaking (not resumable)", self.role);
1151 break;
1152 }
1153 }
1154 Err(error) => {
1155 vox_types::dlog!("[session {:?}] recv loop: conduit recv error: {}", self.role, error);
1156 if !self.handle_conduit_break(&mut keepalive_runtime).await {
1157 vox_types::dlog!("[session {:?}] recv loop: breaking (not resumable)", self.role);
1158 break;
1159 }
1160 }
1161 }
1162 }
1163 Some(req) = self.open_rx.recv() => {
1164 self.handle_open_request(req).await;
1165 }
1166 Some(req) = self.close_rx.recv() => {
1167 self.handle_close_request(req).await;
1168 }
1169 Some(req) = self.resume_rx.recv() => {
1170 let _ = req.result_tx.send(Err(SessionError::Protocol(
1171 "resume is only valid while the session is disconnected".into(),
1172 )));
1173 }
1174 Some(req) = self.control_rx.recv() => {
1175 if !self.handle_drop_control_request(req).await {
1176 break;
1177 }
1178 }
1179 _ = async {
1180 if let Some(interval) = keepalive_tick.as_mut() {
1181 interval.tick().await;
1182 }
1183 }, if keepalive_tick.is_some() => {
1184 if !self.handle_keepalive_tick(&mut keepalive_runtime).await {
1185 break;
1186 }
1187 }
1188 }
1189 }
1190
1191 self.close_all_connections();
1193 trace!("session recv loop exited");
1194 }
1195
1196 async fn handle_conduit_break(
1197 &mut self,
1198 keepalive_runtime: &mut Option<KeepaliveRuntime>,
1199 ) -> bool {
1200 if let Some(recoverer) = self.recoverer.as_mut() {
1207 let recovery_fut = recoverer.next_conduit(self.session_resume_key.as_ref());
1208 let recovery_result = match self.recovery_timeout {
1209 Some(timeout) => match tokio::time::timeout(timeout, recovery_fut).await {
1210 Ok(r) => r,
1211 Err(_) => return false,
1212 },
1213 None => recovery_fut.await,
1214 };
1215 match recovery_result {
1216 Ok(recovered) => {
1217 let result =
1218 self.resume_from_handshake(recovered.tx, recovered.rx, recovered.handshake);
1219 match result {
1220 Ok(()) => {
1221 let next_generation = self.resume_notifier.borrow().wrapping_add(1);
1222 let _ = self.resume_notifier.send(next_generation);
1223 *keepalive_runtime = self.make_keepalive_runtime();
1224 return true;
1225 }
1226 Err(_) => return false,
1227 }
1228 }
1229 Err(_) => return false,
1230 }
1231 }
1232
1233 if !self.registered_in_registry {
1234 return false;
1235 }
1236
1237 loop {
1238 tokio::select! {
1239 Some(req) = self.resume_rx.recv() => {
1240 let result =
1241 self.resume_from_handshake(req.tx, req.rx, req.handshake_result);
1242 let ok = result.is_ok();
1243 let _ = req.result_tx.send(result);
1244 if ok {
1245 let next_generation = self.resume_notifier.borrow().wrapping_add(1);
1246 let _ = self.resume_notifier.send(next_generation);
1247 *keepalive_runtime = self.make_keepalive_runtime();
1248 return true;
1249 }
1250 }
1251 Some(req) = self.control_rx.recv() => {
1252 if !self.handle_drop_control_request(req).await {
1253 return false;
1254 }
1255 }
1256 Some(req) = self.open_rx.recv() => {
1257 let _ = req.result_tx.send(Err(SessionError::Protocol(
1258 "session is disconnected; resume before opening connections".into(),
1259 )));
1260 }
1261 Some(req) = self.close_rx.recv() => {
1262 let _ = req.result_tx.send(Err(SessionError::Protocol(
1263 "session is disconnected; resume before closing connections".into(),
1264 )));
1265 }
1266 else => return false,
1267 }
1268 }
1269 }
1270
1271 fn resume_from_handshake(
1273 &mut self,
1274 tx: Arc<dyn DynConduitTx>,
1275 rx: Box<dyn DynConduitRx>,
1276 result: HandshakeResult,
1277 ) -> Result<(), SessionError> {
1278 let Some(peer_settings) = self.peer_root_settings.clone() else {
1279 return Err(SessionError::Protocol("missing peer root settings".into()));
1280 };
1281
1282 if result.our_settings != self.local_root_settings {
1283 return Err(SessionError::Protocol(
1284 "local root settings changed across session resume".into(),
1285 ));
1286 }
1287
1288 if result.peer_settings != peer_settings {
1289 return Err(SessionError::Protocol(
1290 "peer root settings changed across session resume".into(),
1291 ));
1292 }
1293
1294 self.peer_supports_retry = result.peer_supports_retry;
1295 self.session_resume_key = result.session_resume_key.or(self.session_resume_key);
1296
1297 self.sess_core.replace_tx_and_reset_schemas(tx);
1298 self.rx = rx;
1299 if let Some(ConnectionSlot::Active(state)) = self.conns.get_mut(&ConnectionId::ROOT) {
1302 state.schema_recv_tracker = Arc::new(vox_types::SchemaRecvTracker::new());
1303 }
1304 Ok(())
1305 }
1306
1307 async fn handle_message(
1308 &mut self,
1309 msg: SelfRef<Message<'static>>,
1310 keepalive_runtime: &mut Option<KeepaliveRuntime>,
1311 ) {
1312 let msg_ref = msg.get();
1313 let conn_id = msg_ref.connection_id;
1314 match &msg_ref.payload {
1315 MessagePayload::Ping(ping) => {
1316 let _ = self
1317 .sess_core
1318 .send(
1319 Message {
1320 connection_id: conn_id,
1321 payload: MessagePayload::Pong(vox_types::Pong { nonce: ping.nonce }),
1322 },
1323 None,
1324 None,
1325 )
1326 .await;
1327 return;
1328 }
1329 MessagePayload::Pong(pong) => {
1330 if conn_id.is_root() {
1331 self.handle_keepalive_pong(pong.nonce, keepalive_runtime);
1332 }
1333 return;
1334 }
1335 MessagePayload::SchemaMessage(schema_msg) => {
1336 let schema_recv_tracker = match self.conns.get(&conn_id) {
1337 Some(ConnectionSlot::Active(state)) => Arc::clone(&state.schema_recv_tracker),
1338 _ => return,
1339 };
1340 let _ = self.record_received_schema_cbor(
1341 conn_id,
1342 schema_recv_tracker,
1343 schema_msg.method_id,
1344 schema_msg.direction,
1345 &schema_msg.schemas,
1346 "standalone schema message",
1347 );
1348 return;
1349 }
1350 _ => {}
1351 }
1352 vox_types::selfref_match!(msg, payload {
1353 MessagePayload::ConnectionClose(_) => {
1355 if conn_id.is_root() {
1356 warn!("received ConnectionClose for root connection");
1357 } else {
1358 trace!(conn_id = conn_id.0, "received ConnectionClose for virtual connection");
1359 }
1360 self.remove_connection(&conn_id);
1364 self.maybe_request_shutdown_after_root_closed();
1365 }
1366 MessagePayload::ConnectionOpen(open) => {
1367 self.handle_inbound_open(conn_id, open).await;
1368 }
1369 MessagePayload::ConnectionAccept(accept) => {
1370 self.handle_inbound_accept(conn_id, accept);
1371 }
1372 MessagePayload::ConnectionReject(reject) => {
1373 self.handle_inbound_reject(conn_id, reject);
1374 }
1375 MessagePayload::RequestMessage(r) => {
1376 let r_ref = r.get();
1377 vox_types::dlog!(
1378 "[session {:?}] recv request: conn={:?} req={:?} body={} method={:?}",
1379 self.role,
1380 conn_id,
1381 r_ref.id,
1382 match &r_ref.body {
1383 RequestBody::Call(_) => "Call",
1384 RequestBody::Response(_) => "Response",
1385 RequestBody::Cancel(_) => "Cancel",
1386 },
1387 match &r_ref.body {
1388 RequestBody::Call(call) => Some(call.method_id),
1389 RequestBody::Response(_) | RequestBody::Cancel(_) => None,
1390 }
1391 );
1392 let response_had_schema_payload = matches!(&r_ref.body, RequestBody::Response(resp) if !resp.schemas.is_empty());
1394 {
1395 let schemas_cbor = match &r_ref.body {
1396 RequestBody::Call(call) => Some(&call.schemas),
1397 RequestBody::Response(resp) => Some(&resp.schemas),
1398 _ => None,
1399 };
1400 vox_types::dlog!(
1401 "[schema] recv ({:?}): req={:?} body={} schemas_len={:?}",
1402 self.role,
1403 r_ref.id,
1404 match &r_ref.body {
1405 RequestBody::Call(_) => "Call",
1406 RequestBody::Response(_) => "Response",
1407 RequestBody::Cancel(_) => "Cancel",
1408 },
1409 schemas_cbor.map(|s| s.0.len())
1410 );
1411 let schema_recv_tracker = match self.conns.get(&conn_id) {
1412 Some(ConnectionSlot::Active(state)) => {
1413 Arc::clone(&state.schema_recv_tracker)
1414 }
1415 _ => return,
1416 };
1417 if let Some(schemas_cbor) = schemas_cbor
1418 && !schemas_cbor.is_empty()
1419 {
1420 let (method_id, direction) = match &r_ref.body {
1421 RequestBody::Call(call) => {
1422 (call.method_id, vox_types::BindingDirection::Args)
1423 }
1424 RequestBody::Response(_) => {
1425 let Some(method_id) =
1426 self.sess_core.take_outgoing_call_method(conn_id, r_ref.id)
1427 else {
1428 self.close_connection_for_protocol_error(
1429 conn_id,
1430 format!(
1431 "response schemas for unknown inflight request {:?}",
1432 r_ref.id
1433 ),
1434 );
1435 return;
1436 };
1437 (method_id, vox_types::BindingDirection::Response)
1438 }
1439 RequestBody::Cancel(_) => unreachable!(),
1440 };
1441 if !self.record_received_schema_cbor(
1442 conn_id,
1443 schema_recv_tracker,
1444 method_id,
1445 direction,
1446 schemas_cbor,
1447 "inlined request schemas",
1448 ) {
1449 return;
1450 }
1451 }
1452 }
1453 if matches!(&r_ref.body, RequestBody::Response(_)) && !response_had_schema_payload {
1454 let _ = self.sess_core.take_outgoing_call_method(conn_id, r_ref.id);
1455 }
1456 if let RequestBody::Call(call) = &r_ref.body {
1459 self.sess_core.record_incoming_call(conn_id, r_ref.id, call.method_id);
1460 }
1461 let state = match self.conns.get(&conn_id) {
1462 Some(ConnectionSlot::Active(state)) => state,
1463 _ => return,
1464 };
1465 let conn_tx = state.conn_tx.clone();
1466 let request_id = r_ref.id;
1467 let body_kind = match &r_ref.body {
1468 RequestBody::Call(_) => "Call",
1469 RequestBody::Response(_) => "Response",
1470 RequestBody::Cancel(_) => "Cancel",
1471 };
1472 let recv_msg = RecvMessage {
1473 schemas: Arc::clone(&state.schema_recv_tracker),
1474 msg: r.map(ConnectionMessage::Request),
1475 };
1476 vox_types::dlog!(
1477 "[session {:?}] dispatch request: conn={:?} req={:?} body={}",
1478 self.role,
1479 conn_id,
1480 request_id,
1481 body_kind
1482 );
1483 if conn_tx.send(recv_msg).await.is_err() {
1484 self.remove_connection(&conn_id);
1485 self.maybe_request_shutdown_after_root_closed();
1486 }
1487 }
1488 MessagePayload::ChannelMessage(c) => {
1489 let state = match self.conns.get(&conn_id) {
1490 Some(ConnectionSlot::Active(state)) => state,
1491 _ => return,
1492 };
1493 let conn_tx = state.conn_tx.clone();
1494 let recv_msg = RecvMessage {
1495 schemas: Arc::clone(&state.schema_recv_tracker),
1496 msg: c.map(ConnectionMessage::Channel),
1497 };
1498 if conn_tx.send(recv_msg).await.is_err() {
1499 self.remove_connection(&conn_id);
1500 self.maybe_request_shutdown_after_root_closed();
1501 }
1502 }
1503 })
1505 }
1506
1507 fn make_keepalive_runtime(&self) -> Option<KeepaliveRuntime> {
1508 let config = self.keepalive?;
1509 if config.ping_interval.is_zero() || config.pong_timeout.is_zero() {
1510 warn!("keepalive disabled due to non-positive interval/timeout");
1511 return None;
1512 }
1513 let now = tokio::time::Instant::now();
1514 Some(KeepaliveRuntime {
1515 ping_interval: config.ping_interval,
1516 pong_timeout: config.pong_timeout,
1517 next_ping_at: now + config.ping_interval,
1518 waiting_pong_nonce: None,
1519 pong_deadline: now,
1520 next_ping_nonce: 1,
1521 })
1522 }
1523
1524 fn handle_keepalive_pong(&self, nonce: u64, keepalive_runtime: &mut Option<KeepaliveRuntime>) {
1525 let Some(runtime) = keepalive_runtime.as_mut() else {
1526 return;
1527 };
1528 if runtime.waiting_pong_nonce != Some(nonce) {
1529 return;
1530 }
1531 runtime.waiting_pong_nonce = None;
1532 runtime.next_ping_at = tokio::time::Instant::now() + runtime.ping_interval;
1533 }
1534
1535 async fn handle_keepalive_tick(
1536 &mut self,
1537 keepalive_runtime: &mut Option<KeepaliveRuntime>,
1538 ) -> bool {
1539 let Some(runtime) = keepalive_runtime.as_mut() else {
1540 return true;
1541 };
1542 let now = tokio::time::Instant::now();
1543
1544 if let Some(waiting_nonce) = runtime.waiting_pong_nonce {
1545 if now >= runtime.pong_deadline {
1546 warn!(
1547 nonce = waiting_nonce,
1548 timeout_ms = runtime.pong_timeout.as_millis(),
1549 "keepalive timeout waiting for pong"
1550 );
1551 return false;
1552 }
1553 return true;
1554 }
1555
1556 if now < runtime.next_ping_at {
1557 return true;
1558 }
1559
1560 let nonce = runtime.next_ping_nonce;
1561 if self
1562 .sess_core
1563 .send(
1564 Message {
1565 connection_id: ConnectionId::ROOT,
1566 payload: MessagePayload::Ping(vox_types::Ping { nonce }),
1567 },
1568 None,
1569 None,
1570 )
1571 .await
1572 .is_err()
1573 {
1574 warn!("failed to send keepalive ping");
1575 return false;
1576 }
1577
1578 runtime.waiting_pong_nonce = Some(nonce);
1579 runtime.pong_deadline = now + runtime.pong_timeout;
1580 runtime.next_ping_at = now + runtime.ping_interval;
1581 runtime.next_ping_nonce = runtime.next_ping_nonce.wrapping_add(1);
1582 true
1583 }
1584
1585 async fn handle_inbound_open(
1586 &mut self,
1587 conn_id: ConnectionId,
1588 open: SelfRef<ConnectionOpen<'static>>,
1589 ) {
1590 let peer_parity = self.parity.other();
1592 if !conn_id.has_parity(peer_parity) {
1593 let _ = self
1595 .sess_core
1596 .send(
1597 Message {
1598 connection_id: conn_id,
1599 payload: MessagePayload::ConnectionReject(vox_types::ConnectionReject {
1600 metadata: vec![],
1601 }),
1602 },
1603 None,
1604 None,
1605 )
1606 .await;
1607 return;
1608 }
1609
1610 if self.conns.contains_key(&conn_id) {
1612 let _ = self
1614 .sess_core
1615 .send(
1616 Message {
1617 connection_id: conn_id,
1618 payload: MessagePayload::ConnectionReject(vox_types::ConnectionReject {
1619 metadata: vec![],
1620 }),
1621 },
1622 None,
1623 None,
1624 )
1625 .await;
1626 return;
1627 }
1628
1629 if self.on_connection.is_none() {
1632 let _ = self
1633 .sess_core
1634 .send(
1635 Message {
1636 connection_id: conn_id,
1637 payload: MessagePayload::ConnectionReject(vox_types::ConnectionReject {
1638 metadata: vec![],
1639 }),
1640 },
1641 None,
1642 None,
1643 )
1644 .await;
1645 return;
1646 }
1647
1648 let open = open.get();
1650 let our_settings = ConnectionSettings {
1651 parity: open.connection_settings.parity.other(),
1652 max_concurrent_requests: open.connection_settings.max_concurrent_requests,
1653 };
1654
1655 let handle = self.make_connection_handle(
1657 conn_id,
1658 our_settings.clone(),
1659 open.connection_settings.clone(),
1660 );
1661
1662 let mut metadata: Vec<vox_types::MetadataEntry<'_>> = open.metadata.to_vec();
1664 metadata.push(vox_types::MetadataEntry::str(
1665 "vox-connection-kind",
1666 "virtual",
1667 ));
1668 let request = match ConnectionRequest::new(&metadata) {
1669 Ok(r) => r,
1670 Err(e) => {
1671 trace!(%conn_id, %e, "rejecting virtual connection");
1672 self.conns.remove(&conn_id);
1673 let _ = self
1674 .sess_core
1675 .send(
1676 Message {
1677 connection_id: conn_id,
1678 payload: MessagePayload::ConnectionReject(
1679 vox_types::ConnectionReject {
1680 metadata: vec![vox_types::MetadataEntry::str(
1681 "error",
1682 e.to_string(),
1683 )],
1684 },
1685 ),
1686 },
1687 None,
1688 None,
1689 )
1690 .await;
1691 return;
1692 }
1693 };
1694 let pending = PendingConnection::new(handle);
1695 let acceptor = self.on_connection.as_ref().unwrap();
1696 trace!(%conn_id, "calling acceptor for virtual connection");
1697 match acceptor.accept(&request, pending) {
1698 Ok(()) => {
1699 trace!(%conn_id, "acceptor accepted virtual connection, sending ConnectionAccept");
1700 let _ = self
1701 .sess_core
1702 .send(
1703 Message {
1704 connection_id: conn_id,
1705 payload: MessagePayload::ConnectionAccept(
1706 vox_types::ConnectionAccept {
1707 connection_settings: our_settings,
1708 metadata: vec![],
1709 },
1710 ),
1711 },
1712 None,
1713 None,
1714 )
1715 .await;
1716 }
1717 Err(reject_metadata) => {
1718 trace!(%conn_id, "acceptor rejected, removing conn slot");
1720 self.conns.remove(&conn_id);
1721 let _ = self
1722 .sess_core
1723 .send(
1724 Message {
1725 connection_id: conn_id,
1726 payload: MessagePayload::ConnectionReject(
1727 vox_types::ConnectionReject {
1728 metadata: reject_metadata,
1729 },
1730 ),
1731 },
1732 None,
1733 None,
1734 )
1735 .await;
1736 }
1737 }
1738 }
1739
1740 fn handle_inbound_accept(
1741 &mut self,
1742 conn_id: ConnectionId,
1743 accept: SelfRef<ConnectionAccept<'static>>,
1744 ) {
1745 let accept = accept.get();
1746 let slot = self.remove_connection(&conn_id);
1747 match slot {
1748 Some(ConnectionSlot::PendingOutbound(mut pending)) => {
1749 let handle = self.make_connection_handle(
1750 conn_id,
1751 pending.local_settings.clone(),
1752 accept.connection_settings.clone(),
1753 );
1754
1755 if let Some(tx) = pending.result_tx.take() {
1756 let _ = tx.send(Ok(handle));
1757 }
1758 }
1759 Some(other) => {
1760 self.conns.insert(conn_id, other);
1762 }
1763 None => {
1764 }
1766 }
1767 }
1768
1769 fn handle_inbound_reject(
1770 &mut self,
1771 conn_id: ConnectionId,
1772 reject: SelfRef<ConnectionReject<'static>>,
1773 ) {
1774 let reject = reject.get();
1775 let slot = self.remove_connection(&conn_id);
1776 match slot {
1777 Some(ConnectionSlot::PendingOutbound(mut pending)) => {
1778 if let Some(tx) = pending.result_tx.take() {
1779 let _ = tx.send(Err(SessionError::Rejected(vox_types::metadata_into_owned(
1780 reject.metadata.to_vec(),
1781 ))));
1782 }
1783 }
1784 Some(other) => {
1785 self.conns.insert(conn_id, other);
1786 }
1787 None => {}
1788 }
1789 }
1790
1791 async fn handle_open_request(&mut self, req: OpenRequest) {
1793 let conn_id = self.conn_ids.alloc();
1794
1795 let send_result = self
1797 .sess_core
1798 .send(
1799 Message {
1800 connection_id: conn_id,
1801 payload: MessagePayload::ConnectionOpen(ConnectionOpen {
1802 connection_settings: req.settings.clone(),
1803 metadata: req.metadata,
1804 }),
1805 },
1806 None,
1807 None,
1808 )
1809 .await;
1810
1811 if send_result.is_err() {
1812 let _ = req.result_tx.send(Err(SessionError::Protocol(
1813 "failed to send ConnectionOpen".into(),
1814 )));
1815 return;
1816 }
1817
1818 self.conns.insert(
1821 conn_id,
1822 ConnectionSlot::PendingOutbound(PendingOutboundData {
1823 local_settings: req.settings,
1824 result_tx: Some(req.result_tx),
1825 }),
1826 );
1827 }
1828
1829 async fn handle_close_request(&mut self, req: CloseRequest) {
1831 if req.conn_id.is_root() {
1832 let _ = req.result_tx.send(Err(SessionError::Protocol(
1833 "cannot close root connection".into(),
1834 )));
1835 return;
1836 }
1837
1838 if self.remove_connection(&req.conn_id).is_none() {
1841 let _ = req
1842 .result_tx
1843 .send(Err(SessionError::Protocol("connection not found".into())));
1844 return;
1845 }
1846
1847 let send_result = self
1849 .sess_core
1850 .send(
1851 Message {
1852 connection_id: req.conn_id,
1853 payload: MessagePayload::ConnectionClose(ConnectionClose {
1854 metadata: req.metadata,
1855 }),
1856 },
1857 None,
1858 None,
1859 )
1860 .await;
1861
1862 if send_result.is_err() {
1863 let _ = req.result_tx.send(Err(SessionError::Protocol(
1864 "failed to send ConnectionClose".into(),
1865 )));
1866 return;
1867 }
1868
1869 let _ = req.result_tx.send(Ok(()));
1870 self.maybe_request_shutdown_after_root_closed();
1871 }
1872
1873 async fn handle_drop_control_request(&mut self, req: DropControlRequest) -> bool {
1874 match req {
1875 DropControlRequest::Shutdown => {
1876 trace!("session shutdown requested");
1877 false
1878 }
1879 DropControlRequest::Close(conn_id) => {
1880 if conn_id.is_root() {
1882 trace!("root callers dropped; internally closing root connection");
1884 self.root_closed_internal = true;
1885 return self.has_virtual_connections();
1887 }
1888
1889 if self.remove_connection(&conn_id).is_some() {
1890 let _ = self
1891 .sess_core
1892 .send(
1893 Message {
1894 connection_id: conn_id,
1895 payload: MessagePayload::ConnectionClose(ConnectionClose {
1896 metadata: vec![],
1897 }),
1898 },
1899 None,
1900 None,
1901 )
1902 .await;
1903 }
1904
1905 !self.root_closed_internal || self.has_virtual_connections()
1906 }
1907 }
1908 }
1909
1910 fn has_virtual_connections(&self) -> bool {
1911 self.conns.keys().any(|id| !id.is_root())
1912 }
1913
1914 fn remove_connection(&mut self, conn_id: &ConnectionId) -> Option<ConnectionSlot> {
1915 trace!(%conn_id, "remove_connection called");
1916 let slot = self.conns.remove(conn_id);
1917 if let Some(ConnectionSlot::Active(state)) = &slot {
1918 let _ = state.closed_tx.send(true);
1919 }
1920 slot
1921 }
1922
1923 fn close_all_connections(&mut self) {
1924 trace!(role = ?self.role, count = self.conns.len(), "close_all_connections");
1925 vox_types::dlog!(
1926 "[session {:?}] close_all_connections: {} slots",
1927 self.role,
1928 self.conns.len()
1929 );
1930 for (conn_id, slot) in self.conns.iter() {
1931 if let ConnectionSlot::Active(state) = slot {
1932 vox_types::dlog!("[session {:?}] closing connection {:?}", self.role, conn_id);
1933 let _ = state.closed_tx.send(true);
1934 }
1935 }
1936 self.conns.clear();
1937 }
1938
1939 fn maybe_request_shutdown_after_root_closed(&self) {
1940 if self.root_closed_internal && !self.has_virtual_connections() {
1941 let _ = send_drop_control(&self.control_tx, DropControlRequest::Shutdown);
1942 }
1943 }
1944}
1945
1946pub(crate) struct SessionCore {
1947 inner: std::sync::Mutex<SessionCoreInner>,
1948 outbound_tx: tokio_mpsc::Sender<OutboundBatch>,
1949}
1950
1951pub trait OutboundSendFuture: Future<Output = std::io::Result<()>> + MaybeSend + 'static {}
1952impl<T> OutboundSendFuture for T where T: Future<Output = std::io::Result<()>> + MaybeSend + 'static {}
1953
1954type OutboundSend = Pin<Box<dyn OutboundSendFuture>>;
1955
1956#[derive(Clone)]
1957struct PendingSchemaSend {
1958 method_id: vox_types::MethodId,
1959 direction: vox_types::BindingDirection,
1960 prepared: vox_types::PreparedSchemaPlan,
1961}
1962
1963struct OutboundBatch {
1964 conn_id: ConnectionId,
1965 conn_state: Arc<std::sync::Mutex<SendConnState>>,
1966 tx: Arc<dyn DynConduitTx>,
1967 schema_sends: Vec<PendingSchemaSend>,
1968 payload_send: OutboundSend,
1969 result_tx: tokio_oneshot::Sender<std::io::Result<()>>,
1970}
1971
1972async fn run_outbound_worker(mut rx: tokio_mpsc::Receiver<OutboundBatch>) {
1973 while let Some(batch) = rx.recv().await {
1974 let mut result = Ok(());
1975 for schema_send in batch.schema_sends {
1976 let schemas = {
1977 let mut conn_state = batch
1978 .conn_state
1979 .lock()
1980 .expect("send conn state mutex poisoned");
1981 conn_state.send_tracker.preview_prepared_plan(
1982 schema_send.method_id,
1983 schema_send.direction,
1984 &schema_send.prepared,
1985 )
1986 };
1987 if schemas.is_empty() {
1988 continue;
1989 }
1990
1991 let schema_msg = Message {
1992 connection_id: batch.conn_id,
1993 payload: MessagePayload::SchemaMessage(SchemaMessage {
1994 method_id: schema_send.method_id,
1995 direction: schema_send.direction,
1996 schemas,
1997 }),
1998 };
1999 let send = match batch.tx.clone().prepare_msg(schema_msg, None) {
2000 Ok(send) => send,
2001 Err(error) => {
2002 result = Err(error);
2003 break;
2004 }
2005 };
2006 if let Err(error) = send.await {
2007 result = Err(error);
2008 break;
2009 }
2010 let mut conn_state = batch
2011 .conn_state
2012 .lock()
2013 .expect("send conn state mutex poisoned");
2014 conn_state.send_tracker.mark_prepared_plan_sent(
2015 schema_send.method_id,
2016 schema_send.direction,
2017 &schema_send.prepared,
2018 );
2019 conn_state
2020 .planned_bindings
2021 .remove(&(schema_send.direction, schema_send.method_id));
2022 }
2023 if result.is_ok()
2024 && let Err(error) = batch.payload_send.await
2025 {
2026 result = Err(error);
2027 }
2028 let _ = batch.result_tx.send(result);
2029 }
2030}
2031
2032#[cfg(not(target_arch = "wasm32"))]
2033fn spawn_outbound_worker(rx: tokio_mpsc::Receiver<OutboundBatch>) {
2034 if tokio::runtime::Handle::try_current().is_ok() {
2035 tokio::spawn(run_outbound_worker(rx));
2036 return;
2037 }
2038
2039 std::thread::spawn(move || {
2040 let runtime = tokio::runtime::Builder::new_current_thread()
2041 .enable_all()
2042 .build()
2043 .expect("build outbound worker runtime");
2044 runtime.block_on(run_outbound_worker(rx));
2045 });
2046}
2047
2048#[cfg(target_arch = "wasm32")]
2049fn spawn_outbound_worker(rx: tokio_mpsc::Receiver<OutboundBatch>) {
2050 wasm_bindgen_futures::spawn_local(run_outbound_worker(rx));
2051}
2052
2053struct SendConnState {
2054 send_tracker: vox_types::SchemaSendTracker,
2056
2057 inflight_incoming: HashMap<RequestId, vox_types::MethodId>,
2060
2061 inflight_outgoing: HashMap<RequestId, vox_types::MethodId>,
2064
2065 planned_bindings:
2067 HashMap<(vox_types::BindingDirection, vox_types::MethodId), vox_types::PreparedSchemaPlan>,
2068}
2069
2070impl SendConnState {
2071 fn new() -> Self {
2072 SendConnState {
2073 send_tracker: vox_types::SchemaSendTracker::new(),
2074 inflight_incoming: HashMap::new(),
2075 inflight_outgoing: HashMap::new(),
2076 planned_bindings: HashMap::new(),
2077 }
2078 }
2079}
2080
2081struct SessionCoreInner {
2082 tx: Arc<dyn DynConduitTx>,
2084
2085 conns: HashMap<ConnectionId, Arc<std::sync::Mutex<SendConnState>>>,
2087}
2088
2089fn get_or_create_send_conn_state(
2090 inner: &mut SessionCoreInner,
2091 conn_id: ConnectionId,
2092) -> Arc<std::sync::Mutex<SendConnState>> {
2093 inner
2094 .conns
2095 .entry(conn_id)
2096 .or_insert_with(|| Arc::new(std::sync::Mutex::new(SendConnState::new())))
2097 .clone()
2098}
2099
2100impl SessionCore {
2101 pub(crate) async fn send<'a>(
2103 &self,
2104 mut msg: Message<'a>,
2105 binder: Option<&'a dyn vox_types::ChannelBinder>,
2106 forwarded_schemas: Option<&vox_types::SchemaRecvTracker>,
2107 ) -> Result<(), ()> {
2108 let conn_id = msg.connection_id;
2109 let (tx, conn_state, schema_sends) = {
2110 let mut inner = self.inner.lock().expect("session core mutex poisoned");
2111 let tx = inner.tx.clone();
2112 let conn_state = get_or_create_send_conn_state(&mut inner, conn_id);
2113 drop(inner);
2114
2115 if let MessagePayload::RequestMessage(req) = &mut msg.payload {
2116 vox_types::dlog!(
2117 "[session-core] send request: conn={:?} req={:?} body={} forwarded={}",
2118 conn_id,
2119 req.id,
2120 match &req.body {
2121 RequestBody::Call(_) => "Call",
2122 RequestBody::Response(_) => "Response",
2123 RequestBody::Cancel(_) => "Cancel",
2124 },
2125 forwarded_schemas.is_some()
2126 );
2127 let schema_sends = {
2128 let mut conn_state_guard =
2129 conn_state.lock().expect("send conn state mutex poisoned");
2130 let mut schema_sends = Vec::new();
2131 match &mut req.body {
2132 RequestBody::Call(call) => {
2133 if let Some(schema_send) = Self::plan_call_schema_send(
2134 &mut conn_state_guard,
2135 req.id,
2136 call.method_id,
2137 call,
2138 forwarded_schemas,
2139 ) {
2140 schema_sends.push(schema_send);
2141 }
2142 call.schemas = Default::default();
2143 }
2144 RequestBody::Response(resp) => {
2145 if let Some(method_id) =
2146 conn_state_guard.inflight_incoming.remove(&req.id)
2147 && let Some(schema_send) = Self::plan_response_schema_send(
2148 &mut conn_state_guard,
2149 req.id,
2150 method_id,
2151 resp,
2152 forwarded_schemas,
2153 )
2154 {
2155 schema_sends.push(schema_send);
2156 }
2157 resp.schemas = Default::default();
2158 }
2159 RequestBody::Cancel(_) => {}
2160 }
2161 schema_sends
2162 };
2163 (tx, conn_state, schema_sends)
2164 } else {
2165 (tx, conn_state, Vec::new())
2166 }
2167 };
2168 let payload_send = tx.clone().prepare_msg(msg, binder).map_err(|_| ())?;
2169
2170 let (result_tx, result_rx) = tokio_oneshot::channel();
2171 self.outbound_tx
2172 .send(OutboundBatch {
2173 conn_id,
2174 conn_state,
2175 tx,
2176 schema_sends,
2177 payload_send,
2178 result_tx,
2179 })
2180 .await
2181 .map_err(|_| ())?;
2182 result_rx.await.map_err(|_| ())?.map_err(|_| ())
2183 }
2184
2185 pub(crate) fn record_incoming_call(
2188 &self,
2189 conn_id: ConnectionId,
2190 request_id: RequestId,
2191 method_id: vox_types::MethodId,
2192 ) {
2193 let mut inner = self.inner.lock().expect("session core mutex poisoned");
2194 let conn_state = get_or_create_send_conn_state(&mut inner, conn_id);
2195 vox_types::dlog!(
2196 "[schema] record_incoming_call: conn={:?} req={:?} method={:?}",
2197 conn_id,
2198 request_id,
2199 method_id
2200 );
2201 conn_state
2202 .lock()
2203 .expect("send conn state mutex poisoned")
2204 .inflight_incoming
2205 .insert(request_id, method_id);
2206 }
2207
2208 pub(crate) fn take_outgoing_call_method(
2209 &self,
2210 conn_id: ConnectionId,
2211 request_id: RequestId,
2212 ) -> Option<vox_types::MethodId> {
2213 let inner = self.inner.lock().expect("session core mutex poisoned");
2214 inner.conns.get(&conn_id).and_then(|conn_state| {
2215 conn_state
2216 .lock()
2217 .expect("send conn state mutex poisoned")
2218 .inflight_outgoing
2219 .remove(&request_id)
2220 })
2221 }
2222
2223 pub(crate) fn prepare_response_for_method(
2224 &self,
2225 conn_id: ConnectionId,
2226 request_id: RequestId,
2227 method_id: vox_types::MethodId,
2228 response: &mut RequestResponse<'_>,
2229 ) {
2230 let mut inner = self.inner.lock().expect("session core mutex poisoned");
2231 let conn_state = get_or_create_send_conn_state(&mut inner, conn_id);
2232 let mut conn_state = conn_state.lock().expect("send conn state mutex poisoned");
2233 let key = (vox_types::BindingDirection::Response, method_id);
2234 if conn_state
2235 .send_tracker
2236 .has_sent_binding(method_id, vox_types::BindingDirection::Response)
2237 {
2238 response.schemas = Default::default();
2239 return;
2240 }
2241
2242 let prepared = match &response.ret {
2243 vox_types::Payload::Value { shape, .. } => {
2244 match Self::get_or_plan_binding_for_shape(
2245 &mut conn_state,
2246 key,
2247 request_id,
2248 "response",
2249 shape,
2250 ) {
2251 Some(prepared) => prepared,
2252 None => return,
2253 }
2254 }
2255 vox_types::Payload::PostcardBytes(_) => {
2256 tracing::error!(
2257 "schema attachment failed: missing forwarded response schemas for method {:?}",
2258 method_id
2259 );
2260 return;
2261 }
2262 };
2263 response.schemas = prepared.to_cbor();
2264 }
2265
2266 pub(crate) fn schema_registry(&self, conn_id: ConnectionId) -> vox_types::SchemaRegistry {
2269 let inner = self.inner.lock().expect("session core mutex poisoned");
2270 inner
2271 .conns
2272 .get(&conn_id)
2273 .map(|cs| {
2274 cs.lock()
2275 .expect("send conn state mutex poisoned")
2276 .send_tracker
2277 .registry()
2278 .clone()
2279 })
2280 .unwrap_or_default()
2281 }
2282
2283 pub(crate) fn prepare_response_from_source(
2285 &self,
2286 conn_id: ConnectionId,
2287 _request_id: RequestId,
2288 method_id: vox_types::MethodId,
2289 root_type: &vox_types::TypeRef,
2290 source: &dyn vox_types::SchemaSource,
2291 response: &mut RequestResponse<'_>,
2292 ) {
2293 let mut inner = self.inner.lock().expect("session core mutex poisoned");
2294 let conn_state = get_or_create_send_conn_state(&mut inner, conn_id);
2295 let mut conn_state = conn_state.lock().expect("send conn state mutex poisoned");
2296 let key = (vox_types::BindingDirection::Response, method_id);
2297 if conn_state
2298 .send_tracker
2299 .has_sent_binding(method_id, vox_types::BindingDirection::Response)
2300 {
2301 response.schemas = Default::default();
2302 return;
2303 }
2304 let prepared =
2305 Self::get_or_plan_binding_from_source(&mut conn_state, key, root_type, source);
2306 response.schemas = prepared.to_cbor();
2307 }
2308
2309 fn get_or_plan_binding_for_shape(
2310 conn_state: &mut SendConnState,
2311 key: (vox_types::BindingDirection, vox_types::MethodId),
2312 request_id: RequestId,
2313 kind: &str,
2314 shape: &'static Shape,
2315 ) -> Option<vox_types::PreparedSchemaPlan> {
2316 if let Some(prepared) = conn_state.planned_bindings.get(&key) {
2317 return Some(prepared.clone());
2318 }
2319 match vox_types::SchemaSendTracker::plan_for_shape(shape) {
2320 Ok(prepared) => {
2321 vox_types::dlog!(
2322 "[schema] planned {} {} schemas for method {:?} (req {:?})",
2323 prepared.schemas.len(),
2324 kind,
2325 key.1,
2326 request_id
2327 );
2328 conn_state.planned_bindings.insert(key, prepared.clone());
2329 Some(prepared)
2330 }
2331 Err(e) => {
2332 tracing::error!("schema extraction failed: {e}");
2333 None
2334 }
2335 }
2336 }
2337
2338 fn get_or_plan_binding_from_source(
2339 conn_state: &mut SendConnState,
2340 key: (vox_types::BindingDirection, vox_types::MethodId),
2341 root_type: &vox_types::TypeRef,
2342 source: &dyn vox_types::SchemaSource,
2343 ) -> vox_types::PreparedSchemaPlan {
2344 if let Some(prepared) = conn_state.planned_bindings.get(&key) {
2345 return prepared.clone();
2346 }
2347 let prepared = vox_types::SchemaSendTracker::plan_from_source(root_type, source);
2348 conn_state.planned_bindings.insert(key, prepared.clone());
2349 prepared
2350 }
2351
2352 fn plan_response_schema_send(
2353 conn_state: &mut SendConnState,
2354 request_id: RequestId,
2355 method_id: vox_types::MethodId,
2356 response: &mut RequestResponse<'_>,
2357 forwarded_schemas: Option<&vox_types::SchemaRecvTracker>,
2358 ) -> Option<PendingSchemaSend> {
2359 if conn_state
2360 .send_tracker
2361 .has_sent_binding(method_id, vox_types::BindingDirection::Response)
2362 {
2363 response.schemas = Default::default();
2364 return None;
2365 }
2366
2367 let key = (vox_types::BindingDirection::Response, method_id);
2368 let prepared = if !response.schemas.is_empty() {
2369 conn_state
2370 .planned_bindings
2371 .get(&key)
2372 .cloned()
2373 .unwrap_or_else(|| {
2374 let prepared_payload = vox_types::SchemaPayload::from_cbor(&response.schemas.0)
2375 .expect("prepared schema payloads must be valid CBOR");
2376 vox_types::PreparedSchemaPlan {
2377 schemas: prepared_payload.schemas,
2378 root: prepared_payload.root,
2379 }
2380 })
2381 } else {
2382 match &response.ret {
2383 vox_types::Payload::Value { shape, .. } => Self::get_or_plan_binding_for_shape(
2384 conn_state, key, request_id, "response", shape,
2385 )?,
2386 vox_types::Payload::PostcardBytes(_) => {
2387 let Some(source) = forwarded_schemas else {
2388 tracing::error!(
2389 "schema attachment failed: missing forwarded response schemas for method {:?}",
2390 method_id
2391 );
2392 return None;
2393 };
2394 let Some(root) = source.get_remote_response_root(method_id) else {
2395 tracing::error!(
2396 "schema attachment failed: missing forwarded response root for method {:?}",
2397 method_id
2398 );
2399 return None;
2400 };
2401 Self::get_or_plan_binding_from_source(conn_state, key, &root, source)
2402 }
2403 }
2404 };
2405
2406 Some(PendingSchemaSend {
2407 method_id,
2408 direction: vox_types::BindingDirection::Response,
2409 prepared,
2410 })
2411 }
2412
2413 fn plan_call_schema_send(
2414 conn_state: &mut SendConnState,
2415 request_id: RequestId,
2416 method_id: vox_types::MethodId,
2417 call: &mut vox_types::RequestCall<'_>,
2418 forwarded_schemas: Option<&vox_types::SchemaRecvTracker>,
2419 ) -> Option<PendingSchemaSend> {
2420 conn_state.inflight_outgoing.insert(request_id, method_id);
2421 if conn_state
2422 .send_tracker
2423 .has_sent_binding(method_id, vox_types::BindingDirection::Args)
2424 {
2425 call.schemas = Default::default();
2426 return None;
2427 }
2428
2429 let key = (vox_types::BindingDirection::Args, method_id);
2430 let prepared = match &call.args {
2431 vox_types::Payload::Value { shape, .. } => {
2432 Self::get_or_plan_binding_for_shape(conn_state, key, request_id, "args", shape)?
2433 }
2434 vox_types::Payload::PostcardBytes(_) => {
2435 let Some(source) = forwarded_schemas else {
2436 tracing::error!(
2437 "schema attachment failed: missing forwarded args schemas for method {:?}",
2438 method_id
2439 );
2440 return None;
2441 };
2442 let Some(root) = source.get_remote_args_root(method_id) else {
2443 tracing::error!(
2444 "schema attachment failed: missing forwarded args root for method {:?}",
2445 method_id
2446 );
2447 return None;
2448 };
2449 Self::get_or_plan_binding_from_source(conn_state, key, &root, source)
2450 }
2451 };
2452
2453 Some(PendingSchemaSend {
2454 method_id,
2455 direction: vox_types::BindingDirection::Args,
2456 prepared,
2457 })
2458 }
2459
2460 fn replace_tx_and_reset_schemas(&self, tx: Arc<dyn DynConduitTx>) {
2461 let mut inner = self.inner.lock().expect("session core mutex poisoned");
2462 inner.tx = tx;
2463 inner.conns.clear();
2464 }
2465}
2466
2467pub(crate) struct RecoveredConduit {
2468 pub tx: Arc<dyn DynConduitTx>,
2469 pub rx: Box<dyn DynConduitRx>,
2470 pub handshake: HandshakeResult,
2471}
2472
2473pub(crate) trait ConduitRecoverer: MaybeSend {
2474 fn next_conduit<'a>(
2475 &'a mut self,
2476 resume_key: Option<&'a SessionResumeKey>,
2477 ) -> BoxFut<'a, Result<RecoveredConduit, SessionError>>;
2478}
2479
2480pub trait DynConduitTx: MaybeSend + MaybeSync {
2481 fn prepare_msg<'a>(
2482 self: Arc<Self>,
2483 msg: Message<'a>,
2484 binder: Option<&'a dyn vox_types::ChannelBinder>,
2485 ) -> std::io::Result<OutboundSend>;
2486}
2487pub trait DynConduitRx: MaybeSend {
2488 fn recv_msg<'a>(&'a mut self)
2489 -> BoxFut<'a, std::io::Result<Option<SelfRef<Message<'static>>>>>;
2490}
2491
2492impl<T> DynConduitTx for T
2495where
2496 T: ConduitTx<Msg = MessageFamily> + MaybeSend + MaybeSync + 'static,
2497{
2498 fn prepare_msg<'a>(
2499 self: Arc<Self>,
2500 msg: Message<'a>,
2501 binder: Option<&'a dyn vox_types::ChannelBinder>,
2502 ) -> std::io::Result<OutboundSend> {
2503 let prepared = if let Some(binder) = binder {
2504 vox_types::with_channel_binder(binder, || self.prepare_send(msg))
2505 } else {
2506 self.prepare_send(msg)
2507 };
2508 let prepared = prepared.map_err(|e| std::io::Error::other(e.to_string()))?;
2509 Ok(Box::pin(async move {
2510 self.send_prepared(prepared)
2511 .await
2512 .map_err(|e| std::io::Error::other(e.to_string()))
2513 }))
2514 }
2515}
2516
2517impl<T> DynConduitRx for T
2518where
2519 T: ConduitRx<Msg = MessageFamily> + MaybeSend,
2520{
2521 fn recv_msg<'a>(
2522 &'a mut self,
2523 ) -> BoxFut<'a, std::io::Result<Option<SelfRef<Message<'static>>>>> {
2524 Box::pin(async move {
2525 self.recv()
2526 .await
2527 .map_err(|error| std::io::Error::other(error.to_string()))
2528 })
2529 }
2530}
2531
2532#[cfg(test)]
2533mod tests {
2534 use moire::sync::mpsc;
2535 use vox_types::{
2536 Backing, Conduit, ConnectionAccept, ConnectionReject, HandshakeResult, SelfRef,
2537 };
2538
2539 use super::*;
2540
2541 fn make_session() -> Session {
2542 let (a, b) = crate::memory_link_pair(32);
2543 std::mem::forget(b);
2545 let conduit = crate::BareConduit::new(a);
2546 let (tx, rx) = conduit.split();
2547 let (_open_tx, open_rx) = mpsc::channel::<OpenRequest>("session.open.test", 4);
2548 let (_close_tx, close_rx) = mpsc::channel::<CloseRequest>("session.close.test", 4);
2549 let (_resume_tx, resume_rx) = mpsc::channel::<ResumeRequest>("session.resume.test", 1);
2550 let (control_tx, control_rx) = mpsc::unbounded_channel("session.control.test");
2551 Session::pre_handshake(
2552 tx, rx, None, open_rx, close_rx, resume_rx, control_tx, control_rx, None, false, None,
2553 None,
2554 )
2555 }
2556
2557 fn resumed_handshake(
2558 our_settings: ConnectionSettings,
2559 peer_settings: ConnectionSettings,
2560 ) -> HandshakeResult {
2561 HandshakeResult {
2562 role: SessionRole::Initiator,
2563 our_settings,
2564 peer_settings,
2565 peer_supports_retry: true,
2566 session_resume_key: Some(SessionResumeKey([7; 16])),
2567 peer_resume_key: None,
2568 our_schema: vec![],
2569 peer_schema: vec![],
2570 peer_metadata: vec![],
2571 }
2572 }
2573
2574 fn accept_ref() -> SelfRef<ConnectionAccept<'static>> {
2575 SelfRef::owning(
2576 Backing::Boxed(Box::<[u8]>::default()),
2577 ConnectionAccept {
2578 connection_settings: ConnectionSettings {
2579 parity: Parity::Even,
2580 max_concurrent_requests: 64,
2581 },
2582 metadata: vec![],
2583 },
2584 )
2585 }
2586
2587 fn reject_ref() -> SelfRef<ConnectionReject<'static>> {
2588 SelfRef::owning(
2589 Backing::Boxed(Box::<[u8]>::default()),
2590 ConnectionReject { metadata: vec![] },
2591 )
2592 }
2593
2594 #[tokio::test]
2595 async fn duplicate_connection_accept_is_ignored_after_first() {
2596 let mut session = make_session();
2597 let conn_id = ConnectionId(1);
2598 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.test.open_result");
2599
2600 session.conns.insert(
2601 conn_id,
2602 ConnectionSlot::PendingOutbound(PendingOutboundData {
2603 local_settings: ConnectionSettings {
2604 parity: Parity::Odd,
2605 max_concurrent_requests: 64,
2606 },
2607 result_tx: Some(result_tx),
2608 }),
2609 );
2610
2611 session.handle_inbound_accept(conn_id, accept_ref());
2612 let handle = result_rx
2613 .await
2614 .expect("pending outbound result should resolve")
2615 .expect("accept should resolve as Ok");
2616 assert_eq!(handle.connection_id(), conn_id);
2617
2618 session.handle_inbound_accept(conn_id, accept_ref());
2619 assert!(
2620 matches!(
2621 session.conns.get(&conn_id),
2622 Some(ConnectionSlot::Active(ConnectionState { id, .. })) if *id == conn_id
2623 ),
2624 "duplicate accept should keep existing active connection state"
2625 );
2626 }
2627
2628 #[tokio::test]
2629 async fn duplicate_connection_reject_is_ignored_after_first() {
2630 let mut session = make_session();
2631 let conn_id = ConnectionId(1);
2632 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.test.open_result");
2633
2634 session.conns.insert(
2635 conn_id,
2636 ConnectionSlot::PendingOutbound(PendingOutboundData {
2637 local_settings: ConnectionSettings {
2638 parity: Parity::Odd,
2639 max_concurrent_requests: 64,
2640 },
2641 result_tx: Some(result_tx),
2642 }),
2643 );
2644
2645 session.handle_inbound_reject(conn_id, reject_ref());
2646 let result = result_rx
2647 .await
2648 .expect("pending outbound result should resolve");
2649 assert!(
2650 matches!(result, Err(SessionError::Rejected(_))),
2651 "expected rejection, got: {result:?}"
2652 );
2653
2654 session.handle_inbound_reject(conn_id, reject_ref());
2655 assert!(
2656 !session.conns.contains_key(&conn_id),
2657 "duplicate reject should not recreate connection state"
2658 );
2659 }
2660
2661 #[test]
2662 fn out_of_order_accept_or_reject_without_pending_is_ignored() {
2663 let mut session = make_session();
2664 let conn_id = ConnectionId(99);
2665
2666 session.handle_inbound_accept(conn_id, accept_ref());
2667 session.handle_inbound_reject(conn_id, reject_ref());
2668
2669 assert!(
2670 session.conns.is_empty(),
2671 "out-of-order accept/reject should not mutate empty connection table"
2672 );
2673 }
2674
2675 #[tokio::test]
2676 async fn close_request_clears_pending_outbound_open() {
2677 let mut session = make_session();
2678 let (open_result_tx, open_result_rx) = moire::sync::oneshot::channel("session.open.result");
2679 let (close_result_tx, close_result_rx) =
2680 moire::sync::oneshot::channel("session.close.result");
2681
2682 session.conns.insert(
2683 ConnectionId(1),
2684 ConnectionSlot::PendingOutbound(PendingOutboundData {
2685 local_settings: ConnectionSettings {
2686 parity: Parity::Odd,
2687 max_concurrent_requests: 64,
2688 },
2689 result_tx: Some(open_result_tx),
2690 }),
2691 );
2692
2693 session
2694 .handle_close_request(CloseRequest {
2695 conn_id: ConnectionId(1),
2696 metadata: vec![],
2697 result_tx: close_result_tx,
2698 })
2699 .await;
2700
2701 let close_result = close_result_rx
2702 .await
2703 .expect("close result should be delivered");
2704 assert!(
2705 close_result.is_ok(),
2706 "close should succeed for pending outbound connection"
2707 );
2708
2709 assert!(
2710 open_result_rx.await.is_err(),
2711 "pending open result channel should be closed once the pending slot is removed"
2712 );
2713 }
2714
2715 #[test]
2716 fn resume_rejects_changed_local_root_settings() {
2717 let mut session = make_session();
2718 let local_settings = ConnectionSettings {
2719 parity: Parity::Odd,
2720 max_concurrent_requests: 64,
2721 };
2722 let peer_settings = ConnectionSettings {
2723 parity: Parity::Even,
2724 max_concurrent_requests: 64,
2725 };
2726 let _root = session
2727 .establish_from_handshake(resumed_handshake(
2728 local_settings.clone(),
2729 peer_settings.clone(),
2730 ))
2731 .expect("initial handshake should establish session");
2732
2733 let (link_a, _link_b) = crate::memory_link_pair(32);
2734 let conduit = crate::BareConduit::new(link_a);
2735 let (tx, rx) = conduit.split();
2736
2737 let result = session.resume_from_handshake(
2738 Arc::new(tx),
2739 Box::new(rx),
2740 resumed_handshake(
2741 ConnectionSettings {
2742 parity: Parity::Odd,
2743 max_concurrent_requests: 65,
2744 },
2745 peer_settings,
2746 ),
2747 );
2748
2749 assert!(
2750 matches!(
2751 &result,
2752 Err(SessionError::Protocol(message))
2753 if message == "local root settings changed across session resume"
2754 ),
2755 "expected local-root-settings mismatch, got: {result:?}"
2756 );
2757 }
2758
2759 #[test]
2760 fn resume_rejects_changed_peer_root_settings() {
2761 let mut session = make_session();
2762 let local_settings = ConnectionSettings {
2763 parity: Parity::Odd,
2764 max_concurrent_requests: 64,
2765 };
2766 let peer_settings = ConnectionSettings {
2767 parity: Parity::Even,
2768 max_concurrent_requests: 64,
2769 };
2770 let _root = session
2771 .establish_from_handshake(resumed_handshake(
2772 local_settings.clone(),
2773 peer_settings.clone(),
2774 ))
2775 .expect("initial handshake should establish session");
2776
2777 let (link_a, _link_b) = crate::memory_link_pair(32);
2778 let conduit = crate::BareConduit::new(link_a);
2779 let (tx, rx) = conduit.split();
2780
2781 let result = session.resume_from_handshake(
2782 Arc::new(tx),
2783 Box::new(rx),
2784 resumed_handshake(
2785 local_settings,
2786 ConnectionSettings {
2787 parity: Parity::Even,
2788 max_concurrent_requests: 65,
2789 },
2790 ),
2791 );
2792
2793 assert!(
2794 matches!(
2795 &result,
2796 Err(SessionError::Protocol(message))
2797 if message == "peer root settings changed across session resume"
2798 ),
2799 "expected peer-root-settings mismatch, got: {result:?}"
2800 );
2801 }
2802}