1use crate::transport::Transport;
2use crate::types::*;
3use crate::McpError;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::sync::mpsc::Sender;
8use std::sync::{Arc, Mutex};
9use tracing::{debug, warn};
10use url::Url;
11
12pub type RequestId = u64;
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
22#[serde(transparent)]
23pub struct JsonRpcVersion(String);
24
25impl Default for JsonRpcVersion {
26 fn default() -> Self {
27 JsonRpcVersion("2.0".to_owned())
28 }
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
32#[serde(deny_unknown_fields)]
33#[serde(untagged)]
34pub enum JsonRpcMessage {
35 Response(JsonRpcResponse),
36 Request(JsonRpcRequest),
37 Notification(JsonRpcNotification),
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
41#[serde(deny_unknown_fields)]
42pub struct JsonRpcRequest {
43 pub id: RequestId,
44 pub method: String,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub params: Option<Value>,
47 pub jsonrpc: JsonRpcVersion,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
51#[serde(deny_unknown_fields)]
52#[serde(rename_all = "camelCase")]
53#[serde(default)]
54pub struct JsonRpcResponse {
55 pub id: RequestId,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 pub result: Option<Value>,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub error: Option<JsonRpcError>,
60 pub jsonrpc: JsonRpcVersion,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
64#[serde(rename_all = "camelCase")]
65#[serde(default)]
66pub struct JsonRpcError {
67 pub code: i32,
68 pub message: String,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub data: Option<Value>,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
79#[serde(rename_all = "camelCase")]
80pub enum NotificationMethod {
81 #[serde(rename = "notifications/initialized")]
82 Initialized,
83 #[serde(rename = "notifications/progress")]
84 Progress,
85 #[serde(rename = "notifications/resources/updated")]
86 ResourceUpdated,
87 #[serde(rename = "notifications/resources/list_changed")]
88 ResourceListChanged,
89 #[serde(rename = "notifications/tools/list_changed")]
90 ToolListChanged,
91 #[serde(rename = "notifications/prompts/list_changed")]
92 PromptListChanged,
93 #[serde(rename = "notifications/message")]
94 LogMessage,
95}
96
97impl NotificationMethod {
98 pub fn as_str(&self) -> &'static str {
99 match self {
100 Self::Initialized => "notifications/initialized",
101 Self::Progress => "notifications/progress",
102 Self::ResourceUpdated => "notifications/resources/updated",
103 Self::ResourceListChanged => "notifications/resources/list_changed",
104 Self::ToolListChanged => "notifications/tools/list_changed",
105 Self::PromptListChanged => "notifications/prompts/list_changed",
106 Self::LogMessage => "notifications/message",
107 }
108 }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
112#[serde(rename_all = "camelCase")]
113pub struct ProgressParams {
114 pub progress_token: String,
115 pub progress: f64,
116 #[serde(skip_serializing_if = "Option::is_none")]
117 pub total: Option<f64>,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
121#[serde(rename_all = "camelCase")]
122pub struct ResourceUpdatedParams {
123 pub uri: String,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
127#[serde(rename_all = "camelCase")]
128pub enum LogLevel {
129 Debug,
130 Info,
131 Warning,
132 Error,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
136#[serde(rename_all = "camelCase")]
137pub struct LogMessageParams {
138 pub level: LogLevel,
139 pub data: Value,
140 #[serde(skip_serializing_if = "Option::is_none")]
141 pub logger: Option<String>,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
145#[serde(untagged)]
146pub enum NotificationParams {
147 Progress(ProgressParams),
148 ResourceUpdated(ResourceUpdatedParams),
149 LogMessage(LogMessageParams),
150 None,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
154#[serde(rename_all = "camelCase")]
155#[serde(deny_unknown_fields)]
156pub struct JsonRpcNotification {
157 #[serde(serialize_with = "serialize_notification_method")]
158 pub method: NotificationMethod,
159 #[serde(skip_serializing_if = "Option::is_none")]
160 pub params: Option<NotificationParams>,
161 pub jsonrpc: JsonRpcVersion,
162}
163
164fn serialize_notification_method<S>(
165 method: &NotificationMethod,
166 serializer: S,
167) -> Result<S::Ok, S::Error>
168where
169 S: serde::Serializer,
170{
171 serializer.serialize_str(method.as_str())
172}
173
174impl JsonRpcNotification {
175 pub fn new(method: NotificationMethod, params: Option<NotificationParams>) -> Self {
176 Self {
177 method,
178 params,
179 jsonrpc: JsonRpcVersion::default(),
180 }
181 }
182
183 pub fn initialized() -> Self {
184 Self::new(NotificationMethod::Initialized, None)
185 }
186
187 pub fn progress(token: impl Into<String>, progress: f64, total: Option<f64>) -> Self {
188 Self::new(
189 NotificationMethod::Progress,
190 Some(NotificationParams::Progress(ProgressParams {
191 progress_token: token.into(),
192 progress,
193 total,
194 })),
195 )
196 }
197
198 pub fn resource_updated(uri: impl Into<String>) -> Self {
199 Self::new(
200 NotificationMethod::ResourceUpdated,
201 Some(NotificationParams::ResourceUpdated(ResourceUpdatedParams {
202 uri: uri.into(),
203 })),
204 )
205 }
206
207 pub fn log_message(level: LogLevel, data: impl Into<Value>, logger: Option<String>) -> Self {
208 Self::new(
209 NotificationMethod::LogMessage,
210 Some(NotificationParams::LogMessage(LogMessageParams {
211 level,
212 data: data.into(),
213 logger,
214 })),
215 )
216 }
217}
218
219pub struct MessageHandler<T: Transport + 'static> {
225 transport: Arc<T>,
226 pending_requests: Arc<Mutex<HashMap<u64, Sender<JsonRpcResponse>>>>,
227 sampling_handler: Option<Arc<Box<dyn SamplingHandler + Send>>>,
228 notification_handler: Option<Arc<Box<dyn NotificationHandler + Send>>>,
229}
230
231impl<T: Transport + 'static> MessageHandler<T> {
232 pub fn new(
233 transport: Arc<T>,
234 pending_requests: Arc<Mutex<HashMap<u64, Sender<JsonRpcResponse>>>>,
235 sampling_handler: Option<Arc<Box<dyn SamplingHandler + Send>>>,
236 notification_handler: Option<Arc<Box<dyn NotificationHandler + Send>>>,
237 ) -> Self {
238 Self {
239 transport,
240 pending_requests,
241 sampling_handler,
242 notification_handler,
243 }
244 }
245
246 pub fn handle_message(&self, message: JsonRpcMessage) -> Result<(), McpError> {
247 match message {
248 JsonRpcMessage::Request(request) => self.handle_request(request),
249 JsonRpcMessage::Response(response) => self.handle_response(response),
250 JsonRpcMessage::Notification(notification) => self.handle_notification(notification),
251 }
252 }
253
254 pub fn handle_request(&self, request: JsonRpcRequest) -> Result<(), McpError> {
255 match request.method.as_str() {
256 "sampling/createMessage" => {
257 if let Some(handler) = &self.sampling_handler {
258 self.handle_sampling_request(handler, &request)?;
259 }
260 Ok(())
261 }
262 _ => {
263 warn!("Unknown request method: {}", request.method);
264 Ok(())
265 }
266 }
267 }
268
269 pub fn handle_sampling_request(
270 &self,
271 handler: &Arc<Box<dyn SamplingHandler + Send>>,
272 request: &JsonRpcRequest,
273 ) -> Result<(), McpError> {
274 debug!("Processing sampling request...");
275 let params =
276 serde_json::from_value(request.clone().params.unwrap_or(serde_json::Value::Null))?;
277 let result = handler.handle_message(params)?;
278 let value = serde_json::to_value(result)?;
279
280 let response = JsonRpcResponse {
281 id: request.id,
282 result: Some(value),
283 error: None,
284 jsonrpc: Default::default(),
285 };
286
287 debug!("Sending sampling response...");
288 self.transport.send(&JsonRpcMessage::Response(response))?;
289 debug!("Sampling response sent");
290 Ok(())
291 }
292
293 pub fn handle_response(&self, response: JsonRpcResponse) -> Result<(), McpError> {
294 debug!("Got response with id: {}", response.id);
295 let mut pending = self.pending_requests.lock().unwrap();
296 debug!(
297 "Current pending request IDs: {:?}",
298 pending.keys().collect::<Vec<_>>()
299 );
300
301 if let Some(sender) = pending.remove(&response.id) {
302 sender
303 .send(response.clone())
304 .map_err(|e| McpError::SendError {
305 id: response.id,
306 source: e,
307 })?;
308 }
309 Ok(())
310 }
311
312 pub fn handle_notification(&self, notification: JsonRpcNotification) -> Result<(), McpError> {
313 debug!("Got notification: method={:?}", notification.method);
314
315 if let Some(handler) = &self.notification_handler {
316 match notification.method {
317 NotificationMethod::ResourceUpdated => {
318 if let Some(NotificationParams::ResourceUpdated(params)) = notification.params {
319 let url = Url::parse(¶ms.uri)?;
320 handler.handle_resource_update(&url)?;
321 }
322 }
323 NotificationMethod::LogMessage => {
324 if let Some(NotificationParams::LogMessage(params)) = notification.params {
325 handler.handle_log_message(¶ms.level, ¶ms.data, ¶ms.logger);
326 }
327 }
328 NotificationMethod::Progress => {
329 if let Some(NotificationParams::Progress(params)) = notification.params {
330 handler.handle_progress_update(
331 ¶ms.progress_token,
332 ¶ms.progress,
333 ¶ms.total,
334 );
335 }
336 }
337 NotificationMethod::Initialized => {
338 debug!("Server initialization completed");
339 handler.handle_initialized();
340 }
341 NotificationMethod::ToolListChanged
342 | NotificationMethod::PromptListChanged
343 | NotificationMethod::ResourceListChanged => {
344 debug!("List changed notification: {:?}", notification.method);
345 handler.handle_list_changed(¬ification.method);
346 }
347 }
348 } else {
349 debug!("Received notification but no handler is registered");
350 }
351 Ok(())
352 }
353}