1use std::{
2 collections::{BTreeMap, HashMap, HashSet},
3 sync::Arc,
4 time::Duration,
5};
6
7use moire::sync::mpsc;
8use tokio::sync::watch;
9use tracing::{trace, warn};
10use vox_types::{
11 BoxFut, ChannelMessage, Conduit, ConduitRx, ConduitTx, ConduitTxPermit, ConnectionAccept,
12 ConnectionClose, ConnectionId, ConnectionOpen, ConnectionReject, ConnectionSettings,
13 HandshakeResult, IdAllocator, MaybeSend, MaybeSync, Message, MessageFamily, MessagePayload,
14 Metadata, Parity, RequestBody, RequestId, RequestMessage, RequestResponse, SelfRef,
15 SessionResumeKey, SessionRole,
16};
17
18mod builders;
19pub use builders::*;
20
21#[derive(Debug, Clone, Copy)]
23pub struct SessionKeepaliveConfig {
24 pub ping_interval: Duration,
25 pub pong_timeout: Duration,
26}
27
28pub trait ConnectionAcceptor: Send + 'static {
41 fn accept(
42 &self,
43 conn_id: ConnectionId,
44 peer_settings: &ConnectionSettings,
45 metadata: &[vox_types::MetadataEntry],
46 ) -> Result<AcceptedConnection, Metadata<'static>>;
47}
48
49pub struct AcceptedConnection {
51 pub settings: ConnectionSettings,
53 pub metadata: Metadata<'static>,
55 pub setup: Box<dyn FnOnce(ConnectionHandle) + Send>,
57}
58
59struct OpenRequest {
64 settings: ConnectionSettings,
65 metadata: Metadata<'static>,
66 result_tx: moire::sync::oneshot::Sender<Result<ConnectionHandle, SessionError>>,
67}
68
69struct CloseRequest {
70 conn_id: ConnectionId,
71 metadata: Metadata<'static>,
72 result_tx: moire::sync::oneshot::Sender<Result<(), SessionError>>,
73}
74
75struct ResumeRequest {
76 tx: Arc<dyn DynConduitTx>,
77 rx: Box<dyn DynConduitRx>,
78 handshake_result: HandshakeResult,
79 result_tx: moire::sync::oneshot::Sender<Result<(), SessionError>>,
80}
81
82#[derive(Debug, Clone, Copy)]
83pub(crate) enum DropControlRequest {
84 Shutdown,
85 Close(ConnectionId),
86}
87
88#[derive(Clone, Copy, Debug)]
89pub(crate) enum FailureDisposition {
90 Cancelled,
91 Indeterminate,
92}
93
94#[cfg(not(target_arch = "wasm32"))]
95fn send_drop_control(
96 tx: &mpsc::UnboundedSender<DropControlRequest>,
97 req: DropControlRequest,
98) -> Result<(), ()> {
99 tx.send(req).map_err(|_| ())
100}
101
102#[cfg(target_arch = "wasm32")]
103fn send_drop_control(
104 tx: &mpsc::UnboundedSender<DropControlRequest>,
105 req: DropControlRequest,
106) -> Result<(), ()> {
107 tx.try_send(req).map_err(|_| ())
108}
109
110#[derive(Clone)]
121pub struct SessionHandle {
122 open_tx: mpsc::Sender<OpenRequest>,
123 close_tx: mpsc::Sender<CloseRequest>,
124 resume_tx: mpsc::Sender<ResumeRequest>,
125 control_tx: mpsc::UnboundedSender<DropControlRequest>,
126 resume_key: Option<SessionResumeKey>,
127}
128
129impl SessionHandle {
130 pub async fn open_connection(
137 &self,
138 settings: ConnectionSettings,
139 metadata: Metadata<'static>,
140 ) -> Result<ConnectionHandle, SessionError> {
141 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.open_result");
142 self.open_tx
143 .send(OpenRequest {
144 settings,
145 metadata,
146 result_tx,
147 })
148 .await
149 .map_err(|_| SessionError::Protocol("session closed".into()))?;
150 result_rx
151 .await
152 .map_err(|_| SessionError::Protocol("session closed".into()))?
153 }
154
155 pub async fn close_connection(
162 &self,
163 conn_id: ConnectionId,
164 metadata: Metadata<'static>,
165 ) -> Result<(), SessionError> {
166 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.close_result");
167 self.close_tx
168 .send(CloseRequest {
169 conn_id,
170 metadata,
171 result_tx,
172 })
173 .await
174 .map_err(|_| SessionError::Protocol("session closed".into()))?;
175 result_rx
176 .await
177 .map_err(|_| SessionError::Protocol("session closed".into()))?
178 }
179
180 pub async fn resume<I: crate::IntoConduit>(
181 &self,
182 into_conduit: I,
183 handshake_result: HandshakeResult,
184 ) -> Result<(), SessionError>
185 where
186 I::Conduit: Conduit<Msg = MessageFamily> + 'static,
187 <I::Conduit as Conduit>::Tx: MaybeSend + MaybeSync + 'static,
188 for<'p> <<I::Conduit as Conduit>::Tx as ConduitTx>::Permit<'p>: MaybeSend,
189 <I::Conduit as Conduit>::Rx: MaybeSend + 'static,
190 {
191 let (tx, rx) = into_conduit.into_conduit().split();
192 self.resume_parts(Arc::new(tx), Box::new(rx), handshake_result)
193 .await
194 }
195
196 pub(crate) async fn resume_parts(
197 &self,
198 tx: Arc<dyn DynConduitTx>,
199 rx: Box<dyn DynConduitRx>,
200 handshake_result: HandshakeResult,
201 ) -> Result<(), SessionError> {
202 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.resume_result");
203 self.resume_tx
204 .send(ResumeRequest {
205 tx,
206 rx,
207 handshake_result,
208 result_tx,
209 })
210 .await
211 .map_err(|_| SessionError::Protocol("session closed".into()))?;
212 result_rx
213 .await
214 .map_err(|_| SessionError::Protocol("session closed".into()))?
215 }
216
217 pub fn resume_key(&self) -> Option<&SessionResumeKey> {
219 self.resume_key.as_ref()
220 }
221
222 pub fn shutdown(&self) -> Result<(), SessionError> {
224 send_drop_control(&self.control_tx, DropControlRequest::Shutdown)
225 .map_err(|_| SessionError::Protocol("session closed".into()))
226 }
227}
228
229pub struct Session {
237 rx: Box<dyn DynConduitRx>,
239
240 role: SessionRole,
242
243 parity: Parity,
246
247 sess_core: Arc<SessionCore>,
249 peer_supports_retry: bool,
250 local_root_settings: ConnectionSettings,
251 peer_root_settings: Option<ConnectionSettings>,
252 resumable: bool,
253 session_resume_key: Option<SessionResumeKey>,
254
255 conns: BTreeMap<ConnectionId, ConnectionSlot>,
257 root_closed_internal: bool,
259
260 conn_ids: IdAllocator<ConnectionId>,
262
263 on_connection: Option<Box<dyn ConnectionAcceptor>>,
265
266 open_rx: mpsc::Receiver<OpenRequest>,
268
269 close_rx: mpsc::Receiver<CloseRequest>,
271
272 resume_rx: mpsc::Receiver<ResumeRequest>,
274
275 control_tx: mpsc::UnboundedSender<DropControlRequest>,
277 control_rx: mpsc::UnboundedReceiver<DropControlRequest>,
278
279 keepalive: Option<SessionKeepaliveConfig>,
281 resume_notifier: watch::Sender<u64>,
282 recoverer: Option<Box<dyn ConduitRecoverer>>,
283}
284
285#[derive(Debug)]
286struct KeepaliveRuntime {
287 ping_interval: Duration,
288 pong_timeout: Duration,
289 next_ping_at: tokio::time::Instant,
290 waiting_pong_nonce: Option<u64>,
291 pong_deadline: tokio::time::Instant,
292 next_ping_nonce: u64,
293}
294
295#[derive(Debug)]
298pub struct ConnectionState {
299 pub id: ConnectionId,
301
302 pub local_settings: ConnectionSettings,
304
305 pub peer_settings: ConnectionSettings,
307
308 conn_tx: mpsc::Sender<RecvMessage>,
310 closed_tx: watch::Sender<bool>,
311
312 schema_recv_tracker: Arc<vox_types::SchemaRecvTracker>,
314}
315
316#[derive(Debug)]
317enum ConnectionSlot {
318 Active(ConnectionState),
319 PendingOutbound(PendingOutboundData),
320}
321
322struct PendingOutboundData {
324 local_settings: ConnectionSettings,
325 result_tx: Option<moire::sync::oneshot::Sender<Result<ConnectionHandle, SessionError>>>,
326}
327
328impl std::fmt::Debug for PendingOutboundData {
329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330 f.debug_struct("PendingOutbound")
331 .field("local_settings", &self.local_settings)
332 .finish()
333 }
334}
335
336#[derive(Clone)]
337pub(crate) struct ConnectionSender {
338 connection_id: ConnectionId,
339 pub(crate) sess_core: Arc<SessionCore>,
340 failures: Arc<mpsc::UnboundedSender<(RequestId, FailureDisposition)>>,
341}
342
343fn forwarded_payload<'a>(payload: &'a vox_types::Payload<'static>) -> vox_types::Payload<'a> {
344 let vox_types::Payload::PostcardBytes(bytes) = payload else {
345 unreachable!("proxy forwarding expects decoded incoming payload bytes")
346 };
347 vox_types::Payload::PostcardBytes(bytes)
348}
349
350fn forwarded_request_body<'a>(body: &'a RequestBody<'static>) -> RequestBody<'a> {
351 match body {
352 RequestBody::Call(call) => RequestBody::Call(vox_types::RequestCall {
353 method_id: call.method_id,
354 metadata: call.metadata.clone(),
355 args: forwarded_payload(&call.args),
356 schemas: call.schemas.clone(),
357 }),
358 RequestBody::Response(response) => RequestBody::Response(RequestResponse {
359 metadata: response.metadata.clone(),
360 ret: forwarded_payload(&response.ret),
361 schemas: response.schemas.clone(),
362 }),
363 RequestBody::Cancel(cancel) => RequestBody::Cancel(vox_types::RequestCancel {
364 metadata: cancel.metadata.clone(),
365 }),
366 }
367}
368
369fn forwarded_channel_body<'a>(
370 body: &'a vox_types::ChannelBody<'static>,
371) -> vox_types::ChannelBody<'a> {
372 match body {
373 vox_types::ChannelBody::Item(item) => {
374 vox_types::ChannelBody::Item(vox_types::ChannelItem {
375 item: forwarded_payload(&item.item),
376 })
377 }
378 vox_types::ChannelBody::Close(close) => {
379 vox_types::ChannelBody::Close(vox_types::ChannelClose {
380 metadata: close.metadata.clone(),
381 })
382 }
383 vox_types::ChannelBody::Reset(reset) => {
384 vox_types::ChannelBody::Reset(vox_types::ChannelReset {
385 metadata: reset.metadata.clone(),
386 })
387 }
388 vox_types::ChannelBody::GrantCredit(credit) => {
389 vox_types::ChannelBody::GrantCredit(vox_types::ChannelGrantCredit {
390 additional: credit.additional,
391 })
392 }
393 }
394}
395
396impl ConnectionSender {
397 pub(crate) fn connection_id(&self) -> ConnectionId {
398 self.connection_id
399 }
400
401 pub(crate) async fn send_with_binder<'a>(
402 &self,
403 msg: ConnectionMessage<'a>,
404 binder: Option<&'a dyn vox_types::ChannelBinder>,
405 ) -> Result<(), ()> {
406 let payload = match msg {
407 ConnectionMessage::Request(r) => MessagePayload::RequestMessage(r),
408 ConnectionMessage::Channel(c) => MessagePayload::ChannelMessage(c),
409 };
410 let message = Message {
411 connection_id: self.connection_id,
412 payload,
413 };
414 self.sess_core
415 .send(message, binder, None)
416 .await
417 .map_err(|_| ())
418 }
419
420 pub async fn send<'a>(&self, msg: ConnectionMessage<'a>) -> Result<(), ()> {
422 self.send_with_binder(msg, None).await
423 }
424
425 pub(crate) async fn send_owned(
427 &self,
428 schemas: Arc<vox_types::SchemaRecvTracker>,
429 msg: SelfRef<ConnectionMessage<'static>>,
430 ) -> Result<(), ()> {
431 let payload = match &*msg {
432 ConnectionMessage::Request(request) => MessagePayload::RequestMessage(RequestMessage {
433 id: request.id,
434 body: forwarded_request_body(&request.body),
435 }),
436 ConnectionMessage::Channel(channel) => MessagePayload::ChannelMessage(ChannelMessage {
437 id: channel.id,
438 body: forwarded_channel_body(&channel.body),
439 }),
440 };
441
442 self.sess_core
443 .send(
444 Message {
445 connection_id: self.connection_id,
446 payload,
447 },
448 None,
449 Some(&*schemas),
450 )
451 .await
452 .map_err(|_| ())
453 }
454
455 pub async fn send_response<'a>(
457 &self,
458 request_id: RequestId,
459 response: RequestResponse<'a>,
460 ) -> Result<(), ()> {
461 self.send(ConnectionMessage::Request(RequestMessage {
462 id: request_id,
463 body: RequestBody::Response(response),
464 }))
465 .await
466 }
467
468 pub async fn send_response_for_method<'a>(
470 &self,
471 request_id: RequestId,
472 method_id: vox_types::MethodId,
473 mut response: RequestResponse<'a>,
474 ) -> Result<(), ()> {
475 self.prepare_response_for_method(request_id, method_id, &mut response);
476 self.send(ConnectionMessage::Request(RequestMessage {
477 id: request_id,
478 body: RequestBody::Response(response),
479 }))
480 .await
481 }
482
483 pub(crate) fn prepare_response_for_method(
485 &self,
486 request_id: RequestId,
487 method_id: vox_types::MethodId,
488 response: &mut RequestResponse<'_>,
489 ) {
490 self.sess_core.prepare_response_for_method(
491 self.connection_id,
492 request_id,
493 method_id,
494 response,
495 );
496 }
497
498 pub(crate) fn prepare_response_from_source(
500 &self,
501 request_id: RequestId,
502 method_id: vox_types::MethodId,
503 root_type: &vox_types::TypeRef,
504 source: &dyn vox_types::SchemaSource,
505 response: &mut RequestResponse<'_>,
506 ) {
507 self.sess_core.prepare_response_from_source(
508 self.connection_id,
509 request_id,
510 method_id,
511 root_type,
512 source,
513 response,
514 );
515 }
516
517 pub fn mark_failure(&self, request_id: RequestId, disposition: FailureDisposition) {
520 let _ = self.failures.send((request_id, disposition));
521 }
522
523 pub fn schema_registry(&self) -> vox_types::SchemaRegistry {
525 self.sess_core.schema_registry(self.connection_id)
526 }
527
528 pub fn prepare_replay_schemas(
530 &self,
531 request_id: RequestId,
532 method_id: vox_types::MethodId,
533 root_type: &vox_types::TypeRef,
534 store: &dyn crate::OperationStore,
535 response: &mut RequestResponse<'_>,
536 ) {
537 self.prepare_response_from_source(
538 request_id,
539 method_id,
540 root_type,
541 store.schema_source(),
542 response,
543 );
544 }
545}
546
547pub struct ConnectionHandle {
548 pub(crate) sender: ConnectionSender,
549 pub(crate) rx: mpsc::Receiver<RecvMessage>,
550 pub(crate) failures_rx: mpsc::UnboundedReceiver<(RequestId, FailureDisposition)>,
551 pub(crate) control_tx: Option<mpsc::UnboundedSender<DropControlRequest>>,
552 pub(crate) closed_rx: watch::Receiver<bool>,
553 pub(crate) resumed_rx: watch::Receiver<u64>,
554 pub parity: Parity,
556 pub(crate) peer_supports_retry: bool,
557}
558
559impl std::fmt::Debug for ConnectionHandle {
560 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
561 f.debug_struct("ConnectionHandle")
562 .field("connection_id", &self.sender.connection_id)
563 .finish()
564 }
565}
566
567pub(crate) enum ConnectionMessage<'payload> {
568 Request(RequestMessage<'payload>),
569 Channel(ChannelMessage<'payload>),
570}
571
572pub(crate) struct RecvMessage {
576 pub schemas: Arc<vox_types::SchemaRecvTracker>,
577 pub msg: SelfRef<ConnectionMessage<'static>>,
578}
579
580impl ConnectionHandle {
581 pub fn connection_id(&self) -> ConnectionId {
583 self.sender.connection_id
584 }
585
586 pub async fn closed(&self) {
588 if *self.closed_rx.borrow() {
589 return;
590 }
591 let mut rx = self.closed_rx.clone();
592 while rx.changed().await.is_ok() {
593 if *rx.borrow() {
594 return;
595 }
596 }
597 }
598
599 pub fn is_connected(&self) -> bool {
601 !*self.closed_rx.borrow()
602 }
603
604 pub fn peer_supports_retry(&self) -> bool {
605 self.peer_supports_retry
606 }
607}
608
609pub async fn proxy_connections(left: ConnectionHandle, right: ConnectionHandle) {
615 let left_conn_id = left.connection_id();
616 let right_conn_id = right.connection_id();
617 let ConnectionHandle {
618 sender: left_sender,
619 rx: mut left_rx,
620 failures_rx: _left_failures_rx,
621 control_tx: left_control_tx,
622 closed_rx: _left_closed_rx,
623 resumed_rx: _left_resumed_rx,
624 parity: _left_parity,
625 peer_supports_retry: _left_peer_supports_retry,
626 } = left;
627 let ConnectionHandle {
628 sender: right_sender,
629 rx: mut right_rx,
630 failures_rx: _right_failures_rx,
631 control_tx: right_control_tx,
632 closed_rx: _right_closed_rx,
633 resumed_rx: _right_resumed_rx,
634 parity: _right_parity,
635 peer_supports_retry: _right_peer_supports_retry,
636 } = right;
637
638 loop {
639 tokio::select! {
640 recv = left_rx.recv() => {
641 let Some(recv) = recv else {
642 break;
643 };
644 if right_sender.send_owned(recv.schemas, recv.msg).await.is_err() {
645 break;
646 }
647 }
648 recv = right_rx.recv() => {
649 let Some(recv) = recv else {
650 break;
651 };
652 if left_sender.send_owned(recv.schemas, recv.msg).await.is_err() {
653 break;
654 }
655 }
656 }
657 }
658
659 if let Some(tx) = left_control_tx.as_ref() {
660 let _ = send_drop_control(tx, DropControlRequest::Close(left_conn_id));
661 }
662 if let Some(tx) = right_control_tx.as_ref() {
663 let _ = send_drop_control(tx, DropControlRequest::Close(right_conn_id));
664 }
665}
666
667#[derive(Debug)]
669pub enum SessionError {
670 Io(std::io::Error),
671 Protocol(String),
672 Rejected(Metadata<'static>),
673 NotResumable,
674}
675
676impl std::fmt::Display for SessionError {
677 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
678 match self {
679 Self::Io(e) => write!(f, "io error: {e}"),
680 Self::Protocol(msg) => write!(f, "protocol error: {msg}"),
681 Self::Rejected(_) => write!(f, "connection rejected"),
682 Self::NotResumable => write!(f, "session is not resumable"),
683 }
684 }
685}
686
687impl std::error::Error for SessionError {}
688
689impl Session {
690 #[allow(clippy::too_many_arguments)]
691 fn pre_handshake<Tx, Rx>(
692 tx: Tx,
693 rx: Rx,
694 on_connection: Option<Box<dyn ConnectionAcceptor>>,
695 open_rx: mpsc::Receiver<OpenRequest>,
696 close_rx: mpsc::Receiver<CloseRequest>,
697 resume_rx: mpsc::Receiver<ResumeRequest>,
698 control_tx: mpsc::UnboundedSender<DropControlRequest>,
699 control_rx: mpsc::UnboundedReceiver<DropControlRequest>,
700 keepalive: Option<SessionKeepaliveConfig>,
701 resumable: bool,
702 recoverer: Option<Box<dyn ConduitRecoverer>>,
703 ) -> Self
704 where
705 Tx: ConduitTx<Msg = MessageFamily> + MaybeSend + MaybeSync + 'static,
706 for<'p> <Tx as ConduitTx>::Permit<'p>: MaybeSend,
707 Rx: ConduitRx<Msg = MessageFamily> + MaybeSend + 'static,
708 {
709 let sess_core = Arc::new(SessionCore {
710 inner: std::sync::Mutex::new(SessionCoreInner {
711 tx: Arc::new(tx) as Arc<dyn DynConduitTx>,
712 conns: HashMap::new(),
713 }),
714 });
715 let (resume_notifier, _resume_rx) = watch::channel(0_u64);
716 Session {
717 rx: Box::new(rx),
718 role: SessionRole::Initiator, parity: Parity::Odd, sess_core,
721 peer_supports_retry: false,
722 local_root_settings: ConnectionSettings {
723 parity: Parity::Odd,
724 max_concurrent_requests: 64,
725 },
726 peer_root_settings: None,
727 resumable,
728 session_resume_key: None,
729 conns: BTreeMap::new(),
730 root_closed_internal: false,
731 conn_ids: IdAllocator::new(Parity::Odd), on_connection,
733 open_rx,
734 close_rx,
735 resume_rx,
736 control_tx,
737 control_rx,
738 keepalive,
739 resume_notifier,
740 recoverer,
741 }
742 }
743
744 pub(crate) fn resume_key(&self) -> Option<SessionResumeKey> {
745 self.session_resume_key
746 }
747
748 fn establish_from_handshake(
750 &mut self,
751 result: HandshakeResult,
752 ) -> Result<ConnectionHandle, SessionError> {
753 self.role = result.role;
754 self.parity = result.our_settings.parity;
755 self.conn_ids = IdAllocator::new(result.our_settings.parity);
756 self.local_root_settings = result.our_settings.clone();
757 self.peer_root_settings = Some(result.peer_settings.clone());
758 self.peer_supports_retry = result.peer_supports_retry;
759 self.session_resume_key = result.session_resume_key;
760
761 if self.resumable && self.session_resume_key.is_none() {
762 return Err(SessionError::NotResumable);
763 }
764
765 Ok(self.make_root_handle(result.our_settings, result.peer_settings))
766 }
767
768 fn make_root_handle(
769 &mut self,
770 local_settings: ConnectionSettings,
771 peer_settings: ConnectionSettings,
772 ) -> ConnectionHandle {
773 self.make_connection_handle(ConnectionId::ROOT, local_settings, peer_settings)
774 }
775
776 fn make_connection_handle(
777 &mut self,
778 conn_id: ConnectionId,
779 local_settings: ConnectionSettings,
780 peer_settings: ConnectionSettings,
781 ) -> ConnectionHandle {
782 let label = format!("session.conn{}", conn_id.0);
783 let (conn_tx, conn_rx) = mpsc::channel::<RecvMessage>(&label, 64);
784 let (failures_tx, failures_rx) = mpsc::unbounded_channel(format!("{label}.failures"));
785 let (closed_tx, closed_rx) = watch::channel(false);
786 let resumed_rx = self.resume_notifier.subscribe();
787
788 let sender = ConnectionSender {
789 connection_id: conn_id,
790 sess_core: Arc::clone(&self.sess_core),
791 failures: Arc::new(failures_tx),
792 };
793
794 let parity = local_settings.parity;
795 self.conns.insert(
796 conn_id,
797 ConnectionSlot::Active(ConnectionState {
798 id: conn_id,
799 local_settings,
800 peer_settings,
801 conn_tx,
802 closed_tx,
803 schema_recv_tracker: Arc::new(vox_types::SchemaRecvTracker::new()),
804 }),
805 );
806
807 ConnectionHandle {
808 sender,
809 rx: conn_rx,
810 failures_rx,
811 control_tx: Some(self.control_tx.clone()),
812 closed_rx,
813 resumed_rx,
814 parity,
815 peer_supports_retry: self.peer_supports_retry,
816 }
817 }
818
819 pub async fn run(&mut self) {
824 let mut keepalive_runtime = self.make_keepalive_runtime();
825 let mut keepalive_tick = keepalive_runtime.as_ref().map(|_| {
826 let mut interval = tokio::time::interval(Duration::from_millis(10));
827 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
828 interval
829 });
830
831 loop {
832 tokio::select! {
833 msg = self.rx.recv_msg() => {
834 vox_types::dlog!("[session {:?}] recv_msg returned", self.role);
835 match msg {
836 Ok(Some(msg)) => {
837 self.handle_message(msg, &mut keepalive_runtime).await;
838 }
839 Ok(None) => {
840 vox_types::dlog!("[session {:?}] recv loop: conduit returned EOF", self.role);
841 if !self.handle_conduit_break(&mut keepalive_runtime).await {
842 vox_types::dlog!("[session {:?}] recv loop: breaking (not resumable)", self.role);
843 break;
844 }
845 }
846 Err(error) => {
847 vox_types::dlog!("[session {:?}] recv loop: conduit recv error: {}", self.role, error);
848 if !self.handle_conduit_break(&mut keepalive_runtime).await {
849 vox_types::dlog!("[session {:?}] recv loop: breaking (not resumable)", self.role);
850 break;
851 }
852 }
853 }
854 }
855 Some(req) = self.open_rx.recv() => {
856 self.handle_open_request(req).await;
857 }
858 Some(req) = self.close_rx.recv() => {
859 self.handle_close_request(req).await;
860 }
861 Some(req) = self.resume_rx.recv() => {
862 let _ = req.result_tx.send(Err(SessionError::Protocol(
863 "resume is only valid while the session is disconnected".into(),
864 )));
865 }
866 Some(req) = self.control_rx.recv() => {
867 if !self.handle_drop_control_request(req).await {
868 break;
869 }
870 }
871 _ = async {
872 if let Some(interval) = keepalive_tick.as_mut() {
873 interval.tick().await;
874 }
875 }, if keepalive_tick.is_some() => {
876 if !self.handle_keepalive_tick(&mut keepalive_runtime).await {
877 break;
878 }
879 }
880 }
881 }
882
883 self.close_all_connections();
885 trace!("session recv loop exited");
886 }
887
888 async fn handle_conduit_break(
889 &mut self,
890 keepalive_runtime: &mut Option<KeepaliveRuntime>,
891 ) -> bool {
892 if !self.resumable {
893 return false;
894 }
895
896 if let Some(recoverer) = self.recoverer.as_mut() {
897 match recoverer
898 .next_conduit(self.session_resume_key.as_ref())
899 .await
900 {
901 Ok(recovered) => {
902 let result =
903 self.resume_from_handshake(recovered.tx, recovered.rx, recovered.handshake);
904 match result {
905 Ok(()) => {
906 let next_generation = self.resume_notifier.borrow().wrapping_add(1);
907 let _ = self.resume_notifier.send(next_generation);
908 *keepalive_runtime = self.make_keepalive_runtime();
909 return true;
910 }
911 Err(_) => return false,
912 }
913 }
914 Err(_) => return false,
915 }
916 }
917
918 loop {
919 tokio::select! {
920 Some(req) = self.resume_rx.recv() => {
921 let result =
922 self.resume_from_handshake(req.tx, req.rx, req.handshake_result);
923 let ok = result.is_ok();
924 let _ = req.result_tx.send(result);
925 if ok {
926 let next_generation = self.resume_notifier.borrow().wrapping_add(1);
927 let _ = self.resume_notifier.send(next_generation);
928 *keepalive_runtime = self.make_keepalive_runtime();
929 return true;
930 }
931 }
932 Some(req) = self.control_rx.recv() => {
933 if !self.handle_drop_control_request(req).await {
934 return false;
935 }
936 }
937 Some(req) = self.open_rx.recv() => {
938 let _ = req.result_tx.send(Err(SessionError::Protocol(
939 "session is disconnected; resume before opening connections".into(),
940 )));
941 }
942 Some(req) = self.close_rx.recv() => {
943 let _ = req.result_tx.send(Err(SessionError::Protocol(
944 "session is disconnected; resume before closing connections".into(),
945 )));
946 }
947 else => return false,
948 }
949 }
950 }
951
952 fn resume_from_handshake(
954 &mut self,
955 tx: Arc<dyn DynConduitTx>,
956 rx: Box<dyn DynConduitRx>,
957 result: HandshakeResult,
958 ) -> Result<(), SessionError> {
959 let Some(peer_settings) = self.peer_root_settings.clone() else {
960 return Err(SessionError::Protocol("missing peer root settings".into()));
961 };
962
963 if result.our_settings != self.local_root_settings {
964 return Err(SessionError::Protocol(
965 "local root settings changed across session resume".into(),
966 ));
967 }
968
969 if result.peer_settings != peer_settings {
970 return Err(SessionError::Protocol(
971 "peer root settings changed across session resume".into(),
972 ));
973 }
974
975 self.peer_supports_retry = result.peer_supports_retry;
976 self.session_resume_key = result.session_resume_key.or(self.session_resume_key);
977
978 self.sess_core.replace_tx_and_reset_schemas(tx);
979 self.rx = rx;
980 if let Some(ConnectionSlot::Active(state)) = self.conns.get_mut(&ConnectionId::ROOT) {
983 state.schema_recv_tracker = Arc::new(vox_types::SchemaRecvTracker::new());
984 }
985 Ok(())
986 }
987
988 async fn handle_message(
989 &mut self,
990 msg: SelfRef<Message<'static>>,
991 keepalive_runtime: &mut Option<KeepaliveRuntime>,
992 ) {
993 let conn_id = msg.connection_id;
994 vox_types::selfref_match!(msg, payload {
995 MessagePayload::ConnectionClose(_) => {
997 if conn_id.is_root() {
998 warn!("received ConnectionClose for root connection");
999 } else {
1000 trace!(conn_id = conn_id.0, "received ConnectionClose for virtual connection");
1001 }
1002 self.remove_connection(&conn_id);
1006 self.maybe_request_shutdown_after_root_closed();
1007 }
1008 MessagePayload::ConnectionOpen(open) => {
1009 self.handle_inbound_open(conn_id, open).await;
1010 }
1011 MessagePayload::ConnectionAccept(accept) => {
1012 self.handle_inbound_accept(conn_id, accept);
1013 }
1014 MessagePayload::ConnectionReject(reject) => {
1015 self.handle_inbound_reject(conn_id, reject);
1016 }
1017 MessagePayload::RequestMessage(r) => {
1018 vox_types::dlog!(
1019 "[session {:?}] recv request: conn={:?} req={:?} body={} method={:?}",
1020 self.role,
1021 conn_id,
1022 r.id,
1023 match &r.body {
1024 RequestBody::Call(_) => "Call",
1025 RequestBody::Response(_) => "Response",
1026 RequestBody::Cancel(_) => "Cancel",
1027 },
1028 match &r.body {
1029 RequestBody::Call(call) => Some(call.method_id),
1030 RequestBody::Response(_) | RequestBody::Cancel(_) => None,
1031 }
1032 );
1033 let response_had_schema_payload = matches!(&r.body, RequestBody::Response(resp) if !resp.schemas.is_empty());
1035 {
1036 let schemas_cbor = match &r.body {
1037 RequestBody::Call(call) => Some(&call.schemas),
1038 RequestBody::Response(resp) => Some(&resp.schemas),
1039 _ => None,
1040 };
1041 vox_types::dlog!(
1042 "[schema] recv ({:?}): req={:?} body={} schemas_len={:?}",
1043 self.role,
1044 r.id,
1045 match &r.body {
1046 RequestBody::Call(_) => "Call",
1047 RequestBody::Response(_) => "Response",
1048 RequestBody::Cancel(_) => "Cancel",
1049 },
1050 schemas_cbor.map(|s| s.0.len())
1051 );
1052 let state = match self.conns.get(&conn_id) {
1053 Some(ConnectionSlot::Active(state)) => state,
1054 _ => return,
1055 };
1056 if let Some(schemas_cbor) = schemas_cbor
1057 && !schemas_cbor.is_empty()
1058 {
1059 let payload = vox_types::SchemaPayload::from_cbor(&schemas_cbor.0)
1060 .expect("inlined schemas must be valid CBOR");
1061 let (method_id, direction) = match &r.body {
1062 RequestBody::Call(call) => {
1063 (call.method_id, vox_types::BindingDirection::Args)
1064 }
1065 RequestBody::Response(_) => {
1066 let method_id = self
1067 .sess_core
1068 .take_outgoing_call_method(conn_id, r.id)
1069 .expect("response schemas require an inflight method binding");
1070 (method_id, vox_types::BindingDirection::Response)
1071 }
1072 RequestBody::Cancel(_) => unreachable!(),
1073 };
1074 state
1075 .schema_recv_tracker
1076 .record_received(method_id, direction, payload)
1077 .expect("received schemas must not contain duplicate type IDs");
1078 }
1079 }
1080 if matches!(&r.body, RequestBody::Response(_)) && !response_had_schema_payload {
1081 let _ = self.sess_core.take_outgoing_call_method(conn_id, r.id);
1082 }
1083 if let RequestBody::Call(call) = &r.body {
1086 self.sess_core.record_incoming_call(conn_id, r.id, call.method_id);
1087 }
1088 let state = match self.conns.get(&conn_id) {
1089 Some(ConnectionSlot::Active(state)) => state,
1090 _ => return,
1091 };
1092 let conn_tx = state.conn_tx.clone();
1093 let request_id = r.id;
1094 let body_kind = match &r.body {
1095 RequestBody::Call(_) => "Call",
1096 RequestBody::Response(_) => "Response",
1097 RequestBody::Cancel(_) => "Cancel",
1098 };
1099 let recv_msg = RecvMessage {
1100 schemas: Arc::clone(&state.schema_recv_tracker),
1101 msg: r.map(ConnectionMessage::Request),
1102 };
1103 vox_types::dlog!(
1104 "[session {:?}] dispatch request: conn={:?} req={:?} body={}",
1105 self.role,
1106 conn_id,
1107 request_id,
1108 body_kind
1109 );
1110 if conn_tx.send(recv_msg).await.is_err() {
1111 self.remove_connection(&conn_id);
1112 self.maybe_request_shutdown_after_root_closed();
1113 }
1114 }
1115 MessagePayload::ChannelMessage(c) => {
1116 let state = match self.conns.get(&conn_id) {
1117 Some(ConnectionSlot::Active(state)) => state,
1118 _ => return,
1119 };
1120 let conn_tx = state.conn_tx.clone();
1121 let recv_msg = RecvMessage {
1122 schemas: Arc::clone(&state.schema_recv_tracker),
1123 msg: c.map(ConnectionMessage::Channel),
1124 };
1125 if conn_tx.send(recv_msg).await.is_err() {
1126 self.remove_connection(&conn_id);
1127 self.maybe_request_shutdown_after_root_closed();
1128 }
1129 }
1130 MessagePayload::Ping(ping) => {
1131 let _ = self
1132 .sess_core
1133 .send(Message {
1134 connection_id: conn_id,
1135 payload: MessagePayload::Pong(vox_types::Pong { nonce: ping.nonce }),
1136 }, None, None)
1137 .await;
1138 }
1139 MessagePayload::Pong(pong) => {
1140 if conn_id.is_root() {
1141 self.handle_keepalive_pong(pong.nonce, keepalive_runtime);
1142 }
1143 }
1144 })
1146 }
1147
1148 fn make_keepalive_runtime(&self) -> Option<KeepaliveRuntime> {
1149 let config = self.keepalive?;
1150 if config.ping_interval.is_zero() || config.pong_timeout.is_zero() {
1151 warn!("keepalive disabled due to non-positive interval/timeout");
1152 return None;
1153 }
1154 let now = tokio::time::Instant::now();
1155 Some(KeepaliveRuntime {
1156 ping_interval: config.ping_interval,
1157 pong_timeout: config.pong_timeout,
1158 next_ping_at: now + config.ping_interval,
1159 waiting_pong_nonce: None,
1160 pong_deadline: now,
1161 next_ping_nonce: 1,
1162 })
1163 }
1164
1165 fn handle_keepalive_pong(&self, nonce: u64, keepalive_runtime: &mut Option<KeepaliveRuntime>) {
1166 let Some(runtime) = keepalive_runtime.as_mut() else {
1167 return;
1168 };
1169 if runtime.waiting_pong_nonce != Some(nonce) {
1170 return;
1171 }
1172 runtime.waiting_pong_nonce = None;
1173 runtime.next_ping_at = tokio::time::Instant::now() + runtime.ping_interval;
1174 }
1175
1176 async fn handle_keepalive_tick(
1177 &mut self,
1178 keepalive_runtime: &mut Option<KeepaliveRuntime>,
1179 ) -> bool {
1180 let Some(runtime) = keepalive_runtime.as_mut() else {
1181 return true;
1182 };
1183 let now = tokio::time::Instant::now();
1184
1185 if let Some(waiting_nonce) = runtime.waiting_pong_nonce {
1186 if now >= runtime.pong_deadline {
1187 warn!(
1188 nonce = waiting_nonce,
1189 timeout_ms = runtime.pong_timeout.as_millis(),
1190 "keepalive timeout waiting for pong"
1191 );
1192 return false;
1193 }
1194 return true;
1195 }
1196
1197 if now < runtime.next_ping_at {
1198 return true;
1199 }
1200
1201 let nonce = runtime.next_ping_nonce;
1202 if self
1203 .sess_core
1204 .send(
1205 Message {
1206 connection_id: ConnectionId::ROOT,
1207 payload: MessagePayload::Ping(vox_types::Ping { nonce }),
1208 },
1209 None,
1210 None,
1211 )
1212 .await
1213 .is_err()
1214 {
1215 warn!("failed to send keepalive ping");
1216 return false;
1217 }
1218
1219 runtime.waiting_pong_nonce = Some(nonce);
1220 runtime.pong_deadline = now + runtime.pong_timeout;
1221 runtime.next_ping_at = now + runtime.ping_interval;
1222 runtime.next_ping_nonce = runtime.next_ping_nonce.wrapping_add(1);
1223 true
1224 }
1225
1226 async fn handle_inbound_open(
1227 &mut self,
1228 conn_id: ConnectionId,
1229 open: SelfRef<ConnectionOpen<'static>>,
1230 ) {
1231 let peer_parity = self.parity.other();
1233 if !conn_id.has_parity(peer_parity) {
1234 let _ = self
1236 .sess_core
1237 .send(
1238 Message {
1239 connection_id: conn_id,
1240 payload: MessagePayload::ConnectionReject(vox_types::ConnectionReject {
1241 metadata: vec![],
1242 }),
1243 },
1244 None,
1245 None,
1246 )
1247 .await;
1248 return;
1249 }
1250
1251 if self.conns.contains_key(&conn_id) {
1253 let _ = self
1255 .sess_core
1256 .send(
1257 Message {
1258 connection_id: conn_id,
1259 payload: MessagePayload::ConnectionReject(vox_types::ConnectionReject {
1260 metadata: vec![],
1261 }),
1262 },
1263 None,
1264 None,
1265 )
1266 .await;
1267 return;
1268 }
1269
1270 let acceptor = match &self.on_connection {
1273 Some(a) => a,
1274 None => {
1275 let _ = self
1276 .sess_core
1277 .send(
1278 Message {
1279 connection_id: conn_id,
1280 payload: MessagePayload::ConnectionReject(
1281 vox_types::ConnectionReject { metadata: vec![] },
1282 ),
1283 },
1284 None,
1285 None,
1286 )
1287 .await;
1288 return;
1289 }
1290 };
1291
1292 match acceptor.accept(conn_id, &open.connection_settings, &open.metadata) {
1293 Ok(accepted) => {
1294 let handle = self.make_connection_handle(
1296 conn_id,
1297 accepted.settings.clone(),
1298 open.connection_settings.clone(),
1299 );
1300
1301 let _ = self
1303 .sess_core
1304 .send(
1305 Message {
1306 connection_id: conn_id,
1307 payload: MessagePayload::ConnectionAccept(
1308 vox_types::ConnectionAccept {
1309 connection_settings: accepted.settings,
1310 metadata: accepted.metadata,
1311 },
1312 ),
1313 },
1314 None,
1315 None,
1316 )
1317 .await;
1318
1319 (accepted.setup)(handle);
1321 }
1322 Err(reject_metadata) => {
1323 let _ = self
1324 .sess_core
1325 .send(
1326 Message {
1327 connection_id: conn_id,
1328 payload: MessagePayload::ConnectionReject(
1329 vox_types::ConnectionReject {
1330 metadata: reject_metadata,
1331 },
1332 ),
1333 },
1334 None,
1335 None,
1336 )
1337 .await;
1338 }
1339 }
1340 }
1341
1342 fn handle_inbound_accept(
1343 &mut self,
1344 conn_id: ConnectionId,
1345 accept: SelfRef<ConnectionAccept<'static>>,
1346 ) {
1347 let slot = self.remove_connection(&conn_id);
1348 match slot {
1349 Some(ConnectionSlot::PendingOutbound(mut pending)) => {
1350 let handle = self.make_connection_handle(
1351 conn_id,
1352 pending.local_settings.clone(),
1353 accept.connection_settings.clone(),
1354 );
1355
1356 if let Some(tx) = pending.result_tx.take() {
1357 let _ = tx.send(Ok(handle));
1358 }
1359 }
1360 Some(other) => {
1361 self.conns.insert(conn_id, other);
1363 }
1364 None => {
1365 }
1367 }
1368 }
1369
1370 fn handle_inbound_reject(
1371 &mut self,
1372 conn_id: ConnectionId,
1373 reject: SelfRef<ConnectionReject<'static>>,
1374 ) {
1375 let slot = self.remove_connection(&conn_id);
1376 match slot {
1377 Some(ConnectionSlot::PendingOutbound(mut pending)) => {
1378 if let Some(tx) = pending.result_tx.take() {
1379 let _ = tx.send(Err(SessionError::Rejected(reject.metadata.to_vec())));
1380 }
1381 }
1382 Some(other) => {
1383 self.conns.insert(conn_id, other);
1384 }
1385 None => {}
1386 }
1387 }
1388
1389 async fn handle_open_request(&mut self, req: OpenRequest) {
1391 let conn_id = self.conn_ids.alloc();
1392
1393 let send_result = self
1395 .sess_core
1396 .send(
1397 Message {
1398 connection_id: conn_id,
1399 payload: MessagePayload::ConnectionOpen(ConnectionOpen {
1400 connection_settings: req.settings.clone(),
1401 metadata: req.metadata,
1402 }),
1403 },
1404 None,
1405 None,
1406 )
1407 .await;
1408
1409 if send_result.is_err() {
1410 let _ = req.result_tx.send(Err(SessionError::Protocol(
1411 "failed to send ConnectionOpen".into(),
1412 )));
1413 return;
1414 }
1415
1416 self.conns.insert(
1419 conn_id,
1420 ConnectionSlot::PendingOutbound(PendingOutboundData {
1421 local_settings: req.settings,
1422 result_tx: Some(req.result_tx),
1423 }),
1424 );
1425 }
1426
1427 async fn handle_close_request(&mut self, req: CloseRequest) {
1429 if req.conn_id.is_root() {
1430 let _ = req.result_tx.send(Err(SessionError::Protocol(
1431 "cannot close root connection".into(),
1432 )));
1433 return;
1434 }
1435
1436 if self.remove_connection(&req.conn_id).is_none() {
1439 let _ = req
1440 .result_tx
1441 .send(Err(SessionError::Protocol("connection not found".into())));
1442 return;
1443 }
1444
1445 let send_result = self
1447 .sess_core
1448 .send(
1449 Message {
1450 connection_id: req.conn_id,
1451 payload: MessagePayload::ConnectionClose(ConnectionClose {
1452 metadata: req.metadata,
1453 }),
1454 },
1455 None,
1456 None,
1457 )
1458 .await;
1459
1460 if send_result.is_err() {
1461 let _ = req.result_tx.send(Err(SessionError::Protocol(
1462 "failed to send ConnectionClose".into(),
1463 )));
1464 return;
1465 }
1466
1467 let _ = req.result_tx.send(Ok(()));
1468 self.maybe_request_shutdown_after_root_closed();
1469 }
1470
1471 async fn handle_drop_control_request(&mut self, req: DropControlRequest) -> bool {
1472 match req {
1473 DropControlRequest::Shutdown => {
1474 trace!("session shutdown requested");
1475 false
1476 }
1477 DropControlRequest::Close(conn_id) => {
1478 if conn_id.is_root() {
1480 trace!("root callers dropped; internally closing root connection");
1482 self.root_closed_internal = true;
1483 return self.has_virtual_connections();
1485 }
1486
1487 if self.remove_connection(&conn_id).is_some() {
1488 let _ = self
1489 .sess_core
1490 .send(
1491 Message {
1492 connection_id: conn_id,
1493 payload: MessagePayload::ConnectionClose(ConnectionClose {
1494 metadata: vec![],
1495 }),
1496 },
1497 None,
1498 None,
1499 )
1500 .await;
1501 }
1502
1503 !self.root_closed_internal || self.has_virtual_connections()
1504 }
1505 }
1506 }
1507
1508 fn has_virtual_connections(&self) -> bool {
1509 self.conns.keys().any(|id| !id.is_root())
1510 }
1511
1512 fn remove_connection(&mut self, conn_id: &ConnectionId) -> Option<ConnectionSlot> {
1513 let slot = self.conns.remove(conn_id);
1514 if let Some(ConnectionSlot::Active(state)) = &slot {
1515 let _ = state.closed_tx.send(true);
1516 }
1517 slot
1518 }
1519
1520 fn close_all_connections(&mut self) {
1521 vox_types::dlog!(
1522 "[session {:?}] close_all_connections: {} slots",
1523 self.role,
1524 self.conns.len()
1525 );
1526 for (conn_id, slot) in self.conns.iter() {
1527 if let ConnectionSlot::Active(state) = slot {
1528 vox_types::dlog!("[session {:?}] closing connection {:?}", self.role, conn_id);
1529 let _ = state.closed_tx.send(true);
1530 }
1531 }
1532 self.conns.clear();
1533 }
1534
1535 fn maybe_request_shutdown_after_root_closed(&self) {
1536 if self.root_closed_internal && !self.has_virtual_connections() {
1537 let _ = send_drop_control(&self.control_tx, DropControlRequest::Shutdown);
1538 }
1539 }
1540}
1541
1542pub(crate) struct SessionCore {
1543 inner: std::sync::Mutex<SessionCoreInner>,
1544}
1545
1546struct SendConnState {
1547 method_tracker: HashSet<(vox_types::BindingDirection, vox_types::MethodId)>,
1550
1551 send_tracker: vox_types::SchemaSendTracker,
1553
1554 inflight_incoming: HashMap<RequestId, vox_types::MethodId>,
1557
1558 inflight_outgoing: HashMap<RequestId, vox_types::MethodId>,
1561}
1562
1563impl SendConnState {
1564 fn new() -> Self {
1565 SendConnState {
1566 method_tracker: HashSet::new(),
1567 send_tracker: vox_types::SchemaSendTracker::new(),
1568 inflight_incoming: HashMap::new(),
1569 inflight_outgoing: HashMap::new(),
1570 }
1571 }
1572}
1573
1574struct SessionCoreInner {
1575 tx: Arc<dyn DynConduitTx>,
1577
1578 conns: HashMap<ConnectionId, SendConnState>,
1580}
1581
1582impl SessionCore {
1583 pub(crate) async fn send<'a>(
1585 &self,
1586 mut msg: Message<'a>,
1587 binder: Option<&'a dyn vox_types::ChannelBinder>,
1588 forwarded_schemas: Option<&vox_types::SchemaRecvTracker>,
1589 ) -> Result<(), ()> {
1590 let tx = {
1591 let mut inner = self.inner.lock().expect("session core mutex poisoned");
1592 let conn_id = msg.connection_id;
1593
1594 if let MessagePayload::RequestMessage(req) = &mut msg.payload {
1595 vox_types::dlog!(
1596 "[session-core] send request: conn={:?} req={:?} body={} forwarded={}",
1597 conn_id,
1598 req.id,
1599 match &req.body {
1600 RequestBody::Call(_) => "Call",
1601 RequestBody::Response(_) => "Response",
1602 RequestBody::Cancel(_) => "Cancel",
1603 },
1604 forwarded_schemas.is_some()
1605 );
1606 let conn_state = inner
1607 .conns
1608 .entry(conn_id)
1609 .or_insert_with(SendConnState::new);
1610 match &mut req.body {
1611 RequestBody::Call(call) => {
1612 Self::prepare_call_schemas(
1613 conn_state,
1614 req.id,
1615 call.method_id,
1616 call,
1617 forwarded_schemas,
1618 );
1619 }
1620 RequestBody::Response(resp) => {
1621 if let Some(method_id) = conn_state.inflight_incoming.remove(&req.id) {
1622 Self::prepare_response_schemas(
1623 conn_state,
1624 req.id,
1625 method_id,
1626 resp,
1627 forwarded_schemas,
1628 );
1629 }
1630 }
1631 RequestBody::Cancel(_) => {}
1632 }
1633 }
1634
1635 inner.tx.clone()
1636 };
1637 tx.send_msg(msg, binder).await.map_err(|_| ())
1638 }
1639
1640 pub(crate) fn record_incoming_call(
1643 &self,
1644 conn_id: ConnectionId,
1645 request_id: RequestId,
1646 method_id: vox_types::MethodId,
1647 ) {
1648 let mut inner = self.inner.lock().expect("session core mutex poisoned");
1649 let conn_state = inner
1650 .conns
1651 .entry(conn_id)
1652 .or_insert_with(SendConnState::new);
1653 vox_types::dlog!(
1654 "[schema] record_incoming_call: conn={:?} req={:?} method={:?}",
1655 conn_id,
1656 request_id,
1657 method_id
1658 );
1659 conn_state.inflight_incoming.insert(request_id, method_id);
1660 }
1661
1662 pub(crate) fn take_outgoing_call_method(
1663 &self,
1664 conn_id: ConnectionId,
1665 request_id: RequestId,
1666 ) -> Option<vox_types::MethodId> {
1667 let mut inner = self.inner.lock().expect("session core mutex poisoned");
1668 inner
1669 .conns
1670 .get_mut(&conn_id)
1671 .and_then(|conn_state| conn_state.inflight_outgoing.remove(&request_id))
1672 }
1673
1674 pub(crate) fn prepare_response_for_method(
1675 &self,
1676 conn_id: ConnectionId,
1677 request_id: RequestId,
1678 method_id: vox_types::MethodId,
1679 response: &mut RequestResponse<'_>,
1680 ) {
1681 let mut inner = self.inner.lock().expect("session core mutex poisoned");
1682 let conn_state = inner
1683 .conns
1684 .entry(conn_id)
1685 .or_insert_with(SendConnState::new);
1686 conn_state.inflight_incoming.remove(&request_id);
1687 Self::prepare_response_schemas(conn_state, request_id, method_id, response, None);
1688 }
1689
1690 pub(crate) fn schema_registry(&self, conn_id: ConnectionId) -> vox_types::SchemaRegistry {
1693 let inner = self.inner.lock().expect("session core mutex poisoned");
1694 inner
1695 .conns
1696 .get(&conn_id)
1697 .map(|cs| cs.send_tracker.registry().clone())
1698 .unwrap_or_default()
1699 }
1700
1701 pub(crate) fn prepare_response_from_source(
1703 &self,
1704 conn_id: ConnectionId,
1705 request_id: RequestId,
1706 method_id: vox_types::MethodId,
1707 root_type: &vox_types::TypeRef,
1708 source: &dyn vox_types::SchemaSource,
1709 response: &mut RequestResponse<'_>,
1710 ) {
1711 let mut inner = self.inner.lock().expect("session core mutex poisoned");
1712 let conn_state = inner
1713 .conns
1714 .entry(conn_id)
1715 .or_insert_with(SendConnState::new);
1716 conn_state.inflight_incoming.remove(&request_id);
1717 let key = (vox_types::BindingDirection::Response, method_id);
1718 if conn_state.method_tracker.contains(&key) {
1719 return;
1720 }
1721 let cbor = conn_state.send_tracker.prepare_send(
1722 method_id,
1723 vox_types::BindingDirection::Response,
1724 root_type,
1725 source,
1726 );
1727 if !cbor.is_empty() {
1728 response.schemas = cbor;
1729 }
1730 conn_state.method_tracker.insert(key);
1731 }
1732
1733 fn prepare_response_schemas(
1734 conn_state: &mut SendConnState,
1735 request_id: RequestId,
1736 method_id: vox_types::MethodId,
1737 response: &mut RequestResponse<'_>,
1738 forwarded_schemas: Option<&vox_types::SchemaRecvTracker>,
1739 ) {
1740 let key = (vox_types::BindingDirection::Response, method_id);
1741 if conn_state.method_tracker.contains(&key) {
1742 return;
1743 }
1744
1745 let prepared = match &response.ret {
1746 vox_types::Payload::Value { shape, .. } => {
1747 match conn_state
1748 .send_tracker
1749 .attach_schemas_for_shape_if_needed(method_id, shape, response)
1750 {
1751 Ok(schemas) => {
1752 vox_types::dlog!(
1753 "[schema] prepared {} bytes of response schemas for method {:?} (req {:?})",
1754 schemas.0.len(),
1755 method_id,
1756 request_id
1757 );
1758 true
1759 }
1760 Err(e) => {
1761 tracing::error!("schema extraction failed: {e}");
1762 false
1763 }
1764 }
1765 }
1766 vox_types::Payload::PostcardBytes(_) => {
1767 let Some(source) = forwarded_schemas else {
1768 tracing::error!(
1769 "schema attachment failed: missing forwarded response schemas for method {:?}",
1770 method_id
1771 );
1772 return;
1773 };
1774 let Some(root) = source.get_remote_response_root(method_id) else {
1775 tracing::error!(
1776 "schema attachment failed: missing forwarded response root for method {:?}",
1777 method_id
1778 );
1779 return;
1780 };
1781 let schemas = conn_state.send_tracker.prepare_send(
1782 method_id,
1783 vox_types::BindingDirection::Response,
1784 &root,
1785 source,
1786 );
1787 response.schemas = schemas.clone();
1788 vox_types::dlog!(
1789 "[schema] prepared {} bytes of forwarded response schemas for method {:?} (req {:?})",
1790 schemas.0.len(),
1791 method_id,
1792 request_id
1793 );
1794 true
1795 }
1796 };
1797
1798 if prepared {
1799 conn_state.method_tracker.insert(key);
1800 }
1801 }
1802
1803 fn prepare_call_schemas(
1804 conn_state: &mut SendConnState,
1805 request_id: RequestId,
1806 method_id: vox_types::MethodId,
1807 call: &mut vox_types::RequestCall<'_>,
1808 forwarded_schemas: Option<&vox_types::SchemaRecvTracker>,
1809 ) {
1810 conn_state.inflight_outgoing.insert(request_id, method_id);
1811 let key = (vox_types::BindingDirection::Args, method_id);
1812 if conn_state.method_tracker.contains(&key) {
1813 return;
1814 }
1815
1816 let prepared = match &call.args {
1817 vox_types::Payload::Value { shape, .. } => {
1818 match conn_state
1819 .send_tracker
1820 .attach_schemas_for_shape_if_needed(method_id, shape, call)
1821 {
1822 Ok(_) => true,
1823 Err(e) => {
1824 tracing::error!("schema extraction failed: {e}");
1825 false
1826 }
1827 }
1828 }
1829 vox_types::Payload::PostcardBytes(_) => {
1830 let Some(source) = forwarded_schemas else {
1831 tracing::error!(
1832 "schema attachment failed: missing forwarded args schemas for method {:?}",
1833 method_id
1834 );
1835 return;
1836 };
1837 let Some(root) = source.get_remote_args_root(method_id) else {
1838 tracing::error!(
1839 "schema attachment failed: missing forwarded args root for method {:?}",
1840 method_id
1841 );
1842 return;
1843 };
1844 call.schemas = conn_state.send_tracker.prepare_send(
1845 method_id,
1846 vox_types::BindingDirection::Args,
1847 &root,
1848 source,
1849 );
1850 true
1851 }
1852 };
1853
1854 if prepared {
1855 conn_state.method_tracker.insert(key);
1856 }
1857 }
1858
1859 fn replace_tx_and_reset_schemas(&self, tx: Arc<dyn DynConduitTx>) {
1860 let mut inner = self.inner.lock().expect("session core mutex poisoned");
1861 inner.tx = tx;
1862 inner.conns.clear();
1863 }
1864}
1865
1866pub(crate) struct RecoveredConduit {
1867 pub tx: Arc<dyn DynConduitTx>,
1868 pub rx: Box<dyn DynConduitRx>,
1869 pub handshake: HandshakeResult,
1870}
1871
1872pub(crate) trait ConduitRecoverer: MaybeSend {
1873 fn next_conduit<'a>(
1874 &'a mut self,
1875 resume_key: Option<&'a SessionResumeKey>,
1876 ) -> BoxFut<'a, Result<RecoveredConduit, SessionError>>;
1877}
1878
1879pub trait DynConduitTx: MaybeSend + MaybeSync {
1880 fn send_msg<'a>(
1881 &'a self,
1882 msg: Message<'a>,
1883 binder: Option<&'a dyn vox_types::ChannelBinder>,
1884 ) -> BoxFut<'a, std::io::Result<()>>;
1885}
1886pub trait DynConduitRx: MaybeSend {
1887 fn recv_msg<'a>(&'a mut self)
1888 -> BoxFut<'a, std::io::Result<Option<SelfRef<Message<'static>>>>>;
1889}
1890
1891impl<T> DynConduitTx for T
1894where
1895 T: ConduitTx<Msg = MessageFamily> + MaybeSend + MaybeSync,
1896 for<'p> <T as ConduitTx>::Permit<'p>: MaybeSend,
1897{
1898 fn send_msg<'a>(
1899 &'a self,
1900 msg: Message<'a>,
1901 binder: Option<&'a dyn vox_types::ChannelBinder>,
1902 ) -> BoxFut<'a, std::io::Result<()>> {
1903 Box::pin(async move {
1904 let permit = self.reserve().await?;
1905 let result = if let Some(binder) = binder {
1906 vox_types::with_channel_binder(binder, || permit.send(msg))
1907 } else {
1908 permit.send(msg)
1909 };
1910 result.map_err(|e| std::io::Error::other(e.to_string()))
1911 })
1912 }
1913}
1914
1915impl<T> DynConduitRx for T
1916where
1917 T: ConduitRx<Msg = MessageFamily> + MaybeSend,
1918{
1919 fn recv_msg<'a>(
1920 &'a mut self,
1921 ) -> BoxFut<'a, std::io::Result<Option<SelfRef<Message<'static>>>>> {
1922 Box::pin(async move {
1923 self.recv()
1924 .await
1925 .map_err(|error| std::io::Error::other(error.to_string()))
1926 })
1927 }
1928}
1929
1930#[cfg(test)]
1931mod tests {
1932 use moire::sync::mpsc;
1933 use vox_types::{
1934 Backing, Conduit, ConnectionAccept, ConnectionReject, HandshakeResult, SelfRef,
1935 };
1936
1937 use super::*;
1938
1939 fn make_session() -> Session {
1940 let (a, b) = crate::memory_link_pair(32);
1941 std::mem::forget(b);
1943 let conduit = crate::BareConduit::new(a);
1944 let (tx, rx) = conduit.split();
1945 let (_open_tx, open_rx) = mpsc::channel::<OpenRequest>("session.open.test", 4);
1946 let (_close_tx, close_rx) = mpsc::channel::<CloseRequest>("session.close.test", 4);
1947 let (_resume_tx, resume_rx) = mpsc::channel::<ResumeRequest>("session.resume.test", 1);
1948 let (control_tx, control_rx) = mpsc::unbounded_channel("session.control.test");
1949 Session::pre_handshake(
1950 tx, rx, None, open_rx, close_rx, resume_rx, control_tx, control_rx, None, false, None,
1951 )
1952 }
1953
1954 fn resumed_handshake(
1955 our_settings: ConnectionSettings,
1956 peer_settings: ConnectionSettings,
1957 ) -> HandshakeResult {
1958 HandshakeResult {
1959 role: SessionRole::Initiator,
1960 our_settings,
1961 peer_settings,
1962 peer_supports_retry: true,
1963 session_resume_key: Some(SessionResumeKey([7; 16])),
1964 peer_resume_key: None,
1965 our_schema: vec![],
1966 peer_schema: vec![],
1967 }
1968 }
1969
1970 fn accept_ref() -> SelfRef<ConnectionAccept<'static>> {
1971 SelfRef::owning(
1972 Backing::Boxed(Box::<[u8]>::default()),
1973 ConnectionAccept {
1974 connection_settings: ConnectionSettings {
1975 parity: Parity::Even,
1976 max_concurrent_requests: 64,
1977 },
1978 metadata: vec![],
1979 },
1980 )
1981 }
1982
1983 fn reject_ref() -> SelfRef<ConnectionReject<'static>> {
1984 SelfRef::owning(
1985 Backing::Boxed(Box::<[u8]>::default()),
1986 ConnectionReject { metadata: vec![] },
1987 )
1988 }
1989
1990 #[tokio::test]
1991 async fn duplicate_connection_accept_is_ignored_after_first() {
1992 let mut session = make_session();
1993 let conn_id = ConnectionId(1);
1994 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.test.open_result");
1995
1996 session.conns.insert(
1997 conn_id,
1998 ConnectionSlot::PendingOutbound(PendingOutboundData {
1999 local_settings: ConnectionSettings {
2000 parity: Parity::Odd,
2001 max_concurrent_requests: 64,
2002 },
2003 result_tx: Some(result_tx),
2004 }),
2005 );
2006
2007 session.handle_inbound_accept(conn_id, accept_ref());
2008 let handle = result_rx
2009 .await
2010 .expect("pending outbound result should resolve")
2011 .expect("accept should resolve as Ok");
2012 assert_eq!(handle.connection_id(), conn_id);
2013
2014 session.handle_inbound_accept(conn_id, accept_ref());
2015 assert!(
2016 matches!(
2017 session.conns.get(&conn_id),
2018 Some(ConnectionSlot::Active(ConnectionState { id, .. })) if *id == conn_id
2019 ),
2020 "duplicate accept should keep existing active connection state"
2021 );
2022 }
2023
2024 #[tokio::test]
2025 async fn duplicate_connection_reject_is_ignored_after_first() {
2026 let mut session = make_session();
2027 let conn_id = ConnectionId(1);
2028 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.test.open_result");
2029
2030 session.conns.insert(
2031 conn_id,
2032 ConnectionSlot::PendingOutbound(PendingOutboundData {
2033 local_settings: ConnectionSettings {
2034 parity: Parity::Odd,
2035 max_concurrent_requests: 64,
2036 },
2037 result_tx: Some(result_tx),
2038 }),
2039 );
2040
2041 session.handle_inbound_reject(conn_id, reject_ref());
2042 let result = result_rx
2043 .await
2044 .expect("pending outbound result should resolve");
2045 assert!(
2046 matches!(result, Err(SessionError::Rejected(_))),
2047 "expected rejection, got: {result:?}"
2048 );
2049
2050 session.handle_inbound_reject(conn_id, reject_ref());
2051 assert!(
2052 !session.conns.contains_key(&conn_id),
2053 "duplicate reject should not recreate connection state"
2054 );
2055 }
2056
2057 #[test]
2058 fn out_of_order_accept_or_reject_without_pending_is_ignored() {
2059 let mut session = make_session();
2060 let conn_id = ConnectionId(99);
2061
2062 session.handle_inbound_accept(conn_id, accept_ref());
2063 session.handle_inbound_reject(conn_id, reject_ref());
2064
2065 assert!(
2066 session.conns.is_empty(),
2067 "out-of-order accept/reject should not mutate empty connection table"
2068 );
2069 }
2070
2071 #[tokio::test]
2072 async fn close_request_clears_pending_outbound_open() {
2073 let mut session = make_session();
2074 let (open_result_tx, open_result_rx) = moire::sync::oneshot::channel("session.open.result");
2075 let (close_result_tx, close_result_rx) =
2076 moire::sync::oneshot::channel("session.close.result");
2077
2078 session.conns.insert(
2079 ConnectionId(1),
2080 ConnectionSlot::PendingOutbound(PendingOutboundData {
2081 local_settings: ConnectionSettings {
2082 parity: Parity::Odd,
2083 max_concurrent_requests: 64,
2084 },
2085 result_tx: Some(open_result_tx),
2086 }),
2087 );
2088
2089 session
2090 .handle_close_request(CloseRequest {
2091 conn_id: ConnectionId(1),
2092 metadata: vec![],
2093 result_tx: close_result_tx,
2094 })
2095 .await;
2096
2097 let close_result = close_result_rx
2098 .await
2099 .expect("close result should be delivered");
2100 assert!(
2101 close_result.is_ok(),
2102 "close should succeed for pending outbound connection"
2103 );
2104
2105 assert!(
2106 open_result_rx.await.is_err(),
2107 "pending open result channel should be closed once the pending slot is removed"
2108 );
2109 }
2110
2111 #[test]
2112 fn resume_rejects_changed_local_root_settings() {
2113 let mut session = make_session();
2114 let local_settings = ConnectionSettings {
2115 parity: Parity::Odd,
2116 max_concurrent_requests: 64,
2117 };
2118 let peer_settings = ConnectionSettings {
2119 parity: Parity::Even,
2120 max_concurrent_requests: 64,
2121 };
2122 let _root = session
2123 .establish_from_handshake(resumed_handshake(
2124 local_settings.clone(),
2125 peer_settings.clone(),
2126 ))
2127 .expect("initial handshake should establish session");
2128
2129 let (link_a, _link_b) = crate::memory_link_pair(32);
2130 let conduit = crate::BareConduit::new(link_a);
2131 let (tx, rx) = conduit.split();
2132
2133 let result = session.resume_from_handshake(
2134 Arc::new(tx),
2135 Box::new(rx),
2136 resumed_handshake(
2137 ConnectionSettings {
2138 parity: Parity::Odd,
2139 max_concurrent_requests: 65,
2140 },
2141 peer_settings,
2142 ),
2143 );
2144
2145 assert!(
2146 matches!(
2147 &result,
2148 Err(SessionError::Protocol(message))
2149 if message == "local root settings changed across session resume"
2150 ),
2151 "expected local-root-settings mismatch, got: {result:?}"
2152 );
2153 }
2154
2155 #[test]
2156 fn resume_rejects_changed_peer_root_settings() {
2157 let mut session = make_session();
2158 let local_settings = ConnectionSettings {
2159 parity: Parity::Odd,
2160 max_concurrent_requests: 64,
2161 };
2162 let peer_settings = ConnectionSettings {
2163 parity: Parity::Even,
2164 max_concurrent_requests: 64,
2165 };
2166 let _root = session
2167 .establish_from_handshake(resumed_handshake(
2168 local_settings.clone(),
2169 peer_settings.clone(),
2170 ))
2171 .expect("initial handshake should establish session");
2172
2173 let (link_a, _link_b) = crate::memory_link_pair(32);
2174 let conduit = crate::BareConduit::new(link_a);
2175 let (tx, rx) = conduit.split();
2176
2177 let result = session.resume_from_handshake(
2178 Arc::new(tx),
2179 Box::new(rx),
2180 resumed_handshake(
2181 local_settings,
2182 ConnectionSettings {
2183 parity: Parity::Even,
2184 max_concurrent_requests: 65,
2185 },
2186 ),
2187 );
2188
2189 assert!(
2190 matches!(
2191 &result,
2192 Err(SessionError::Protocol(message))
2193 if message == "peer root settings changed across session resume"
2194 ),
2195 "expected peer-root-settings mismatch, got: {result:?}"
2196 );
2197 }
2198}