Skip to main content

zagens_runtime_adapters/mcp/
connection.rs

1use std::sync::atomic::{AtomicU64, Ordering};
2use std::time::Duration;
3
4use anyhow::{Context, Result};
5use tracing::Instrument;
6
7use crate::http_client::apply_env_proxy;
8use crate::network_policy::{Decision, NetworkPolicyDecider, host_from_url};
9
10use super::auth::apply_default_headers;
11use super::config::{McpServerConfig, McpTimeouts, McpTransportKind};
12use super::transport::{McpTransport, SseTransport, StdioTransport, StreamableHttpTransport};
13
14/// Protocol version we advertise on `initialize` (latest spec we target).
15const PREFERRED_PROTOCOL_VERSION: &str = "2025-06-18";
16
17/// Protocol versions this client knows how to speak, newest first.
18const SUPPORTED_PROTOCOL_VERSIONS: &[&str] = &["2025-06-18", "2025-03-26", "2024-11-05"];
19use super::types::{ConnectionState, McpPrompt, McpResource, McpResourceTemplate, McpTool};
20
21pub struct McpConnection {
22    name: String,
23    pub(super) transport: Box<dyn McpTransport>,
24    tools: Vec<McpTool>,
25    resources: Vec<McpResource>,
26    resource_templates: Vec<McpResourceTemplate>,
27    prompts: Vec<McpPrompt>,
28    request_id: AtomicU64,
29    state: ConnectionState,
30    config: McpServerConfig,
31    cancel_token: tokio_util::sync::CancellationToken,
32}
33
34impl McpConnection {
35    /// Connect to an MCP server and initialize it.
36    ///
37    /// `network_policy` (added in v0.7.0 for #135) is consulted for HTTP/SSE
38    /// transports only — STDIO transports are unaffected. Pass `None` to
39    /// match pre-v0.7.0 permissive behavior.
40    pub async fn connect_with_policy(
41        name: String,
42        config: McpServerConfig,
43        global_timeouts: &McpTimeouts,
44        network_policy: Option<&NetworkPolicyDecider>,
45    ) -> Result<Self> {
46        let connect_timeout_secs = config.effective_connect_timeout(global_timeouts);
47        let read_timeout_secs = config.effective_read_timeout(global_timeouts);
48        let cancel_token = tokio_util::sync::CancellationToken::new();
49
50        let transport_kind = config
51            .transport_kind()
52            .with_context(|| format!("MCP server '{name}' has an invalid transport config"))?;
53
54        let transport: Box<dyn McpTransport> = match transport_kind {
55            McpTransportKind::Sse | McpTransportKind::Http => {
56                let url = config
57                    .url
58                    .as_ref()
59                    .ok_or_else(|| anyhow::anyhow!("MCP server '{name}' requires a 'url'"))?;
60
61                // Per-domain network policy gate (#135). Only HTTP/SSE
62                // transports are gated; STDIO MCP servers run as local
63                // subprocesses and never touch the network here.
64                if let Some(decider) = network_policy
65                    && let Some(host) = host_from_url(url)
66                {
67                    match decider.evaluate(&host, "mcp") {
68                        Decision::Allow => {}
69                        Decision::Deny => {
70                            anyhow::bail!(
71                                "MCP server '{name}' connection to '{host}' blocked by network policy"
72                            );
73                        }
74                        Decision::Prompt => {
75                            anyhow::bail!(
76                                "MCP server '{name}' connection to '{host}' requires approval; \
77                                 re-run after `/network allow {host}` or set network.default = \"allow\" in config"
78                            );
79                        }
80                    }
81                }
82
83                let http_headers = config.resolve_http_headers(&name)?;
84
85                if transport_kind == McpTransportKind::Http {
86                    // Streamable HTTP returns each response in the POST body,
87                    // so a single request can span an entire tool execution.
88                    // Bound only the TCP connect with connect_timeout and use
89                    // read_timeout as the overall backstop; the per-call
90                    // `tokio::time::timeout` in `call_method` is the real cap.
91                    let builder = apply_env_proxy(
92                        reqwest::Client::builder()
93                            .connect_timeout(Duration::from_secs(connect_timeout_secs))
94                            .timeout(Duration::from_secs(read_timeout_secs)),
95                    );
96                    let builder = apply_default_headers(builder, &http_headers)?;
97                    let client = builder.build()?;
98                    Box::new(StreamableHttpTransport::new(client, url.clone()))
99                } else {
100                    let builder = apply_env_proxy(
101                        reqwest::Client::builder()
102                            .timeout(Duration::from_secs(connect_timeout_secs)),
103                    );
104                    let builder = apply_default_headers(builder, &http_headers)?;
105                    let client = builder.build()?;
106                    Box::new(
107                        SseTransport::connect(client, url.clone(), cancel_token.clone()).await?,
108                    )
109                }
110            }
111            McpTransportKind::Stdio => {
112                let command = config
113                    .command
114                    .as_ref()
115                    .ok_or_else(|| anyhow::anyhow!("MCP server '{name}' requires a 'command'"))?;
116                let mut cmd = super::stdio_spawn::build_stdio_command(
117                    command,
118                    &config.args,
119                    &config.env,
120                )
121                .with_context(|| {
122                    format!(
123                        "MCP stdio command resolution failed (server={name} cmd={command:?} args={:?})",
124                        config.args,
125                    )
126                })?;
127
128                let mut child = cmd.spawn().with_context(|| {
129                    let env_keys: Vec<&str> = config.env.keys().map(String::as_str).collect();
130                    format!(
131                        "MCP stdio spawn failed (transport=stdio server={name} cmd={command:?} args={:?} env_keys={env_keys:?}). \
132                         On Windows ensure Node.js is installed; try full path to npx.cmd in mcp.json.",
133                        config.args,
134                    )
135                })?;
136
137                let stdin = child.stdin.take().context("Failed to get MCP stdin")?;
138                let stdout = child.stdout.take().context("Failed to get MCP stdout")?;
139
140                Box::new(StdioTransport {
141                    child,
142                    stdin,
143                    reader: tokio::io::BufReader::new(stdout),
144                })
145            }
146        };
147
148        let mut conn = Self {
149            name: name.clone(),
150            transport,
151            tools: Vec::new(),
152            resources: Vec::new(),
153            resource_templates: Vec::new(),
154            prompts: Vec::new(),
155            request_id: AtomicU64::new(1),
156            state: ConnectionState::Connecting,
157            config,
158            cancel_token,
159        };
160
161        // Initialize with timeout
162        tokio::time::timeout(Duration::from_secs(connect_timeout_secs), conn.initialize())
163            .await
164            .with_context(|| format!("MCP server '{name}' initialization timed out"))??;
165
166        // Discover tools, resources, and prompts with timeout
167        tokio::time::timeout(
168            Duration::from_secs(connect_timeout_secs),
169            conn.discover_all(),
170        )
171        .await
172        .with_context(|| format!("MCP server '{name}' discovery timed out"))??;
173
174        conn.state = ConnectionState::Ready;
175        Ok(conn)
176    }
177
178    /// Send initialize request, negotiate the protocol version, and complete
179    /// the handshake with the `initialized` notification.
180    ///
181    /// We advertise [`PREFERRED_PROTOCOL_VERSION`] and adopt whatever version
182    /// the server reports back: if it's one we support we use it directly; if
183    /// it's unknown we still proceed with the server's value (best-effort
184    /// interop) but log a warning. The negotiated version is handed to the
185    /// transport so Streamable HTTP can echo it via `MCP-Protocol-Version`.
186    async fn initialize(&mut self) -> Result<()> {
187        let init_id = self.next_id();
188        self.send(serde_json::json!({
189            "jsonrpc": "2.0",
190            "id": init_id,
191            "method": "initialize",
192            "params": {
193                "protocolVersion": PREFERRED_PROTOCOL_VERSION,
194                "clientInfo": {
195                    "name": "deepseek-runtime",
196                    "version": env!("CARGO_PKG_VERSION")
197                },
198                "capabilities": {
199                    "tools": {},
200                    "resources": {},
201                    "prompts": {}
202                }
203            }
204        }))
205        .await?;
206
207        let response = self.recv(init_id).await?;
208        let negotiated = self.negotiate_protocol_version(&response);
209        self.transport.set_protocol_version(&negotiated);
210
211        // Send initialized notification (no id, no response expected)
212        self.send(serde_json::json!({
213            "jsonrpc": "2.0",
214            "method": "notifications/initialized"
215        }))
216        .await?;
217
218        Ok(())
219    }
220
221    /// Resolve the effective protocol version from the `initialize` response,
222    /// defaulting to our preferred version when the server omits it.
223    fn negotiate_protocol_version(&self, response: &serde_json::Value) -> String {
224        let server_version = response
225            .get("result")
226            .and_then(|r| r.get("protocolVersion"))
227            .and_then(serde_json::Value::as_str);
228
229        match server_version {
230            Some(version) if SUPPORTED_PROTOCOL_VERSIONS.contains(&version) => version.to_string(),
231            Some(version) => {
232                tracing::warn!(
233                    server = %self.name,
234                    server_version = version,
235                    preferred = PREFERRED_PROTOCOL_VERSION,
236                    "MCP server reported an unsupported protocol version; proceeding best-effort"
237                );
238                version.to_string()
239            }
240            None => PREFERRED_PROTOCOL_VERSION.to_string(),
241        }
242    }
243
244    /// Discover tools, resources, and prompts
245    async fn discover_all(&mut self) -> Result<()> {
246        // We use join! to discover everything concurrently if possible,
247        // but for now let's keep it sequential for simplicity in error handling
248        self.discover_tools().await?;
249        self.discover_resources().await?;
250        self.discover_resource_templates().await?;
251        self.discover_prompts().await?;
252        Ok(())
253    }
254
255    /// Discover available tools from the MCP server
256    async fn discover_tools(&mut self) -> Result<()> {
257        let list_id = self.next_id();
258        self.send(serde_json::json!({
259            "jsonrpc": "2.0",
260            "id": list_id,
261            "method": "tools/list",
262            "params": {}
263        }))
264        .await?;
265
266        let response = self.recv(list_id).await?;
267
268        if let Some(result) = response.get("result")
269            && let Some(tools) = result.get("tools")
270        {
271            self.tools = serde_json::from_value(tools.clone()).unwrap_or_default();
272        }
273
274        Ok(())
275    }
276
277    /// Discover available resources from the MCP server
278    async fn discover_resources(&mut self) -> Result<()> {
279        let list_id = self.next_id();
280        self.send(serde_json::json!({
281            "jsonrpc": "2.0",
282            "id": list_id,
283            "method": "resources/list",
284            "params": {}
285        }))
286        .await?;
287
288        let response = self.recv(list_id).await?;
289
290        if let Some(result) = response.get("result")
291            && let Some(resources) = result.get("resources")
292        {
293            self.resources = serde_json::from_value(resources.clone()).unwrap_or_default();
294        }
295
296        Ok(())
297    }
298
299    /// Discover available resource templates from the MCP server
300    async fn discover_resource_templates(&mut self) -> Result<()> {
301        let list_id = self.next_id();
302        self.send(serde_json::json!({
303            "jsonrpc": "2.0",
304            "id": list_id,
305            "method": "resources/templates/list",
306            "params": {}
307        }))
308        .await?;
309
310        let response = self.recv(list_id).await?;
311
312        if let Some(result) = response.get("result") {
313            let templates = result
314                .get("resourceTemplates")
315                .or_else(|| result.get("templates"))
316                .or_else(|| result.get("resource_templates"));
317            if let Some(templates) = templates {
318                self.resource_templates =
319                    serde_json::from_value(templates.clone()).unwrap_or_default();
320            }
321        }
322
323        Ok(())
324    }
325
326    /// Discover available prompts from the MCP server
327    async fn discover_prompts(&mut self) -> Result<()> {
328        let list_id = self.next_id();
329        self.send(serde_json::json!({
330            "jsonrpc": "2.0",
331            "id": list_id,
332            "method": "prompts/list",
333            "params": {}
334        }))
335        .await?;
336
337        let response = self.recv(list_id).await?;
338
339        if let Some(result) = response.get("result")
340            && let Some(prompts) = result.get("prompts")
341        {
342            self.prompts = serde_json::from_value(prompts.clone()).unwrap_or_default();
343        }
344
345        Ok(())
346    }
347
348    /// Call a tool on this MCP server
349    pub async fn call_tool(
350        &mut self,
351        tool_name: &str,
352        arguments: serde_json::Value,
353        timeout_secs: u64,
354    ) -> Result<serde_json::Value> {
355        self.call_method(
356            "tools/call",
357            serde_json::json!({
358                "name": tool_name,
359                "arguments": arguments
360            }),
361            timeout_secs,
362        )
363        .await
364    }
365
366    /// Read a resource from this MCP server
367    pub async fn read_resource(
368        &mut self,
369        uri: &str,
370        timeout_secs: u64,
371    ) -> Result<serde_json::Value> {
372        self.call_method(
373            "resources/read",
374            serde_json::json!({
375                "uri": uri
376            }),
377            timeout_secs,
378        )
379        .await
380    }
381
382    /// Get a prompt from this MCP server
383    pub async fn get_prompt(
384        &mut self,
385        prompt_name: &str,
386        arguments: serde_json::Value,
387        timeout_secs: u64,
388    ) -> Result<serde_json::Value> {
389        self.call_method(
390            "prompts/get",
391            serde_json::json!({
392                "name": prompt_name,
393                "arguments": arguments
394            }),
395            timeout_secs,
396        )
397        .await
398    }
399
400    /// Generic method to call an MCP method
401    async fn call_method(
402        &mut self,
403        method: &str,
404        params: serde_json::Value,
405        timeout_secs: u64,
406    ) -> Result<serde_json::Value> {
407        let started = std::time::Instant::now();
408        let server = self.name.clone();
409        let method_name = method.to_string();
410        let span = tracing::info_span!(
411            "mcp.rpc",
412            server = %server,
413            method = %method_name,
414            timeout_secs
415        );
416
417        let outcome = self
418            .call_method_inner(method, params, timeout_secs)
419            .instrument(span)
420            .await;
421        let duration_ms = started.elapsed().as_millis() as u64;
422        let (success, err_msg, result_bytes) = match &outcome {
423            Ok(value) => (
424                true,
425                None,
426                serde_json::to_string(value).map(|s| s.len()).unwrap_or(0),
427            ),
428            Err(err) => (false, Some(err.to_string()), 0),
429        };
430        super::observability::record_mcp_call(
431            &server,
432            &method_name,
433            duration_ms,
434            success,
435            err_msg,
436            result_bytes,
437        );
438        outcome
439    }
440
441    async fn call_method_inner(
442        &mut self,
443        method: &str,
444        params: serde_json::Value,
445        timeout_secs: u64,
446    ) -> Result<serde_json::Value> {
447        if self.state != ConnectionState::Ready {
448            anyhow::bail!(
449                "Failed to call MCP method '{}': connection '{}' is not ready",
450                method,
451                self.name
452            );
453        }
454
455        let call_id = self.next_id();
456        let request = serde_json::json!({
457            "jsonrpc": "2.0",
458            "id": call_id,
459            "method": method,
460            "params": params
461        });
462
463        // Bound the whole exchange: for Streamable HTTP the response is
464        // delivered in the POST body during `send`, so timing only `recv`
465        // would leave long tool calls effectively unbounded.
466        let response = tokio::time::timeout(Duration::from_secs(timeout_secs), async {
467            self.send(request).await?;
468            self.recv(call_id).await
469        })
470        .await
471        .with_context(|| {
472            format!(
473                "MCP method '{}' on server '{}' timed out after {}s",
474                method, self.name, timeout_secs
475            )
476        })??;
477
478        if let Some(error) = response.get("error") {
479            return Err(anyhow::anyhow!(
480                "MCP error in '{}': {}",
481                method,
482                serde_json::to_string_pretty(error)?
483            ));
484        }
485
486        Ok(response
487            .get("result")
488            .cloned()
489            .unwrap_or(serde_json::json!(null)))
490    }
491
492    /// Get discovered tools
493    pub fn tools(&self) -> &[McpTool] {
494        &self.tools
495    }
496
497    /// Get discovered resources
498    pub fn resources(&self) -> &[McpResource] {
499        &self.resources
500    }
501
502    /// Get discovered resource templates
503    pub fn resource_templates(&self) -> &[McpResourceTemplate] {
504        &self.resource_templates
505    }
506
507    /// Get discovered prompts
508    pub fn prompts(&self) -> &[McpPrompt] {
509        &self.prompts
510    }
511
512    /// Get server name
513    #[allow(dead_code)] // Public API for MCP consumers
514    pub fn name(&self) -> &str {
515        &self.name
516    }
517
518    /// Check if connection is ready
519    pub fn is_ready(&self) -> bool {
520        self.state == ConnectionState::Ready
521    }
522
523    /// Get server config
524    pub fn config(&self) -> &McpServerConfig {
525        &self.config
526    }
527
528    /// Get connection state
529    #[allow(dead_code)] // Public API for MCP consumers
530    pub fn state(&self) -> ConnectionState {
531        self.state
532    }
533
534    fn next_id(&self) -> u64 {
535        self.request_id.fetch_add(1, Ordering::SeqCst)
536    }
537
538    async fn send(&mut self, msg: serde_json::Value) -> Result<()> {
539        self.transport.send(msg).await
540    }
541
542    async fn recv(&mut self, expected_id: u64) -> Result<serde_json::Value> {
543        loop {
544            let value = self.transport.recv().await.inspect_err(|_e| {
545                self.state = ConnectionState::Disconnected;
546            })?;
547
548            // Check if this is a response with the expected id
549            if value.get("id").and_then(serde_json::Value::as_u64) == Some(expected_id) {
550                return Ok(value);
551            }
552            // Skip notifications (no id) and responses with different ids
553        }
554    }
555
556    /// Gracefully close the connection
557    #[allow(dead_code)] // Public API for MCP consumers
558    pub fn close(&mut self) {
559        self.cancel_token.cancel();
560        self.state = ConnectionState::Disconnected;
561    }
562}
563
564impl Drop for McpConnection {
565    fn drop(&mut self) {
566        self.cancel_token.cancel();
567    }
568}