Skip to main content

stakpak_server/
sandbox.rs

1//! Per-session sandboxed MCP server management.
2//!
3//! When sandbox mode is enabled for a session, a `stakpak mcp start` server
4//! is spawned inside a warden container. The host-side proxy connects to it
5//! via HTTPS/mTLS, and tool calls from the agent loop are routed through the
6//! containerized server — executing `run_command`, file I/O, etc. inside the
7//! sandbox.
8//!
9//! **mTLS key exchange** — each side generates its own identity independently:
10//!
11//! 1. Host generates a client identity (CA + leaf cert + key, all in memory)
12//! 2. Host passes the client **CA cert** (public only) to the container via env var
13//! 3. Container generates a server identity (CA + leaf cert + key, all in memory)
14//! 4. Container outputs the server **CA cert** (public only) to stdout
15//! 5. Host parses the server CA cert and builds a client TLS config
16//!
17//! Private keys never leave their respective processes.
18
19use stakpak_mcp_client::McpClient;
20use stakpak_mcp_proxy::client::{ClientPoolConfig, ServerConfig};
21use stakpak_mcp_proxy::server::start_proxy_server;
22use stakpak_shared::cert_utils::{CertificateChain, MtlsIdentity};
23use std::collections::HashMap;
24use std::path::Path;
25use std::sync::Arc;
26use tokio::io::AsyncBufReadExt;
27use tokio::net::TcpListener;
28use tokio::process::Child;
29use tokio::sync::broadcast;
30
31/// Environment variable used to pass the client CA cert PEM to the container.
32const TRUSTED_CLIENT_CA_ENV: &str = "STAKPAK_MCP_CLIENT_CA";
33
34/// Configuration for spawning sandboxed MCP servers.
35#[derive(Clone, Debug)]
36pub struct SandboxConfig {
37    /// Path to the warden binary.
38    pub warden_path: String,
39    /// Container image for the sandbox (e.g., `ghcr.io/stakpak/agent:v1.2.3`).
40    pub image: String,
41    /// Volume mounts for the container (e.g., `["./:/agent:ro"]`).
42    pub volumes: Vec<String>,
43}
44
45/// A running sandboxed MCP server with its associated proxy and client.
46///
47/// Drop this struct to shut down the sandbox.
48pub struct SandboxedMcpServer {
49    /// MCP client connected via the per-session proxy.
50    pub client: Arc<McpClient>,
51    /// Tools available from the sandboxed server.
52    pub tools: Vec<stakai::Tool>,
53    /// Channel to shut down the per-session proxy.
54    proxy_shutdown_tx: broadcast::Sender<()>,
55    /// The warden container child process.
56    container_process: Child,
57}
58
59impl SandboxedMcpServer {
60    /// Spawn a sandboxed MCP server inside a warden container and connect to it.
61    ///
62    /// 1. Generates a client mTLS identity (private key stays in host memory)
63    /// 2. Passes the client CA cert (public) to the container via env var
64    /// 3. Spawns `warden wrap <image> -- stakpak mcp start`
65    /// 4. Parses the server CA cert (public) from the container's stdout
66    /// 5. Builds a client TLS config trusting the server CA, using the client key
67    /// 6. Starts a per-session MCP proxy pointing to the container
68    /// 7. Connects a client to the proxy
69    pub async fn spawn(config: &SandboxConfig) -> Result<Self, String> {
70        // 1. Generate client identity — private key stays in host memory
71        let client_identity = MtlsIdentity::generate_client()
72            .map_err(|e| format!("Failed to generate client identity: {e}"))?;
73
74        let client_ca_pem = client_identity
75            .ca_cert_pem()
76            .map_err(|e| format!("Failed to get client CA PEM: {e}"))?;
77
78        // 2. Find a free port for the container's MCP server to expose
79        let container_host_port = find_free_port()
80            .await
81            .map_err(|e| format!("Failed to find free port for sandbox: {e}"))?;
82
83        // 3. Spawn warden container, passing client CA cert (public) via env var
84        let mut container_process =
85            spawn_warden_container(config, container_host_port, &client_ca_pem)
86                .await
87                .map_err(|e| format!("Failed to spawn sandbox container: {e}"))?;
88
89        // 4. Parse the server CA cert (public) from the container's stdout
90        let server_ca_pem = parse_server_ca_from_stdout(&mut container_process).await?;
91        tracing::info!(
92            "Parsed server CA from container stdout ({} bytes)",
93            server_ca_pem.len()
94        );
95
96        // 5. Build client TLS config — trusts server CA, authenticates with our key
97        let container_client_config = client_identity
98            .create_client_config(&server_ca_pem)
99            .map_err(|e| format!("Failed to create client TLS config: {e}"))?;
100
101        // 6. Wait for the MCP server inside the container to be ready
102        let server_url = format!("https://127.0.0.1:{container_host_port}/mcp");
103        tracing::info!(url = %server_url, "Waiting for sandbox MCP server to be ready");
104        wait_for_server_ready(&server_url, &container_client_config).await?;
105        tracing::info!("Sandbox MCP server is ready");
106
107        // 7. Start a per-session proxy connecting to the sandboxed server
108        let (proxy_shutdown_tx, proxy_shutdown_rx) = broadcast::channel::<()>(1);
109
110        let proxy_binding = find_available_binding("sandbox proxy").await?;
111        let proxy_url = format!("https://{}/mcp", proxy_binding.address);
112
113        let proxy_cert_chain = Arc::new(
114            CertificateChain::generate()
115                .map_err(|e| format!("Failed to generate proxy certificates: {e}"))?,
116        );
117
118        let pool_config = build_sandbox_proxy_config(server_url, Arc::new(container_client_config));
119
120        let proxy_chain_for_server = proxy_cert_chain.clone();
121        let proxy_listener = proxy_binding.listener;
122        tokio::spawn(async move {
123            if let Err(e) = start_proxy_server(
124                pool_config,
125                proxy_listener,
126                proxy_chain_for_server,
127                true,  // redact_secrets
128                false, // privacy_mode
129                Some(proxy_shutdown_rx),
130            )
131            .await
132            {
133                tracing::error!("Sandbox proxy error: {e}");
134            }
135        });
136
137        // Small delay for proxy to start
138        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
139
140        // 8. Connect client to proxy
141        let client = connect_to_proxy(&proxy_url, proxy_cert_chain).await?;
142
143        // 9. Get tools
144        let mcp_tools = stakpak_mcp_client::get_tools(&client)
145            .await
146            .map_err(|e| format!("Failed to get sandbox tools: {e}"))?;
147
148        let tools = mcp_tools
149            .into_iter()
150            .map(|tool| stakai::Tool {
151                tool_type: "function".to_string(),
152                function: stakai::ToolFunction {
153                    name: tool.name.as_ref().to_string(),
154                    description: tool
155                        .description
156                        .as_ref()
157                        .map(std::string::ToString::to_string)
158                        .unwrap_or_default(),
159                    parameters: serde_json::Value::Object((*tool.input_schema).clone()),
160                },
161                provider_options: None,
162            })
163            .collect();
164
165        Ok(Self {
166            client,
167            tools,
168            proxy_shutdown_tx,
169            container_process,
170        })
171    }
172
173    /// Shut down the sandbox: stop the proxy and kill the container.
174    pub async fn shutdown(mut self) {
175        let _ = self.proxy_shutdown_tx.send(());
176
177        let _ = self.container_process.kill().await;
178        let _ = self.container_process.wait().await;
179    }
180}
181
182/// Expand `~` to `$HOME` in a volume path.
183fn expand_volume_path(volume: &str) -> String {
184    if (volume.starts_with("~/") || volume.starts_with("~:"))
185        && let Ok(home_dir) = std::env::var("HOME")
186    {
187        return volume.replacen("~", &home_dir, 1);
188    }
189    volume.to_string()
190}
191
192async fn spawn_warden_container(
193    config: &SandboxConfig,
194    host_port: u16,
195    client_ca_pem: &str,
196) -> Result<Child, String> {
197    let mut cmd = tokio::process::Command::new(&config.warden_path);
198    cmd.arg("wrap");
199    cmd.arg(&config.image);
200
201    // Mount configured volumes
202    for vol in &config.volumes {
203        let expanded = expand_volume_path(vol);
204        // Only mount volumes where the host path exists
205        let host_path = expanded.split(':').next().unwrap_or(&expanded);
206        if Path::new(host_path).exists() {
207            cmd.args(["--volume", &expanded]);
208        }
209    }
210
211    // Port forwarding for the MCP server — publish on the sidecar so the
212    // host can reach the container's MCP server port directly.
213    cmd.args(["-p", &format!("127.0.0.1:{host_port}:8080")]);
214
215    // Prevent warden re-entry
216    cmd.args(["--env", "STAKPAK_SKIP_WARDEN=1"]);
217
218    // Tell the MCP server to bind to a fixed port inside the container
219    // so it matches the published port on the sidecar.
220    cmd.args(["--env", "STAKPAK_MCP_PORT=8080"]);
221
222    // Pass the client CA cert (public only) so the server can trust the client.
223    cmd.args(["--env", &format!("{TRUSTED_CLIENT_CA_ENV}={client_ca_pem}")]);
224
225    // Pass through API credentials if set
226    if let Ok(api_key) = std::env::var("STAKPAK_API_KEY") {
227        cmd.args(["--env", &format!("STAKPAK_API_KEY={api_key}")]);
228    }
229    if let Ok(profile) = std::env::var("STAKPAK_PROFILE") {
230        cmd.args(["--env", &format!("STAKPAK_PROFILE={profile}")]);
231    }
232    if let Ok(endpoint) = std::env::var("STAKPAK_API_ENDPOINT") {
233        cmd.args(["--env", &format!("STAKPAK_API_ENDPOINT={endpoint}")]);
234    }
235
236    // The MCP server detects STAKPAK_MCP_CLIENT_CA and generates its own
237    // server identity, outputting the server CA cert to stdout.
238    cmd.args(["--", "stakpak", "mcp", "start"]);
239
240    cmd.stdout(std::process::Stdio::piped());
241    cmd.stderr(std::process::Stdio::piped());
242    cmd.stdin(std::process::Stdio::null());
243
244    let child = cmd
245        .spawn()
246        .map_err(|e| format!("Failed to spawn warden process: {e}"))?;
247
248    Ok(child)
249}
250
251/// Parse the server CA certificate PEM from the container's stdout.
252///
253/// The MCP server outputs the server CA cert between structured delimiters:
254/// ```text
255/// ---BEGIN STAKPAK SERVER CA---
256/// -----BEGIN CERTIFICATE-----
257/// ...
258/// -----END CERTIFICATE-----
259/// ---END STAKPAK SERVER CA---
260/// ```
261async fn parse_server_ca_from_stdout(process: &mut Child) -> Result<String, String> {
262    let stdout = process
263        .stdout
264        .take()
265        .ok_or_else(|| "Container stdout not captured".to_string())?;
266
267    let mut reader = tokio::io::BufReader::new(stdout);
268    let mut server_ca_pem = String::new();
269    let mut in_server_ca = false;
270    let mut line = String::new();
271
272    let timeout_duration = tokio::time::Duration::from_secs(60);
273    let deadline = tokio::time::Instant::now() + timeout_duration;
274
275    tracing::debug!("Starting to read container stdout for server CA...");
276
277    loop {
278        line.clear();
279        let bytes_read = tokio::time::timeout_at(deadline, reader.read_line(&mut line))
280            .await
281            .map_err(|_| {
282                "Timed out waiting for container to output server CA certificate".to_string()
283            })?
284            .map_err(|e| format!("Failed to read container stdout: {e}"))?;
285
286        if bytes_read == 0 {
287            tracing::error!("Container stdout EOF before server CA was found");
288            return Err("Container exited before outputting server CA certificate".to_string());
289        }
290
291        let trimmed = line.trim();
292        tracing::debug!(line = %trimmed, bytes = bytes_read, "Read line from container stdout");
293
294        if trimmed == "---BEGIN STAKPAK SERVER CA---" {
295            in_server_ca = true;
296            continue;
297        }
298
299        if trimmed == "---END STAKPAK SERVER CA---" {
300            tracing::debug!("Found end of server CA block");
301            break;
302        }
303
304        if in_server_ca {
305            server_ca_pem.push_str(trimmed);
306            server_ca_pem.push('\n');
307        }
308    }
309
310    let server_ca_pem = server_ca_pem.trim().to_string();
311
312    if server_ca_pem.is_empty() {
313        return Err("Failed to parse server CA certificate from container output".to_string());
314    }
315
316    Ok(server_ca_pem)
317}
318
319async fn wait_for_server_ready(
320    url: &str,
321    client_config: &rustls::ClientConfig,
322) -> Result<(), String> {
323    let http_client = reqwest::Client::builder()
324        .use_preconfigured_tls(client_config.clone())
325        .build()
326        .map_err(|e| format!("Failed to build readiness check client: {e}"))?;
327
328    let mut last_error = String::new();
329    for attempt in 0..30 {
330        tokio::time::sleep(tokio::time::Duration::from_millis(if attempt < 5 {
331            500
332        } else {
333            1000
334        }))
335        .await;
336
337        match http_client.get(url).send().await {
338            Ok(_) => {
339                tracing::info!(attempt, "Sandbox MCP server ready");
340                return Ok(());
341            }
342            Err(e) => {
343                last_error = format!("{e:?}");
344                tracing::debug!(attempt, error = %last_error, "Readiness check failed");
345            }
346        }
347    }
348
349    Err(format!(
350        "Sandbox MCP server failed to become ready after 30 attempts: {last_error}"
351    ))
352}
353
354struct ProxyBinding {
355    address: String,
356    listener: TcpListener,
357}
358
359async fn find_available_binding(purpose: &str) -> Result<ProxyBinding, String> {
360    let listener = TcpListener::bind("127.0.0.1:0")
361        .await
362        .map_err(|e| format!("Failed to bind port for {purpose}: {e}"))?;
363    let addr = listener
364        .local_addr()
365        .map_err(|e| format!("Failed to get address for {purpose}: {e}"))?;
366    Ok(ProxyBinding {
367        address: addr.to_string(),
368        listener,
369    })
370}
371
372async fn find_free_port() -> Result<u16, String> {
373    let listener = TcpListener::bind("127.0.0.1:0")
374        .await
375        .map_err(|e| format!("Failed to bind ephemeral port: {e}"))?;
376    let port = listener
377        .local_addr()
378        .map_err(|e| format!("Failed to get ephemeral port: {e}"))?
379        .port();
380    // Drop the listener to free the port for Docker to use
381    drop(listener);
382    Ok(port)
383}
384
385fn build_sandbox_proxy_config(
386    sandbox_server_url: String,
387    client_tls_config: Arc<rustls::ClientConfig>,
388) -> ClientPoolConfig {
389    let mut servers: HashMap<String, ServerConfig> = HashMap::new();
390
391    // Register the sandboxed MCP server under the same name ("stakpak") so
392    // tool names like `stakpak__run_command` route correctly through the proxy.
393    servers.insert(
394        "stakpak".to_string(),
395        ServerConfig::Http {
396            url: sandbox_server_url,
397            headers: None,
398            certificate_chain: Arc::new(None),
399            client_tls_config: Some(client_tls_config),
400        },
401    );
402
403    // Keep the external paks server accessible
404    servers.insert(
405        "paks".to_string(),
406        ServerConfig::Http {
407            url: "https://apiv2.stakpak.dev/v1/paks/mcp".to_string(),
408            headers: None,
409            certificate_chain: Arc::new(None),
410            client_tls_config: None,
411        },
412    );
413
414    ClientPoolConfig::with_servers(servers)
415}
416
417async fn connect_to_proxy(
418    proxy_url: &str,
419    cert_chain: Arc<CertificateChain>,
420) -> Result<Arc<McpClient>, String> {
421    const MAX_RETRIES: u32 = 5;
422    let mut retry_delay = tokio::time::Duration::from_millis(50);
423    let mut last_error = None;
424
425    for attempt in 1..=MAX_RETRIES {
426        match stakpak_mcp_client::connect_https(proxy_url, Some(cert_chain.clone()), None).await {
427            Ok(client) => return Ok(Arc::new(client)),
428            Err(e) => {
429                last_error = Some(e);
430                if attempt < MAX_RETRIES {
431                    tokio::time::sleep(retry_delay).await;
432                    retry_delay *= 2;
433                }
434            }
435        }
436    }
437
438    Err(format!(
439        "Failed to connect to sandbox proxy after {MAX_RETRIES} retries: {}",
440        last_error.map(|e| e.to_string()).unwrap_or_default()
441    ))
442}
443
444#[cfg(test)]
445mod tests {
446    #[test]
447    fn parse_server_ca_from_structured_output() {
448        let output = "\
449🔐 mTLS enabled - independent identity (sandbox mode)
450---BEGIN STAKPAK SERVER CA---
451-----BEGIN CERTIFICATE-----
452MIIB0zCCAXmgAwIBAgIUFAKE=
453-----END CERTIFICATE-----
454---END STAKPAK SERVER CA---
455MCP server started at https://0.0.0.0:8080/mcp
456";
457
458        let expected_ca = "\
459-----BEGIN CERTIFICATE-----
460MIIB0zCCAXmgAwIBAgIUFAKE=
461-----END CERTIFICATE-----";
462
463        // Parse the same way parse_server_ca_from_stdout does
464        let mut server_ca_pem = String::new();
465        let mut in_server_ca = false;
466
467        for line in output.lines() {
468            let trimmed = line.trim();
469            if trimmed == "---BEGIN STAKPAK SERVER CA---" {
470                in_server_ca = true;
471                continue;
472            }
473            if trimmed == "---END STAKPAK SERVER CA---" {
474                break;
475            }
476            if in_server_ca {
477                server_ca_pem.push_str(trimmed);
478                server_ca_pem.push('\n');
479            }
480        }
481
482        assert_eq!(server_ca_pem.trim(), expected_ca);
483    }
484
485    #[test]
486    fn mtls_identity_cross_trust() {
487        use stakpak_shared::cert_utils::MtlsIdentity;
488
489        // Ensure a crypto provider is installed for rustls
490        let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
491
492        // Simulate the sandbox mTLS exchange
493        let client_identity = MtlsIdentity::generate_client().expect("generate client identity");
494        let server_identity = MtlsIdentity::generate_server().expect("generate server identity");
495
496        let client_ca_pem = client_identity.ca_cert_pem().expect("client CA PEM");
497        let server_ca_pem = server_identity.ca_cert_pem().expect("server CA PEM");
498
499        // Server trusts client CA, client trusts server CA
500        let _server_config = server_identity
501            .create_server_config(&client_ca_pem)
502            .expect("server config with client CA trust");
503        let _client_config = client_identity
504            .create_client_config(&server_ca_pem)
505            .expect("client config with server CA trust");
506
507        // Only public CA certs were exchanged — private keys stayed in their
508        // respective MtlsIdentity structs.
509        assert!(client_ca_pem.contains("BEGIN CERTIFICATE"));
510        assert!(server_ca_pem.contains("BEGIN CERTIFICATE"));
511        assert!(!client_ca_pem.contains("PRIVATE KEY"));
512        assert!(!server_ca_pem.contains("PRIVATE KEY"));
513    }
514}