1use std::{
2 collections::{BTreeMap, HashMap, HashSet},
3 sync::Arc,
4 time::Duration,
5};
6
7use moire::sync::mpsc;
8use vox_types::{
9 BoxFut, ChannelMessage, Conduit, ConduitRx, ConduitTx, ConduitTxPermit, ConnectionAccept,
10 ConnectionClose, ConnectionId, ConnectionOpen, ConnectionReject, ConnectionSettings,
11 HandshakeResult, IdAllocator, MaybeSend, MaybeSync, Message, MessageFamily, MessagePayload,
12 Metadata, Parity, RequestBody, RequestId, RequestMessage, RequestResponse, SelfRef,
13 SessionResumeKey, SessionRole,
14};
15use tokio::sync::watch;
16use tracing::{debug, warn};
17
18mod builders;
19pub use builders::*;
20
21#[derive(Debug, Clone, Copy)]
23pub struct SessionKeepaliveConfig {
24 pub ping_interval: Duration,
25 pub pong_timeout: Duration,
26}
27
28pub trait ConnectionAcceptor: Send + 'static {
41 fn accept(
42 &self,
43 conn_id: ConnectionId,
44 peer_settings: &ConnectionSettings,
45 metadata: &[vox_types::MetadataEntry],
46 ) -> Result<AcceptedConnection, Metadata<'static>>;
47}
48
49pub struct AcceptedConnection {
51 pub settings: ConnectionSettings,
53 pub metadata: Metadata<'static>,
55 pub setup: Box<dyn FnOnce(ConnectionHandle) + Send>,
57}
58
59struct OpenRequest {
64 settings: ConnectionSettings,
65 metadata: Metadata<'static>,
66 result_tx: moire::sync::oneshot::Sender<Result<ConnectionHandle, SessionError>>,
67}
68
69struct CloseRequest {
70 conn_id: ConnectionId,
71 metadata: Metadata<'static>,
72 result_tx: moire::sync::oneshot::Sender<Result<(), SessionError>>,
73}
74
75struct ResumeRequest {
76 tx: Arc<dyn DynConduitTx>,
77 rx: Box<dyn DynConduitRx>,
78 handshake_result: HandshakeResult,
79 result_tx: moire::sync::oneshot::Sender<Result<(), SessionError>>,
80}
81
82#[derive(Debug, Clone, Copy)]
83pub(crate) enum DropControlRequest {
84 Shutdown,
85 Close(ConnectionId),
86}
87
88#[derive(Clone, Copy, Debug)]
89pub(crate) enum FailureDisposition {
90 Cancelled,
91 Indeterminate,
92}
93
94#[cfg(not(target_arch = "wasm32"))]
95fn send_drop_control(
96 tx: &mpsc::UnboundedSender<DropControlRequest>,
97 req: DropControlRequest,
98) -> Result<(), ()> {
99 tx.send(req).map_err(|_| ())
100}
101
102#[cfg(target_arch = "wasm32")]
103fn send_drop_control(
104 tx: &mpsc::UnboundedSender<DropControlRequest>,
105 req: DropControlRequest,
106) -> Result<(), ()> {
107 tx.try_send(req).map_err(|_| ())
108}
109
110#[derive(Clone)]
121pub struct SessionHandle {
122 open_tx: mpsc::Sender<OpenRequest>,
123 close_tx: mpsc::Sender<CloseRequest>,
124 resume_tx: mpsc::Sender<ResumeRequest>,
125 control_tx: mpsc::UnboundedSender<DropControlRequest>,
126 resume_key: Option<SessionResumeKey>,
127}
128
129impl SessionHandle {
130 pub async fn open_connection(
137 &self,
138 settings: ConnectionSettings,
139 metadata: Metadata<'static>,
140 ) -> Result<ConnectionHandle, SessionError> {
141 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.open_result");
142 self.open_tx
143 .send(OpenRequest {
144 settings,
145 metadata,
146 result_tx,
147 })
148 .await
149 .map_err(|_| SessionError::Protocol("session closed".into()))?;
150 result_rx
151 .await
152 .map_err(|_| SessionError::Protocol("session closed".into()))?
153 }
154
155 pub async fn close_connection(
162 &self,
163 conn_id: ConnectionId,
164 metadata: Metadata<'static>,
165 ) -> Result<(), SessionError> {
166 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.close_result");
167 self.close_tx
168 .send(CloseRequest {
169 conn_id,
170 metadata,
171 result_tx,
172 })
173 .await
174 .map_err(|_| SessionError::Protocol("session closed".into()))?;
175 result_rx
176 .await
177 .map_err(|_| SessionError::Protocol("session closed".into()))?
178 }
179
180 pub async fn resume<I: crate::IntoConduit>(
181 &self,
182 into_conduit: I,
183 handshake_result: HandshakeResult,
184 ) -> Result<(), SessionError>
185 where
186 I::Conduit: Conduit<Msg = MessageFamily> + 'static,
187 <I::Conduit as Conduit>::Tx: MaybeSend + MaybeSync + 'static,
188 for<'p> <<I::Conduit as Conduit>::Tx as ConduitTx>::Permit<'p>: MaybeSend,
189 <I::Conduit as Conduit>::Rx: MaybeSend + 'static,
190 {
191 let (tx, rx) = into_conduit.into_conduit().split();
192 self.resume_parts(Arc::new(tx), Box::new(rx), handshake_result)
193 .await
194 }
195
196 pub(crate) async fn resume_parts(
197 &self,
198 tx: Arc<dyn DynConduitTx>,
199 rx: Box<dyn DynConduitRx>,
200 handshake_result: HandshakeResult,
201 ) -> Result<(), SessionError> {
202 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.resume_result");
203 self.resume_tx
204 .send(ResumeRequest {
205 tx,
206 rx,
207 handshake_result,
208 result_tx,
209 })
210 .await
211 .map_err(|_| SessionError::Protocol("session closed".into()))?;
212 result_rx
213 .await
214 .map_err(|_| SessionError::Protocol("session closed".into()))?
215 }
216
217 pub fn resume_key(&self) -> Option<&SessionResumeKey> {
219 self.resume_key.as_ref()
220 }
221
222 pub fn shutdown(&self) -> Result<(), SessionError> {
224 send_drop_control(&self.control_tx, DropControlRequest::Shutdown)
225 .map_err(|_| SessionError::Protocol("session closed".into()))
226 }
227}
228
229pub struct Session {
237 rx: Box<dyn DynConduitRx>,
239
240 role: SessionRole,
242
243 parity: Parity,
246
247 sess_core: Arc<SessionCore>,
249 peer_supports_retry: bool,
250 local_root_settings: ConnectionSettings,
251 peer_root_settings: Option<ConnectionSettings>,
252 resumable: bool,
253 session_resume_key: Option<SessionResumeKey>,
254
255 conns: BTreeMap<ConnectionId, ConnectionSlot>,
257 root_closed_internal: bool,
259
260 conn_ids: IdAllocator<ConnectionId>,
262
263 on_connection: Option<Box<dyn ConnectionAcceptor>>,
265
266 open_rx: mpsc::Receiver<OpenRequest>,
268
269 close_rx: mpsc::Receiver<CloseRequest>,
271
272 resume_rx: mpsc::Receiver<ResumeRequest>,
274
275 control_tx: mpsc::UnboundedSender<DropControlRequest>,
277 control_rx: mpsc::UnboundedReceiver<DropControlRequest>,
278
279 keepalive: Option<SessionKeepaliveConfig>,
281 resume_notifier: watch::Sender<u64>,
282 recoverer: Option<Box<dyn ConduitRecoverer>>,
283}
284
285#[derive(Debug)]
286struct KeepaliveRuntime {
287 ping_interval: Duration,
288 pong_timeout: Duration,
289 next_ping_at: tokio::time::Instant,
290 waiting_pong_nonce: Option<u64>,
291 pong_deadline: tokio::time::Instant,
292 next_ping_nonce: u64,
293}
294
295#[derive(Debug)]
298pub struct ConnectionState {
299 pub id: ConnectionId,
301
302 pub local_settings: ConnectionSettings,
304
305 pub peer_settings: ConnectionSettings,
307
308 conn_tx: mpsc::Sender<RecvMessage>,
310 closed_tx: watch::Sender<bool>,
311
312 schema_recv_tracker: Arc<vox_types::SchemaRecvTracker>,
314}
315
316#[derive(Debug)]
317enum ConnectionSlot {
318 Active(ConnectionState),
319 PendingOutbound(PendingOutboundData),
320}
321
322struct PendingOutboundData {
324 local_settings: ConnectionSettings,
325 result_tx: Option<moire::sync::oneshot::Sender<Result<ConnectionHandle, SessionError>>>,
326}
327
328impl std::fmt::Debug for PendingOutboundData {
329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330 f.debug_struct("PendingOutbound")
331 .field("local_settings", &self.local_settings)
332 .finish()
333 }
334}
335
336#[derive(Clone)]
337pub(crate) struct ConnectionSender {
338 connection_id: ConnectionId,
339 pub(crate) sess_core: Arc<SessionCore>,
340 failures: Arc<mpsc::UnboundedSender<(RequestId, FailureDisposition)>>,
341}
342
343fn forwarded_payload<'a>(payload: &'a vox_types::Payload<'static>) -> vox_types::Payload<'a> {
344 let vox_types::Payload::PostcardBytes(bytes) = payload else {
345 unreachable!("proxy forwarding expects decoded incoming payload bytes")
346 };
347 vox_types::Payload::PostcardBytes(bytes)
348}
349
350fn forwarded_request_body<'a>(body: &'a RequestBody<'static>) -> RequestBody<'a> {
351 match body {
352 RequestBody::Call(call) => RequestBody::Call(vox_types::RequestCall {
353 method_id: call.method_id,
354 metadata: call.metadata.clone(),
355 args: forwarded_payload(&call.args),
356 schemas: call.schemas.clone(),
357 }),
358 RequestBody::Response(response) => RequestBody::Response(RequestResponse {
359 metadata: response.metadata.clone(),
360 ret: forwarded_payload(&response.ret),
361 schemas: response.schemas.clone(),
362 }),
363 RequestBody::Cancel(cancel) => RequestBody::Cancel(vox_types::RequestCancel {
364 metadata: cancel.metadata.clone(),
365 }),
366 }
367}
368
369fn forwarded_channel_body<'a>(
370 body: &'a vox_types::ChannelBody<'static>,
371) -> vox_types::ChannelBody<'a> {
372 match body {
373 vox_types::ChannelBody::Item(item) => {
374 vox_types::ChannelBody::Item(vox_types::ChannelItem {
375 item: forwarded_payload(&item.item),
376 })
377 }
378 vox_types::ChannelBody::Close(close) => {
379 vox_types::ChannelBody::Close(vox_types::ChannelClose {
380 metadata: close.metadata.clone(),
381 })
382 }
383 vox_types::ChannelBody::Reset(reset) => {
384 vox_types::ChannelBody::Reset(vox_types::ChannelReset {
385 metadata: reset.metadata.clone(),
386 })
387 }
388 vox_types::ChannelBody::GrantCredit(credit) => {
389 vox_types::ChannelBody::GrantCredit(vox_types::ChannelGrantCredit {
390 additional: credit.additional,
391 })
392 }
393 }
394}
395
396impl ConnectionSender {
397 pub(crate) fn connection_id(&self) -> ConnectionId {
398 self.connection_id
399 }
400
401 pub(crate) async fn send_with_binder<'a>(
402 &self,
403 msg: ConnectionMessage<'a>,
404 binder: Option<&'a dyn vox_types::ChannelBinder>,
405 ) -> Result<(), ()> {
406 let payload = match msg {
407 ConnectionMessage::Request(r) => MessagePayload::RequestMessage(r),
408 ConnectionMessage::Channel(c) => MessagePayload::ChannelMessage(c),
409 };
410 let message = Message {
411 connection_id: self.connection_id,
412 payload,
413 };
414 self.sess_core
415 .send(message, binder, None)
416 .await
417 .map_err(|_| ())
418 }
419
420 pub async fn send<'a>(&self, msg: ConnectionMessage<'a>) -> Result<(), ()> {
422 self.send_with_binder(msg, None).await
423 }
424
425 pub(crate) async fn send_owned(
427 &self,
428 schemas: Arc<vox_types::SchemaRecvTracker>,
429 msg: SelfRef<ConnectionMessage<'static>>,
430 ) -> Result<(), ()> {
431 let payload = match &*msg {
432 ConnectionMessage::Request(request) => MessagePayload::RequestMessage(RequestMessage {
433 id: request.id,
434 body: forwarded_request_body(&request.body),
435 }),
436 ConnectionMessage::Channel(channel) => MessagePayload::ChannelMessage(ChannelMessage {
437 id: channel.id,
438 body: forwarded_channel_body(&channel.body),
439 }),
440 };
441
442 self.sess_core
443 .send(
444 Message {
445 connection_id: self.connection_id,
446 payload,
447 },
448 None,
449 Some(&*schemas),
450 )
451 .await
452 .map_err(|_| ())
453 }
454
455 pub async fn send_response<'a>(
457 &self,
458 request_id: RequestId,
459 response: RequestResponse<'a>,
460 ) -> Result<(), ()> {
461 self.send(ConnectionMessage::Request(RequestMessage {
462 id: request_id,
463 body: RequestBody::Response(response),
464 }))
465 .await
466 }
467
468 pub async fn send_response_for_method<'a>(
470 &self,
471 request_id: RequestId,
472 method_id: vox_types::MethodId,
473 mut response: RequestResponse<'a>,
474 ) -> Result<(), ()> {
475 self.prepare_response_for_method(request_id, method_id, &mut response);
476 self.send(ConnectionMessage::Request(RequestMessage {
477 id: request_id,
478 body: RequestBody::Response(response),
479 }))
480 .await
481 }
482
483 pub(crate) fn prepare_response_for_method(
485 &self,
486 request_id: RequestId,
487 method_id: vox_types::MethodId,
488 response: &mut RequestResponse<'_>,
489 ) {
490 self.sess_core.prepare_response_for_method(
491 self.connection_id,
492 request_id,
493 method_id,
494 response,
495 );
496 }
497
498 pub(crate) fn prepare_response_from_source(
500 &self,
501 request_id: RequestId,
502 method_id: vox_types::MethodId,
503 root_type: &vox_types::TypeRef,
504 source: &dyn vox_types::SchemaSource,
505 response: &mut RequestResponse<'_>,
506 ) {
507 self.sess_core.prepare_response_from_source(
508 self.connection_id,
509 request_id,
510 method_id,
511 root_type,
512 source,
513 response,
514 );
515 }
516
517 pub fn mark_failure(&self, request_id: RequestId, disposition: FailureDisposition) {
520 let _ = self.failures.send((request_id, disposition));
521 }
522
523 pub fn schema_registry(&self) -> vox_types::SchemaRegistry {
525 self.sess_core.schema_registry(self.connection_id)
526 }
527
528 pub fn prepare_replay_schemas(
530 &self,
531 request_id: RequestId,
532 method_id: vox_types::MethodId,
533 root_type: &vox_types::TypeRef,
534 store: &dyn crate::OperationStore,
535 response: &mut RequestResponse<'_>,
536 ) {
537 self.prepare_response_from_source(
538 request_id,
539 method_id,
540 root_type,
541 store.schema_source(),
542 response,
543 );
544 }
545}
546
547pub struct ConnectionHandle {
548 pub(crate) sender: ConnectionSender,
549 pub(crate) rx: mpsc::Receiver<RecvMessage>,
550 pub(crate) failures_rx: mpsc::UnboundedReceiver<(RequestId, FailureDisposition)>,
551 pub(crate) control_tx: Option<mpsc::UnboundedSender<DropControlRequest>>,
552 pub(crate) closed_rx: watch::Receiver<bool>,
553 pub(crate) resumed_rx: watch::Receiver<u64>,
554 pub parity: Parity,
556 pub(crate) peer_supports_retry: bool,
557}
558
559impl std::fmt::Debug for ConnectionHandle {
560 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
561 f.debug_struct("ConnectionHandle")
562 .field("connection_id", &self.sender.connection_id)
563 .finish()
564 }
565}
566
567pub(crate) enum ConnectionMessage<'payload> {
568 Request(RequestMessage<'payload>),
569 Channel(ChannelMessage<'payload>),
570}
571
572pub(crate) struct RecvMessage {
576 pub schemas: Arc<vox_types::SchemaRecvTracker>,
577 pub msg: SelfRef<ConnectionMessage<'static>>,
578}
579
580impl ConnectionHandle {
581 pub fn connection_id(&self) -> ConnectionId {
583 self.sender.connection_id
584 }
585
586 pub async fn closed(&self) {
588 if *self.closed_rx.borrow() {
589 return;
590 }
591 let mut rx = self.closed_rx.clone();
592 while rx.changed().await.is_ok() {
593 if *rx.borrow() {
594 return;
595 }
596 }
597 }
598
599 pub fn is_connected(&self) -> bool {
601 !*self.closed_rx.borrow()
602 }
603
604 pub fn peer_supports_retry(&self) -> bool {
605 self.peer_supports_retry
606 }
607}
608
609pub async fn proxy_connections(left: ConnectionHandle, right: ConnectionHandle) {
615 let left_conn_id = left.connection_id();
616 let right_conn_id = right.connection_id();
617 let ConnectionHandle {
618 sender: left_sender,
619 rx: mut left_rx,
620 failures_rx: _left_failures_rx,
621 control_tx: left_control_tx,
622 closed_rx: _left_closed_rx,
623 resumed_rx: _left_resumed_rx,
624 parity: _left_parity,
625 peer_supports_retry: _left_peer_supports_retry,
626 } = left;
627 let ConnectionHandle {
628 sender: right_sender,
629 rx: mut right_rx,
630 failures_rx: _right_failures_rx,
631 control_tx: right_control_tx,
632 closed_rx: _right_closed_rx,
633 resumed_rx: _right_resumed_rx,
634 parity: _right_parity,
635 peer_supports_retry: _right_peer_supports_retry,
636 } = right;
637
638 loop {
639 tokio::select! {
640 recv = left_rx.recv() => {
641 let Some(recv) = recv else {
642 break;
643 };
644 if right_sender.send_owned(recv.schemas, recv.msg).await.is_err() {
645 break;
646 }
647 }
648 recv = right_rx.recv() => {
649 let Some(recv) = recv else {
650 break;
651 };
652 if left_sender.send_owned(recv.schemas, recv.msg).await.is_err() {
653 break;
654 }
655 }
656 }
657 }
658
659 if let Some(tx) = left_control_tx.as_ref() {
660 let _ = send_drop_control(tx, DropControlRequest::Close(left_conn_id));
661 }
662 if let Some(tx) = right_control_tx.as_ref() {
663 let _ = send_drop_control(tx, DropControlRequest::Close(right_conn_id));
664 }
665}
666
667#[derive(Debug)]
669pub enum SessionError {
670 Io(std::io::Error),
671 Protocol(String),
672 Rejected(Metadata<'static>),
673 NotResumable,
674}
675
676impl std::fmt::Display for SessionError {
677 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
678 match self {
679 Self::Io(e) => write!(f, "io error: {e}"),
680 Self::Protocol(msg) => write!(f, "protocol error: {msg}"),
681 Self::Rejected(_) => write!(f, "connection rejected"),
682 Self::NotResumable => write!(f, "session is not resumable"),
683 }
684 }
685}
686
687impl std::error::Error for SessionError {}
688
689impl Session {
690 #[allow(clippy::too_many_arguments)]
691 fn pre_handshake<Tx, Rx>(
692 tx: Tx,
693 rx: Rx,
694 on_connection: Option<Box<dyn ConnectionAcceptor>>,
695 open_rx: mpsc::Receiver<OpenRequest>,
696 close_rx: mpsc::Receiver<CloseRequest>,
697 resume_rx: mpsc::Receiver<ResumeRequest>,
698 control_tx: mpsc::UnboundedSender<DropControlRequest>,
699 control_rx: mpsc::UnboundedReceiver<DropControlRequest>,
700 keepalive: Option<SessionKeepaliveConfig>,
701 resumable: bool,
702 recoverer: Option<Box<dyn ConduitRecoverer>>,
703 ) -> Self
704 where
705 Tx: ConduitTx<Msg = MessageFamily> + MaybeSend + MaybeSync + 'static,
706 for<'p> <Tx as ConduitTx>::Permit<'p>: MaybeSend,
707 Rx: ConduitRx<Msg = MessageFamily> + MaybeSend + 'static,
708 {
709 let sess_core = Arc::new(SessionCore {
710 inner: std::sync::Mutex::new(SessionCoreInner {
711 tx: Arc::new(tx) as Arc<dyn DynConduitTx>,
712 conns: HashMap::new(),
713 }),
714 });
715 let (resume_notifier, _resume_rx) = watch::channel(0_u64);
716 Session {
717 rx: Box::new(rx),
718 role: SessionRole::Initiator, parity: Parity::Odd, sess_core,
721 peer_supports_retry: false,
722 local_root_settings: ConnectionSettings {
723 parity: Parity::Odd,
724 max_concurrent_requests: 64,
725 },
726 peer_root_settings: None,
727 resumable,
728 session_resume_key: None,
729 conns: BTreeMap::new(),
730 root_closed_internal: false,
731 conn_ids: IdAllocator::new(Parity::Odd), on_connection,
733 open_rx,
734 close_rx,
735 resume_rx,
736 control_tx,
737 control_rx,
738 keepalive,
739 resume_notifier,
740 recoverer,
741 }
742 }
743
744 pub(crate) fn resume_key(&self) -> Option<SessionResumeKey> {
745 self.session_resume_key
746 }
747
748 fn establish_from_handshake(
750 &mut self,
751 result: HandshakeResult,
752 ) -> Result<ConnectionHandle, SessionError> {
753 self.role = result.role;
754 self.parity = result.our_settings.parity;
755 self.conn_ids = IdAllocator::new(result.our_settings.parity);
756 self.local_root_settings = result.our_settings.clone();
757 self.peer_root_settings = Some(result.peer_settings.clone());
758 self.peer_supports_retry = result.peer_supports_retry;
759 self.session_resume_key = result.session_resume_key;
760
761 if self.resumable && self.session_resume_key.is_none() {
762 return Err(SessionError::NotResumable);
763 }
764
765 Ok(self.make_root_handle(result.our_settings, result.peer_settings))
766 }
767
768 fn make_root_handle(
769 &mut self,
770 local_settings: ConnectionSettings,
771 peer_settings: ConnectionSettings,
772 ) -> ConnectionHandle {
773 self.make_connection_handle(ConnectionId::ROOT, local_settings, peer_settings)
774 }
775
776 fn make_connection_handle(
777 &mut self,
778 conn_id: ConnectionId,
779 local_settings: ConnectionSettings,
780 peer_settings: ConnectionSettings,
781 ) -> ConnectionHandle {
782 let label = format!("session.conn{}", conn_id.0);
783 let (conn_tx, conn_rx) = mpsc::channel::<RecvMessage>(&label, 64);
784 let (failures_tx, failures_rx) = mpsc::unbounded_channel(format!("{label}.failures"));
785 let (closed_tx, closed_rx) = watch::channel(false);
786 let resumed_rx = self.resume_notifier.subscribe();
787
788 let sender = ConnectionSender {
789 connection_id: conn_id,
790 sess_core: Arc::clone(&self.sess_core),
791 failures: Arc::new(failures_tx),
792 };
793
794 let parity = local_settings.parity;
795 self.conns.insert(
796 conn_id,
797 ConnectionSlot::Active(ConnectionState {
798 id: conn_id,
799 local_settings,
800 peer_settings,
801 conn_tx,
802 closed_tx,
803 schema_recv_tracker: Arc::new(vox_types::SchemaRecvTracker::new()),
804 }),
805 );
806
807 ConnectionHandle {
808 sender,
809 rx: conn_rx,
810 failures_rx,
811 control_tx: Some(self.control_tx.clone()),
812 closed_rx,
813 resumed_rx,
814 parity,
815 peer_supports_retry: self.peer_supports_retry,
816 }
817 }
818
819 pub async fn run(&mut self) {
824 let mut keepalive_runtime = self.make_keepalive_runtime();
825 let mut keepalive_tick = keepalive_runtime.as_ref().map(|_| {
826 let mut interval = tokio::time::interval(Duration::from_millis(10));
827 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
828 interval
829 });
830
831 loop {
832 tokio::select! {
833 msg = self.rx.recv_msg() => {
834 vox_types::dlog!("[session {:?}] recv_msg returned", self.role);
835 match msg {
836 Ok(Some(msg)) => {
837 tracing::debug!(conn_id = msg.connection_id.0, "session received message");
838 self.handle_message(msg, &mut keepalive_runtime).await;
839 }
840 Ok(None) => {
841 vox_types::dlog!("[session {:?}] recv loop: conduit returned EOF", self.role);
842 if !self.handle_conduit_break(&mut keepalive_runtime).await {
843 vox_types::dlog!("[session {:?}] recv loop: breaking (not resumable)", self.role);
844 break;
845 }
846 }
847 Err(error) => {
848 vox_types::dlog!("[session {:?}] recv loop: conduit recv error: {}", self.role, error);
849 if !self.handle_conduit_break(&mut keepalive_runtime).await {
850 vox_types::dlog!("[session {:?}] recv loop: breaking (not resumable)", self.role);
851 break;
852 }
853 }
854 }
855 }
856 Some(req) = self.open_rx.recv() => {
857 self.handle_open_request(req).await;
858 }
859 Some(req) = self.close_rx.recv() => {
860 self.handle_close_request(req).await;
861 }
862 Some(req) = self.resume_rx.recv() => {
863 let _ = req.result_tx.send(Err(SessionError::Protocol(
864 "resume is only valid while the session is disconnected".into(),
865 )));
866 }
867 Some(req) = self.control_rx.recv() => {
868 if !self.handle_drop_control_request(req).await {
869 break;
870 }
871 }
872 _ = async {
873 if let Some(interval) = keepalive_tick.as_mut() {
874 interval.tick().await;
875 }
876 }, if keepalive_tick.is_some() => {
877 if !self.handle_keepalive_tick(&mut keepalive_runtime).await {
878 break;
879 }
880 }
881 }
882 }
883
884 self.close_all_connections();
886 debug!("session recv loop exited");
887 }
888
889 async fn handle_conduit_break(
890 &mut self,
891 keepalive_runtime: &mut Option<KeepaliveRuntime>,
892 ) -> bool {
893 if !self.resumable {
894 return false;
895 }
896
897 if let Some(recoverer) = self.recoverer.as_mut() {
898 match recoverer
899 .next_conduit(self.session_resume_key.as_ref())
900 .await
901 {
902 Ok(recovered) => {
903 let result =
904 self.resume_from_handshake(recovered.tx, recovered.rx, recovered.handshake);
905 match result {
906 Ok(()) => {
907 let next_generation = self.resume_notifier.borrow().wrapping_add(1);
908 let _ = self.resume_notifier.send(next_generation);
909 *keepalive_runtime = self.make_keepalive_runtime();
910 return true;
911 }
912 Err(_) => return false,
913 }
914 }
915 Err(_) => return false,
916 }
917 }
918
919 loop {
920 tokio::select! {
921 Some(req) = self.resume_rx.recv() => {
922 let result =
923 self.resume_from_handshake(req.tx, req.rx, req.handshake_result);
924 let ok = result.is_ok();
925 let _ = req.result_tx.send(result);
926 if ok {
927 let next_generation = self.resume_notifier.borrow().wrapping_add(1);
928 let _ = self.resume_notifier.send(next_generation);
929 *keepalive_runtime = self.make_keepalive_runtime();
930 return true;
931 }
932 }
933 Some(req) = self.control_rx.recv() => {
934 if !self.handle_drop_control_request(req).await {
935 return false;
936 }
937 }
938 Some(req) = self.open_rx.recv() => {
939 let _ = req.result_tx.send(Err(SessionError::Protocol(
940 "session is disconnected; resume before opening connections".into(),
941 )));
942 }
943 Some(req) = self.close_rx.recv() => {
944 let _ = req.result_tx.send(Err(SessionError::Protocol(
945 "session is disconnected; resume before closing connections".into(),
946 )));
947 }
948 else => return false,
949 }
950 }
951 }
952
953 fn resume_from_handshake(
955 &mut self,
956 tx: Arc<dyn DynConduitTx>,
957 rx: Box<dyn DynConduitRx>,
958 result: HandshakeResult,
959 ) -> Result<(), SessionError> {
960 let Some(peer_settings) = self.peer_root_settings.clone() else {
961 return Err(SessionError::Protocol("missing peer root settings".into()));
962 };
963
964 if result.our_settings != self.local_root_settings {
965 return Err(SessionError::Protocol(
966 "local root settings changed across session resume".into(),
967 ));
968 }
969
970 if result.peer_settings != peer_settings {
971 return Err(SessionError::Protocol(
972 "peer root settings changed across session resume".into(),
973 ));
974 }
975
976 self.peer_supports_retry = result.peer_supports_retry;
977 self.session_resume_key = result.session_resume_key.or(self.session_resume_key);
978
979 self.sess_core.replace_tx_and_reset_schemas(tx);
980 self.rx = rx;
981 if let Some(ConnectionSlot::Active(state)) = self.conns.get_mut(&ConnectionId::ROOT) {
984 state.schema_recv_tracker = Arc::new(vox_types::SchemaRecvTracker::new());
985 }
986 Ok(())
987 }
988
989 async fn handle_message(
990 &mut self,
991 msg: SelfRef<Message<'static>>,
992 keepalive_runtime: &mut Option<KeepaliveRuntime>,
993 ) {
994 let conn_id = msg.connection_id;
995 vox_types::selfref_match!(msg, payload {
996 MessagePayload::ConnectionClose(_) => {
998 if conn_id.is_root() {
999 warn!("received ConnectionClose for root connection");
1000 } else {
1001 debug!(conn_id = conn_id.0, "received ConnectionClose for virtual connection");
1002 }
1003 self.remove_connection(&conn_id);
1007 self.maybe_request_shutdown_after_root_closed();
1008 }
1009 MessagePayload::ConnectionOpen(open) => {
1010 self.handle_inbound_open(conn_id, open).await;
1011 }
1012 MessagePayload::ConnectionAccept(accept) => {
1013 self.handle_inbound_accept(conn_id, accept);
1014 }
1015 MessagePayload::ConnectionReject(reject) => {
1016 self.handle_inbound_reject(conn_id, reject);
1017 }
1018 MessagePayload::RequestMessage(r) => {
1019 vox_types::dlog!(
1020 "[session {:?}] recv request: conn={:?} req={:?} body={} method={:?}",
1021 self.role,
1022 conn_id,
1023 r.id,
1024 match &r.body {
1025 RequestBody::Call(_) => "Call",
1026 RequestBody::Response(_) => "Response",
1027 RequestBody::Cancel(_) => "Cancel",
1028 },
1029 match &r.body {
1030 RequestBody::Call(call) => Some(call.method_id),
1031 RequestBody::Response(_) | RequestBody::Cancel(_) => None,
1032 }
1033 );
1034 let response_had_schema_payload = matches!(&r.body, RequestBody::Response(resp) if !resp.schemas.is_empty());
1036 {
1037 let schemas_cbor = match &r.body {
1038 RequestBody::Call(call) => Some(&call.schemas),
1039 RequestBody::Response(resp) => Some(&resp.schemas),
1040 _ => None,
1041 };
1042 vox_types::dlog!(
1043 "[schema] recv ({:?}): req={:?} body={} schemas_len={:?}",
1044 self.role,
1045 r.id,
1046 match &r.body {
1047 RequestBody::Call(_) => "Call",
1048 RequestBody::Response(_) => "Response",
1049 RequestBody::Cancel(_) => "Cancel",
1050 },
1051 schemas_cbor.map(|s| s.0.len())
1052 );
1053 let state = match self.conns.get(&conn_id) {
1054 Some(ConnectionSlot::Active(state)) => state,
1055 _ => return,
1056 };
1057 if let Some(schemas_cbor) = schemas_cbor
1058 && !schemas_cbor.is_empty()
1059 {
1060 let payload = vox_types::SchemaPayload::from_cbor(&schemas_cbor.0)
1061 .expect("inlined schemas must be valid CBOR");
1062 let (method_id, direction) = match &r.body {
1063 RequestBody::Call(call) => {
1064 (call.method_id, vox_types::BindingDirection::Args)
1065 }
1066 RequestBody::Response(_) => {
1067 let method_id = self
1068 .sess_core
1069 .take_outgoing_call_method(conn_id, r.id)
1070 .expect("response schemas require an inflight method binding");
1071 (method_id, vox_types::BindingDirection::Response)
1072 }
1073 RequestBody::Cancel(_) => unreachable!(),
1074 };
1075 state
1076 .schema_recv_tracker
1077 .record_received(method_id, direction, payload)
1078 .expect("received schemas must not contain duplicate type IDs");
1079 }
1080 }
1081 if matches!(&r.body, RequestBody::Response(_)) && !response_had_schema_payload {
1082 let _ = self.sess_core.take_outgoing_call_method(conn_id, r.id);
1083 }
1084 if let RequestBody::Call(call) = &r.body {
1087 self.sess_core.record_incoming_call(conn_id, r.id, call.method_id);
1088 }
1089 let state = match self.conns.get(&conn_id) {
1090 Some(ConnectionSlot::Active(state)) => state,
1091 _ => return,
1092 };
1093 let conn_tx = state.conn_tx.clone();
1094 let request_id = r.id;
1095 let body_kind = match &r.body {
1096 RequestBody::Call(_) => "Call",
1097 RequestBody::Response(_) => "Response",
1098 RequestBody::Cancel(_) => "Cancel",
1099 };
1100 let recv_msg = RecvMessage {
1101 schemas: Arc::clone(&state.schema_recv_tracker),
1102 msg: r.map(ConnectionMessage::Request),
1103 };
1104 vox_types::dlog!(
1105 "[session {:?}] dispatch request: conn={:?} req={:?} body={}",
1106 self.role,
1107 conn_id,
1108 request_id,
1109 body_kind
1110 );
1111 if conn_tx.send(recv_msg).await.is_err() {
1112 self.remove_connection(&conn_id);
1113 self.maybe_request_shutdown_after_root_closed();
1114 }
1115 }
1116 MessagePayload::ChannelMessage(c) => {
1117 let state = match self.conns.get(&conn_id) {
1118 Some(ConnectionSlot::Active(state)) => state,
1119 _ => return,
1120 };
1121 let conn_tx = state.conn_tx.clone();
1122 let recv_msg = RecvMessage {
1123 schemas: Arc::clone(&state.schema_recv_tracker),
1124 msg: c.map(ConnectionMessage::Channel),
1125 };
1126 if conn_tx.send(recv_msg).await.is_err() {
1127 self.remove_connection(&conn_id);
1128 self.maybe_request_shutdown_after_root_closed();
1129 }
1130 }
1131 MessagePayload::Ping(ping) => {
1132 let _ = self
1133 .sess_core
1134 .send(Message {
1135 connection_id: conn_id,
1136 payload: MessagePayload::Pong(vox_types::Pong { nonce: ping.nonce }),
1137 }, None, None)
1138 .await;
1139 }
1140 MessagePayload::Pong(pong) => {
1141 if conn_id.is_root() {
1142 self.handle_keepalive_pong(pong.nonce, keepalive_runtime);
1143 }
1144 }
1145 })
1147 }
1148
1149 fn make_keepalive_runtime(&self) -> Option<KeepaliveRuntime> {
1150 let config = self.keepalive?;
1151 if config.ping_interval.is_zero() || config.pong_timeout.is_zero() {
1152 warn!("keepalive disabled due to non-positive interval/timeout");
1153 return None;
1154 }
1155 let now = tokio::time::Instant::now();
1156 Some(KeepaliveRuntime {
1157 ping_interval: config.ping_interval,
1158 pong_timeout: config.pong_timeout,
1159 next_ping_at: now + config.ping_interval,
1160 waiting_pong_nonce: None,
1161 pong_deadline: now,
1162 next_ping_nonce: 1,
1163 })
1164 }
1165
1166 fn handle_keepalive_pong(&self, nonce: u64, keepalive_runtime: &mut Option<KeepaliveRuntime>) {
1167 let Some(runtime) = keepalive_runtime.as_mut() else {
1168 return;
1169 };
1170 if runtime.waiting_pong_nonce != Some(nonce) {
1171 return;
1172 }
1173 runtime.waiting_pong_nonce = None;
1174 runtime.next_ping_at = tokio::time::Instant::now() + runtime.ping_interval;
1175 }
1176
1177 async fn handle_keepalive_tick(
1178 &mut self,
1179 keepalive_runtime: &mut Option<KeepaliveRuntime>,
1180 ) -> bool {
1181 let Some(runtime) = keepalive_runtime.as_mut() else {
1182 return true;
1183 };
1184 let now = tokio::time::Instant::now();
1185
1186 if let Some(waiting_nonce) = runtime.waiting_pong_nonce {
1187 if now >= runtime.pong_deadline {
1188 warn!(
1189 nonce = waiting_nonce,
1190 timeout_ms = runtime.pong_timeout.as_millis(),
1191 "keepalive timeout waiting for pong"
1192 );
1193 return false;
1194 }
1195 return true;
1196 }
1197
1198 if now < runtime.next_ping_at {
1199 return true;
1200 }
1201
1202 let nonce = runtime.next_ping_nonce;
1203 if self
1204 .sess_core
1205 .send(
1206 Message {
1207 connection_id: ConnectionId::ROOT,
1208 payload: MessagePayload::Ping(vox_types::Ping { nonce }),
1209 },
1210 None,
1211 None,
1212 )
1213 .await
1214 .is_err()
1215 {
1216 warn!("failed to send keepalive ping");
1217 return false;
1218 }
1219
1220 runtime.waiting_pong_nonce = Some(nonce);
1221 runtime.pong_deadline = now + runtime.pong_timeout;
1222 runtime.next_ping_at = now + runtime.ping_interval;
1223 runtime.next_ping_nonce = runtime.next_ping_nonce.wrapping_add(1);
1224 true
1225 }
1226
1227 async fn handle_inbound_open(
1228 &mut self,
1229 conn_id: ConnectionId,
1230 open: SelfRef<ConnectionOpen<'static>>,
1231 ) {
1232 let peer_parity = self.parity.other();
1234 if !conn_id.has_parity(peer_parity) {
1235 let _ = self
1237 .sess_core
1238 .send(
1239 Message {
1240 connection_id: conn_id,
1241 payload: MessagePayload::ConnectionReject(vox_types::ConnectionReject {
1242 metadata: vec![],
1243 }),
1244 },
1245 None,
1246 None,
1247 )
1248 .await;
1249 return;
1250 }
1251
1252 if self.conns.contains_key(&conn_id) {
1254 let _ = self
1256 .sess_core
1257 .send(
1258 Message {
1259 connection_id: conn_id,
1260 payload: MessagePayload::ConnectionReject(vox_types::ConnectionReject {
1261 metadata: vec![],
1262 }),
1263 },
1264 None,
1265 None,
1266 )
1267 .await;
1268 return;
1269 }
1270
1271 let acceptor = match &self.on_connection {
1274 Some(a) => a,
1275 None => {
1276 let _ = self
1277 .sess_core
1278 .send(
1279 Message {
1280 connection_id: conn_id,
1281 payload: MessagePayload::ConnectionReject(
1282 vox_types::ConnectionReject { metadata: vec![] },
1283 ),
1284 },
1285 None,
1286 None,
1287 )
1288 .await;
1289 return;
1290 }
1291 };
1292
1293 match acceptor.accept(conn_id, &open.connection_settings, &open.metadata) {
1294 Ok(accepted) => {
1295 let handle = self.make_connection_handle(
1297 conn_id,
1298 accepted.settings.clone(),
1299 open.connection_settings.clone(),
1300 );
1301
1302 let _ = self
1304 .sess_core
1305 .send(
1306 Message {
1307 connection_id: conn_id,
1308 payload: MessagePayload::ConnectionAccept(
1309 vox_types::ConnectionAccept {
1310 connection_settings: accepted.settings,
1311 metadata: accepted.metadata,
1312 },
1313 ),
1314 },
1315 None,
1316 None,
1317 )
1318 .await;
1319
1320 (accepted.setup)(handle);
1322 }
1323 Err(reject_metadata) => {
1324 let _ = self
1325 .sess_core
1326 .send(
1327 Message {
1328 connection_id: conn_id,
1329 payload: MessagePayload::ConnectionReject(
1330 vox_types::ConnectionReject {
1331 metadata: reject_metadata,
1332 },
1333 ),
1334 },
1335 None,
1336 None,
1337 )
1338 .await;
1339 }
1340 }
1341 }
1342
1343 fn handle_inbound_accept(
1344 &mut self,
1345 conn_id: ConnectionId,
1346 accept: SelfRef<ConnectionAccept<'static>>,
1347 ) {
1348 let slot = self.remove_connection(&conn_id);
1349 match slot {
1350 Some(ConnectionSlot::PendingOutbound(mut pending)) => {
1351 let handle = self.make_connection_handle(
1352 conn_id,
1353 pending.local_settings.clone(),
1354 accept.connection_settings.clone(),
1355 );
1356
1357 if let Some(tx) = pending.result_tx.take() {
1358 let _ = tx.send(Ok(handle));
1359 }
1360 }
1361 Some(other) => {
1362 self.conns.insert(conn_id, other);
1364 }
1365 None => {
1366 }
1368 }
1369 }
1370
1371 fn handle_inbound_reject(
1372 &mut self,
1373 conn_id: ConnectionId,
1374 reject: SelfRef<ConnectionReject<'static>>,
1375 ) {
1376 let slot = self.remove_connection(&conn_id);
1377 match slot {
1378 Some(ConnectionSlot::PendingOutbound(mut pending)) => {
1379 if let Some(tx) = pending.result_tx.take() {
1380 let _ = tx.send(Err(SessionError::Rejected(reject.metadata.to_vec())));
1381 }
1382 }
1383 Some(other) => {
1384 self.conns.insert(conn_id, other);
1385 }
1386 None => {}
1387 }
1388 }
1389
1390 async fn handle_open_request(&mut self, req: OpenRequest) {
1392 let conn_id = self.conn_ids.alloc();
1393
1394 let send_result = self
1396 .sess_core
1397 .send(
1398 Message {
1399 connection_id: conn_id,
1400 payload: MessagePayload::ConnectionOpen(ConnectionOpen {
1401 connection_settings: req.settings.clone(),
1402 metadata: req.metadata,
1403 }),
1404 },
1405 None,
1406 None,
1407 )
1408 .await;
1409
1410 if send_result.is_err() {
1411 let _ = req.result_tx.send(Err(SessionError::Protocol(
1412 "failed to send ConnectionOpen".into(),
1413 )));
1414 return;
1415 }
1416
1417 self.conns.insert(
1420 conn_id,
1421 ConnectionSlot::PendingOutbound(PendingOutboundData {
1422 local_settings: req.settings,
1423 result_tx: Some(req.result_tx),
1424 }),
1425 );
1426 }
1427
1428 async fn handle_close_request(&mut self, req: CloseRequest) {
1430 if req.conn_id.is_root() {
1431 let _ = req.result_tx.send(Err(SessionError::Protocol(
1432 "cannot close root connection".into(),
1433 )));
1434 return;
1435 }
1436
1437 if self.remove_connection(&req.conn_id).is_none() {
1440 let _ = req
1441 .result_tx
1442 .send(Err(SessionError::Protocol("connection not found".into())));
1443 return;
1444 }
1445
1446 let send_result = self
1448 .sess_core
1449 .send(
1450 Message {
1451 connection_id: req.conn_id,
1452 payload: MessagePayload::ConnectionClose(ConnectionClose {
1453 metadata: req.metadata,
1454 }),
1455 },
1456 None,
1457 None,
1458 )
1459 .await;
1460
1461 if send_result.is_err() {
1462 let _ = req.result_tx.send(Err(SessionError::Protocol(
1463 "failed to send ConnectionClose".into(),
1464 )));
1465 return;
1466 }
1467
1468 let _ = req.result_tx.send(Ok(()));
1469 self.maybe_request_shutdown_after_root_closed();
1470 }
1471
1472 async fn handle_drop_control_request(&mut self, req: DropControlRequest) -> bool {
1473 match req {
1474 DropControlRequest::Shutdown => {
1475 debug!("session shutdown requested");
1476 false
1477 }
1478 DropControlRequest::Close(conn_id) => {
1479 if conn_id.is_root() {
1481 debug!("root callers dropped; internally closing root connection");
1483 self.root_closed_internal = true;
1484 return self.has_virtual_connections();
1486 }
1487
1488 if self.remove_connection(&conn_id).is_some() {
1489 let _ = self
1490 .sess_core
1491 .send(
1492 Message {
1493 connection_id: conn_id,
1494 payload: MessagePayload::ConnectionClose(ConnectionClose {
1495 metadata: vec![],
1496 }),
1497 },
1498 None,
1499 None,
1500 )
1501 .await;
1502 }
1503
1504 !self.root_closed_internal || self.has_virtual_connections()
1505 }
1506 }
1507 }
1508
1509 fn has_virtual_connections(&self) -> bool {
1510 self.conns.keys().any(|id| !id.is_root())
1511 }
1512
1513 fn remove_connection(&mut self, conn_id: &ConnectionId) -> Option<ConnectionSlot> {
1514 let slot = self.conns.remove(conn_id);
1515 if let Some(ConnectionSlot::Active(state)) = &slot {
1516 let _ = state.closed_tx.send(true);
1517 }
1518 slot
1519 }
1520
1521 fn close_all_connections(&mut self) {
1522 vox_types::dlog!(
1523 "[session {:?}] close_all_connections: {} slots",
1524 self.role,
1525 self.conns.len()
1526 );
1527 for (conn_id, slot) in self.conns.iter() {
1528 if let ConnectionSlot::Active(state) = slot {
1529 vox_types::dlog!("[session {:?}] closing connection {:?}", self.role, conn_id);
1530 let _ = state.closed_tx.send(true);
1531 }
1532 }
1533 self.conns.clear();
1534 }
1535
1536 fn maybe_request_shutdown_after_root_closed(&self) {
1537 if self.root_closed_internal && !self.has_virtual_connections() {
1538 let _ = send_drop_control(&self.control_tx, DropControlRequest::Shutdown);
1539 }
1540 }
1541}
1542
1543pub(crate) struct SessionCore {
1544 inner: std::sync::Mutex<SessionCoreInner>,
1545}
1546
1547struct SendConnState {
1548 method_tracker: HashSet<(vox_types::BindingDirection, vox_types::MethodId)>,
1551
1552 send_tracker: vox_types::SchemaSendTracker,
1554
1555 inflight_incoming: HashMap<RequestId, vox_types::MethodId>,
1558
1559 inflight_outgoing: HashMap<RequestId, vox_types::MethodId>,
1562}
1563
1564impl SendConnState {
1565 fn new() -> Self {
1566 SendConnState {
1567 method_tracker: HashSet::new(),
1568 send_tracker: vox_types::SchemaSendTracker::new(),
1569 inflight_incoming: HashMap::new(),
1570 inflight_outgoing: HashMap::new(),
1571 }
1572 }
1573}
1574
1575struct SessionCoreInner {
1576 tx: Arc<dyn DynConduitTx>,
1578
1579 conns: HashMap<ConnectionId, SendConnState>,
1581}
1582
1583impl SessionCore {
1584 pub(crate) async fn send<'a>(
1586 &self,
1587 mut msg: Message<'a>,
1588 binder: Option<&'a dyn vox_types::ChannelBinder>,
1589 forwarded_schemas: Option<&vox_types::SchemaRecvTracker>,
1590 ) -> Result<(), ()> {
1591 let tx = {
1592 let mut inner = self.inner.lock().expect("session core mutex poisoned");
1593 let conn_id = msg.connection_id;
1594
1595 if let MessagePayload::RequestMessage(req) = &mut msg.payload {
1596 vox_types::dlog!(
1597 "[session-core] send request: conn={:?} req={:?} body={} forwarded={}",
1598 conn_id,
1599 req.id,
1600 match &req.body {
1601 RequestBody::Call(_) => "Call",
1602 RequestBody::Response(_) => "Response",
1603 RequestBody::Cancel(_) => "Cancel",
1604 },
1605 forwarded_schemas.is_some()
1606 );
1607 let conn_state = inner
1608 .conns
1609 .entry(conn_id)
1610 .or_insert_with(SendConnState::new);
1611 match &mut req.body {
1612 RequestBody::Call(call) => {
1613 Self::prepare_call_schemas(
1614 conn_state,
1615 req.id,
1616 call.method_id,
1617 call,
1618 forwarded_schemas,
1619 );
1620 }
1621 RequestBody::Response(resp) => {
1622 if let Some(method_id) = conn_state.inflight_incoming.remove(&req.id) {
1623 Self::prepare_response_schemas(
1624 conn_state,
1625 req.id,
1626 method_id,
1627 resp,
1628 forwarded_schemas,
1629 );
1630 }
1631 }
1632 RequestBody::Cancel(_) => {}
1633 }
1634 }
1635
1636 inner.tx.clone()
1637 };
1638 tx.send_msg(msg, binder).await.map_err(|_| ())
1639 }
1640
1641 pub(crate) fn record_incoming_call(
1644 &self,
1645 conn_id: ConnectionId,
1646 request_id: RequestId,
1647 method_id: vox_types::MethodId,
1648 ) {
1649 let mut inner = self.inner.lock().expect("session core mutex poisoned");
1650 let conn_state = inner
1651 .conns
1652 .entry(conn_id)
1653 .or_insert_with(SendConnState::new);
1654 vox_types::dlog!(
1655 "[schema] record_incoming_call: conn={:?} req={:?} method={:?}",
1656 conn_id,
1657 request_id,
1658 method_id
1659 );
1660 conn_state.inflight_incoming.insert(request_id, method_id);
1661 }
1662
1663 pub(crate) fn take_outgoing_call_method(
1664 &self,
1665 conn_id: ConnectionId,
1666 request_id: RequestId,
1667 ) -> Option<vox_types::MethodId> {
1668 let mut inner = self.inner.lock().expect("session core mutex poisoned");
1669 inner
1670 .conns
1671 .get_mut(&conn_id)
1672 .and_then(|conn_state| conn_state.inflight_outgoing.remove(&request_id))
1673 }
1674
1675 pub(crate) fn prepare_response_for_method(
1676 &self,
1677 conn_id: ConnectionId,
1678 request_id: RequestId,
1679 method_id: vox_types::MethodId,
1680 response: &mut RequestResponse<'_>,
1681 ) {
1682 let mut inner = self.inner.lock().expect("session core mutex poisoned");
1683 let conn_state = inner
1684 .conns
1685 .entry(conn_id)
1686 .or_insert_with(SendConnState::new);
1687 conn_state.inflight_incoming.remove(&request_id);
1688 Self::prepare_response_schemas(conn_state, request_id, method_id, response, None);
1689 }
1690
1691 pub(crate) fn schema_registry(&self, conn_id: ConnectionId) -> vox_types::SchemaRegistry {
1694 let inner = self.inner.lock().expect("session core mutex poisoned");
1695 inner
1696 .conns
1697 .get(&conn_id)
1698 .map(|cs| cs.send_tracker.registry().clone())
1699 .unwrap_or_default()
1700 }
1701
1702 pub(crate) fn prepare_response_from_source(
1704 &self,
1705 conn_id: ConnectionId,
1706 request_id: RequestId,
1707 method_id: vox_types::MethodId,
1708 root_type: &vox_types::TypeRef,
1709 source: &dyn vox_types::SchemaSource,
1710 response: &mut RequestResponse<'_>,
1711 ) {
1712 let mut inner = self.inner.lock().expect("session core mutex poisoned");
1713 let conn_state = inner
1714 .conns
1715 .entry(conn_id)
1716 .or_insert_with(SendConnState::new);
1717 conn_state.inflight_incoming.remove(&request_id);
1718 let key = (vox_types::BindingDirection::Response, method_id);
1719 if conn_state.method_tracker.contains(&key) {
1720 return;
1721 }
1722 let cbor = conn_state.send_tracker.prepare_send(
1723 method_id,
1724 vox_types::BindingDirection::Response,
1725 root_type,
1726 source,
1727 );
1728 if !cbor.is_empty() {
1729 response.schemas = cbor;
1730 }
1731 conn_state.method_tracker.insert(key);
1732 }
1733
1734 fn prepare_response_schemas(
1735 conn_state: &mut SendConnState,
1736 request_id: RequestId,
1737 method_id: vox_types::MethodId,
1738 response: &mut RequestResponse<'_>,
1739 forwarded_schemas: Option<&vox_types::SchemaRecvTracker>,
1740 ) {
1741 let key = (vox_types::BindingDirection::Response, method_id);
1742 if conn_state.method_tracker.contains(&key) {
1743 return;
1744 }
1745
1746 let prepared = match &response.ret {
1747 vox_types::Payload::Value { shape, .. } => {
1748 match conn_state
1749 .send_tracker
1750 .attach_schemas_for_shape_if_needed(method_id, shape, response)
1751 {
1752 Ok(schemas) => {
1753 vox_types::dlog!(
1754 "[schema] prepared {} bytes of response schemas for method {:?} (req {:?})",
1755 schemas.0.len(),
1756 method_id,
1757 request_id
1758 );
1759 true
1760 }
1761 Err(e) => {
1762 tracing::error!("schema extraction failed: {e}");
1763 false
1764 }
1765 }
1766 }
1767 vox_types::Payload::PostcardBytes(_) => {
1768 let Some(source) = forwarded_schemas else {
1769 tracing::error!(
1770 "schema attachment failed: missing forwarded response schemas for method {:?}",
1771 method_id
1772 );
1773 return;
1774 };
1775 let Some(root) = source.get_remote_response_root(method_id) else {
1776 tracing::error!(
1777 "schema attachment failed: missing forwarded response root for method {:?}",
1778 method_id
1779 );
1780 return;
1781 };
1782 let schemas = conn_state.send_tracker.prepare_send(
1783 method_id,
1784 vox_types::BindingDirection::Response,
1785 &root,
1786 source,
1787 );
1788 response.schemas = schemas.clone();
1789 vox_types::dlog!(
1790 "[schema] prepared {} bytes of forwarded response schemas for method {:?} (req {:?})",
1791 schemas.0.len(),
1792 method_id,
1793 request_id
1794 );
1795 true
1796 }
1797 };
1798
1799 if prepared {
1800 conn_state.method_tracker.insert(key);
1801 }
1802 }
1803
1804 fn prepare_call_schemas(
1805 conn_state: &mut SendConnState,
1806 request_id: RequestId,
1807 method_id: vox_types::MethodId,
1808 call: &mut vox_types::RequestCall<'_>,
1809 forwarded_schemas: Option<&vox_types::SchemaRecvTracker>,
1810 ) {
1811 conn_state.inflight_outgoing.insert(request_id, method_id);
1812 let key = (vox_types::BindingDirection::Args, method_id);
1813 if conn_state.method_tracker.contains(&key) {
1814 return;
1815 }
1816
1817 let prepared = match &call.args {
1818 vox_types::Payload::Value { shape, .. } => {
1819 match conn_state
1820 .send_tracker
1821 .attach_schemas_for_shape_if_needed(method_id, shape, call)
1822 {
1823 Ok(_) => true,
1824 Err(e) => {
1825 tracing::error!("schema extraction failed: {e}");
1826 false
1827 }
1828 }
1829 }
1830 vox_types::Payload::PostcardBytes(_) => {
1831 let Some(source) = forwarded_schemas else {
1832 tracing::error!(
1833 "schema attachment failed: missing forwarded args schemas for method {:?}",
1834 method_id
1835 );
1836 return;
1837 };
1838 let Some(root) = source.get_remote_args_root(method_id) else {
1839 tracing::error!(
1840 "schema attachment failed: missing forwarded args root for method {:?}",
1841 method_id
1842 );
1843 return;
1844 };
1845 call.schemas = conn_state.send_tracker.prepare_send(
1846 method_id,
1847 vox_types::BindingDirection::Args,
1848 &root,
1849 source,
1850 );
1851 true
1852 }
1853 };
1854
1855 if prepared {
1856 conn_state.method_tracker.insert(key);
1857 }
1858 }
1859
1860 fn replace_tx_and_reset_schemas(&self, tx: Arc<dyn DynConduitTx>) {
1861 let mut inner = self.inner.lock().expect("session core mutex poisoned");
1862 inner.tx = tx;
1863 inner.conns.clear();
1864 }
1865}
1866
1867pub(crate) struct RecoveredConduit {
1868 pub tx: Arc<dyn DynConduitTx>,
1869 pub rx: Box<dyn DynConduitRx>,
1870 pub handshake: HandshakeResult,
1871}
1872
1873pub(crate) trait ConduitRecoverer: MaybeSend {
1874 fn next_conduit<'a>(
1875 &'a mut self,
1876 resume_key: Option<&'a SessionResumeKey>,
1877 ) -> BoxFut<'a, Result<RecoveredConduit, SessionError>>;
1878}
1879
1880pub trait DynConduitTx: MaybeSend + MaybeSync {
1881 fn send_msg<'a>(
1882 &'a self,
1883 msg: Message<'a>,
1884 binder: Option<&'a dyn vox_types::ChannelBinder>,
1885 ) -> BoxFut<'a, std::io::Result<()>>;
1886}
1887pub trait DynConduitRx: MaybeSend {
1888 fn recv_msg<'a>(&'a mut self)
1889 -> BoxFut<'a, std::io::Result<Option<SelfRef<Message<'static>>>>>;
1890}
1891
1892impl<T> DynConduitTx for T
1895where
1896 T: ConduitTx<Msg = MessageFamily> + MaybeSend + MaybeSync,
1897 for<'p> <T as ConduitTx>::Permit<'p>: MaybeSend,
1898{
1899 fn send_msg<'a>(
1900 &'a self,
1901 msg: Message<'a>,
1902 binder: Option<&'a dyn vox_types::ChannelBinder>,
1903 ) -> BoxFut<'a, std::io::Result<()>> {
1904 Box::pin(async move {
1905 let permit = self.reserve().await?;
1906 let result = if let Some(binder) = binder {
1907 vox_types::with_channel_binder(binder, || permit.send(msg))
1908 } else {
1909 permit.send(msg)
1910 };
1911 result.map_err(|e| std::io::Error::other(e.to_string()))
1912 })
1913 }
1914}
1915
1916impl<T> DynConduitRx for T
1917where
1918 T: ConduitRx<Msg = MessageFamily> + MaybeSend,
1919{
1920 fn recv_msg<'a>(
1921 &'a mut self,
1922 ) -> BoxFut<'a, std::io::Result<Option<SelfRef<Message<'static>>>>> {
1923 Box::pin(async move {
1924 self.recv()
1925 .await
1926 .map_err(|error| std::io::Error::other(error.to_string()))
1927 })
1928 }
1929}
1930
1931#[cfg(test)]
1932mod tests {
1933 use moire::sync::mpsc;
1934 use vox_types::{
1935 Backing, Conduit, ConnectionAccept, ConnectionReject, HandshakeResult, SelfRef,
1936 };
1937
1938 use super::*;
1939
1940 fn make_session() -> Session {
1941 let (a, b) = crate::memory_link_pair(32);
1942 std::mem::forget(b);
1944 let conduit = crate::BareConduit::new(a);
1945 let (tx, rx) = conduit.split();
1946 let (_open_tx, open_rx) = mpsc::channel::<OpenRequest>("session.open.test", 4);
1947 let (_close_tx, close_rx) = mpsc::channel::<CloseRequest>("session.close.test", 4);
1948 let (_resume_tx, resume_rx) = mpsc::channel::<ResumeRequest>("session.resume.test", 1);
1949 let (control_tx, control_rx) = mpsc::unbounded_channel("session.control.test");
1950 Session::pre_handshake(
1951 tx, rx, None, open_rx, close_rx, resume_rx, control_tx, control_rx, None, false, None,
1952 )
1953 }
1954
1955 fn resumed_handshake(
1956 our_settings: ConnectionSettings,
1957 peer_settings: ConnectionSettings,
1958 ) -> HandshakeResult {
1959 HandshakeResult {
1960 role: SessionRole::Initiator,
1961 our_settings,
1962 peer_settings,
1963 peer_supports_retry: true,
1964 session_resume_key: Some(SessionResumeKey([7; 16])),
1965 peer_resume_key: None,
1966 our_schema: vec![],
1967 peer_schema: vec![],
1968 }
1969 }
1970
1971 fn accept_ref() -> SelfRef<ConnectionAccept<'static>> {
1972 SelfRef::owning(
1973 Backing::Boxed(Box::<[u8]>::default()),
1974 ConnectionAccept {
1975 connection_settings: ConnectionSettings {
1976 parity: Parity::Even,
1977 max_concurrent_requests: 64,
1978 },
1979 metadata: vec![],
1980 },
1981 )
1982 }
1983
1984 fn reject_ref() -> SelfRef<ConnectionReject<'static>> {
1985 SelfRef::owning(
1986 Backing::Boxed(Box::<[u8]>::default()),
1987 ConnectionReject { metadata: vec![] },
1988 )
1989 }
1990
1991 #[tokio::test]
1992 async fn duplicate_connection_accept_is_ignored_after_first() {
1993 let mut session = make_session();
1994 let conn_id = ConnectionId(1);
1995 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.test.open_result");
1996
1997 session.conns.insert(
1998 conn_id,
1999 ConnectionSlot::PendingOutbound(PendingOutboundData {
2000 local_settings: ConnectionSettings {
2001 parity: Parity::Odd,
2002 max_concurrent_requests: 64,
2003 },
2004 result_tx: Some(result_tx),
2005 }),
2006 );
2007
2008 session.handle_inbound_accept(conn_id, accept_ref());
2009 let handle = result_rx
2010 .await
2011 .expect("pending outbound result should resolve")
2012 .expect("accept should resolve as Ok");
2013 assert_eq!(handle.connection_id(), conn_id);
2014
2015 session.handle_inbound_accept(conn_id, accept_ref());
2016 assert!(
2017 matches!(
2018 session.conns.get(&conn_id),
2019 Some(ConnectionSlot::Active(ConnectionState { id, .. })) if *id == conn_id
2020 ),
2021 "duplicate accept should keep existing active connection state"
2022 );
2023 }
2024
2025 #[tokio::test]
2026 async fn duplicate_connection_reject_is_ignored_after_first() {
2027 let mut session = make_session();
2028 let conn_id = ConnectionId(1);
2029 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.test.open_result");
2030
2031 session.conns.insert(
2032 conn_id,
2033 ConnectionSlot::PendingOutbound(PendingOutboundData {
2034 local_settings: ConnectionSettings {
2035 parity: Parity::Odd,
2036 max_concurrent_requests: 64,
2037 },
2038 result_tx: Some(result_tx),
2039 }),
2040 );
2041
2042 session.handle_inbound_reject(conn_id, reject_ref());
2043 let result = result_rx
2044 .await
2045 .expect("pending outbound result should resolve");
2046 assert!(
2047 matches!(result, Err(SessionError::Rejected(_))),
2048 "expected rejection, got: {result:?}"
2049 );
2050
2051 session.handle_inbound_reject(conn_id, reject_ref());
2052 assert!(
2053 !session.conns.contains_key(&conn_id),
2054 "duplicate reject should not recreate connection state"
2055 );
2056 }
2057
2058 #[test]
2059 fn out_of_order_accept_or_reject_without_pending_is_ignored() {
2060 let mut session = make_session();
2061 let conn_id = ConnectionId(99);
2062
2063 session.handle_inbound_accept(conn_id, accept_ref());
2064 session.handle_inbound_reject(conn_id, reject_ref());
2065
2066 assert!(
2067 session.conns.is_empty(),
2068 "out-of-order accept/reject should not mutate empty connection table"
2069 );
2070 }
2071
2072 #[tokio::test]
2073 async fn close_request_clears_pending_outbound_open() {
2074 let mut session = make_session();
2075 let (open_result_tx, open_result_rx) = moire::sync::oneshot::channel("session.open.result");
2076 let (close_result_tx, close_result_rx) =
2077 moire::sync::oneshot::channel("session.close.result");
2078
2079 session.conns.insert(
2080 ConnectionId(1),
2081 ConnectionSlot::PendingOutbound(PendingOutboundData {
2082 local_settings: ConnectionSettings {
2083 parity: Parity::Odd,
2084 max_concurrent_requests: 64,
2085 },
2086 result_tx: Some(open_result_tx),
2087 }),
2088 );
2089
2090 session
2091 .handle_close_request(CloseRequest {
2092 conn_id: ConnectionId(1),
2093 metadata: vec![],
2094 result_tx: close_result_tx,
2095 })
2096 .await;
2097
2098 let close_result = close_result_rx
2099 .await
2100 .expect("close result should be delivered");
2101 assert!(
2102 close_result.is_ok(),
2103 "close should succeed for pending outbound connection"
2104 );
2105
2106 assert!(
2107 open_result_rx.await.is_err(),
2108 "pending open result channel should be closed once the pending slot is removed"
2109 );
2110 }
2111
2112 #[test]
2113 fn resume_rejects_changed_local_root_settings() {
2114 let mut session = make_session();
2115 let local_settings = ConnectionSettings {
2116 parity: Parity::Odd,
2117 max_concurrent_requests: 64,
2118 };
2119 let peer_settings = ConnectionSettings {
2120 parity: Parity::Even,
2121 max_concurrent_requests: 64,
2122 };
2123 let _root = session
2124 .establish_from_handshake(resumed_handshake(
2125 local_settings.clone(),
2126 peer_settings.clone(),
2127 ))
2128 .expect("initial handshake should establish session");
2129
2130 let (link_a, _link_b) = crate::memory_link_pair(32);
2131 let conduit = crate::BareConduit::new(link_a);
2132 let (tx, rx) = conduit.split();
2133
2134 let result = session.resume_from_handshake(
2135 Arc::new(tx),
2136 Box::new(rx),
2137 resumed_handshake(
2138 ConnectionSettings {
2139 parity: Parity::Odd,
2140 max_concurrent_requests: 65,
2141 },
2142 peer_settings,
2143 ),
2144 );
2145
2146 assert!(
2147 matches!(
2148 &result,
2149 Err(SessionError::Protocol(message))
2150 if message == "local root settings changed across session resume"
2151 ),
2152 "expected local-root-settings mismatch, got: {result:?}"
2153 );
2154 }
2155
2156 #[test]
2157 fn resume_rejects_changed_peer_root_settings() {
2158 let mut session = make_session();
2159 let local_settings = ConnectionSettings {
2160 parity: Parity::Odd,
2161 max_concurrent_requests: 64,
2162 };
2163 let peer_settings = ConnectionSettings {
2164 parity: Parity::Even,
2165 max_concurrent_requests: 64,
2166 };
2167 let _root = session
2168 .establish_from_handshake(resumed_handshake(
2169 local_settings.clone(),
2170 peer_settings.clone(),
2171 ))
2172 .expect("initial handshake should establish session");
2173
2174 let (link_a, _link_b) = crate::memory_link_pair(32);
2175 let conduit = crate::BareConduit::new(link_a);
2176 let (tx, rx) = conduit.split();
2177
2178 let result = session.resume_from_handshake(
2179 Arc::new(tx),
2180 Box::new(rx),
2181 resumed_handshake(
2182 local_settings,
2183 ConnectionSettings {
2184 parity: Parity::Even,
2185 max_concurrent_requests: 65,
2186 },
2187 ),
2188 );
2189
2190 assert!(
2191 matches!(
2192 &result,
2193 Err(SessionError::Protocol(message))
2194 if message == "peer root settings changed across session resume"
2195 ),
2196 "expected peer-root-settings mismatch, got: {result:?}"
2197 );
2198 }
2199}