Skip to main content

tower_mcp/client/
handler.rs

1//! Handler trait for server-initiated requests and notifications.
2//!
3//! The [`ClientHandler`] trait defines how the client responds to requests
4//! and notifications sent by the server. All methods have default
5//! implementations, so you only need to override the ones you care about.
6//!
7//! The unit type `()` implements this trait with all defaults, which is
8//! used by [`McpClient::connect()`](super::McpClient::connect).
9//!
10//! # Notification Handler
11//!
12//! For notification-only use cases, [`NotificationHandler`] provides a
13//! builder-based alternative to implementing the full trait:
14//!
15//! ```rust
16//! use tower_mcp::client::NotificationHandler;
17//!
18//! let handler = NotificationHandler::new()
19//!     .on_tools_changed(|| {
20//!         println!("Tools changed, re-fetching...");
21//!     })
22//!     .on_log_message(|msg| {
23//!         println!("[{}] {}", msg.level, msg.data);
24//!     });
25//! ```
26//!
27//! For forwarding MCP log messages to the [`tracing`] crate:
28//!
29//! ```rust
30//! use tower_mcp::client::NotificationHandler;
31//!
32//! let handler = NotificationHandler::with_log_forwarding();
33//! ```
34//!
35//! # Custom Handler
36//!
37//! ```rust,ignore
38//! use async_trait::async_trait;
39//! use tower_mcp::client::ClientHandler;
40//! use tower_mcp::protocol::{CreateMessageParams, CreateMessageResult};
41//! use tower_mcp_types::JsonRpcError;
42//!
43//! struct MySamplingHandler;
44//!
45//! #[async_trait]
46//! impl ClientHandler for MySamplingHandler {
47//!     async fn handle_create_message(
48//!         &self,
49//!         params: CreateMessageParams,
50//!     ) -> Result<CreateMessageResult, JsonRpcError> {
51//!         // Forward to your LLM and return the result
52//!         todo!()
53//!     }
54//! }
55//! ```
56
57use async_trait::async_trait;
58
59use crate::protocol::{
60    CreateMessageParams, CreateMessageResult, ElicitRequestParams, ElicitResult, ListRootsResult,
61    LogLevel, LoggingMessageParams, ProgressParams,
62};
63use tower_mcp_types::JsonRpcError;
64
65/// Notification sent from the server to the client.
66///
67/// These correspond to the `notifications/` methods defined in the MCP spec
68/// that flow from server to client.
69#[derive(Debug, Clone)]
70#[non_exhaustive]
71pub enum ServerNotification {
72    /// Progress update for a request (`notifications/progress`).
73    Progress(ProgressParams),
74    /// Log message (`notifications/message`).
75    LogMessage(LoggingMessageParams),
76    /// A subscribed resource has been updated (`notifications/resources/updated`).
77    ResourceUpdated {
78        /// The URI of the updated resource.
79        uri: String,
80    },
81    /// The list of available resources has changed.
82    ResourcesListChanged,
83    /// The list of available tools has changed.
84    ToolsListChanged,
85    /// The list of available prompts has changed.
86    PromptsListChanged,
87    /// An unknown or unrecognized notification.
88    Unknown {
89        /// The notification method name.
90        method: String,
91        /// The notification parameters, if any.
92        params: Option<serde_json::Value>,
93    },
94}
95
96/// Handler for server-initiated requests and notifications.
97///
98/// Implement this trait to handle sampling requests, elicitation requests,
99/// roots listing, and server notifications. All methods have default
100/// implementations that either return sensible defaults or reject with
101/// `method_not_found`.
102///
103/// The unit type `()` implements this trait with all defaults, which is
104/// used by [`McpClient::connect()`](super::McpClient::connect).
105#[async_trait]
106pub trait ClientHandler: Send + Sync + 'static {
107    /// Handle a `sampling/createMessage` request from the server.
108    ///
109    /// The server is asking the client to perform LLM inference. Override
110    /// this to forward the request to your LLM provider.
111    ///
112    /// Default: returns `method_not_found` error.
113    async fn handle_create_message(
114        &self,
115        _params: CreateMessageParams,
116    ) -> Result<CreateMessageResult, JsonRpcError> {
117        Err(JsonRpcError::method_not_found("sampling/createMessage"))
118    }
119
120    /// Handle an `elicitation/create` request from the server.
121    ///
122    /// The server is asking the client for user input (form data or URL).
123    ///
124    /// Default: returns `method_not_found` error.
125    async fn handle_elicit(
126        &self,
127        _params: ElicitRequestParams,
128    ) -> Result<ElicitResult, JsonRpcError> {
129        Err(JsonRpcError::method_not_found("elicitation/create"))
130    }
131
132    /// Handle a `roots/list` request from the server.
133    ///
134    /// The server is asking which filesystem roots the client has access to.
135    /// If roots were configured on the [`McpClient`](super::McpClient) via
136    /// the builder, those are returned automatically before this method
137    /// is called.
138    ///
139    /// Default: returns an empty list.
140    async fn handle_list_roots(&self) -> Result<ListRootsResult, JsonRpcError> {
141        Ok(ListRootsResult {
142            roots: vec![],
143            meta: None,
144        })
145    }
146
147    /// Called when the server sends a notification.
148    ///
149    /// Override to handle progress updates, log messages, resource changes, etc.
150    ///
151    /// Default: no-op.
152    async fn on_notification(&self, _notification: ServerNotification) {}
153}
154
155/// Unit type implements [`ClientHandler`] with all defaults.
156#[async_trait]
157impl ClientHandler for () {}
158
159// Type aliases for notification callback boxes.
160type ProgressCallback = Box<dyn Fn(ProgressParams) + Send + Sync>;
161type LogMessageCallback = Box<dyn Fn(LoggingMessageParams) + Send + Sync>;
162type ResourceUpdatedCallback = Box<dyn Fn(String) + Send + Sync>;
163type SimpleCallback = Box<dyn Fn() + Send + Sync>;
164
165/// Callback-based handler for server notifications.
166///
167/// Provides typed callback registration for each notification type,
168/// without requiring a full [`ClientHandler`] trait implementation.
169/// Server-initiated requests (sampling, elicitation, roots) are
170/// rejected with `method_not_found`.
171///
172/// # Example
173///
174/// ```rust
175/// use tower_mcp::client::NotificationHandler;
176///
177/// let handler = NotificationHandler::new()
178///     .on_progress(|p| {
179///         println!("Progress: {}/{}", p.progress, p.total.unwrap_or(1.0));
180///     })
181///     .on_tools_changed(|| {
182///         println!("Server tools changed!");
183///     });
184/// ```
185pub struct NotificationHandler {
186    on_progress: Option<ProgressCallback>,
187    on_log_message: Option<LogMessageCallback>,
188    on_resource_updated: Option<ResourceUpdatedCallback>,
189    on_resources_changed: Option<SimpleCallback>,
190    on_tools_changed: Option<SimpleCallback>,
191    on_prompts_changed: Option<SimpleCallback>,
192}
193
194impl NotificationHandler {
195    /// Create a new handler with no callbacks registered.
196    pub fn new() -> Self {
197        Self {
198            on_progress: None,
199            on_log_message: None,
200            on_resource_updated: None,
201            on_resources_changed: None,
202            on_tools_changed: None,
203            on_prompts_changed: None,
204        }
205    }
206
207    /// Create a handler that forwards MCP log messages to [`tracing`].
208    ///
209    /// Maps MCP log levels to tracing levels:
210    /// - Emergency, Alert, Critical -> `error!`
211    /// - Error -> `error!`
212    /// - Warning -> `warn!`
213    /// - Notice, Info -> `info!`
214    /// - Debug -> `debug!`
215    pub fn with_log_forwarding() -> Self {
216        Self::new().on_log_message(|msg| {
217            let logger = msg.logger.as_deref().unwrap_or("mcp");
218            match msg.level {
219                LogLevel::Emergency | LogLevel::Alert | LogLevel::Critical | LogLevel::Error => {
220                    tracing::error!(logger = logger, "{}", msg.data);
221                }
222                LogLevel::Warning => {
223                    tracing::warn!(logger = logger, "{}", msg.data);
224                }
225                LogLevel::Notice | LogLevel::Info => {
226                    tracing::info!(logger = logger, "{}", msg.data);
227                }
228                LogLevel::Debug => {
229                    tracing::debug!(logger = logger, "{}", msg.data);
230                }
231                _ => {
232                    tracing::trace!(logger = logger, "{}", msg.data);
233                }
234            }
235        })
236    }
237
238    /// Register a callback for progress notifications.
239    pub fn on_progress(mut self, f: impl Fn(ProgressParams) + Send + Sync + 'static) -> Self {
240        self.on_progress = Some(Box::new(f));
241        self
242    }
243
244    /// Register a callback for log message notifications.
245    pub fn on_log_message(
246        mut self,
247        f: impl Fn(LoggingMessageParams) + Send + Sync + 'static,
248    ) -> Self {
249        self.on_log_message = Some(Box::new(f));
250        self
251    }
252
253    /// Register a callback for resource updated notifications.
254    ///
255    /// The callback receives the URI of the updated resource.
256    pub fn on_resource_updated(mut self, f: impl Fn(String) + Send + Sync + 'static) -> Self {
257        self.on_resource_updated = Some(Box::new(f));
258        self
259    }
260
261    /// Register a callback for resources list changed notifications.
262    pub fn on_resources_changed(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
263        self.on_resources_changed = Some(Box::new(f));
264        self
265    }
266
267    /// Register a callback for tools list changed notifications.
268    pub fn on_tools_changed(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
269        self.on_tools_changed = Some(Box::new(f));
270        self
271    }
272
273    /// Register a callback for prompts list changed notifications.
274    pub fn on_prompts_changed(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
275        self.on_prompts_changed = Some(Box::new(f));
276        self
277    }
278}
279
280impl Default for NotificationHandler {
281    fn default() -> Self {
282        Self::new()
283    }
284}
285
286impl std::fmt::Debug for NotificationHandler {
287    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288        f.debug_struct("NotificationHandler")
289            .field("on_progress", &self.on_progress.is_some())
290            .field("on_log_message", &self.on_log_message.is_some())
291            .field("on_resource_updated", &self.on_resource_updated.is_some())
292            .field("on_resources_changed", &self.on_resources_changed.is_some())
293            .field("on_tools_changed", &self.on_tools_changed.is_some())
294            .field("on_prompts_changed", &self.on_prompts_changed.is_some())
295            .finish()
296    }
297}
298
299#[async_trait]
300impl ClientHandler for NotificationHandler {
301    async fn on_notification(&self, notification: ServerNotification) {
302        match notification {
303            ServerNotification::Progress(params) => {
304                if let Some(cb) = &self.on_progress {
305                    cb(params);
306                }
307            }
308            ServerNotification::LogMessage(params) => {
309                if let Some(cb) = &self.on_log_message {
310                    cb(params);
311                }
312            }
313            ServerNotification::ResourceUpdated { uri } => {
314                if let Some(cb) = &self.on_resource_updated {
315                    cb(uri);
316                }
317            }
318            ServerNotification::ResourcesListChanged => {
319                if let Some(cb) = &self.on_resources_changed {
320                    cb();
321                }
322            }
323            ServerNotification::ToolsListChanged => {
324                if let Some(cb) = &self.on_tools_changed {
325                    cb();
326                }
327            }
328            ServerNotification::PromptsListChanged => {
329                if let Some(cb) = &self.on_prompts_changed {
330                    cb();
331                }
332            }
333            ServerNotification::Unknown { .. } => {}
334        }
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use std::sync::Arc;
342    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
343
344    #[tokio::test]
345    async fn test_notification_handler_progress() {
346        let called = Arc::new(AtomicBool::new(false));
347        let called_clone = called.clone();
348        let handler = NotificationHandler::new().on_progress(move |p| {
349            assert!((p.progress - 0.5).abs() < f64::EPSILON);
350            called_clone.store(true, Ordering::SeqCst);
351        });
352
353        handler
354            .on_notification(ServerNotification::Progress(ProgressParams {
355                progress_token: crate::protocol::ProgressToken::String("t1".into()),
356                progress: 0.5,
357                total: Some(1.0),
358                message: None,
359                meta: None,
360            }))
361            .await;
362
363        assert!(called.load(Ordering::SeqCst));
364    }
365
366    #[tokio::test]
367    async fn test_notification_handler_log_message() {
368        let called = Arc::new(AtomicBool::new(false));
369        let called_clone = called.clone();
370        let handler = NotificationHandler::new().on_log_message(move |msg| {
371            assert_eq!(msg.level, LogLevel::Info);
372            called_clone.store(true, Ordering::SeqCst);
373        });
374
375        handler
376            .on_notification(ServerNotification::LogMessage(LoggingMessageParams {
377                level: LogLevel::Info,
378                logger: Some("test".into()),
379                data: serde_json::json!("hello"),
380                meta: None,
381            }))
382            .await;
383
384        assert!(called.load(Ordering::SeqCst));
385    }
386
387    #[tokio::test]
388    async fn test_notification_handler_resource_updated() {
389        let called = Arc::new(AtomicBool::new(false));
390        let called_clone = called.clone();
391        let handler = NotificationHandler::new().on_resource_updated(move |uri| {
392            assert_eq!(uri, "file:///test.txt");
393            called_clone.store(true, Ordering::SeqCst);
394        });
395
396        handler
397            .on_notification(ServerNotification::ResourceUpdated {
398                uri: "file:///test.txt".to_string(),
399            })
400            .await;
401
402        assert!(called.load(Ordering::SeqCst));
403    }
404
405    #[tokio::test]
406    async fn test_notification_handler_list_changed() {
407        let tools_count = Arc::new(AtomicUsize::new(0));
408        let resources_count = Arc::new(AtomicUsize::new(0));
409        let prompts_count = Arc::new(AtomicUsize::new(0));
410
411        let tc = tools_count.clone();
412        let rc = resources_count.clone();
413        let pc = prompts_count.clone();
414
415        let handler = NotificationHandler::new()
416            .on_tools_changed(move || {
417                tc.fetch_add(1, Ordering::SeqCst);
418            })
419            .on_resources_changed(move || {
420                rc.fetch_add(1, Ordering::SeqCst);
421            })
422            .on_prompts_changed(move || {
423                pc.fetch_add(1, Ordering::SeqCst);
424            });
425
426        handler
427            .on_notification(ServerNotification::ToolsListChanged)
428            .await;
429        handler
430            .on_notification(ServerNotification::ResourcesListChanged)
431            .await;
432        handler
433            .on_notification(ServerNotification::PromptsListChanged)
434            .await;
435
436        assert_eq!(tools_count.load(Ordering::SeqCst), 1);
437        assert_eq!(resources_count.load(Ordering::SeqCst), 1);
438        assert_eq!(prompts_count.load(Ordering::SeqCst), 1);
439    }
440
441    #[tokio::test]
442    async fn test_notification_handler_unset_callbacks_are_noop() {
443        // Handler with no callbacks should not panic
444        let handler = NotificationHandler::new();
445
446        handler
447            .on_notification(ServerNotification::ToolsListChanged)
448            .await;
449        handler
450            .on_notification(ServerNotification::Progress(ProgressParams {
451                progress_token: crate::protocol::ProgressToken::String("t".into()),
452                progress: 1.0,
453                total: None,
454                message: None,
455                meta: None,
456            }))
457            .await;
458        handler
459            .on_notification(ServerNotification::LogMessage(LoggingMessageParams {
460                level: LogLevel::Debug,
461                logger: None,
462                data: serde_json::json!("test"),
463                meta: None,
464            }))
465            .await;
466        handler
467            .on_notification(ServerNotification::Unknown {
468                method: "custom/thing".into(),
469                params: None,
470            })
471            .await;
472    }
473
474    #[tokio::test]
475    async fn test_notification_handler_rejects_requests() {
476        use crate::protocol::{ElicitFormParams, ElicitFormSchema};
477
478        let handler = NotificationHandler::new();
479
480        let params = serde_json::from_value::<CreateMessageParams>(serde_json::json!({
481            "messages": [],
482            "maxTokens": 100
483        }))
484        .unwrap();
485        let err = handler.handle_create_message(params).await.unwrap_err();
486        assert_eq!(err.code, -32601); // method_not_found
487
488        let err = handler
489            .handle_elicit(ElicitRequestParams::Form(ElicitFormParams {
490                mode: None,
491                message: "test".into(),
492                requested_schema: ElicitFormSchema {
493                    schema_type: "object".into(),
494                    properties: Default::default(),
495                    required: vec![],
496                },
497                meta: None,
498            }))
499            .await
500            .unwrap_err();
501        assert_eq!(err.code, -32601);
502    }
503
504    #[test]
505    fn test_notification_handler_debug() {
506        let handler = NotificationHandler::new().on_progress(|_| {});
507        let debug = format!("{:?}", handler);
508        assert!(debug.contains("on_progress: true"));
509        assert!(debug.contains("on_log_message: false"));
510    }
511
512    #[test]
513    fn test_notification_handler_default() {
514        let _handler = NotificationHandler::default();
515    }
516}