Skip to main content

selfware/mcp/
client.rs

1#![allow(dead_code, unused_imports, unused_variables)]
2//! MCP client that manages the protocol lifecycle.
3//!
4//! Handles initialization, tool discovery, and tool execution via the transport.
5
6use anyhow::{Context, Result};
7use serde_json::Value;
8use std::sync::Arc;
9use tracing::{debug, info, warn};
10
11use super::transport::Transport;
12use super::McpServerConfig;
13
14/// MCP protocol version we support.
15const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
16
17/// Information about the client sent during initialization.
18const CLIENT_NAME: &str = "selfware";
19const CLIENT_VERSION: &str = env!("CARGO_PKG_VERSION");
20
21/// MCP client wrapping a transport connection to a single MCP server.
22pub struct McpClient {
23    transport: Arc<dyn Transport>,
24    server_name: String,
25    server_info: Option<Value>,
26}
27
28impl McpClient {
29    /// Connect to an MCP server and perform the initialization handshake.
30    pub async fn connect(config: &McpServerConfig) -> Result<Self> {
31        let transport = super::StdioTransport::spawn(&config.command, &config.args, &config.env)
32            .await
33            .with_context(|| format!("Failed to spawn MCP server '{}'", config.name))?;
34
35        let transport: Arc<dyn Transport> = Arc::new(transport);
36        let mut client = Self {
37            transport,
38            server_name: config.name.clone(),
39            server_info: None,
40        };
41
42        // Perform MCP initialization with timeout
43        tokio::time::timeout(
44            std::time::Duration::from_secs(config.init_timeout_secs.max(5)),
45            client.initialize(),
46        )
47        .await
48        .map_err(|_| {
49            anyhow::anyhow!(
50                "MCP server '{}' initialization timed out after {}s",
51                config.name,
52                config.init_timeout_secs
53            )
54        })??;
55
56        info!("MCP server '{}' initialized successfully", config.name);
57        Ok(client)
58    }
59
60    /// Perform the MCP initialization handshake.
61    async fn initialize(&mut self) -> Result<()> {
62        let params = serde_json::json!({
63            "protocolVersion": MCP_PROTOCOL_VERSION,
64            "capabilities": {
65                "roots": { "listChanged": false },
66            },
67            "clientInfo": {
68                "name": CLIENT_NAME,
69                "version": CLIENT_VERSION,
70            }
71        });
72
73        let result = self
74            .transport
75            .request("initialize", Some(params))
76            .await
77            .with_context(|| {
78                format!("MCP initialize handshake failed for '{}'", self.server_name)
79            })?;
80
81        self.server_info = Some(result.clone());
82
83        // Send initialized notification
84        self.transport
85            .notify("notifications/initialized", None)
86            .await?;
87
88        let server_name = result
89            .get("serverInfo")
90            .and_then(|i| i.get("name"))
91            .and_then(|n| n.as_str())
92            .unwrap_or("unknown");
93        let protocol_version = result
94            .get("protocolVersion")
95            .and_then(|v| v.as_str())
96            .unwrap_or("unknown");
97
98        info!(
99            "MCP server '{}' (protocol: {})",
100            server_name, protocol_version
101        );
102
103        Ok(())
104    }
105
106    /// List all tools available from this MCP server.
107    pub async fn list_tools(&self) -> Result<Vec<Value>> {
108        let result = self.transport.request("tools/list", None).await?;
109
110        let tools = result
111            .get("tools")
112            .and_then(|t| t.as_array())
113            .cloned()
114            .unwrap_or_default();
115
116        debug!(
117            "MCP server '{}' offers {} tool(s)",
118            self.server_name,
119            tools.len()
120        );
121
122        Ok(tools)
123    }
124
125    /// Call a tool on the MCP server.
126    pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<Value> {
127        let params = serde_json::json!({
128            "name": name,
129            "arguments": arguments,
130        });
131
132        let result = self
133            .transport
134            .request("tools/call", Some(params))
135            .await
136            .with_context(|| {
137                format!(
138                    "MCP tool call '{}' failed on server '{}'",
139                    name, self.server_name
140                )
141            })?;
142
143        // MCP tool results have a `content` array with text/image/resource blocks
144        // Extract text content for simple use
145        if let Some(content) = result.get("content").and_then(|c| c.as_array()) {
146            let text_parts: Vec<&str> = content
147                .iter()
148                .filter_map(|block| {
149                    if block.get("type").and_then(|t| t.as_str()) == Some("text") {
150                        block.get("text").and_then(|t| t.as_str())
151                    } else {
152                        None
153                    }
154                })
155                .collect();
156
157            if !text_parts.is_empty() {
158                return Ok(serde_json::json!({
159                    "content": text_parts.join("\n"),
160                    "isError": result.get("isError").and_then(|e| e.as_bool()).unwrap_or(false),
161                }));
162            }
163        }
164
165        Ok(result)
166    }
167
168    /// Get the server name.
169    pub fn server_name(&self) -> &str {
170        &self.server_name
171    }
172
173    /// Shut down the client and its transport.
174    pub async fn shutdown(&self) -> Result<()> {
175        info!("Shutting down MCP client for '{}'", self.server_name);
176        self.transport.shutdown().await
177    }
178}
179
180impl std::fmt::Debug for McpClient {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        f.debug_struct("McpClient")
183            .field("server_name", &self.server_name)
184            .finish()
185    }
186}