Skip to main content

rustic_ai/
mcp.rs

1use std::pin::Pin;
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_stream::try_stream;
6use async_trait::async_trait;
7use eventsource_stream::Eventsource;
8use futures::lock::Mutex;
9use futures::stream::StreamExt;
10use reqwest::header::HeaderMap;
11use reqwest::{Client, Url};
12use serde::Deserialize;
13use serde_json::{Value, json};
14use uuid::Uuid;
15
16use crate::tools::{RunContext, ToolDefinition, ToolError, ToolKind, Toolset};
17
18#[derive(Clone, Debug)]
19pub struct McpServerStreamableHttp {
20    url: Url,
21    headers: HeaderMap,
22    timeout: Duration,
23    tool_prefix: Option<String>,
24    client: Client,
25    events_url: Option<Url>,
26    cache_tools: bool,
27    cache_resources: bool,
28    cache_prompts: bool,
29    cached_tools: Arc<Mutex<Option<Vec<ToolDefinition>>>>,
30    cached_resources: Arc<Mutex<Option<Vec<McpResource>>>>,
31    cached_prompts: Arc<Mutex<Option<Vec<McpPrompt>>>>,
32}
33
34impl McpServerStreamableHttp {
35    pub fn new(url: impl AsRef<str>) -> Result<Self, ToolError> {
36        let url = Url::parse(url.as_ref())
37            .map_err(|e| ToolError::Toolset(format!("invalid MCP URL: {e}")))?;
38        let timeout = Duration::from_secs(10);
39        let client = Client::builder()
40            .timeout(timeout)
41            .build()
42            .map_err(|e| ToolError::Toolset(format!("failed to build HTTP client: {e}")))?;
43        Ok(Self {
44            url,
45            headers: HeaderMap::new(),
46            timeout,
47            tool_prefix: None,
48            client,
49            events_url: None,
50            cache_tools: true,
51            cache_resources: true,
52            cache_prompts: true,
53            cached_tools: Arc::new(Mutex::new(None)),
54            cached_resources: Arc::new(Mutex::new(None)),
55            cached_prompts: Arc::new(Mutex::new(None)),
56        })
57    }
58
59    pub fn with_headers(mut self, headers: HeaderMap) -> Self {
60        self.headers = headers;
61        self
62    }
63
64    pub fn with_timeout(mut self, timeout: Duration) -> Self {
65        self.timeout = timeout;
66        self.client = Client::builder()
67            .timeout(timeout)
68            .build()
69            .unwrap_or_else(|_| Client::new());
70        self
71    }
72
73    pub fn with_tool_prefix(mut self, prefix: impl Into<String>) -> Self {
74        self.tool_prefix = Some(prefix.into());
75        self
76    }
77
78    pub fn with_events_url(mut self, url: impl AsRef<str>) -> Result<Self, ToolError> {
79        self.events_url = Some(
80            Url::parse(url.as_ref())
81                .map_err(|e| ToolError::Toolset(format!("invalid MCP events URL: {e}")))?,
82        );
83        Ok(self)
84    }
85
86    pub fn cache_tools(mut self, enabled: bool) -> Self {
87        self.cache_tools = enabled;
88        self
89    }
90
91    pub fn cache_resources(mut self, enabled: bool) -> Self {
92        self.cache_resources = enabled;
93        self
94    }
95
96    pub fn cache_prompts(mut self, enabled: bool) -> Self {
97        self.cache_prompts = enabled;
98        self
99    }
100
101    pub async fn invalidate_tools_cache(&self) {
102        *self.cached_tools.lock().await = None;
103    }
104
105    pub async fn invalidate_resources_cache(&self) {
106        *self.cached_resources.lock().await = None;
107    }
108
109    pub async fn invalidate_prompts_cache(&self) {
110        *self.cached_prompts.lock().await = None;
111    }
112
113    async fn rpc(&self, method: &str, params: Value) -> Result<Value, ToolError> {
114        let request_id = Uuid::new_v4().to_string();
115        let payload = json!({
116            "jsonrpc": "2.0",
117            "id": request_id,
118            "method": method,
119            "params": params,
120        });
121        let response = self
122            .client
123            .post(self.url.clone())
124            .headers(self.headers.clone())
125            .json(&payload)
126            .send()
127            .await
128            .map_err(|e| ToolError::Toolset(format!("MCP request failed: {e}")))?;
129
130        let status = response.status();
131        let value: Value = response
132            .json()
133            .await
134            .map_err(|e| ToolError::Toolset(format!("MCP response parse failed: {e}")))?;
135
136        if let Some(error) = value.get("error") {
137            return Err(ToolError::Toolset(format!(
138                "MCP error (status {status}): {error}"
139            )));
140        }
141        value
142            .get("result")
143            .cloned()
144            .ok_or_else(|| ToolError::Toolset("MCP response missing result".to_string()))
145    }
146
147    fn prefix_name(&self, name: &str) -> String {
148        if let Some(prefix) = &self.tool_prefix {
149            format!("{}__{}", prefix, name)
150        } else {
151            name.to_string()
152        }
153    }
154
155    fn unprefix_name<'a>(&self, name: &'a str) -> &'a str {
156        if let Some(prefix) = &self.tool_prefix {
157            let expected = format!("{}__", prefix);
158            name.strip_prefix(&expected).unwrap_or(name)
159        } else {
160            name
161        }
162    }
163
164    pub async fn list_resources(&self) -> Result<Vec<McpResource>, ToolError> {
165        if self.cache_resources
166            && let Some(cached) = self.cached_resources.lock().await.clone()
167        {
168            return Ok(cached);
169        }
170
171        let result = self.rpc("resources/list", json!({})).await?;
172        let resources: RpcResourcesList = serde_json::from_value(result)
173            .map_err(|e| ToolError::Toolset(format!("invalid MCP resources list: {e}")))?;
174        if self.cache_resources {
175            *self.cached_resources.lock().await = Some(resources.resources.clone());
176        }
177        Ok(resources.resources)
178    }
179
180    pub async fn list_resource_templates(&self) -> Result<Vec<McpResourceTemplate>, ToolError> {
181        let result = self.rpc("resources/templates/list", json!({})).await?;
182        let templates: RpcResourceTemplatesList = serde_json::from_value(result)
183            .map_err(|e| ToolError::Toolset(format!("invalid MCP resource templates list: {e}")))?;
184        Ok(templates.resource_templates)
185    }
186
187    pub async fn read_resource(&self, uri: &str) -> Result<Value, ToolError> {
188        let result = self.rpc("resources/read", json!({ "uri": uri })).await?;
189        Ok(result)
190    }
191
192    pub async fn list_prompts(&self) -> Result<Vec<McpPrompt>, ToolError> {
193        if self.cache_prompts
194            && let Some(cached) = self.cached_prompts.lock().await.clone()
195        {
196            return Ok(cached);
197        }
198
199        let result = self.rpc("prompts/list", json!({})).await?;
200        let prompts: RpcPromptsList = serde_json::from_value(result)
201            .map_err(|e| ToolError::Toolset(format!("invalid MCP prompts list: {e}")))?;
202        if self.cache_prompts {
203            *self.cached_prompts.lock().await = Some(prompts.prompts.clone());
204        }
205        Ok(prompts.prompts)
206    }
207
208    pub async fn get_prompt(
209        &self,
210        name: &str,
211        arguments: Option<Value>,
212    ) -> Result<Vec<McpPromptMessage>, ToolError> {
213        let mut params = json!({ "name": name });
214        if let Some(arguments) = arguments
215            && let Value::Object(map) = &mut params
216        {
217            map.insert("arguments".to_string(), arguments);
218        }
219        let result = self.rpc("prompts/get", params).await?;
220        let prompt: RpcPromptGet = serde_json::from_value(result)
221            .map_err(|e| ToolError::Toolset(format!("invalid MCP prompt: {e}")))?;
222        Ok(prompt.messages)
223    }
224
225    pub async fn sample(&self, params: Value) -> Result<Value, ToolError> {
226        self.rpc("sampling/createMessage", params).await
227    }
228
229    pub async fn notifications(&self) -> Result<McpNotificationStream, ToolError> {
230        let events_url = self
231            .events_url
232            .clone()
233            .ok_or_else(|| ToolError::Toolset("MCP events URL not configured".to_string()))?;
234
235        let response = self
236            .client
237            .get(events_url)
238            .headers(self.headers.clone())
239            .send()
240            .await
241            .map_err(|e| ToolError::Toolset(format!("MCP events request failed: {e}")))?;
242
243        if !response.status().is_success() {
244            return Err(ToolError::Toolset(format!(
245                "MCP events error status {}",
246                response.status()
247            )));
248        }
249
250        let mut event_stream = response.bytes_stream().eventsource();
251        let cached_tools = Arc::clone(&self.cached_tools);
252        let cached_resources = Arc::clone(&self.cached_resources);
253        let cached_prompts = Arc::clone(&self.cached_prompts);
254
255        let stream = try_stream! {
256            while let Some(event) = event_stream.next().await {
257                let event = event.map_err(|e| ToolError::Toolset(format!("MCP events stream error: {e}")))?;
258                let notification: McpNotification = serde_json::from_str(&event.data)
259                    .map_err(|e| ToolError::Toolset(format!("MCP notification parse error: {e}")))?;
260
261                match notification.method.as_str() {
262                    "notifications/tools/list_changed" => {
263                        *cached_tools.lock().await = None;
264                    }
265                    "notifications/resources/list_changed" => {
266                        *cached_resources.lock().await = None;
267                    }
268                    "notifications/prompts/list_changed" => {
269                        *cached_prompts.lock().await = None;
270                    }
271                    _ => {}
272                }
273
274                yield notification;
275            }
276        };
277
278        Ok(Box::pin(stream))
279    }
280}
281
282#[derive(Debug, Deserialize)]
283struct RpcToolsList {
284    tools: Vec<RpcTool>,
285}
286
287#[derive(Debug, Deserialize)]
288struct RpcTool {
289    name: String,
290    description: Option<String>,
291    #[serde(rename = "inputSchema")]
292    input_schema: Value,
293    meta: Option<Value>,
294    annotations: Option<Value>,
295    #[serde(rename = "outputSchema")]
296    output_schema: Option<Value>,
297}
298
299#[derive(Debug, Deserialize)]
300struct RpcResourcesList {
301    resources: Vec<McpResource>,
302}
303
304#[derive(Debug, Deserialize)]
305struct RpcResourceTemplatesList {
306    #[serde(rename = "resourceTemplates")]
307    resource_templates: Vec<McpResourceTemplate>,
308}
309
310#[derive(Debug, Deserialize)]
311struct RpcPromptsList {
312    prompts: Vec<McpPrompt>,
313}
314
315#[derive(Debug, Deserialize)]
316struct RpcPromptGet {
317    messages: Vec<McpPromptMessage>,
318}
319
320#[derive(Clone, Debug, Deserialize)]
321pub struct McpResource {
322    pub uri: String,
323    pub name: Option<String>,
324    pub description: Option<String>,
325    #[serde(rename = "mimeType")]
326    pub mime_type: Option<String>,
327    pub metadata: Option<Value>,
328}
329
330#[derive(Clone, Debug, Deserialize)]
331pub struct McpResourceTemplate {
332    pub name: String,
333    pub description: Option<String>,
334    pub uri_template: Option<String>,
335    pub metadata: Option<Value>,
336}
337
338#[derive(Clone, Debug, Deserialize)]
339pub struct McpPrompt {
340    pub name: String,
341    pub description: Option<String>,
342    pub arguments: Option<Vec<McpPromptArgument>>,
343}
344
345#[derive(Clone, Debug, Deserialize)]
346pub struct McpPromptArgument {
347    pub name: String,
348    pub description: Option<String>,
349    pub required: Option<bool>,
350}
351
352#[derive(Clone, Debug, Deserialize)]
353pub struct McpPromptMessage {
354    pub role: String,
355    pub content: Value,
356}
357
358#[derive(Clone, Debug, Deserialize)]
359pub struct McpNotification {
360    pub method: String,
361    pub params: Option<Value>,
362}
363
364pub type McpNotificationStream =
365    Pin<Box<dyn futures::stream::Stream<Item = Result<McpNotification, ToolError>> + Send>>;
366
367#[async_trait]
368impl<Deps> Toolset<Deps> for McpServerStreamableHttp
369where
370    Deps: Send + Sync,
371{
372    async fn list_tools(&self, _ctx: &RunContext<Deps>) -> Result<Vec<ToolDefinition>, ToolError> {
373        if self.cache_tools
374            && let Some(cached) = self.cached_tools.lock().await.clone()
375        {
376            return Ok(cached);
377        }
378
379        let result = self.rpc("tools/list", json!({})).await?;
380        let tools: RpcToolsList = serde_json::from_value(result)
381            .map_err(|e| ToolError::Toolset(format!("invalid MCP tools list: {e}")))?;
382        let mapped: Vec<ToolDefinition> = tools
383            .tools
384            .into_iter()
385            .map(|tool| {
386                let mut def = ToolDefinition::new(
387                    self.prefix_name(&tool.name),
388                    tool.description,
389                    tool.input_schema,
390                );
391                def.kind = ToolKind::Function;
392                def.metadata = Some(json!({
393                    "meta": tool.meta,
394                    "annotations": tool.annotations,
395                    "output_schema": tool.output_schema,
396                }));
397                def
398            })
399            .collect();
400
401        if self.cache_tools {
402            *self.cached_tools.lock().await = Some(mapped.clone());
403        }
404
405        Ok(mapped)
406    }
407
408    async fn call_tool(
409        &self,
410        _ctx: &RunContext<Deps>,
411        name: &str,
412        args: Value,
413    ) -> Result<Value, ToolError> {
414        let name = self.unprefix_name(name).to_string();
415        let result = self
416            .rpc("tools/call", json!({"name": name, "arguments": args}))
417            .await?;
418
419        if let Some(structured) = result.get("structuredContent") {
420            return Ok(structured.clone());
421        }
422
423        if let Some(content) = result.get("content")
424            && let Some(array) = content.as_array()
425            && array.len() == 1
426            && let Some(text) = array[0].get("text").and_then(|v| v.as_str())
427        {
428            return Ok(Value::String(text.to_string()));
429        }
430
431        Ok(result)
432    }
433
434    fn name(&self) -> &str {
435        "mcp-http"
436    }
437}