Skip to main content

sparrow/capabilities/
mcp.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::path::PathBuf;
5use std::sync::Arc;
6use tokio::io::{AsyncWriteExt, BufReader};
7use tokio::process::Command as TokioCommand;
8
9use crate::event::RiskLevel;
10use crate::tools::{Tool, ToolCtx, ToolResult};
11
12// ─── MCP server config ──────────────────────────────────────────────────────────
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct McpServer {
16    pub name: String,
17    #[serde(default)]
18    pub transport: Transport,
19    /// For stdio: command + args
20    #[serde(default)]
21    pub command: Option<String>,
22    #[serde(default)]
23    pub args: Vec<String>,
24    /// For url/sse: endpoint URL
25    #[serde(default)]
26    pub url: Option<String>,
27    /// Environment variables
28    #[serde(default)]
29    pub env: std::collections::HashMap<String, String>,
30    /// Tool allow-list (empty = allow all)
31    #[serde(default)]
32    pub allow_tools: Vec<String>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
36pub enum Transport {
37    #[serde(rename = "stdio")]
38    Stdio,
39    #[serde(rename = "sse")]
40    Sse,
41    #[serde(rename = "url")]
42    Url,
43}
44
45impl Default for Transport {
46    fn default() -> Self {
47        Transport::Stdio
48    }
49}
50
51// ─── JSON-RPC types ─────────────────────────────────────────────────────────────
52
53#[derive(Debug, Serialize, Deserialize)]
54struct JsonRpcRequest {
55    jsonrpc: String,
56    id: u64,
57    method: String,
58    #[serde(default)]
59    params: Value,
60}
61
62#[derive(Debug, Deserialize)]
63struct ToolsListResult {
64    tools: Vec<McpToolDef>,
65}
66
67#[derive(Debug, Deserialize)]
68struct McpToolDef {
69    name: String,
70    #[serde(default)]
71    description: String,
72    #[serde(default)]
73    #[serde(rename = "inputSchema")]
74    input_schema: Value,
75}
76
77// ─── MCP Tool wrapper (real JSON-RPC execution) ────────────────────────────────
78
79use tokio::sync::mpsc;
80
81/// Wraps an MCP server tool. Two transports are supported, picked at connect time.
82struct McpToolWrapper {
83    tool_def: McpToolDef,
84    backend: McpBackend,
85}
86
87enum McpBackend {
88    /// Long-lived child process talking JSON-RPC over stdio.
89    Stdio {
90        request_tx: mpsc::Sender<McpRequest>,
91    },
92    /// One POST per call against a JSON-RPC HTTP endpoint.
93    Http {
94        url: String,
95        client: reqwest::Client,
96    },
97}
98
99struct McpRequest {
100    tool_name: String,
101    args: Value,
102    response_tx: tokio::sync::oneshot::Sender<anyhow::Result<ToolResult>>,
103}
104
105#[async_trait]
106impl Tool for McpToolWrapper {
107    fn name(&self) -> &str {
108        &self.tool_def.name
109    }
110    fn description(&self) -> &str {
111        &self.tool_def.description
112    }
113    fn schema(&self) -> Value {
114        self.tool_def.input_schema.clone()
115    }
116    fn risk(&self) -> RiskLevel {
117        RiskLevel::Exec
118    }
119
120    async fn call(&self, args: Value, _ctx: &ToolCtx) -> anyhow::Result<ToolResult> {
121        match &self.backend {
122            McpBackend::Stdio { request_tx } => {
123                let (tx, rx) = tokio::sync::oneshot::channel();
124                request_tx
125                    .send(McpRequest {
126                        tool_name: self.tool_def.name.clone(),
127                        args,
128                        response_tx: tx,
129                    })
130                    .await
131                    .map_err(|_| anyhow::anyhow!("MCP server process has stopped"))?;
132
133                tokio::time::timeout(std::time::Duration::from_secs(30), rx)
134                    .await
135                    .map_err(|_| anyhow::anyhow!("MCP tool call timed out"))?
136                    .map_err(|_| anyhow::anyhow!("MCP tool call channel closed"))?
137            }
138            McpBackend::Http { url, client } => {
139                static NEXT_ID: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
140                let id = NEXT_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
141                let body = serde_json::json!({
142                    "jsonrpc": "2.0",
143                    "id": id,
144                    "method": "tools/call",
145                    "params": {
146                        "name": self.tool_def.name,
147                        "arguments": args,
148                    }
149                });
150                let resp = tokio::time::timeout(
151                    std::time::Duration::from_secs(30),
152                    client.post(url).json(&body).send(),
153                )
154                .await
155                .map_err(|_| anyhow::anyhow!("MCP HTTP call timed out"))??;
156                if !resp.status().is_success() {
157                    let status = resp.status();
158                    let body = resp.text().await.unwrap_or_default();
159                    return Ok(ToolResult::error(format!(
160                        "MCP HTTP error {}: {}",
161                        status, body
162                    )));
163                }
164                let value: Value = resp.json().await?;
165                if let Some(err) = value.get("error") {
166                    return Ok(ToolResult::error(format!("MCP error: {}", err)));
167                }
168                if let Some(result) = value.get("result") {
169                    Ok(ToolResult::text(result.to_string()))
170                } else {
171                    Ok(ToolResult::text("(empty MCP response)"))
172                }
173            }
174        }
175    }
176}
177
178// ─── THE MCP CLIENT TRAIT ───────────────────────────────────────────────────────
179
180#[async_trait]
181pub trait McpClient: Send + Sync {
182    async fn connect(&self, server: &McpServer) -> anyhow::Result<Vec<Arc<dyn Tool>>>;
183    async fn disconnect(&self, server_name: &str) -> anyhow::Result<()>;
184    async fn list_servers(&self) -> Vec<McpServer>;
185}
186
187// ─── Basic MCP client implementation ────────────────────────────────────────────
188
189pub struct BasicMcpClient {
190    config_dir: PathBuf,
191}
192
193impl BasicMcpClient {
194    pub fn new(config_dir: PathBuf) -> Self {
195        Self { config_dir }
196    }
197
198    fn servers_file(&self) -> PathBuf {
199        self.config_dir.join("mcp_servers.json")
200    }
201
202    fn load_servers(&self) -> Vec<McpServer> {
203        let path = self.servers_file();
204        if !path.exists() {
205            return vec![];
206        }
207        std::fs::read_to_string(&path)
208            .ok()
209            .and_then(|s| serde_json::from_str(&s).ok())
210            .unwrap_or_default()
211    }
212
213    fn save_servers(&self, servers: &[McpServer]) -> anyhow::Result<()> {
214        std::fs::create_dir_all(&self.config_dir)?;
215        let json = serde_json::to_string_pretty(servers)?;
216        std::fs::write(self.servers_file(), json)?;
217        Ok(())
218    }
219
220    pub fn add_server(&self, server: McpServer) -> anyhow::Result<()> {
221        let mut servers = self.load_servers();
222        servers.retain(|s| s.name != server.name);
223        servers.push(server);
224        self.save_servers(&servers)
225    }
226
227    pub fn remove_server(&self, name: &str) -> anyhow::Result<()> {
228        let mut servers = self.load_servers();
229        servers.retain(|s| s.name != name);
230        self.save_servers(&servers)
231    }
232
233    pub fn get_server(&self, name: &str) -> Option<McpServer> {
234        self.load_servers().into_iter().find(|s| s.name == name)
235    }
236}
237
238#[async_trait]
239impl McpClient for BasicMcpClient {
240    async fn connect(&self, server: &McpServer) -> anyhow::Result<Vec<Arc<dyn Tool>>> {
241        match server.transport {
242            Transport::Stdio => self.connect_stdio(server).await,
243            Transport::Url | Transport::Sse => self.connect_http(server).await,
244        }
245    }
246
247    async fn disconnect(&self, _server_name: &str) -> anyhow::Result<()> {
248        // In a full implementation, we'd track active connections
249        // For M3, connections are ephemeral per-connect
250        Ok(())
251    }
252
253    async fn list_servers(&self) -> Vec<McpServer> {
254        self.load_servers()
255    }
256}
257
258impl BasicMcpClient {
259    /// Connect via stdio (spawn process, JSON-RPC over stdin/stdout)
260    async fn connect_stdio(&self, server: &McpServer) -> anyhow::Result<Vec<Arc<dyn Tool>>> {
261        let command = server
262            .command
263            .as_ref()
264            .ok_or_else(|| anyhow::anyhow!("stdio transport requires 'command'"))?;
265
266        let mut child = TokioCommand::new(command)
267            .args(&server.args)
268            .envs(&server.env)
269            .stdin(std::process::Stdio::piped())
270            .stdout(std::process::Stdio::piped())
271            .stderr(std::process::Stdio::piped())
272            .kill_on_drop(true)
273            .spawn()?;
274
275        let stdin = child.stdin.take().unwrap();
276        let stdout = child.stdout.take().unwrap();
277
278        let (mut writer, mut reader) = (tokio::io::BufWriter::new(stdin), BufReader::new(stdout));
279
280        // Send initialize request
281        let init_req = JsonRpcRequest {
282            jsonrpc: "2.0".into(),
283            id: 1,
284            method: "initialize".into(),
285            params: serde_json::json!({
286                "protocolVersion": "2024-11-05",
287                "capabilities": {},
288                "clientInfo": {
289                    "name": "sparrow",
290                    "version": "0.1.0"
291                }
292            }),
293        };
294
295        let req_json = serde_json::to_string(&init_req)? + "\n";
296        writer.write_all(req_json.as_bytes()).await?;
297        writer.flush().await?;
298
299        // Read the initialize response. Skip any notifications (no `id`).
300        let _ = read_jsonrpc_response(&mut reader, 1).await?;
301
302        // Send initialized notification
303        let notif = serde_json::json!({
304            "jsonrpc": "2.0",
305            "method": "notifications/initialized",
306            "params": {}
307        });
308        writer
309            .write_all((serde_json::to_string(&notif)? + "\n").as_bytes())
310            .await?;
311        writer.flush().await?;
312
313        // Request tools/list
314        let list_req = JsonRpcRequest {
315            jsonrpc: "2.0".into(),
316            id: 2,
317            method: "tools/list".into(),
318            params: Value::Null,
319        };
320
321        writer
322            .write_all((serde_json::to_string(&list_req)? + "\n").as_bytes())
323            .await?;
324        writer.flush().await?;
325
326        let tools_resp_value = read_jsonrpc_response(&mut reader, 2).await?;
327
328        // Spawn background task for JSON-RPC request handling
329        let (request_tx, mut request_rx) = mpsc::channel::<McpRequest>(32);
330
331        tokio::spawn(async move {
332            // Keep the child alive for the lifetime of the channel: dropping `child`
333            // when the writer task exits also kills the process (kill_on_drop).
334            let _child_guard = child;
335            let mut call_id: u64 = 3; // Start after initialize(1) and tools/list(2)
336            while let Some(req) = request_rx.recv().await {
337                call_id += 1;
338                let call_req = serde_json::json!({
339                    "jsonrpc": "2.0",
340                    "id": call_id,
341                    "method": "tools/call",
342                    "params": {
343                        "name": req.tool_name,
344                        "arguments": req.args,
345                    }
346                });
347                if writer
348                    .write_all((serde_json::to_string(&call_req).unwrap() + "\n").as_bytes())
349                    .await
350                    .is_err()
351                    || writer.flush().await.is_err()
352                {
353                    let _ = req
354                        .response_tx
355                        .send(Err(anyhow::anyhow!("MCP stdin closed")));
356                    break;
357                }
358
359                // Drain lines until we see one with the matching id (or an error).
360                match read_jsonrpc_response(&mut reader, call_id).await {
361                    Ok(value) => {
362                        let result = if let Some(err) = value.get("error") {
363                            Ok(ToolResult::error(format!("MCP error: {}", err)))
364                        } else if let Some(val) = value.get("result") {
365                            Ok(ToolResult::text(val.to_string()))
366                        } else {
367                            Ok(ToolResult::text("(empty MCP response)"))
368                        };
369                        let _ = req.response_tx.send(result);
370                    }
371                    Err(e) => {
372                        let _ = req
373                            .response_tx
374                            .send(Err(anyhow::anyhow!("MCP read error: {}", e)));
375                        break;
376                    }
377                }
378            }
379        });
380
381        // Parse tools
382        let server_name = server.name.clone();
383        let allow_list = server.allow_tools.clone();
384
385        let tools: Vec<Arc<dyn Tool>> = if let Some(result) = tools_resp_value.get("result") {
386            if let Ok(list) = serde_json::from_value::<ToolsListResult>(result.clone()) {
387                list.tools
388                    .into_iter()
389                    .filter(|t| allow_list.is_empty() || allow_list.contains(&t.name))
390                    .map(|t| {
391                        let _srv = server_name.clone();
392                        Arc::new(McpToolWrapper {
393                            tool_def: t,
394                            backend: McpBackend::Stdio {
395                                request_tx: request_tx.clone(),
396                            },
397                        }) as Arc<dyn Tool>
398                    })
399                    .collect()
400            } else {
401                vec![]
402            }
403        } else {
404            tracing::warn!("MCP server {} returned no tools/list result", server.name);
405            vec![]
406        };
407
408        Ok(tools)
409    }
410
411    /// Connect via HTTP/SSE
412    async fn connect_http(&self, server: &McpServer) -> anyhow::Result<Vec<Arc<dyn Tool>>> {
413        let url = server
414            .url
415            .as_ref()
416            .ok_or_else(|| anyhow::anyhow!("url/sse transport requires 'url'"))?;
417
418        let client = reqwest::Client::new();
419
420        // Send initialize via HTTP POST
421        let _init_resp: Value = client
422            .post(url)
423            .json(&serde_json::json!({
424                "jsonrpc": "2.0",
425                "id": 1,
426                "method": "initialize",
427                "params": {
428                    "protocolVersion": "2024-11-05",
429                    "capabilities": {},
430                    "clientInfo": { "name": "sparrow", "version": "0.1.0" }
431                }
432            }))
433            .send()
434            .await?
435            .json()
436            .await?;
437
438        // List tools
439        let tools_resp: Value = client
440            .post(url)
441            .json(&serde_json::json!({
442                "jsonrpc": "2.0",
443                "id": 2,
444                "method": "tools/list",
445                "params": {}
446            }))
447            .send()
448            .await?
449            .json()
450            .await?;
451
452        let server_name = server.name.clone();
453        let allow_list = server.allow_tools.clone();
454
455        let tools: Vec<Arc<dyn Tool>> = if let Some(result) = tools_resp.get("result") {
456            if let Ok(list) = serde_json::from_value::<ToolsListResult>(result.clone()) {
457                list.tools
458                    .into_iter()
459                    .filter(|t| allow_list.is_empty() || allow_list.contains(&t.name))
460                    .map(|t| {
461                        let _srv = server_name.clone();
462                        Arc::new(McpToolWrapper {
463                            tool_def: t,
464                            backend: McpBackend::Http {
465                                url: url.clone(),
466                                client: client.clone(),
467                            },
468                        }) as Arc<dyn Tool>
469                    })
470                    .collect()
471            } else {
472                vec![]
473            }
474        } else {
475            vec![]
476        };
477
478        Ok(tools)
479    }
480}
481
482/// Read JSON-RPC frames (one per line) until we find one whose `id` matches
483/// `expected_id`. Notifications (no `id`) are skipped. A read failure or EOF
484/// surfaces as an error so callers can stop the loop.
485async fn read_jsonrpc_response<R: tokio::io::AsyncBufRead + Unpin>(
486    reader: &mut R,
487    expected_id: u64,
488) -> anyhow::Result<Value> {
489    use tokio::io::AsyncBufReadExt;
490    let mut line = String::new();
491    for _ in 0..64 {
492        line.clear();
493        let n = reader.read_line(&mut line).await?;
494        if n == 0 {
495            anyhow::bail!("MCP server closed stdout");
496        }
497        let trimmed = line.trim();
498        if trimmed.is_empty() {
499            continue;
500        }
501        let value: Value = match serde_json::from_str(trimmed) {
502            Ok(v) => v,
503            Err(_) => {
504                // Not JSON — likely a stderr leak on stdout, ignore and keep reading.
505                tracing::debug!("MCP non-JSON stdout line: {}", trimmed);
506                continue;
507            }
508        };
509        // Notifications (no "id") are skipped.
510        match value.get("id").and_then(|v| v.as_u64()) {
511            Some(id) if id == expected_id => return Ok(value),
512            Some(_) => continue, // response to an earlier/later request, drop
513            None => continue,    // notification
514        }
515    }
516    anyhow::bail!(
517        "MCP server did not respond to id={} within 64 frames",
518        expected_id
519    )
520}