1use stakpak_mcp_client::McpClient;
20use stakpak_mcp_proxy::client::{ClientPoolConfig, ServerConfig};
21use stakpak_mcp_proxy::server::start_proxy_server;
22use stakpak_shared::cert_utils::{CertificateChain, MtlsIdentity};
23use std::collections::HashMap;
24use std::path::Path;
25use std::sync::Arc;
26use tokio::io::AsyncBufReadExt;
27use tokio::net::TcpListener;
28use tokio::process::Child;
29use tokio::sync::broadcast;
30
31const TRUSTED_CLIENT_CA_ENV: &str = "STAKPAK_MCP_CLIENT_CA";
33
34#[derive(Clone, Debug)]
36pub struct SandboxConfig {
37 pub warden_path: String,
39 pub image: String,
41 pub volumes: Vec<String>,
43}
44
45pub struct SandboxedMcpServer {
49 pub client: Arc<McpClient>,
51 pub tools: Vec<stakai::Tool>,
53 proxy_shutdown_tx: broadcast::Sender<()>,
55 container_process: Child,
57}
58
59impl SandboxedMcpServer {
60 pub async fn spawn(config: &SandboxConfig) -> Result<Self, String> {
70 let client_identity = MtlsIdentity::generate_client()
72 .map_err(|e| format!("Failed to generate client identity: {e}"))?;
73
74 let client_ca_pem = client_identity
75 .ca_cert_pem()
76 .map_err(|e| format!("Failed to get client CA PEM: {e}"))?;
77
78 let container_host_port = find_free_port()
80 .await
81 .map_err(|e| format!("Failed to find free port for sandbox: {e}"))?;
82
83 let mut container_process =
85 spawn_warden_container(config, container_host_port, &client_ca_pem)
86 .await
87 .map_err(|e| format!("Failed to spawn sandbox container: {e}"))?;
88
89 let server_ca_pem = parse_server_ca_from_stdout(&mut container_process).await?;
91 tracing::info!(
92 "Parsed server CA from container stdout ({} bytes)",
93 server_ca_pem.len()
94 );
95
96 let container_client_config = client_identity
98 .create_client_config(&server_ca_pem)
99 .map_err(|e| format!("Failed to create client TLS config: {e}"))?;
100
101 let server_url = format!("https://127.0.0.1:{container_host_port}/mcp");
103 tracing::info!(url = %server_url, "Waiting for sandbox MCP server to be ready");
104 wait_for_server_ready(&server_url, &container_client_config).await?;
105 tracing::info!("Sandbox MCP server is ready");
106
107 let (proxy_shutdown_tx, proxy_shutdown_rx) = broadcast::channel::<()>(1);
109
110 let proxy_binding = find_available_binding("sandbox proxy").await?;
111 let proxy_url = format!("https://{}/mcp", proxy_binding.address);
112
113 let proxy_cert_chain = Arc::new(
114 CertificateChain::generate()
115 .map_err(|e| format!("Failed to generate proxy certificates: {e}"))?,
116 );
117
118 let pool_config = build_sandbox_proxy_config(server_url, Arc::new(container_client_config));
119
120 let proxy_chain_for_server = proxy_cert_chain.clone();
121 let proxy_listener = proxy_binding.listener;
122 tokio::spawn(async move {
123 if let Err(e) = start_proxy_server(
124 pool_config,
125 proxy_listener,
126 proxy_chain_for_server,
127 true, false, Some(proxy_shutdown_rx),
130 )
131 .await
132 {
133 tracing::error!("Sandbox proxy error: {e}");
134 }
135 });
136
137 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
139
140 let client = connect_to_proxy(&proxy_url, proxy_cert_chain).await?;
142
143 let mcp_tools = stakpak_mcp_client::get_tools(&client)
145 .await
146 .map_err(|e| format!("Failed to get sandbox tools: {e}"))?;
147
148 let tools = mcp_tools
149 .into_iter()
150 .map(|tool| stakai::Tool {
151 tool_type: "function".to_string(),
152 function: stakai::ToolFunction {
153 name: tool.name.as_ref().to_string(),
154 description: tool
155 .description
156 .as_ref()
157 .map(std::string::ToString::to_string)
158 .unwrap_or_default(),
159 parameters: serde_json::Value::Object((*tool.input_schema).clone()),
160 },
161 provider_options: None,
162 })
163 .collect();
164
165 Ok(Self {
166 client,
167 tools,
168 proxy_shutdown_tx,
169 container_process,
170 })
171 }
172
173 pub async fn shutdown(mut self) {
175 let _ = self.proxy_shutdown_tx.send(());
176
177 let _ = self.container_process.kill().await;
178 let _ = self.container_process.wait().await;
179 }
180}
181
182fn expand_volume_path(volume: &str) -> String {
184 if (volume.starts_with("~/") || volume.starts_with("~:"))
185 && let Ok(home_dir) = std::env::var("HOME")
186 {
187 return volume.replacen("~", &home_dir, 1);
188 }
189 volume.to_string()
190}
191
192async fn spawn_warden_container(
193 config: &SandboxConfig,
194 host_port: u16,
195 client_ca_pem: &str,
196) -> Result<Child, String> {
197 let mut cmd = tokio::process::Command::new(&config.warden_path);
198 cmd.arg("wrap");
199 cmd.arg(&config.image);
200
201 for vol in &config.volumes {
203 let expanded = expand_volume_path(vol);
204 let host_path = expanded.split(':').next().unwrap_or(&expanded);
206 if Path::new(host_path).exists() {
207 cmd.args(["--volume", &expanded]);
208 }
209 }
210
211 cmd.args(["-p", &format!("127.0.0.1:{host_port}:8080")]);
214
215 cmd.args(["--env", "STAKPAK_SKIP_WARDEN=1"]);
217
218 cmd.args(["--env", "STAKPAK_MCP_PORT=8080"]);
221
222 cmd.args(["--env", &format!("{TRUSTED_CLIENT_CA_ENV}={client_ca_pem}")]);
224
225 if let Ok(api_key) = std::env::var("STAKPAK_API_KEY") {
227 cmd.args(["--env", &format!("STAKPAK_API_KEY={api_key}")]);
228 }
229 if let Ok(profile) = std::env::var("STAKPAK_PROFILE") {
230 cmd.args(["--env", &format!("STAKPAK_PROFILE={profile}")]);
231 }
232 if let Ok(endpoint) = std::env::var("STAKPAK_API_ENDPOINT") {
233 cmd.args(["--env", &format!("STAKPAK_API_ENDPOINT={endpoint}")]);
234 }
235
236 cmd.args(["--", "stakpak", "mcp", "start"]);
239
240 cmd.stdout(std::process::Stdio::piped());
241 cmd.stderr(std::process::Stdio::piped());
242 cmd.stdin(std::process::Stdio::null());
243
244 let child = cmd
245 .spawn()
246 .map_err(|e| format!("Failed to spawn warden process: {e}"))?;
247
248 Ok(child)
249}
250
251async fn parse_server_ca_from_stdout(process: &mut Child) -> Result<String, String> {
262 let stdout = process
263 .stdout
264 .take()
265 .ok_or_else(|| "Container stdout not captured".to_string())?;
266
267 let mut reader = tokio::io::BufReader::new(stdout);
268 let mut server_ca_pem = String::new();
269 let mut in_server_ca = false;
270 let mut line = String::new();
271
272 let timeout_duration = tokio::time::Duration::from_secs(60);
273 let deadline = tokio::time::Instant::now() + timeout_duration;
274
275 tracing::debug!("Starting to read container stdout for server CA...");
276
277 loop {
278 line.clear();
279 let bytes_read = tokio::time::timeout_at(deadline, reader.read_line(&mut line))
280 .await
281 .map_err(|_| {
282 "Timed out waiting for container to output server CA certificate".to_string()
283 })?
284 .map_err(|e| format!("Failed to read container stdout: {e}"))?;
285
286 if bytes_read == 0 {
287 tracing::error!("Container stdout EOF before server CA was found");
288 return Err("Container exited before outputting server CA certificate".to_string());
289 }
290
291 let trimmed = line.trim();
292 tracing::debug!(line = %trimmed, bytes = bytes_read, "Read line from container stdout");
293
294 if trimmed == "---BEGIN STAKPAK SERVER CA---" {
295 in_server_ca = true;
296 continue;
297 }
298
299 if trimmed == "---END STAKPAK SERVER CA---" {
300 tracing::debug!("Found end of server CA block");
301 break;
302 }
303
304 if in_server_ca {
305 server_ca_pem.push_str(trimmed);
306 server_ca_pem.push('\n');
307 }
308 }
309
310 let server_ca_pem = server_ca_pem.trim().to_string();
311
312 if server_ca_pem.is_empty() {
313 return Err("Failed to parse server CA certificate from container output".to_string());
314 }
315
316 Ok(server_ca_pem)
317}
318
319async fn wait_for_server_ready(
320 url: &str,
321 client_config: &rustls::ClientConfig,
322) -> Result<(), String> {
323 let http_client = reqwest::Client::builder()
324 .use_preconfigured_tls(client_config.clone())
325 .build()
326 .map_err(|e| format!("Failed to build readiness check client: {e}"))?;
327
328 let mut last_error = String::new();
329 for attempt in 0..30 {
330 tokio::time::sleep(tokio::time::Duration::from_millis(if attempt < 5 {
331 500
332 } else {
333 1000
334 }))
335 .await;
336
337 match http_client.get(url).send().await {
338 Ok(_) => {
339 tracing::info!(attempt, "Sandbox MCP server ready");
340 return Ok(());
341 }
342 Err(e) => {
343 last_error = format!("{e:?}");
344 tracing::debug!(attempt, error = %last_error, "Readiness check failed");
345 }
346 }
347 }
348
349 Err(format!(
350 "Sandbox MCP server failed to become ready after 30 attempts: {last_error}"
351 ))
352}
353
354struct ProxyBinding {
355 address: String,
356 listener: TcpListener,
357}
358
359async fn find_available_binding(purpose: &str) -> Result<ProxyBinding, String> {
360 let listener = TcpListener::bind("127.0.0.1:0")
361 .await
362 .map_err(|e| format!("Failed to bind port for {purpose}: {e}"))?;
363 let addr = listener
364 .local_addr()
365 .map_err(|e| format!("Failed to get address for {purpose}: {e}"))?;
366 Ok(ProxyBinding {
367 address: addr.to_string(),
368 listener,
369 })
370}
371
372async fn find_free_port() -> Result<u16, String> {
373 let listener = TcpListener::bind("127.0.0.1:0")
374 .await
375 .map_err(|e| format!("Failed to bind ephemeral port: {e}"))?;
376 let port = listener
377 .local_addr()
378 .map_err(|e| format!("Failed to get ephemeral port: {e}"))?
379 .port();
380 drop(listener);
382 Ok(port)
383}
384
385fn build_sandbox_proxy_config(
386 sandbox_server_url: String,
387 client_tls_config: Arc<rustls::ClientConfig>,
388) -> ClientPoolConfig {
389 let mut servers: HashMap<String, ServerConfig> = HashMap::new();
390
391 servers.insert(
394 "stakpak".to_string(),
395 ServerConfig::Http {
396 url: sandbox_server_url,
397 headers: None,
398 certificate_chain: Arc::new(None),
399 client_tls_config: Some(client_tls_config),
400 },
401 );
402
403 servers.insert(
405 "paks".to_string(),
406 ServerConfig::Http {
407 url: "https://apiv2.stakpak.dev/v1/paks/mcp".to_string(),
408 headers: None,
409 certificate_chain: Arc::new(None),
410 client_tls_config: None,
411 },
412 );
413
414 ClientPoolConfig::with_servers(servers)
415}
416
417async fn connect_to_proxy(
418 proxy_url: &str,
419 cert_chain: Arc<CertificateChain>,
420) -> Result<Arc<McpClient>, String> {
421 const MAX_RETRIES: u32 = 5;
422 let mut retry_delay = tokio::time::Duration::from_millis(50);
423 let mut last_error = None;
424
425 for attempt in 1..=MAX_RETRIES {
426 match stakpak_mcp_client::connect_https(proxy_url, Some(cert_chain.clone()), None).await {
427 Ok(client) => return Ok(Arc::new(client)),
428 Err(e) => {
429 last_error = Some(e);
430 if attempt < MAX_RETRIES {
431 tokio::time::sleep(retry_delay).await;
432 retry_delay *= 2;
433 }
434 }
435 }
436 }
437
438 Err(format!(
439 "Failed to connect to sandbox proxy after {MAX_RETRIES} retries: {}",
440 last_error.map(|e| e.to_string()).unwrap_or_default()
441 ))
442}
443
444#[cfg(test)]
445mod tests {
446 #[test]
447 fn parse_server_ca_from_structured_output() {
448 let output = "\
449🔐 mTLS enabled - independent identity (sandbox mode)
450---BEGIN STAKPAK SERVER CA---
451-----BEGIN CERTIFICATE-----
452MIIB0zCCAXmgAwIBAgIUFAKE=
453-----END CERTIFICATE-----
454---END STAKPAK SERVER CA---
455MCP server started at https://0.0.0.0:8080/mcp
456";
457
458 let expected_ca = "\
459-----BEGIN CERTIFICATE-----
460MIIB0zCCAXmgAwIBAgIUFAKE=
461-----END CERTIFICATE-----";
462
463 let mut server_ca_pem = String::new();
465 let mut in_server_ca = false;
466
467 for line in output.lines() {
468 let trimmed = line.trim();
469 if trimmed == "---BEGIN STAKPAK SERVER CA---" {
470 in_server_ca = true;
471 continue;
472 }
473 if trimmed == "---END STAKPAK SERVER CA---" {
474 break;
475 }
476 if in_server_ca {
477 server_ca_pem.push_str(trimmed);
478 server_ca_pem.push('\n');
479 }
480 }
481
482 assert_eq!(server_ca_pem.trim(), expected_ca);
483 }
484
485 #[test]
486 fn mtls_identity_cross_trust() {
487 use stakpak_shared::cert_utils::MtlsIdentity;
488
489 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
491
492 let client_identity = MtlsIdentity::generate_client().expect("generate client identity");
494 let server_identity = MtlsIdentity::generate_server().expect("generate server identity");
495
496 let client_ca_pem = client_identity.ca_cert_pem().expect("client CA PEM");
497 let server_ca_pem = server_identity.ca_cert_pem().expect("server CA PEM");
498
499 let _server_config = server_identity
501 .create_server_config(&client_ca_pem)
502 .expect("server config with client CA trust");
503 let _client_config = client_identity
504 .create_client_config(&server_ca_pem)
505 .expect("client config with server CA trust");
506
507 assert!(client_ca_pem.contains("BEGIN CERTIFICATE"));
510 assert!(server_ca_pem.contains("BEGIN CERTIFICATE"));
511 assert!(!client_ca_pem.contains("PRIVATE KEY"));
512 assert!(!server_ca_pem.contains("PRIVATE KEY"));
513 }
514}