1use async_trait::async_trait;
58
59use crate::protocol::{
60 CreateMessageParams, CreateMessageResult, ElicitRequestParams, ElicitResult, ListRootsResult,
61 LogLevel, LoggingMessageParams, ProgressParams,
62};
63use tower_mcp_types::JsonRpcError;
64
65#[derive(Debug, Clone)]
70#[non_exhaustive]
71pub enum ServerNotification {
72 Progress(ProgressParams),
74 LogMessage(LoggingMessageParams),
76 ResourceUpdated {
78 uri: String,
80 },
81 ResourcesListChanged,
83 ToolsListChanged,
85 PromptsListChanged,
87 Unknown {
89 method: String,
91 params: Option<serde_json::Value>,
93 },
94}
95
96#[async_trait]
106pub trait ClientHandler: Send + Sync + 'static {
107 async fn handle_create_message(
114 &self,
115 _params: CreateMessageParams,
116 ) -> Result<CreateMessageResult, JsonRpcError> {
117 Err(JsonRpcError::method_not_found("sampling/createMessage"))
118 }
119
120 async fn handle_elicit(
126 &self,
127 _params: ElicitRequestParams,
128 ) -> Result<ElicitResult, JsonRpcError> {
129 Err(JsonRpcError::method_not_found("elicitation/create"))
130 }
131
132 async fn handle_list_roots(&self) -> Result<ListRootsResult, JsonRpcError> {
141 Ok(ListRootsResult {
142 roots: vec![],
143 meta: None,
144 })
145 }
146
147 async fn on_notification(&self, _notification: ServerNotification) {}
153}
154
155#[async_trait]
157impl ClientHandler for () {}
158
159type ProgressCallback = Box<dyn Fn(ProgressParams) + Send + Sync>;
161type LogMessageCallback = Box<dyn Fn(LoggingMessageParams) + Send + Sync>;
162type ResourceUpdatedCallback = Box<dyn Fn(String) + Send + Sync>;
163type SimpleCallback = Box<dyn Fn() + Send + Sync>;
164
165pub struct NotificationHandler {
186 on_progress: Option<ProgressCallback>,
187 on_log_message: Option<LogMessageCallback>,
188 on_resource_updated: Option<ResourceUpdatedCallback>,
189 on_resources_changed: Option<SimpleCallback>,
190 on_tools_changed: Option<SimpleCallback>,
191 on_prompts_changed: Option<SimpleCallback>,
192}
193
194impl NotificationHandler {
195 pub fn new() -> Self {
197 Self {
198 on_progress: None,
199 on_log_message: None,
200 on_resource_updated: None,
201 on_resources_changed: None,
202 on_tools_changed: None,
203 on_prompts_changed: None,
204 }
205 }
206
207 pub fn with_log_forwarding() -> Self {
216 Self::new().on_log_message(|msg| {
217 let logger = msg.logger.as_deref().unwrap_or("mcp");
218 match msg.level {
219 LogLevel::Emergency | LogLevel::Alert | LogLevel::Critical | LogLevel::Error => {
220 tracing::error!(logger = logger, "{}", msg.data);
221 }
222 LogLevel::Warning => {
223 tracing::warn!(logger = logger, "{}", msg.data);
224 }
225 LogLevel::Notice | LogLevel::Info => {
226 tracing::info!(logger = logger, "{}", msg.data);
227 }
228 LogLevel::Debug => {
229 tracing::debug!(logger = logger, "{}", msg.data);
230 }
231 _ => {
232 tracing::trace!(logger = logger, "{}", msg.data);
233 }
234 }
235 })
236 }
237
238 pub fn on_progress(mut self, f: impl Fn(ProgressParams) + Send + Sync + 'static) -> Self {
240 self.on_progress = Some(Box::new(f));
241 self
242 }
243
244 pub fn on_log_message(
246 mut self,
247 f: impl Fn(LoggingMessageParams) + Send + Sync + 'static,
248 ) -> Self {
249 self.on_log_message = Some(Box::new(f));
250 self
251 }
252
253 pub fn on_resource_updated(mut self, f: impl Fn(String) + Send + Sync + 'static) -> Self {
257 self.on_resource_updated = Some(Box::new(f));
258 self
259 }
260
261 pub fn on_resources_changed(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
263 self.on_resources_changed = Some(Box::new(f));
264 self
265 }
266
267 pub fn on_tools_changed(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
269 self.on_tools_changed = Some(Box::new(f));
270 self
271 }
272
273 pub fn on_prompts_changed(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
275 self.on_prompts_changed = Some(Box::new(f));
276 self
277 }
278}
279
280impl Default for NotificationHandler {
281 fn default() -> Self {
282 Self::new()
283 }
284}
285
286impl std::fmt::Debug for NotificationHandler {
287 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 f.debug_struct("NotificationHandler")
289 .field("on_progress", &self.on_progress.is_some())
290 .field("on_log_message", &self.on_log_message.is_some())
291 .field("on_resource_updated", &self.on_resource_updated.is_some())
292 .field("on_resources_changed", &self.on_resources_changed.is_some())
293 .field("on_tools_changed", &self.on_tools_changed.is_some())
294 .field("on_prompts_changed", &self.on_prompts_changed.is_some())
295 .finish()
296 }
297}
298
299#[async_trait]
300impl ClientHandler for NotificationHandler {
301 async fn on_notification(&self, notification: ServerNotification) {
302 match notification {
303 ServerNotification::Progress(params) => {
304 if let Some(cb) = &self.on_progress {
305 cb(params);
306 }
307 }
308 ServerNotification::LogMessage(params) => {
309 if let Some(cb) = &self.on_log_message {
310 cb(params);
311 }
312 }
313 ServerNotification::ResourceUpdated { uri } => {
314 if let Some(cb) = &self.on_resource_updated {
315 cb(uri);
316 }
317 }
318 ServerNotification::ResourcesListChanged => {
319 if let Some(cb) = &self.on_resources_changed {
320 cb();
321 }
322 }
323 ServerNotification::ToolsListChanged => {
324 if let Some(cb) = &self.on_tools_changed {
325 cb();
326 }
327 }
328 ServerNotification::PromptsListChanged => {
329 if let Some(cb) = &self.on_prompts_changed {
330 cb();
331 }
332 }
333 ServerNotification::Unknown { .. } => {}
334 }
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use std::sync::Arc;
342 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
343
344 #[tokio::test]
345 async fn test_notification_handler_progress() {
346 let called = Arc::new(AtomicBool::new(false));
347 let called_clone = called.clone();
348 let handler = NotificationHandler::new().on_progress(move |p| {
349 assert!((p.progress - 0.5).abs() < f64::EPSILON);
350 called_clone.store(true, Ordering::SeqCst);
351 });
352
353 handler
354 .on_notification(ServerNotification::Progress(ProgressParams {
355 progress_token: crate::protocol::ProgressToken::String("t1".into()),
356 progress: 0.5,
357 total: Some(1.0),
358 message: None,
359 meta: None,
360 }))
361 .await;
362
363 assert!(called.load(Ordering::SeqCst));
364 }
365
366 #[tokio::test]
367 async fn test_notification_handler_log_message() {
368 let called = Arc::new(AtomicBool::new(false));
369 let called_clone = called.clone();
370 let handler = NotificationHandler::new().on_log_message(move |msg| {
371 assert_eq!(msg.level, LogLevel::Info);
372 called_clone.store(true, Ordering::SeqCst);
373 });
374
375 handler
376 .on_notification(ServerNotification::LogMessage(LoggingMessageParams {
377 level: LogLevel::Info,
378 logger: Some("test".into()),
379 data: serde_json::json!("hello"),
380 meta: None,
381 }))
382 .await;
383
384 assert!(called.load(Ordering::SeqCst));
385 }
386
387 #[tokio::test]
388 async fn test_notification_handler_resource_updated() {
389 let called = Arc::new(AtomicBool::new(false));
390 let called_clone = called.clone();
391 let handler = NotificationHandler::new().on_resource_updated(move |uri| {
392 assert_eq!(uri, "file:///test.txt");
393 called_clone.store(true, Ordering::SeqCst);
394 });
395
396 handler
397 .on_notification(ServerNotification::ResourceUpdated {
398 uri: "file:///test.txt".to_string(),
399 })
400 .await;
401
402 assert!(called.load(Ordering::SeqCst));
403 }
404
405 #[tokio::test]
406 async fn test_notification_handler_list_changed() {
407 let tools_count = Arc::new(AtomicUsize::new(0));
408 let resources_count = Arc::new(AtomicUsize::new(0));
409 let prompts_count = Arc::new(AtomicUsize::new(0));
410
411 let tc = tools_count.clone();
412 let rc = resources_count.clone();
413 let pc = prompts_count.clone();
414
415 let handler = NotificationHandler::new()
416 .on_tools_changed(move || {
417 tc.fetch_add(1, Ordering::SeqCst);
418 })
419 .on_resources_changed(move || {
420 rc.fetch_add(1, Ordering::SeqCst);
421 })
422 .on_prompts_changed(move || {
423 pc.fetch_add(1, Ordering::SeqCst);
424 });
425
426 handler
427 .on_notification(ServerNotification::ToolsListChanged)
428 .await;
429 handler
430 .on_notification(ServerNotification::ResourcesListChanged)
431 .await;
432 handler
433 .on_notification(ServerNotification::PromptsListChanged)
434 .await;
435
436 assert_eq!(tools_count.load(Ordering::SeqCst), 1);
437 assert_eq!(resources_count.load(Ordering::SeqCst), 1);
438 assert_eq!(prompts_count.load(Ordering::SeqCst), 1);
439 }
440
441 #[tokio::test]
442 async fn test_notification_handler_unset_callbacks_are_noop() {
443 let handler = NotificationHandler::new();
445
446 handler
447 .on_notification(ServerNotification::ToolsListChanged)
448 .await;
449 handler
450 .on_notification(ServerNotification::Progress(ProgressParams {
451 progress_token: crate::protocol::ProgressToken::String("t".into()),
452 progress: 1.0,
453 total: None,
454 message: None,
455 meta: None,
456 }))
457 .await;
458 handler
459 .on_notification(ServerNotification::LogMessage(LoggingMessageParams {
460 level: LogLevel::Debug,
461 logger: None,
462 data: serde_json::json!("test"),
463 meta: None,
464 }))
465 .await;
466 handler
467 .on_notification(ServerNotification::Unknown {
468 method: "custom/thing".into(),
469 params: None,
470 })
471 .await;
472 }
473
474 #[tokio::test]
475 async fn test_notification_handler_rejects_requests() {
476 use crate::protocol::{ElicitFormParams, ElicitFormSchema};
477
478 let handler = NotificationHandler::new();
479
480 let params = serde_json::from_value::<CreateMessageParams>(serde_json::json!({
481 "messages": [],
482 "maxTokens": 100
483 }))
484 .unwrap();
485 let err = handler.handle_create_message(params).await.unwrap_err();
486 assert_eq!(err.code, -32601); let err = handler
489 .handle_elicit(ElicitRequestParams::Form(ElicitFormParams {
490 mode: None,
491 message: "test".into(),
492 requested_schema: ElicitFormSchema {
493 schema_type: "object".into(),
494 properties: Default::default(),
495 required: vec![],
496 },
497 meta: None,
498 }))
499 .await
500 .unwrap_err();
501 assert_eq!(err.code, -32601);
502 }
503
504 #[test]
505 fn test_notification_handler_debug() {
506 let handler = NotificationHandler::new().on_progress(|_| {});
507 let debug = format!("{:?}", handler);
508 assert!(debug.contains("on_progress: true"));
509 assert!(debug.contains("on_log_message: false"));
510 }
511
512 #[test]
513 fn test_notification_handler_default() {
514 let _handler = NotificationHandler::default();
515 }
516}