Skip to main content

tower_mcp/middleware/
tracing.rs

1//! MCP request tracing middleware.
2//!
3//! This module provides [`McpTracingLayer`], a Tower middleware that logs
4//! structured information about MCP requests using the [`tracing`] crate.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use tower_mcp::{McpRouter, StdioTransport};
10//! use tower_mcp::middleware::McpTracingLayer;
11//!
12//! let router = McpRouter::new().server_info("my-server", "1.0.0");
13//!
14//! // Add tracing to all MCP requests
15//! let mut transport = StdioTransport::new(router)
16//!     .layer(McpTracingLayer::new());
17//! ```
18//!
19//! # Logged Information
20//!
21//! For each request, the layer logs:
22//! - Request method (e.g., `tools/call`, `resources/read`)
23//! - Request ID
24//! - Operation-specific details:
25//!   - Tool calls: tool name
26//!   - Resource reads: resource URI
27//!   - Prompt gets: prompt name
28//! - Request duration
29//! - Response status (success or error code)
30//!
31//! # Log Levels
32//!
33//! - `INFO`: Request start and completion
34//! - `DEBUG`: Detailed request/response information
35//! - `WARN`: Error responses
36
37use std::convert::Infallible;
38use std::future::Future;
39use std::pin::Pin;
40use std::task::{Context, Poll};
41use std::time::Instant;
42
43use tower::Layer;
44use tower_service::Service;
45use tracing::{Instrument, Level, Span};
46
47use crate::protocol::McpRequest;
48use crate::router::{RouterRequest, RouterResponse};
49
50/// Tower layer that adds structured tracing to MCP requests.
51///
52/// This layer wraps a service and logs information about each request
53/// using the [`tracing`] crate. It's designed to work with tower-mcp's
54/// `RouterRequest`/`RouterResponse` types.
55///
56/// # Example
57///
58/// ```rust,ignore
59/// use tower_mcp::{McpRouter, StdioTransport};
60/// use tower_mcp::middleware::McpTracingLayer;
61///
62/// let router = McpRouter::new().server_info("my-server", "1.0.0");
63///
64/// // Apply at the transport level for all requests
65/// let mut transport = StdioTransport::new(router)
66///     .layer(McpTracingLayer::new());
67///
68/// // Or apply to specific tools via ToolBuilder
69/// let tool = ToolBuilder::new("search")
70///     .handler(|input: SearchInput| async move { ... })
71///     .layer(McpTracingLayer::new())
72///     .build();
73/// ```
74#[derive(Debug, Clone, Copy)]
75pub struct McpTracingLayer {
76    level: Level,
77}
78
79impl Default for McpTracingLayer {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl McpTracingLayer {
86    /// Create a new tracing layer with default settings (INFO level).
87    pub fn new() -> Self {
88        Self { level: Level::INFO }
89    }
90
91    /// Set the log level for request/response logging.
92    ///
93    /// Default is `INFO`.
94    pub fn level(mut self, level: Level) -> Self {
95        self.level = level;
96        self
97    }
98}
99
100impl<S> Layer<S> for McpTracingLayer {
101    type Service = McpTracingService<S>;
102
103    fn layer(&self, inner: S) -> Self::Service {
104        McpTracingService {
105            inner,
106            level: self.level,
107        }
108    }
109}
110
111/// Tower service that adds tracing to MCP requests.
112///
113/// Created by [`McpTracingLayer`].
114#[derive(Debug, Clone)]
115pub struct McpTracingService<S> {
116    inner: S,
117    level: Level,
118}
119
120impl<S> Service<RouterRequest> for McpTracingService<S>
121where
122    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
123        + Clone
124        + Send
125        + 'static,
126    S::Future: Send,
127{
128    type Response = RouterResponse;
129    type Error = Infallible;
130    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
131
132    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
133        self.inner.poll_ready(cx)
134    }
135
136    fn call(&mut self, req: RouterRequest) -> Self::Future {
137        let method = req.inner.method_name().to_string();
138        let request_id = format!("{:?}", req.id);
139
140        // Extract operation-specific details
141        let (operation_name, operation_target) = extract_operation_details(&req.inner);
142
143        // Create the span based on the configured level
144        let span = create_span(
145            self.level,
146            &method,
147            &request_id,
148            operation_name,
149            operation_target,
150        );
151
152        let start = Instant::now();
153        let fut = self.inner.call(req);
154        let level = self.level;
155
156        Box::pin(
157            async move {
158                let result = fut.await;
159                let duration = start.elapsed();
160
161                match &result {
162                    Ok(response) => {
163                        let duration_ms = duration.as_secs_f64() * 1000.0;
164                        match &response.inner {
165                            Ok(_) => {
166                                log_success(level, &method, duration_ms);
167                            }
168                            Err(err) => {
169                                tracing::warn!(
170                                    method = %method,
171                                    error_code = err.code,
172                                    error_message = %err.message,
173                                    duration_ms = duration_ms,
174                                    "MCP request failed"
175                                );
176                            }
177                        }
178                    }
179                    Err(_) => {
180                        // Infallible, but handle for completeness
181                        tracing::error!(method = %method, "MCP request error (infallible)");
182                    }
183                }
184
185                result
186            }
187            .instrument(span),
188        )
189    }
190}
191
192/// Extract operation-specific name and target from the request.
193pub(crate) fn extract_operation_details(
194    req: &McpRequest,
195) -> (Option<&'static str>, Option<String>) {
196    match req {
197        McpRequest::CallTool(params) => (Some("tool"), Some(params.name.clone())),
198        McpRequest::ReadResource(params) => (Some("resource"), Some(params.uri.clone())),
199        McpRequest::GetPrompt(params) => (Some("prompt"), Some(params.name.clone())),
200        McpRequest::ListTools(_) => (Some("list"), Some("tools".to_string())),
201        McpRequest::ListResources(_) => (Some("list"), Some("resources".to_string())),
202        McpRequest::ListResourceTemplates(_) => {
203            (Some("list"), Some("resource_templates".to_string()))
204        }
205        McpRequest::ListPrompts(_) => (Some("list"), Some("prompts".to_string())),
206        McpRequest::SubscribeResource(params) => (Some("subscribe"), Some(params.uri.clone())),
207        McpRequest::UnsubscribeResource(params) => (Some("unsubscribe"), Some(params.uri.clone())),
208        McpRequest::ListTasks(_) => (Some("list"), Some("tasks".to_string())),
209        McpRequest::GetTaskInfo(params) => (Some("task"), Some(params.task_id.clone())),
210        McpRequest::GetTaskResult(params) => (Some("task_result"), Some(params.task_id.clone())),
211        McpRequest::CancelTask(params) => (Some("cancel"), Some(params.task_id.clone())),
212        McpRequest::Complete(params) => {
213            let ref_type = match &params.reference {
214                crate::protocol::CompletionReference::Resource { uri } => {
215                    format!("resource:{}", uri)
216                }
217                crate::protocol::CompletionReference::Prompt { name } => {
218                    format!("prompt:{}", name)
219                }
220                _ => "unknown".to_string(),
221            };
222            (Some("complete"), Some(ref_type))
223        }
224        McpRequest::SetLoggingLevel(params) => {
225            (Some("logging"), Some(format!("{:?}", params.level)))
226        }
227        McpRequest::Initialize(_) => (Some("init"), None),
228        McpRequest::Ping => (Some("ping"), None),
229        McpRequest::Unknown { method, .. } => (Some("unknown"), Some(method.clone())),
230        _ => (Some("unknown"), None),
231    }
232}
233
234/// Create a tracing span with the appropriate level.
235fn create_span(
236    level: Level,
237    method: &str,
238    request_id: &str,
239    operation_name: Option<&str>,
240    operation_target: Option<String>,
241) -> Span {
242    match level {
243        Level::TRACE => tracing::trace_span!(
244            "mcp_request",
245            method = %method,
246            request_id = %request_id,
247            operation = operation_name,
248            target = operation_target.as_deref(),
249        ),
250        Level::DEBUG => tracing::debug_span!(
251            "mcp_request",
252            method = %method,
253            request_id = %request_id,
254            operation = operation_name,
255            target = operation_target.as_deref(),
256        ),
257        Level::INFO => tracing::info_span!(
258            "mcp_request",
259            method = %method,
260            request_id = %request_id,
261            operation = operation_name,
262            target = operation_target.as_deref(),
263        ),
264        Level::WARN => tracing::warn_span!(
265            "mcp_request",
266            method = %method,
267            request_id = %request_id,
268            operation = operation_name,
269            target = operation_target.as_deref(),
270        ),
271        Level::ERROR => tracing::error_span!(
272            "mcp_request",
273            method = %method,
274            request_id = %request_id,
275            operation = operation_name,
276            target = operation_target.as_deref(),
277        ),
278    }
279}
280
281/// Log successful request completion at the configured level.
282fn log_success(level: Level, method: &str, duration_ms: f64) {
283    match level {
284        Level::TRACE => {
285            tracing::trace!(method = %method, duration_ms = duration_ms, "MCP request completed")
286        }
287        Level::DEBUG => {
288            tracing::debug!(method = %method, duration_ms = duration_ms, "MCP request completed")
289        }
290        Level::INFO => {
291            tracing::info!(method = %method, duration_ms = duration_ms, "MCP request completed")
292        }
293        Level::WARN => {
294            tracing::warn!(method = %method, duration_ms = duration_ms, "MCP request completed")
295        }
296        Level::ERROR => {
297            tracing::error!(method = %method, duration_ms = duration_ms, "MCP request completed")
298        }
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn test_layer_creation() {
308        let layer = McpTracingLayer::new();
309        assert_eq!(layer.level, Level::INFO);
310
311        let layer = McpTracingLayer::new().level(Level::DEBUG);
312        assert_eq!(layer.level, Level::DEBUG);
313    }
314
315    #[test]
316    fn test_extract_operation_details() {
317        use crate::protocol::{CallToolParams, GetPromptParams, ReadResourceParams};
318        use serde_json::Value;
319        use std::collections::HashMap;
320
321        // Test tool call
322        let req = McpRequest::CallTool(CallToolParams {
323            name: "my_tool".to_string(),
324            arguments: Value::Null,
325            meta: None,
326            task: None,
327        });
328        let (name, target) = extract_operation_details(&req);
329        assert_eq!(name, Some("tool"));
330        assert_eq!(target, Some("my_tool".to_string()));
331
332        // Test resource read
333        let req = McpRequest::ReadResource(ReadResourceParams {
334            uri: "file:///test.txt".to_string(),
335            meta: None,
336        });
337        let (name, target) = extract_operation_details(&req);
338        assert_eq!(name, Some("resource"));
339        assert_eq!(target, Some("file:///test.txt".to_string()));
340
341        // Test prompt get
342        let req = McpRequest::GetPrompt(GetPromptParams {
343            name: "my_prompt".to_string(),
344            arguments: HashMap::new(),
345            meta: None,
346        });
347        let (name, target) = extract_operation_details(&req);
348        assert_eq!(name, Some("prompt"));
349        assert_eq!(target, Some("my_prompt".to_string()));
350
351        // Test ping
352        let req = McpRequest::Ping;
353        let (name, target) = extract_operation_details(&req);
354        assert_eq!(name, Some("ping"));
355        assert_eq!(target, None);
356    }
357}