zagens_runtime_adapters/mcp/
connection.rs1use std::sync::atomic::{AtomicU64, Ordering};
2use std::time::Duration;
3
4use anyhow::{Context, Result};
5use tracing::Instrument;
6
7use crate::http_client::apply_env_proxy;
8use crate::network_policy::{Decision, NetworkPolicyDecider, host_from_url};
9
10use super::auth::apply_default_headers;
11use super::config::{McpServerConfig, McpTimeouts, McpTransportKind};
12use super::transport::{McpTransport, SseTransport, StdioTransport, StreamableHttpTransport};
13
14const PREFERRED_PROTOCOL_VERSION: &str = "2025-06-18";
16
17const SUPPORTED_PROTOCOL_VERSIONS: &[&str] = &["2025-06-18", "2025-03-26", "2024-11-05"];
19use super::types::{ConnectionState, McpPrompt, McpResource, McpResourceTemplate, McpTool};
20
21pub struct McpConnection {
22 name: String,
23 pub(super) transport: Box<dyn McpTransport>,
24 tools: Vec<McpTool>,
25 resources: Vec<McpResource>,
26 resource_templates: Vec<McpResourceTemplate>,
27 prompts: Vec<McpPrompt>,
28 request_id: AtomicU64,
29 state: ConnectionState,
30 config: McpServerConfig,
31 cancel_token: tokio_util::sync::CancellationToken,
32}
33
34impl McpConnection {
35 pub async fn connect_with_policy(
41 name: String,
42 config: McpServerConfig,
43 global_timeouts: &McpTimeouts,
44 network_policy: Option<&NetworkPolicyDecider>,
45 ) -> Result<Self> {
46 let connect_timeout_secs = config.effective_connect_timeout(global_timeouts);
47 let read_timeout_secs = config.effective_read_timeout(global_timeouts);
48 let cancel_token = tokio_util::sync::CancellationToken::new();
49
50 let transport_kind = config
51 .transport_kind()
52 .with_context(|| format!("MCP server '{name}' has an invalid transport config"))?;
53
54 let transport: Box<dyn McpTransport> = match transport_kind {
55 McpTransportKind::Sse | McpTransportKind::Http => {
56 let url = config
57 .url
58 .as_ref()
59 .ok_or_else(|| anyhow::anyhow!("MCP server '{name}' requires a 'url'"))?;
60
61 if let Some(decider) = network_policy
65 && let Some(host) = host_from_url(url)
66 {
67 match decider.evaluate(&host, "mcp") {
68 Decision::Allow => {}
69 Decision::Deny => {
70 anyhow::bail!(
71 "MCP server '{name}' connection to '{host}' blocked by network policy"
72 );
73 }
74 Decision::Prompt => {
75 anyhow::bail!(
76 "MCP server '{name}' connection to '{host}' requires approval; \
77 re-run after `/network allow {host}` or set network.default = \"allow\" in config"
78 );
79 }
80 }
81 }
82
83 let http_headers = config.resolve_http_headers(&name)?;
84
85 if transport_kind == McpTransportKind::Http {
86 let builder = apply_env_proxy(
92 reqwest::Client::builder()
93 .connect_timeout(Duration::from_secs(connect_timeout_secs))
94 .timeout(Duration::from_secs(read_timeout_secs)),
95 );
96 let builder = apply_default_headers(builder, &http_headers)?;
97 let client = builder.build()?;
98 Box::new(StreamableHttpTransport::new(client, url.clone()))
99 } else {
100 let builder = apply_env_proxy(
101 reqwest::Client::builder()
102 .timeout(Duration::from_secs(connect_timeout_secs)),
103 );
104 let builder = apply_default_headers(builder, &http_headers)?;
105 let client = builder.build()?;
106 Box::new(
107 SseTransport::connect(client, url.clone(), cancel_token.clone()).await?,
108 )
109 }
110 }
111 McpTransportKind::Stdio => {
112 let command = config
113 .command
114 .as_ref()
115 .ok_or_else(|| anyhow::anyhow!("MCP server '{name}' requires a 'command'"))?;
116 let mut cmd = super::stdio_spawn::build_stdio_command(
117 command,
118 &config.args,
119 &config.env,
120 )
121 .with_context(|| {
122 format!(
123 "MCP stdio command resolution failed (server={name} cmd={command:?} args={:?})",
124 config.args,
125 )
126 })?;
127
128 let mut child = cmd.spawn().with_context(|| {
129 let env_keys: Vec<&str> = config.env.keys().map(String::as_str).collect();
130 format!(
131 "MCP stdio spawn failed (transport=stdio server={name} cmd={command:?} args={:?} env_keys={env_keys:?}). \
132 On Windows ensure Node.js is installed; try full path to npx.cmd in mcp.json.",
133 config.args,
134 )
135 })?;
136
137 let stdin = child.stdin.take().context("Failed to get MCP stdin")?;
138 let stdout = child.stdout.take().context("Failed to get MCP stdout")?;
139
140 Box::new(StdioTransport {
141 child,
142 stdin,
143 reader: tokio::io::BufReader::new(stdout),
144 })
145 }
146 };
147
148 let mut conn = Self {
149 name: name.clone(),
150 transport,
151 tools: Vec::new(),
152 resources: Vec::new(),
153 resource_templates: Vec::new(),
154 prompts: Vec::new(),
155 request_id: AtomicU64::new(1),
156 state: ConnectionState::Connecting,
157 config,
158 cancel_token,
159 };
160
161 tokio::time::timeout(Duration::from_secs(connect_timeout_secs), conn.initialize())
163 .await
164 .with_context(|| format!("MCP server '{name}' initialization timed out"))??;
165
166 tokio::time::timeout(
168 Duration::from_secs(connect_timeout_secs),
169 conn.discover_all(),
170 )
171 .await
172 .with_context(|| format!("MCP server '{name}' discovery timed out"))??;
173
174 conn.state = ConnectionState::Ready;
175 Ok(conn)
176 }
177
178 async fn initialize(&mut self) -> Result<()> {
187 let init_id = self.next_id();
188 self.send(serde_json::json!({
189 "jsonrpc": "2.0",
190 "id": init_id,
191 "method": "initialize",
192 "params": {
193 "protocolVersion": PREFERRED_PROTOCOL_VERSION,
194 "clientInfo": {
195 "name": "deepseek-runtime",
196 "version": env!("CARGO_PKG_VERSION")
197 },
198 "capabilities": {
199 "tools": {},
200 "resources": {},
201 "prompts": {}
202 }
203 }
204 }))
205 .await?;
206
207 let response = self.recv(init_id).await?;
208 let negotiated = self.negotiate_protocol_version(&response);
209 self.transport.set_protocol_version(&negotiated);
210
211 self.send(serde_json::json!({
213 "jsonrpc": "2.0",
214 "method": "notifications/initialized"
215 }))
216 .await?;
217
218 Ok(())
219 }
220
221 fn negotiate_protocol_version(&self, response: &serde_json::Value) -> String {
224 let server_version = response
225 .get("result")
226 .and_then(|r| r.get("protocolVersion"))
227 .and_then(serde_json::Value::as_str);
228
229 match server_version {
230 Some(version) if SUPPORTED_PROTOCOL_VERSIONS.contains(&version) => version.to_string(),
231 Some(version) => {
232 tracing::warn!(
233 server = %self.name,
234 server_version = version,
235 preferred = PREFERRED_PROTOCOL_VERSION,
236 "MCP server reported an unsupported protocol version; proceeding best-effort"
237 );
238 version.to_string()
239 }
240 None => PREFERRED_PROTOCOL_VERSION.to_string(),
241 }
242 }
243
244 async fn discover_all(&mut self) -> Result<()> {
246 self.discover_tools().await?;
249 self.discover_resources().await?;
250 self.discover_resource_templates().await?;
251 self.discover_prompts().await?;
252 Ok(())
253 }
254
255 async fn discover_tools(&mut self) -> Result<()> {
257 let list_id = self.next_id();
258 self.send(serde_json::json!({
259 "jsonrpc": "2.0",
260 "id": list_id,
261 "method": "tools/list",
262 "params": {}
263 }))
264 .await?;
265
266 let response = self.recv(list_id).await?;
267
268 if let Some(result) = response.get("result")
269 && let Some(tools) = result.get("tools")
270 {
271 self.tools = serde_json::from_value(tools.clone()).unwrap_or_default();
272 }
273
274 Ok(())
275 }
276
277 async fn discover_resources(&mut self) -> Result<()> {
279 let list_id = self.next_id();
280 self.send(serde_json::json!({
281 "jsonrpc": "2.0",
282 "id": list_id,
283 "method": "resources/list",
284 "params": {}
285 }))
286 .await?;
287
288 let response = self.recv(list_id).await?;
289
290 if let Some(result) = response.get("result")
291 && let Some(resources) = result.get("resources")
292 {
293 self.resources = serde_json::from_value(resources.clone()).unwrap_or_default();
294 }
295
296 Ok(())
297 }
298
299 async fn discover_resource_templates(&mut self) -> Result<()> {
301 let list_id = self.next_id();
302 self.send(serde_json::json!({
303 "jsonrpc": "2.0",
304 "id": list_id,
305 "method": "resources/templates/list",
306 "params": {}
307 }))
308 .await?;
309
310 let response = self.recv(list_id).await?;
311
312 if let Some(result) = response.get("result") {
313 let templates = result
314 .get("resourceTemplates")
315 .or_else(|| result.get("templates"))
316 .or_else(|| result.get("resource_templates"));
317 if let Some(templates) = templates {
318 self.resource_templates =
319 serde_json::from_value(templates.clone()).unwrap_or_default();
320 }
321 }
322
323 Ok(())
324 }
325
326 async fn discover_prompts(&mut self) -> Result<()> {
328 let list_id = self.next_id();
329 self.send(serde_json::json!({
330 "jsonrpc": "2.0",
331 "id": list_id,
332 "method": "prompts/list",
333 "params": {}
334 }))
335 .await?;
336
337 let response = self.recv(list_id).await?;
338
339 if let Some(result) = response.get("result")
340 && let Some(prompts) = result.get("prompts")
341 {
342 self.prompts = serde_json::from_value(prompts.clone()).unwrap_or_default();
343 }
344
345 Ok(())
346 }
347
348 pub async fn call_tool(
350 &mut self,
351 tool_name: &str,
352 arguments: serde_json::Value,
353 timeout_secs: u64,
354 ) -> Result<serde_json::Value> {
355 self.call_method(
356 "tools/call",
357 serde_json::json!({
358 "name": tool_name,
359 "arguments": arguments
360 }),
361 timeout_secs,
362 )
363 .await
364 }
365
366 pub async fn read_resource(
368 &mut self,
369 uri: &str,
370 timeout_secs: u64,
371 ) -> Result<serde_json::Value> {
372 self.call_method(
373 "resources/read",
374 serde_json::json!({
375 "uri": uri
376 }),
377 timeout_secs,
378 )
379 .await
380 }
381
382 pub async fn get_prompt(
384 &mut self,
385 prompt_name: &str,
386 arguments: serde_json::Value,
387 timeout_secs: u64,
388 ) -> Result<serde_json::Value> {
389 self.call_method(
390 "prompts/get",
391 serde_json::json!({
392 "name": prompt_name,
393 "arguments": arguments
394 }),
395 timeout_secs,
396 )
397 .await
398 }
399
400 async fn call_method(
402 &mut self,
403 method: &str,
404 params: serde_json::Value,
405 timeout_secs: u64,
406 ) -> Result<serde_json::Value> {
407 let started = std::time::Instant::now();
408 let server = self.name.clone();
409 let method_name = method.to_string();
410 let span = tracing::info_span!(
411 "mcp.rpc",
412 server = %server,
413 method = %method_name,
414 timeout_secs
415 );
416
417 let outcome = self
418 .call_method_inner(method, params, timeout_secs)
419 .instrument(span)
420 .await;
421 let duration_ms = started.elapsed().as_millis() as u64;
422 let (success, err_msg, result_bytes) = match &outcome {
423 Ok(value) => (
424 true,
425 None,
426 serde_json::to_string(value).map(|s| s.len()).unwrap_or(0),
427 ),
428 Err(err) => (false, Some(err.to_string()), 0),
429 };
430 super::observability::record_mcp_call(
431 &server,
432 &method_name,
433 duration_ms,
434 success,
435 err_msg,
436 result_bytes,
437 );
438 outcome
439 }
440
441 async fn call_method_inner(
442 &mut self,
443 method: &str,
444 params: serde_json::Value,
445 timeout_secs: u64,
446 ) -> Result<serde_json::Value> {
447 if self.state != ConnectionState::Ready {
448 anyhow::bail!(
449 "Failed to call MCP method '{}': connection '{}' is not ready",
450 method,
451 self.name
452 );
453 }
454
455 let call_id = self.next_id();
456 let request = serde_json::json!({
457 "jsonrpc": "2.0",
458 "id": call_id,
459 "method": method,
460 "params": params
461 });
462
463 let response = tokio::time::timeout(Duration::from_secs(timeout_secs), async {
467 self.send(request).await?;
468 self.recv(call_id).await
469 })
470 .await
471 .with_context(|| {
472 format!(
473 "MCP method '{}' on server '{}' timed out after {}s",
474 method, self.name, timeout_secs
475 )
476 })??;
477
478 if let Some(error) = response.get("error") {
479 return Err(anyhow::anyhow!(
480 "MCP error in '{}': {}",
481 method,
482 serde_json::to_string_pretty(error)?
483 ));
484 }
485
486 Ok(response
487 .get("result")
488 .cloned()
489 .unwrap_or(serde_json::json!(null)))
490 }
491
492 pub fn tools(&self) -> &[McpTool] {
494 &self.tools
495 }
496
497 pub fn resources(&self) -> &[McpResource] {
499 &self.resources
500 }
501
502 pub fn resource_templates(&self) -> &[McpResourceTemplate] {
504 &self.resource_templates
505 }
506
507 pub fn prompts(&self) -> &[McpPrompt] {
509 &self.prompts
510 }
511
512 #[allow(dead_code)] pub fn name(&self) -> &str {
515 &self.name
516 }
517
518 pub fn is_ready(&self) -> bool {
520 self.state == ConnectionState::Ready
521 }
522
523 pub fn config(&self) -> &McpServerConfig {
525 &self.config
526 }
527
528 #[allow(dead_code)] pub fn state(&self) -> ConnectionState {
531 self.state
532 }
533
534 fn next_id(&self) -> u64 {
535 self.request_id.fetch_add(1, Ordering::SeqCst)
536 }
537
538 async fn send(&mut self, msg: serde_json::Value) -> Result<()> {
539 self.transport.send(msg).await
540 }
541
542 async fn recv(&mut self, expected_id: u64) -> Result<serde_json::Value> {
543 loop {
544 let value = self.transport.recv().await.inspect_err(|_e| {
545 self.state = ConnectionState::Disconnected;
546 })?;
547
548 if value.get("id").and_then(serde_json::Value::as_u64) == Some(expected_id) {
550 return Ok(value);
551 }
552 }
554 }
555
556 #[allow(dead_code)] pub fn close(&mut self) {
559 self.cancel_token.cancel();
560 self.state = ConnectionState::Disconnected;
561 }
562}
563
564impl Drop for McpConnection {
565 fn drop(&mut self) {
566 self.cancel_token.cancel();
567 }
568}