1use std::process::Stdio;
12use std::sync::atomic::{AtomicI64, AtomicUsize, Ordering};
13use std::sync::{Arc, Weak};
14
15use scc::HashMap as SccHashMap;
16use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
17use tokio::process::{Child, ChildStdout, Command};
18use tokio::sync::{broadcast, mpsc, oneshot};
19
20use crate::wire::{self, WireFrameCodec};
21use crate::TransportError;
22
23const EVENT_CHANNEL_CAPACITY: usize = 4096;
25
26const REQUEST_FRAME_QUEUE_CAPACITY: usize = 4096;
28
29const CONTROL_FRAME_QUEUE_CAPACITY: usize = 1024;
31
32const PENDING_REQUEST_LIMIT: usize = 4096;
34
35const SIDECAR_BIN_ENV: &str = "SECURE_EXEC_SIDECAR_BIN";
38
39pub type WireSidecarCallback = Arc<
41 dyn Fn(
42 wire::SidecarRequestPayload,
43 wire::OwnershipScope,
44 ) -> futures::future::BoxFuture<
45 'static,
46 Result<wire::SidecarResponsePayload, TransportError>,
47 > + Send
48 + Sync,
49>;
50
51pub struct SidecarTransport {
54 child: parking_lot::Mutex<Option<Child>>,
56 pending: SccHashMap<wire::RequestId, oneshot::Sender<wire::ResponsePayload>>,
58 pending_request_lock: parking_lot::Mutex<()>,
59 request_counter: AtomicI64,
61 max_frame_bytes: AtomicUsize,
63 event_tx: broadcast::Sender<(wire::OwnershipScope, wire::EventPayload)>,
65 callbacks: SccHashMap<&'static str, WireSidecarCallback>,
67 request_writer_tx: mpsc::Sender<Vec<u8>>,
69 control_writer_tx: mpsc::Sender<Vec<u8>>,
71}
72
73impl SidecarTransport {
74 pub async fn spawn(binary_path: Option<String>) -> Result<Arc<Self>, TransportError> {
79 let bin = resolve_sidecar_binary_path(binary_path);
80 let mut child = Command::new(&bin)
81 .stdin(Stdio::piped())
82 .stdout(Stdio::piped())
83 .stderr(Stdio::inherit())
84 .kill_on_drop(true)
85 .spawn()
86 .map_err(|error| {
87 TransportError::Sidecar(format!("failed to spawn sidecar '{bin}': {error}"))
88 })?;
89
90 let stdin = child
91 .stdin
92 .take()
93 .ok_or_else(|| TransportError::Sidecar("sidecar stdin was not piped".to_string()))?;
94 let stdout = child
95 .stdout
96 .take()
97 .ok_or_else(|| TransportError::Sidecar("sidecar stdout was not piped".to_string()))?;
98
99 let (request_writer_tx, request_writer_rx) = mpsc::channel(REQUEST_FRAME_QUEUE_CAPACITY);
100 let (control_writer_tx, control_writer_rx) = mpsc::channel(CONTROL_FRAME_QUEUE_CAPACITY);
101 let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
102
103 let transport = Arc::new(Self {
104 child: parking_lot::Mutex::new(Some(child)),
105 pending: SccHashMap::new(),
106 pending_request_lock: parking_lot::Mutex::new(()),
107 request_counter: AtomicI64::new(1),
108 max_frame_bytes: AtomicUsize::new(wire::DEFAULT_MAX_FRAME_BYTES),
109 event_tx,
110 callbacks: SccHashMap::new(),
111 request_writer_tx,
112 control_writer_tx,
113 });
114
115 tokio::spawn(run_writer(stdin, control_writer_rx, request_writer_rx));
116 tokio::spawn(run_reader(Arc::downgrade(&transport), stdout));
117
118 Ok(transport)
119 }
120
121 pub fn next_request_id(&self) -> wire::RequestId {
123 self.request_counter.fetch_add(1, Ordering::SeqCst)
124 }
125
126 pub async fn request_wire(
128 &self,
129 ownership: wire::OwnershipScope,
130 payload: wire::RequestPayload,
131 ) -> Result<wire::ResponsePayload, TransportError> {
132 self.request_wire_with_frame_limit(ownership, payload, None)
133 .await
134 }
135
136 pub async fn request_wire_bounded(
138 &self,
139 ownership: wire::OwnershipScope,
140 payload: wire::RequestPayload,
141 max_frame_bytes: usize,
142 ) -> Result<wire::ResponsePayload, TransportError> {
143 self.request_wire_with_frame_limit(ownership, payload, Some(max_frame_bytes))
144 .await
145 }
146
147 async fn request_wire_with_frame_limit(
148 &self,
149 ownership: wire::OwnershipScope,
150 payload: wire::RequestPayload,
151 max_frame_bytes: Option<usize>,
152 ) -> Result<wire::ResponsePayload, TransportError> {
153 let request_id = self.next_request_id();
154 let frame = wire::ProtocolFrame::RequestFrame(wire::RequestFrame {
155 schema: wire::protocol_schema(),
156 request_id,
157 ownership,
158 payload,
159 });
160 let bytes = self.encode_wire_frame(&frame, max_frame_bytes)?;
161
162 let (tx, rx) = oneshot::channel();
163 self.register_pending_request(request_id, tx)?;
164 let _pending_guard = PendingRequestGuard::new(self, request_id);
165
166 if self.request_writer_tx.send(bytes).await.is_err() {
167 self.pending.remove(&request_id);
168 return Err(TransportError::Sidecar(
169 "sidecar transport closed".to_string(),
170 ));
171 }
172
173 rx.await
174 .map_err(|_| TransportError::Sidecar("sidecar transport disconnected".to_string()))
175 }
176
177 pub fn subscribe_wire_events(
179 &self,
180 ) -> broadcast::Receiver<(wire::OwnershipScope, wire::EventPayload)> {
181 self.event_tx.subscribe()
182 }
183
184 pub fn register_wire_callback(&self, key: &'static str, callback: WireSidecarCallback) {
187 let _ = self.callbacks.insert(key, callback);
188 }
189
190 pub fn max_frame_bytes(&self) -> usize {
192 self.max_frame_bytes.load(Ordering::Relaxed)
193 }
194
195 pub fn set_max_frame_bytes(&self, max_frame_bytes: usize) {
197 self.max_frame_bytes
198 .store(max_frame_bytes, Ordering::SeqCst);
199 }
200
201 pub fn kill_child(&self) {
203 if let Some(mut child) = self.child.lock().take() {
204 let _ = child.start_kill();
205 }
206 }
207
208 fn encode_wire_frame(
209 &self,
210 frame: &wire::ProtocolFrame,
211 max_frame_bytes: Option<usize>,
212 ) -> Result<Vec<u8>, TransportError> {
213 let transport_limit = self.max_frame_bytes.load(Ordering::Relaxed);
214 let max_frame_bytes = max_frame_bytes
215 .map(|limit| limit.min(transport_limit))
216 .unwrap_or(transport_limit);
217 let codec = WireFrameCodec::new(max_frame_bytes);
218 Ok(codec.encode(frame)?)
219 }
220
221 async fn handle_wire_frame(self: &Arc<Self>, frame: wire::ProtocolFrame) {
224 match frame {
225 wire::ProtocolFrame::ResponseFrame(response) => {
226 match self.pending.remove(&response.request_id) {
227 Some((_, tx)) => {
228 let _ = tx.send(response.payload);
229 }
230 None => {
231 tracing::warn!(
232 request_id = response.request_id,
233 "response for unknown request id"
234 )
235 }
236 }
237 }
238 wire::ProtocolFrame::EventFrame(event) => {
239 let _ = self.event_tx.send((event.ownership, event.payload));
240 }
241 wire::ProtocolFrame::SidecarRequestFrame(request) => {
242 self.dispatch_sidecar_request(request).await
243 }
244 wire::ProtocolFrame::SidecarResponseFrame(_) | wire::ProtocolFrame::RequestFrame(_) => {
245 tracing::warn!("unexpected inbound frame on host transport")
246 }
247 }
248 }
249
250 async fn dispatch_sidecar_request(self: &Arc<Self>, frame: wire::SidecarRequestFrame) {
255 let key = sidecar_request_key(&frame.payload);
256 let callback = self.callbacks.read(&key, |_, value| value.clone());
257 match callback {
258 Some(callback) => {
259 let transport = Arc::downgrade(self);
260 tokio::spawn(async move {
261 match callback(frame.payload, frame.ownership.clone()).await {
262 Ok(payload) => {
263 let response = wire::ProtocolFrame::SidecarResponseFrame(
264 wire::SidecarResponseFrame {
265 schema: wire::protocol_schema(),
266 request_id: frame.request_id,
267 ownership: frame.ownership,
268 payload,
269 },
270 );
271 let Some(transport) = transport.upgrade() else {
273 return;
274 };
275 if let Ok(bytes) = transport.encode_wire_frame(&response, None) {
276 let _ = transport.control_writer_tx.send(bytes).await;
277 }
278 }
279 Err(error) => tracing::warn!(?error, key, "sidecar callback failed"),
280 }
281 });
282 }
283 None => tracing::warn!(key, "no callback registered for sidecar request"),
284 }
285 }
286
287 fn fail_all_pending(&self) {
289 self.pending.clear();
290 }
291
292 fn register_pending_request(
293 &self,
294 request_id: wire::RequestId,
295 tx: oneshot::Sender<wire::ResponsePayload>,
296 ) -> Result<(), TransportError> {
297 let _guard = self.pending_request_lock.lock();
298 if pending_request_count(self) >= PENDING_REQUEST_LIMIT {
299 return Err(TransportError::Sidecar(format!(
300 "sidecar pending request limit exceeded: at most {PENDING_REQUEST_LIMIT} requests can be in flight"
301 )));
302 }
303 let _ = self.pending.insert(request_id, tx);
304 Ok(())
305 }
306}
307
308struct PendingRequestGuard<'a> {
309 transport: &'a SidecarTransport,
310 request_id: wire::RequestId,
311}
312
313impl<'a> PendingRequestGuard<'a> {
314 fn new(transport: &'a SidecarTransport, request_id: wire::RequestId) -> Self {
315 Self {
316 transport,
317 request_id,
318 }
319 }
320}
321
322impl Drop for PendingRequestGuard<'_> {
323 fn drop(&mut self) {
324 let _ = self.transport.pending.remove(&self.request_id);
325 }
326}
327
328fn pending_request_count(transport: &SidecarTransport) -> usize {
329 let mut count = 0;
330 transport.pending.scan(|_, _| {
331 count += 1;
332 });
333 count
334}
335
336fn sidecar_request_key(payload: &wire::SidecarRequestPayload) -> &'static str {
338 match payload {
339 wire::SidecarRequestPayload::HostCallbackRequest(_) => "host_callback",
340 wire::SidecarRequestPayload::JsBridgeCallRequest(_) => "js_bridge_call",
341 wire::SidecarRequestPayload::ExtEnvelope(_) => "ext",
342 }
343}
344
345async fn run_writer<W>(
348 mut stdin: W,
349 mut control_rx: mpsc::Receiver<Vec<u8>>,
350 mut request_rx: mpsc::Receiver<Vec<u8>>,
351) where
352 W: AsyncWrite + Unpin,
353{
354 let mut prefer_control = true;
355 loop {
356 let (bytes, wrote_control) = if prefer_control {
357 tokio::select! {
358 biased;
359 bytes = control_rx.recv() => match bytes {
360 Some(bytes) => (bytes, true),
361 None => match request_rx.recv().await {
362 Some(bytes) => (bytes, false),
363 None => break,
364 },
365 },
366 bytes = request_rx.recv() => match bytes {
367 Some(bytes) => (bytes, false),
368 None => match control_rx.recv().await {
369 Some(bytes) => (bytes, true),
370 None => break,
371 },
372 },
373 }
374 } else {
375 tokio::select! {
376 biased;
377 bytes = request_rx.recv() => match bytes {
378 Some(bytes) => (bytes, false),
379 None => match control_rx.recv().await {
380 Some(bytes) => (bytes, true),
381 None => break,
382 },
383 },
384 bytes = control_rx.recv() => match bytes {
385 Some(bytes) => (bytes, true),
386 None => match request_rx.recv().await {
387 Some(bytes) => (bytes, false),
388 None => break,
389 },
390 },
391 }
392 };
393 if stdin.write_all(&bytes).await.is_err() {
394 break;
395 }
396 if stdin.flush().await.is_err() {
397 break;
398 }
399 prefer_control = !wrote_control;
400 }
401}
402
403async fn run_reader(transport: Weak<SidecarTransport>, mut stdout: ChildStdout) {
407 loop {
408 let mut length_buf = [0u8; 4];
409 if stdout.read_exact(&mut length_buf).await.is_err() {
410 break;
411 }
412 let length = u32::from_be_bytes(length_buf) as usize;
413
414 let Some(transport) = transport.upgrade() else {
415 break;
416 };
417 let max_frame_bytes = transport.max_frame_bytes.load(Ordering::Relaxed);
418 if frame_length_exceeds_limit(length, max_frame_bytes) {
419 tracing::warn!(
420 size = length,
421 max = max_frame_bytes,
422 "sidecar frame exceeds negotiated limit"
423 );
424 break;
425 }
426
427 let mut frame_bytes = vec![0u8; 4 + length];
428 frame_bytes[..4].copy_from_slice(&length_buf);
429 if stdout.read_exact(&mut frame_bytes[4..]).await.is_err() {
430 break;
431 }
432
433 let codec = WireFrameCodec::new(max_frame_bytes);
434 match codec.decode(&frame_bytes) {
435 Ok(frame) => transport.handle_wire_frame(frame).await,
436 Err(error) => tracing::warn!(?error, "failed to decode sidecar frame"),
437 }
438 }
439
440 if let Some(transport) = transport.upgrade() {
441 transport.fail_all_pending();
442 }
443}
444
445fn frame_length_exceeds_limit(length: usize, max_frame_bytes: usize) -> bool {
446 length > max_frame_bytes
447}
448
449fn resolve_sidecar_binary_path(binary_path: Option<String>) -> String {
450 binary_path
451 .or_else(|| std::env::var(SIDECAR_BIN_ENV).ok())
452 .unwrap_or_else(|| "secure-exec-sidecar".to_string())
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458 use std::sync::Mutex;
459
460 static ENV_LOCK: Mutex<()> = Mutex::new(());
461
462 fn test_transport() -> SidecarTransport {
463 let (request_writer_tx, _request_writer_rx) = mpsc::channel(REQUEST_FRAME_QUEUE_CAPACITY);
464 let (control_writer_tx, _control_writer_rx) = mpsc::channel(CONTROL_FRAME_QUEUE_CAPACITY);
465 let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
466 SidecarTransport {
467 child: parking_lot::Mutex::new(None),
468 pending: SccHashMap::new(),
469 pending_request_lock: parking_lot::Mutex::new(()),
470 request_counter: AtomicI64::new(1),
471 max_frame_bytes: AtomicUsize::new(wire::DEFAULT_MAX_FRAME_BYTES),
472 event_tx,
473 callbacks: SccHashMap::new(),
474 request_writer_tx,
475 control_writer_tx,
476 }
477 }
478
479 #[test]
480 fn binary_path_prefers_explicit_path_over_env() {
481 let _guard = ENV_LOCK.lock().expect("env lock");
482 let previous = std::env::var(SIDECAR_BIN_ENV).ok();
483 std::env::set_var(SIDECAR_BIN_ENV, "/tmp/from-env");
484
485 assert_eq!(
486 resolve_sidecar_binary_path(Some("/tmp/from-config".to_string())),
487 "/tmp/from-config"
488 );
489
490 restore_env(SIDECAR_BIN_ENV, previous);
491 }
492
493 #[test]
494 fn binary_path_uses_secure_exec_env_fallback() {
495 let _guard = ENV_LOCK.lock().expect("env lock");
496 let previous = std::env::var(SIDECAR_BIN_ENV).ok();
497 std::env::set_var(SIDECAR_BIN_ENV, "/tmp/secure-exec-sidecar");
498
499 assert_eq!(
500 resolve_sidecar_binary_path(None),
501 "/tmp/secure-exec-sidecar"
502 );
503
504 restore_env(SIDECAR_BIN_ENV, previous);
505 }
506
507 #[test]
508 fn binary_path_defaults_to_secure_exec_sidecar() {
509 let _guard = ENV_LOCK.lock().expect("env lock");
510 let previous = std::env::var(SIDECAR_BIN_ENV).ok();
511 std::env::remove_var(SIDECAR_BIN_ENV);
512
513 assert_eq!(resolve_sidecar_binary_path(None), "secure-exec-sidecar");
514
515 restore_env(SIDECAR_BIN_ENV, previous);
516 }
517
518 fn restore_env(key: &str, value: Option<String>) {
519 match value {
520 Some(value) => std::env::set_var(key, value),
521 None => std::env::remove_var(key),
522 }
523 }
524
525 #[test]
526 fn frame_length_limit_rejects_oversized_declared_length() {
527 assert!(!frame_length_exceeds_limit(1024, 1024));
528 assert!(frame_length_exceeds_limit(1025, 1024));
529 }
530
531 #[test]
532 fn transport_encodes_requests_with_generated_wire_codec() {
533 let transport = test_transport();
534 let frame = wire::ProtocolFrame::RequestFrame(wire::RequestFrame {
535 schema: wire::protocol_schema(),
536 request_id: 7,
537 ownership: wire::OwnershipScope::ConnectionOwnership(wire::ConnectionOwnership {
538 connection_id: "conn-1".to_string(),
539 }),
540 payload: wire::RequestPayload::AuthenticateRequest(wire::AuthenticateRequest {
541 client_name: "transport-test".to_string(),
542 auth_token: "token".to_string(),
543 protocol_version: wire::PROTOCOL_VERSION,
544 bridge_version: 1,
545 }),
546 });
547
548 let encoded = transport
549 .encode_wire_frame(&frame, None)
550 .expect("encode transport frame");
551 let decoded = WireFrameCodec::default()
552 .decode(&encoded)
553 .expect("decode generated wire frame");
554
555 assert!(matches!(
556 decoded,
557 wire::ProtocolFrame::RequestFrame(wire::RequestFrame {
558 payload: wire::RequestPayload::AuthenticateRequest(_),
559 ..
560 })
561 ));
562 }
563
564 #[tokio::test]
565 async fn transport_fans_out_generated_wire_events() {
566 let transport = Arc::new(test_transport());
567 let mut wire_events = transport.subscribe_wire_events();
568
569 transport
570 .handle_wire_frame(wire::ProtocolFrame::EventFrame(wire::EventFrame {
571 schema: wire::protocol_schema(),
572 ownership: wire::OwnershipScope::VmOwnership(wire::VmOwnership {
573 connection_id: "conn-1".to_string(),
574 session_id: "session-1".to_string(),
575 vm_id: "vm-1".to_string(),
576 }),
577 payload: wire::EventPayload::ProcessOutputEvent(wire::ProcessOutputEvent {
578 process_id: "proc-1".to_string(),
579 channel: wire::StreamChannel::Stdout,
580 chunk: b"hello".to_vec(),
581 }),
582 }))
583 .await;
584
585 let (ownership, payload) = wire_events.recv().await.expect("wire event");
586 assert!(matches!(
587 ownership,
588 wire::OwnershipScope::VmOwnership(wire::VmOwnership {
589 connection_id,
590 session_id,
591 vm_id,
592 }) if connection_id == "conn-1" && session_id == "session-1" && vm_id == "vm-1"
593 ));
594 assert!(matches!(
595 payload,
596 wire::EventPayload::ProcessOutputEvent(wire::ProcessOutputEvent {
597 process_id,
598 channel: wire::StreamChannel::Stdout,
599 chunk,
600 }) if process_id == "proc-1" && chunk == b"hello".to_vec()
601 ));
602 }
603
604 #[test]
605 fn pending_request_guard_removes_registered_slot_on_drop() {
606 let transport = test_transport();
607 let (tx, _rx) = oneshot::channel();
608 transport
609 .register_pending_request(1, tx)
610 .expect("register pending request");
611
612 {
613 let _guard = PendingRequestGuard::new(&transport, 1);
614 assert_eq!(pending_request_count(&transport), 1);
615 }
616
617 assert_eq!(pending_request_count(&transport), 0);
618 }
619
620 #[test]
621 fn pending_request_limit_rejects_full_transport() {
622 let transport = test_transport();
623 for request_id in 1..=PENDING_REQUEST_LIMIT as wire::RequestId {
624 let (tx, _rx) = oneshot::channel();
625 transport
626 .register_pending_request(request_id, tx)
627 .expect("register pending request");
628 }
629 let (tx, _rx) = oneshot::channel();
630 let error = transport
631 .register_pending_request((PENDING_REQUEST_LIMIT + 1) as wire::RequestId, tx)
632 .expect_err("full pending map should reject");
633
634 assert!(
635 error
636 .to_string()
637 .contains("sidecar pending request limit exceeded"),
638 "unexpected error: {error}"
639 );
640 }
641
642 #[tokio::test]
643 async fn writer_prioritizes_control_frames_over_request_backlog() {
644 let (client, mut server) = tokio::io::duplex(64);
645 let (control_tx, control_rx) = mpsc::channel(CONTROL_FRAME_QUEUE_CAPACITY);
646 let (request_tx, request_rx) = mpsc::channel(REQUEST_FRAME_QUEUE_CAPACITY);
647 request_tx
648 .send(vec![b'r'])
649 .await
650 .expect("send request frame");
651 control_tx
652 .send(vec![b'c'])
653 .await
654 .expect("send control frame");
655 drop(control_tx);
656 drop(request_tx);
657
658 let writer = tokio::spawn(run_writer(client, control_rx, request_rx));
659 let mut first = [0u8; 1];
660 server
661 .read_exact(&mut first)
662 .await
663 .expect("read first byte");
664 writer.await.expect("writer task");
665
666 assert_eq!(first, [b'c']);
667 }
668
669 #[tokio::test]
670 async fn writer_alternates_when_control_and_request_are_ready() {
671 let (client, mut server) = tokio::io::duplex(64);
672 let (control_tx, control_rx) = mpsc::channel(CONTROL_FRAME_QUEUE_CAPACITY);
673 let (request_tx, request_rx) = mpsc::channel(REQUEST_FRAME_QUEUE_CAPACITY);
674 control_tx.send(vec![b'c']).await.expect("control one");
675 control_tx.send(vec![b'C']).await.expect("control two");
676 request_tx.send(vec![b'r']).await.expect("request one");
677 request_tx.send(vec![b'R']).await.expect("request two");
678 drop(control_tx);
679 drop(request_tx);
680
681 let writer = tokio::spawn(run_writer(client, control_rx, request_rx));
682 let mut output = [0u8; 4];
683 server.read_exact(&mut output).await.expect("read output");
684 writer.await.expect("writer task");
685
686 assert_eq!(output, [b'c', b'r', b'C', b'R']);
687 }
688}