Skip to main content

stakpak_server/
sandbox.rs

1//! Per-session sandboxed MCP server management.
2//!
3//! When sandbox mode is enabled for a session, a `stakpak mcp start` server
4//! is spawned inside a warden container. The host-side proxy connects to it
5//! via HTTPS/mTLS, and tool calls from the agent loop are routed through the
6//! containerized server — executing `run_command`, file I/O, etc. inside the
7//! sandbox.
8//!
9//! **mTLS key exchange** — each side generates its own identity independently:
10//!
11//! 1. Host generates a client identity (CA + leaf cert + key, all in memory)
12//! 2. Host passes the client **CA cert** (public only) to the container via env var
13//! 3. Container generates a server identity (CA + leaf cert + key, all in memory)
14//! 4. Container outputs the server **CA cert** (public only) to stdout
15//! 5. Host parses the server CA cert and builds a client TLS config
16//!
17//! Private keys never leave their respective processes.
18
19use stakpak_mcp_client::McpClient;
20use stakpak_mcp_proxy::client::{ClientPoolConfig, ServerConfig};
21use stakpak_mcp_proxy::server::start_proxy_server;
22use stakpak_shared::cert_utils::{CertificateChain, MtlsIdentity};
23use std::collections::HashMap;
24use std::path::Path;
25use std::sync::Arc;
26use tokio::io::AsyncBufReadExt;
27use tokio::net::TcpListener;
28use tokio::process::Child;
29use tokio::sync::broadcast;
30
31/// Environment variable used to pass the client CA cert PEM to the container.
32const TRUSTED_CLIENT_CA_ENV: &str = "STAKPAK_MCP_CLIENT_CA";
33
34/// Configuration for spawning sandboxed MCP servers.
35#[derive(Clone, Debug)]
36pub struct SandboxConfig {
37    /// Path to the warden binary.
38    pub warden_path: String,
39    /// Container image for the sandbox (e.g., `ghcr.io/stakpak/agent:v1.2.3`).
40    pub image: String,
41    /// Volume mounts for the container (e.g., `["./:/agent:ro"]`).
42    pub volumes: Vec<String>,
43}
44
45/// A running sandboxed MCP server with its associated proxy and client.
46///
47/// Drop this struct to shut down the sandbox.
48pub struct SandboxedMcpServer {
49    /// MCP client connected via the per-session proxy.
50    pub client: Arc<McpClient>,
51    /// Tools available from the sandboxed server.
52    pub tools: Vec<stakai::Tool>,
53    /// Channel to shut down the per-session proxy.
54    proxy_shutdown_tx: broadcast::Sender<()>,
55    /// The warden container child process.
56    container_process: Child,
57}
58
59impl SandboxedMcpServer {
60    /// Spawn a sandboxed MCP server inside a warden container and connect to it.
61    ///
62    /// 1. Generates a client mTLS identity (private key stays in host memory)
63    /// 2. Passes the client CA cert (public) to the container via env var
64    /// 3. Spawns `warden wrap <image> -- stakpak mcp start`
65    /// 4. Parses the server CA cert (public) from the container's stdout
66    /// 5. Builds a client TLS config trusting the server CA, using the client key
67    /// 6. Starts a per-session MCP proxy pointing to the container
68    /// 7. Connects a client to the proxy
69    pub async fn spawn(config: &SandboxConfig) -> Result<Self, String> {
70        // 1. Generate client identity — private key stays in host memory
71        let client_identity = MtlsIdentity::generate_client()
72            .map_err(|e| format!("Failed to generate client identity: {e}"))?;
73
74        let client_ca_pem = client_identity
75            .ca_cert_pem()
76            .map_err(|e| format!("Failed to get client CA PEM: {e}"))?;
77
78        // 2. Find a free port for the container's MCP server to expose
79        let container_host_port = find_free_port()
80            .await
81            .map_err(|e| format!("Failed to find free port for sandbox: {e}"))?;
82
83        // 3. Spawn warden container, passing client CA cert (public) via env var
84        let mut container_process =
85            spawn_warden_container(config, container_host_port, &client_ca_pem)
86                .await
87                .map_err(|e| format!("Failed to spawn sandbox container: {e}"))?;
88
89        // 4. Parse the server CA cert (public) from the container's stdout
90        let server_ca_pem = parse_server_ca_from_stdout(&mut container_process).await?;
91        tracing::info!(
92            "Parsed server CA from container stdout ({} bytes)",
93            server_ca_pem.len()
94        );
95
96        // 5. Build client TLS config — trusts server CA, authenticates with our key
97        let container_client_config = client_identity
98            .create_client_config(&server_ca_pem)
99            .map_err(|e| format!("Failed to create client TLS config: {e}"))?;
100
101        // 6. Wait for the MCP server inside the container to be ready
102        let server_url = format!("https://127.0.0.1:{container_host_port}/mcp");
103        tracing::info!(url = %server_url, "Waiting for sandbox MCP server to be ready");
104        wait_for_server_ready(&server_url, &container_client_config).await?;
105        tracing::info!("Sandbox MCP server is ready");
106
107        // 7. Start a per-session proxy connecting to the sandboxed server
108        let (proxy_shutdown_tx, proxy_shutdown_rx) = broadcast::channel::<()>(1);
109
110        let proxy_binding = find_available_binding("sandbox proxy").await?;
111        let proxy_url = format!("https://{}/mcp", proxy_binding.address);
112
113        let proxy_cert_chain = Arc::new(
114            CertificateChain::generate()
115                .map_err(|e| format!("Failed to generate proxy certificates: {e}"))?,
116        );
117
118        let pool_config = build_sandbox_proxy_config(server_url, Arc::new(container_client_config));
119
120        let proxy_chain_for_server = proxy_cert_chain.clone();
121        let proxy_listener = proxy_binding.listener;
122        tokio::spawn(async move {
123            if let Err(e) = start_proxy_server(
124                pool_config,
125                proxy_listener,
126                proxy_chain_for_server,
127                true,  // redact_secrets
128                false, // privacy_mode
129                Some(proxy_shutdown_rx),
130            )
131            .await
132            {
133                tracing::error!("Sandbox proxy error: {e}");
134            }
135        });
136
137        // Small delay for proxy to start
138        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
139
140        // 8. Connect client to proxy
141        let client = connect_to_proxy(&proxy_url, proxy_cert_chain).await?;
142
143        // 9. Get tools
144        let mcp_tools = stakpak_mcp_client::get_tools(&client)
145            .await
146            .map_err(|e| format!("Failed to get sandbox tools: {e}"))?;
147
148        let tools = mcp_tools
149            .into_iter()
150            .map(|tool| stakai::Tool {
151                tool_type: "function".to_string(),
152                function: stakai::ToolFunction {
153                    name: tool.name.as_ref().to_string(),
154                    description: tool
155                        .description
156                        .as_ref()
157                        .map(std::string::ToString::to_string)
158                        .unwrap_or_default(),
159                    parameters: serde_json::Value::Object((*tool.input_schema).clone()),
160                },
161                provider_options: None,
162            })
163            .collect();
164
165        Ok(Self {
166            client,
167            tools,
168            proxy_shutdown_tx,
169            container_process,
170        })
171    }
172
173    /// Shut down the sandbox: stop the proxy and kill the container.
174    pub async fn shutdown(mut self) {
175        let _ = self.proxy_shutdown_tx.send(());
176
177        let _ = self.container_process.kill().await;
178        let _ = self.container_process.wait().await;
179    }
180}
181
182async fn spawn_warden_container(
183    config: &SandboxConfig,
184    host_port: u16,
185    client_ca_pem: &str,
186) -> Result<Child, String> {
187    use stakpak_shared::container::{expand_volume_path, is_named_volume};
188
189    let mut cmd = tokio::process::Command::new(&config.warden_path);
190    cmd.arg("wrap");
191    cmd.arg(&config.image);
192
193    // Mount configured volumes
194    for vol in &config.volumes {
195        let expanded = expand_volume_path(vol);
196        let host_path = expanded.split(':').next().unwrap_or(&expanded);
197        // Named volumes (e.g. "stakpak-aqua-cache:/container/path") don't have a
198        // host filesystem path — mount them unconditionally. Bind mounts are only
199        // added when the host path actually exists.
200        if is_named_volume(host_path) || Path::new(host_path).exists() {
201            cmd.args(["--volume", &expanded]);
202        }
203    }
204
205    // Port forwarding for the MCP server — publish on the sidecar so the
206    // host can reach the container's MCP server port directly.
207    cmd.args(["-p", &format!("127.0.0.1:{host_port}:8080")]);
208
209    // Prevent warden re-entry
210    cmd.args(["--env", "STAKPAK_SKIP_WARDEN=1"]);
211
212    // Tell the MCP server to bind to a fixed port inside the container
213    // so it matches the published port on the sidecar.
214    cmd.args(["--env", "STAKPAK_MCP_PORT=8080"]);
215
216    // Pass the client CA cert (public only) so the server can trust the client.
217    cmd.args(["--env", &format!("{TRUSTED_CLIENT_CA_ENV}={client_ca_pem}")]);
218
219    // Pass through API credentials if set
220    if let Ok(api_key) = std::env::var("STAKPAK_API_KEY") {
221        cmd.args(["--env", &format!("STAKPAK_API_KEY={api_key}")]);
222    }
223    if let Ok(profile) = std::env::var("STAKPAK_PROFILE") {
224        cmd.args(["--env", &format!("STAKPAK_PROFILE={profile}")]);
225    }
226    if let Ok(endpoint) = std::env::var("STAKPAK_API_ENDPOINT") {
227        cmd.args(["--env", &format!("STAKPAK_API_ENDPOINT={endpoint}")]);
228    }
229
230    // The MCP server detects STAKPAK_MCP_CLIENT_CA and generates its own
231    // server identity, outputting the server CA cert to stdout.
232    cmd.args(["--", "stakpak", "mcp", "start"]);
233
234    cmd.stdout(std::process::Stdio::piped());
235    cmd.stderr(std::process::Stdio::piped());
236    cmd.stdin(std::process::Stdio::null());
237
238    let child = cmd
239        .spawn()
240        .map_err(|e| format!("Failed to spawn warden process: {e}"))?;
241
242    Ok(child)
243}
244
245/// Parse the server CA certificate PEM from the container's stdout.
246///
247/// The MCP server outputs the server CA cert between structured delimiters:
248/// ```text
249/// ---BEGIN STAKPAK SERVER CA---
250/// -----BEGIN CERTIFICATE-----
251/// ...
252/// -----END CERTIFICATE-----
253/// ---END STAKPAK SERVER CA---
254/// ```
255async fn parse_server_ca_from_stdout(process: &mut Child) -> Result<String, String> {
256    let stdout = process
257        .stdout
258        .take()
259        .ok_or_else(|| "Container stdout not captured".to_string())?;
260
261    let mut reader = tokio::io::BufReader::new(stdout);
262    let mut server_ca_pem = String::new();
263    let mut in_server_ca = false;
264    let mut line = String::new();
265
266    let timeout_duration = tokio::time::Duration::from_secs(60);
267    let deadline = tokio::time::Instant::now() + timeout_duration;
268
269    tracing::debug!("Starting to read container stdout for server CA...");
270
271    loop {
272        line.clear();
273        let bytes_read = tokio::time::timeout_at(deadline, reader.read_line(&mut line))
274            .await
275            .map_err(|_| {
276                "Timed out waiting for container to output server CA certificate".to_string()
277            })?
278            .map_err(|e| format!("Failed to read container stdout: {e}"))?;
279
280        if bytes_read == 0 {
281            tracing::error!("Container stdout EOF before server CA was found");
282            return Err("Container exited before outputting server CA certificate".to_string());
283        }
284
285        let trimmed = line.trim();
286        tracing::debug!(line = %trimmed, bytes = bytes_read, "Read line from container stdout");
287
288        if trimmed == "---BEGIN STAKPAK SERVER CA---" {
289            in_server_ca = true;
290            continue;
291        }
292
293        if trimmed == "---END STAKPAK SERVER CA---" {
294            tracing::debug!("Found end of server CA block");
295            break;
296        }
297
298        if in_server_ca {
299            server_ca_pem.push_str(trimmed);
300            server_ca_pem.push('\n');
301        }
302    }
303
304    let server_ca_pem = server_ca_pem.trim().to_string();
305
306    if server_ca_pem.is_empty() {
307        return Err("Failed to parse server CA certificate from container output".to_string());
308    }
309
310    Ok(server_ca_pem)
311}
312
313async fn wait_for_server_ready(
314    url: &str,
315    client_config: &rustls::ClientConfig,
316) -> Result<(), String> {
317    let http_client = reqwest::Client::builder()
318        .use_preconfigured_tls(client_config.clone())
319        .build()
320        .map_err(|e| format!("Failed to build readiness check client: {e}"))?;
321
322    let mut last_error = String::new();
323    for attempt in 0..30 {
324        tokio::time::sleep(tokio::time::Duration::from_millis(if attempt < 5 {
325            500
326        } else {
327            1000
328        }))
329        .await;
330
331        match http_client.get(url).send().await {
332            Ok(_) => {
333                tracing::info!(attempt, "Sandbox MCP server ready");
334                return Ok(());
335            }
336            Err(e) => {
337                last_error = format!("{e:?}");
338                tracing::debug!(attempt, error = %last_error, "Readiness check failed");
339            }
340        }
341    }
342
343    Err(format!(
344        "Sandbox MCP server failed to become ready after 30 attempts: {last_error}"
345    ))
346}
347
348struct ProxyBinding {
349    address: String,
350    listener: TcpListener,
351}
352
353async fn find_available_binding(purpose: &str) -> Result<ProxyBinding, String> {
354    let listener = TcpListener::bind("127.0.0.1:0")
355        .await
356        .map_err(|e| format!("Failed to bind port for {purpose}: {e}"))?;
357    let addr = listener
358        .local_addr()
359        .map_err(|e| format!("Failed to get address for {purpose}: {e}"))?;
360    Ok(ProxyBinding {
361        address: addr.to_string(),
362        listener,
363    })
364}
365
366async fn find_free_port() -> Result<u16, String> {
367    let listener = TcpListener::bind("127.0.0.1:0")
368        .await
369        .map_err(|e| format!("Failed to bind ephemeral port: {e}"))?;
370    let port = listener
371        .local_addr()
372        .map_err(|e| format!("Failed to get ephemeral port: {e}"))?
373        .port();
374    // Drop the listener to free the port for Docker to use
375    drop(listener);
376    Ok(port)
377}
378
379fn build_sandbox_proxy_config(
380    sandbox_server_url: String,
381    client_tls_config: Arc<rustls::ClientConfig>,
382) -> ClientPoolConfig {
383    let mut servers: HashMap<String, ServerConfig> = HashMap::new();
384
385    // Register the sandboxed MCP server under the same name ("stakpak") so
386    // tool names like `stakpak__run_command` route correctly through the proxy.
387    servers.insert(
388        "stakpak".to_string(),
389        ServerConfig::Http {
390            url: sandbox_server_url,
391            headers: None,
392            certificate_chain: Arc::new(None),
393            client_tls_config: Some(client_tls_config),
394        },
395    );
396
397    // Keep the external paks server accessible
398    servers.insert(
399        "paks".to_string(),
400        ServerConfig::Http {
401            url: "https://apiv2.stakpak.dev/v1/paks/mcp".to_string(),
402            headers: None,
403            certificate_chain: Arc::new(None),
404            client_tls_config: None,
405        },
406    );
407
408    ClientPoolConfig::with_servers(servers)
409}
410
411async fn connect_to_proxy(
412    proxy_url: &str,
413    cert_chain: Arc<CertificateChain>,
414) -> Result<Arc<McpClient>, String> {
415    const MAX_RETRIES: u32 = 5;
416    let mut retry_delay = tokio::time::Duration::from_millis(50);
417    let mut last_error = None;
418
419    for attempt in 1..=MAX_RETRIES {
420        match stakpak_mcp_client::connect_https(proxy_url, Some(cert_chain.clone()), None).await {
421            Ok(client) => return Ok(Arc::new(client)),
422            Err(e) => {
423                last_error = Some(e);
424                if attempt < MAX_RETRIES {
425                    tokio::time::sleep(retry_delay).await;
426                    retry_delay *= 2;
427                }
428            }
429        }
430    }
431
432    Err(format!(
433        "Failed to connect to sandbox proxy after {MAX_RETRIES} retries: {}",
434        last_error.map(|e| e.to_string()).unwrap_or_default()
435    ))
436}
437
438#[cfg(test)]
439mod tests {
440    #[test]
441    fn parse_server_ca_from_structured_output() {
442        let output = "\
443🔐 mTLS enabled - independent identity (sandbox mode)
444---BEGIN STAKPAK SERVER CA---
445-----BEGIN CERTIFICATE-----
446MIIB0zCCAXmgAwIBAgIUFAKE=
447-----END CERTIFICATE-----
448---END STAKPAK SERVER CA---
449MCP server started at https://0.0.0.0:8080/mcp
450";
451
452        let expected_ca = "\
453-----BEGIN CERTIFICATE-----
454MIIB0zCCAXmgAwIBAgIUFAKE=
455-----END CERTIFICATE-----";
456
457        // Parse the same way parse_server_ca_from_stdout does
458        let mut server_ca_pem = String::new();
459        let mut in_server_ca = false;
460
461        for line in output.lines() {
462            let trimmed = line.trim();
463            if trimmed == "---BEGIN STAKPAK SERVER CA---" {
464                in_server_ca = true;
465                continue;
466            }
467            if trimmed == "---END STAKPAK SERVER CA---" {
468                break;
469            }
470            if in_server_ca {
471                server_ca_pem.push_str(trimmed);
472                server_ca_pem.push('\n');
473            }
474        }
475
476        assert_eq!(server_ca_pem.trim(), expected_ca);
477    }
478
479    #[test]
480    fn mtls_identity_cross_trust() {
481        use stakpak_shared::cert_utils::MtlsIdentity;
482
483        // Ensure a crypto provider is installed for rustls
484        let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
485
486        // Simulate the sandbox mTLS exchange
487        let client_identity = MtlsIdentity::generate_client().expect("generate client identity");
488        let server_identity = MtlsIdentity::generate_server().expect("generate server identity");
489
490        let client_ca_pem = client_identity.ca_cert_pem().expect("client CA PEM");
491        let server_ca_pem = server_identity.ca_cert_pem().expect("server CA PEM");
492
493        // Server trusts client CA, client trusts server CA
494        let _server_config = server_identity
495            .create_server_config(&client_ca_pem)
496            .expect("server config with client CA trust");
497        let _client_config = client_identity
498            .create_client_config(&server_ca_pem)
499            .expect("client config with server CA trust");
500
501        // Only public CA certs were exchanged — private keys stayed in their
502        // respective MtlsIdentity structs.
503        assert!(client_ca_pem.contains("BEGIN CERTIFICATE"));
504        assert!(server_ca_pem.contains("BEGIN CERTIFICATE"));
505        assert!(!client_ca_pem.contains("PRIVATE KEY"));
506        assert!(!server_ca_pem.contains("PRIVATE KEY"));
507    }
508
509    // ── Named volume detection in expand_volume_path / mount filter ────
510
511    #[test]
512    fn expand_volume_path_leaves_named_volumes_unchanged() {
513        use stakpak_shared::container::expand_volume_path;
514        let named = "stakpak-aqua-cache:/home/agent/.local/share/aquaproj-aqua";
515        assert_eq!(expand_volume_path(named), named);
516    }
517
518    /// Named volumes (no `/` or `.` prefix in the host part) must pass the
519    /// mount filter even though they don't exist on the host filesystem.
520    #[test]
521    fn named_volume_is_detected_correctly() {
522        use stakpak_shared::container::is_named_volume;
523        let cases = vec![
524            ("stakpak-aqua-cache", true),
525            ("my-volume", true),
526            ("./relative/path", false),
527            ("/absolute/path", false),
528            ("relative/with/slash", false),
529            (".", false),
530        ];
531        for (host_part, expected) in cases {
532            assert_eq!(
533                is_named_volume(host_part),
534                expected,
535                "host_part={host_part:?} expected named={expected}"
536            );
537        }
538    }
539}