sovran_mcp/
messaging.rs

1use crate::transport::Transport;
2use crate::types::*;
3use crate::McpError;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::sync::mpsc::Sender;
8use std::sync::{Arc, Mutex};
9use tracing::{debug, warn};
10use url::Url;
11
12//
13// Core JSON-RPC Types
14// These types represent the basic building blocks of the JSON-RPC protocol
15//
16
17/// Request ID type
18pub type RequestId = u64;
19
20/// JSON RPC version type
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
22#[serde(transparent)]
23pub struct JsonRpcVersion(String);
24
25impl Default for JsonRpcVersion {
26    fn default() -> Self {
27        JsonRpcVersion("2.0".to_owned())
28    }
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
32#[serde(deny_unknown_fields)]
33#[serde(untagged)]
34pub enum JsonRpcMessage {
35    Response(JsonRpcResponse),
36    Request(JsonRpcRequest),
37    Notification(JsonRpcNotification),
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
41#[serde(deny_unknown_fields)]
42pub struct JsonRpcRequest {
43    pub id: RequestId,
44    pub method: String,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub params: Option<Value>,
47    pub jsonrpc: JsonRpcVersion,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
51#[serde(deny_unknown_fields)]
52#[serde(rename_all = "camelCase")]
53#[serde(default)]
54pub struct JsonRpcResponse {
55    pub id: RequestId,
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub result: Option<Value>,
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub error: Option<JsonRpcError>,
60    pub jsonrpc: JsonRpcVersion,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
64#[serde(rename_all = "camelCase")]
65#[serde(default)]
66pub struct JsonRpcError {
67    pub code: i32,
68    pub message: String,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub data: Option<Value>,
71}
72
73//
74// Notification System
75// Types and implementations for the MCP notification system
76//
77
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
79#[serde(rename_all = "camelCase")]
80pub enum NotificationMethod {
81    #[serde(rename = "notifications/initialized")]
82    Initialized,
83    #[serde(rename = "notifications/progress")]
84    Progress,
85    #[serde(rename = "notifications/resources/updated")]
86    ResourceUpdated,
87    #[serde(rename = "notifications/resources/list_changed")]
88    ResourceListChanged,
89    #[serde(rename = "notifications/tools/list_changed")]
90    ToolListChanged,
91    #[serde(rename = "notifications/prompts/list_changed")]
92    PromptListChanged,
93    #[serde(rename = "notifications/message")]
94    LogMessage,
95}
96
97impl NotificationMethod {
98    pub fn as_str(&self) -> &'static str {
99        match self {
100            Self::Initialized => "notifications/initialized",
101            Self::Progress => "notifications/progress",
102            Self::ResourceUpdated => "notifications/resources/updated",
103            Self::ResourceListChanged => "notifications/resources/list_changed",
104            Self::ToolListChanged => "notifications/tools/list_changed",
105            Self::PromptListChanged => "notifications/prompts/list_changed",
106            Self::LogMessage => "notifications/message",
107        }
108    }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
112#[serde(rename_all = "camelCase")]
113pub struct ProgressParams {
114    pub progress_token: String,
115    pub progress: f64,
116    #[serde(skip_serializing_if = "Option::is_none")]
117    pub total: Option<f64>,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
121#[serde(rename_all = "camelCase")]
122pub struct ResourceUpdatedParams {
123    pub uri: String,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
127#[serde(rename_all = "camelCase")]
128pub enum LogLevel {
129    Debug,
130    Info,
131    Warning,
132    Error,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
136#[serde(rename_all = "camelCase")]
137pub struct LogMessageParams {
138    pub level: LogLevel,
139    pub data: Value,
140    #[serde(skip_serializing_if = "Option::is_none")]
141    pub logger: Option<String>,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
145#[serde(untagged)]
146pub enum NotificationParams {
147    Progress(ProgressParams),
148    ResourceUpdated(ResourceUpdatedParams),
149    LogMessage(LogMessageParams),
150    None,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
154#[serde(rename_all = "camelCase")]
155#[serde(deny_unknown_fields)]
156pub struct JsonRpcNotification {
157    #[serde(serialize_with = "serialize_notification_method")]
158    pub method: NotificationMethod,
159    #[serde(skip_serializing_if = "Option::is_none")]
160    pub params: Option<NotificationParams>,
161    pub jsonrpc: JsonRpcVersion,
162}
163
164fn serialize_notification_method<S>(
165    method: &NotificationMethod,
166    serializer: S,
167) -> Result<S::Ok, S::Error>
168where
169    S: serde::Serializer,
170{
171    serializer.serialize_str(method.as_str())
172}
173
174impl JsonRpcNotification {
175    pub fn new(method: NotificationMethod, params: Option<NotificationParams>) -> Self {
176        Self {
177            method,
178            params,
179            jsonrpc: JsonRpcVersion::default(),
180        }
181    }
182
183    pub fn initialized() -> Self {
184        Self::new(NotificationMethod::Initialized, None)
185    }
186
187    pub fn progress(token: impl Into<String>, progress: f64, total: Option<f64>) -> Self {
188        Self::new(
189            NotificationMethod::Progress,
190            Some(NotificationParams::Progress(ProgressParams {
191                progress_token: token.into(),
192                progress,
193                total,
194            })),
195        )
196    }
197
198    pub fn resource_updated(uri: impl Into<String>) -> Self {
199        Self::new(
200            NotificationMethod::ResourceUpdated,
201            Some(NotificationParams::ResourceUpdated(ResourceUpdatedParams {
202                uri: uri.into(),
203            })),
204        )
205    }
206
207    pub fn log_message(level: LogLevel, data: impl Into<Value>, logger: Option<String>) -> Self {
208        Self::new(
209            NotificationMethod::LogMessage,
210            Some(NotificationParams::LogMessage(LogMessageParams {
211                level,
212                data: data.into(),
213                logger,
214            })),
215        )
216    }
217}
218
219//
220// Message Handler
221// Implementation of the MCP message handling system
222//
223
224pub struct MessageHandler<T: Transport + 'static> {
225    transport: Arc<T>,
226    pending_requests: Arc<Mutex<HashMap<u64, Sender<JsonRpcResponse>>>>,
227    sampling_handler: Option<Arc<Box<dyn SamplingHandler + Send>>>,
228    notification_handler: Option<Arc<Box<dyn NotificationHandler + Send>>>,
229}
230
231impl<T: Transport + 'static> MessageHandler<T> {
232    pub fn new(
233        transport: Arc<T>,
234        pending_requests: Arc<Mutex<HashMap<u64, Sender<JsonRpcResponse>>>>,
235        sampling_handler: Option<Arc<Box<dyn SamplingHandler + Send>>>,
236        notification_handler: Option<Arc<Box<dyn NotificationHandler + Send>>>,
237    ) -> Self {
238        Self {
239            transport,
240            pending_requests,
241            sampling_handler,
242            notification_handler,
243        }
244    }
245
246    pub fn handle_message(&self, message: JsonRpcMessage) -> Result<(), McpError> {
247        match message {
248            JsonRpcMessage::Request(request) => self.handle_request(request),
249            JsonRpcMessage::Response(response) => self.handle_response(response),
250            JsonRpcMessage::Notification(notification) => self.handle_notification(notification),
251        }
252    }
253
254    pub fn handle_request(&self, request: JsonRpcRequest) -> Result<(), McpError> {
255        match request.method.as_str() {
256            "sampling/createMessage" => {
257                if let Some(handler) = &self.sampling_handler {
258                    self.handle_sampling_request(handler, &request)?;
259                }
260                Ok(())
261            }
262            _ => {
263                warn!("Unknown request method: {}", request.method);
264                Ok(())
265            }
266        }
267    }
268
269    pub fn handle_sampling_request(
270        &self,
271        handler: &Arc<Box<dyn SamplingHandler + Send>>,
272        request: &JsonRpcRequest,
273    ) -> Result<(), McpError> {
274        debug!("Processing sampling request...");
275        let params =
276            serde_json::from_value(request.clone().params.unwrap_or(serde_json::Value::Null))?;
277        let result = handler.handle_message(params)?;
278        let value = serde_json::to_value(result)?;
279
280        let response = JsonRpcResponse {
281            id: request.id,
282            result: Some(value),
283            error: None,
284            jsonrpc: Default::default(),
285        };
286
287        debug!("Sending sampling response...");
288        self.transport.send(&JsonRpcMessage::Response(response))?;
289        debug!("Sampling response sent");
290        Ok(())
291    }
292
293    pub fn handle_response(&self, response: JsonRpcResponse) -> Result<(), McpError> {
294        debug!("Got response with id: {}", response.id);
295        let mut pending = self.pending_requests.lock().unwrap();
296        debug!(
297            "Current pending request IDs: {:?}",
298            pending.keys().collect::<Vec<_>>()
299        );
300
301        if let Some(sender) = pending.remove(&response.id) {
302            sender
303                .send(response.clone())
304                .map_err(|e| McpError::SendError {
305                    id: response.id,
306                    source: e,
307                })?;
308        }
309        Ok(())
310    }
311
312    pub fn handle_notification(&self, notification: JsonRpcNotification) -> Result<(), McpError> {
313        debug!("Got notification: method={:?}", notification.method);
314
315        if let Some(handler) = &self.notification_handler {
316            match notification.method {
317                NotificationMethod::ResourceUpdated => {
318                    if let Some(NotificationParams::ResourceUpdated(params)) = notification.params {
319                        let url = Url::parse(&params.uri)?;
320                        handler.handle_resource_update(&url)?;
321                    }
322                }
323                NotificationMethod::LogMessage => {
324                    if let Some(NotificationParams::LogMessage(params)) = notification.params {
325                        handler.handle_log_message(&params.level, &params.data, &params.logger);
326                    }
327                }
328                NotificationMethod::Progress => {
329                    if let Some(NotificationParams::Progress(params)) = notification.params {
330                        handler.handle_progress_update(
331                            &params.progress_token,
332                            &params.progress,
333                            &params.total,
334                        );
335                    }
336                }
337                NotificationMethod::Initialized => {
338                    debug!("Server initialization completed");
339                    handler.handle_initialized();
340                }
341                NotificationMethod::ToolListChanged
342                | NotificationMethod::PromptListChanged
343                | NotificationMethod::ResourceListChanged => {
344                    debug!("List changed notification: {:?}", notification.method);
345                    handler.handle_list_changed(&notification.method);
346                }
347            }
348        } else {
349            debug!("Received notification but no handler is registered");
350        }
351        Ok(())
352    }
353}