Skip to main content

tower_mcp/
tracing_layer.rs

1//! MCP request tracing middleware.
2//!
3//! This module provides [`McpTracingLayer`], a Tower middleware that logs
4//! structured information about MCP requests using the [`tracing`] crate.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use tower_mcp::{McpRouter, McpTracingLayer, HttpTransport};
10//!
11//! let router = McpRouter::new().server_info("my-server", "1.0.0");
12//!
13//! // Add tracing to all MCP requests
14//! let transport = HttpTransport::new(router)
15//!     .layer(McpTracingLayer::new());
16//! ```
17//!
18//! # Logged Information
19//!
20//! For each request, the layer logs:
21//! - Request method (e.g., `tools/call`, `resources/read`)
22//! - Request ID
23//! - Operation-specific details:
24//!   - Tool calls: tool name
25//!   - Resource reads: resource URI
26//!   - Prompt gets: prompt name
27//! - Request duration
28//! - Response status (success or error code)
29//!
30//! # Log Levels
31//!
32//! - `INFO`: Request start and completion
33//! - `DEBUG`: Detailed request/response information
34//! - `WARN`: Error responses
35
36use std::convert::Infallible;
37use std::future::Future;
38use std::pin::Pin;
39use std::task::{Context, Poll};
40use std::time::Instant;
41
42use tower::Layer;
43use tower_service::Service;
44use tracing::{Instrument, Level, Span};
45
46use crate::protocol::McpRequest;
47use crate::router::{RouterRequest, RouterResponse};
48
49/// Tower layer that adds structured tracing to MCP requests.
50///
51/// This layer wraps a service and logs information about each request
52/// using the [`tracing`] crate. It's designed to work with tower-mcp's
53/// `RouterRequest`/`RouterResponse` types.
54///
55/// # Example
56///
57/// ```rust,ignore
58/// use tower_mcp::{McpRouter, McpTracingLayer, HttpTransport};
59///
60/// let router = McpRouter::new().server_info("my-server", "1.0.0");
61///
62/// // Apply at the transport level for all requests
63/// let transport = HttpTransport::new(router)
64///     .layer(McpTracingLayer::new());
65///
66/// // Or apply to specific tools via ToolBuilder
67/// let tool = ToolBuilder::new("search")
68///     .handler(|input: SearchInput| async move { ... })
69///     .layer(McpTracingLayer::new())
70///     .build()?;
71/// ```
72#[derive(Debug, Clone, Copy)]
73pub struct McpTracingLayer {
74    level: Level,
75}
76
77impl Default for McpTracingLayer {
78    fn default() -> Self {
79        Self::new()
80    }
81}
82
83impl McpTracingLayer {
84    /// Create a new tracing layer with default settings (INFO level).
85    pub fn new() -> Self {
86        Self { level: Level::INFO }
87    }
88
89    /// Set the log level for request/response logging.
90    ///
91    /// Default is `INFO`.
92    pub fn level(mut self, level: Level) -> Self {
93        self.level = level;
94        self
95    }
96}
97
98impl<S> Layer<S> for McpTracingLayer {
99    type Service = McpTracingService<S>;
100
101    fn layer(&self, inner: S) -> Self::Service {
102        McpTracingService {
103            inner,
104            level: self.level,
105        }
106    }
107}
108
109/// Tower service that adds tracing to MCP requests.
110///
111/// Created by [`McpTracingLayer`].
112#[derive(Debug, Clone)]
113pub struct McpTracingService<S> {
114    inner: S,
115    level: Level,
116}
117
118impl<S> Service<RouterRequest> for McpTracingService<S>
119where
120    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
121        + Clone
122        + Send
123        + 'static,
124    S::Future: Send,
125{
126    type Response = RouterResponse;
127    type Error = Infallible;
128    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
129
130    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
131        self.inner.poll_ready(cx)
132    }
133
134    fn call(&mut self, req: RouterRequest) -> Self::Future {
135        let method = req.inner.method_name().to_string();
136        let request_id = format!("{:?}", req.id);
137
138        // Extract operation-specific details
139        let (operation_name, operation_target) = extract_operation_details(&req.inner);
140
141        // Create the span based on the configured level
142        let span = create_span(
143            self.level,
144            &method,
145            &request_id,
146            operation_name,
147            operation_target,
148        );
149
150        let start = Instant::now();
151        let fut = self.inner.call(req);
152        let level = self.level;
153
154        Box::pin(
155            async move {
156                let result = fut.await;
157                let duration = start.elapsed();
158
159                match &result {
160                    Ok(response) => {
161                        let duration_ms = duration.as_secs_f64() * 1000.0;
162                        match &response.inner {
163                            Ok(_) => {
164                                log_success(level, &method, duration_ms);
165                            }
166                            Err(err) => {
167                                tracing::warn!(
168                                    method = %method,
169                                    error_code = err.code,
170                                    error_message = %err.message,
171                                    duration_ms = duration_ms,
172                                    "MCP request failed"
173                                );
174                            }
175                        }
176                    }
177                    Err(_) => {
178                        // Infallible, but handle for completeness
179                        tracing::error!(method = %method, "MCP request error (infallible)");
180                    }
181                }
182
183                result
184            }
185            .instrument(span),
186        )
187    }
188}
189
190/// Extract operation-specific name and target from the request.
191fn extract_operation_details(req: &McpRequest) -> (Option<&'static str>, Option<String>) {
192    match req {
193        McpRequest::CallTool(params) => (Some("tool"), Some(params.name.clone())),
194        McpRequest::ReadResource(params) => (Some("resource"), Some(params.uri.clone())),
195        McpRequest::GetPrompt(params) => (Some("prompt"), Some(params.name.clone())),
196        McpRequest::ListTools(_) => (Some("list"), Some("tools".to_string())),
197        McpRequest::ListResources(_) => (Some("list"), Some("resources".to_string())),
198        McpRequest::ListResourceTemplates(_) => {
199            (Some("list"), Some("resource_templates".to_string()))
200        }
201        McpRequest::ListPrompts(_) => (Some("list"), Some("prompts".to_string())),
202        McpRequest::SubscribeResource(params) => (Some("subscribe"), Some(params.uri.clone())),
203        McpRequest::UnsubscribeResource(params) => (Some("unsubscribe"), Some(params.uri.clone())),
204        McpRequest::EnqueueTask(params) => (Some("task"), Some(params.tool_name.clone())),
205        McpRequest::ListTasks(_) => (Some("list"), Some("tasks".to_string())),
206        McpRequest::GetTaskInfo(params) => (Some("task"), Some(params.task_id.clone())),
207        McpRequest::GetTaskResult(params) => (Some("task_result"), Some(params.task_id.clone())),
208        McpRequest::CancelTask(params) => (Some("cancel"), Some(params.task_id.clone())),
209        McpRequest::Complete(params) => {
210            let ref_type = match &params.reference {
211                crate::protocol::CompletionReference::Resource { uri } => {
212                    format!("resource:{}", uri)
213                }
214                crate::protocol::CompletionReference::Prompt { name } => {
215                    format!("prompt:{}", name)
216                }
217            };
218            (Some("complete"), Some(ref_type))
219        }
220        McpRequest::SetLoggingLevel(params) => {
221            (Some("logging"), Some(format!("{:?}", params.level)))
222        }
223        McpRequest::Initialize(_) => (Some("init"), None),
224        McpRequest::Ping => (Some("ping"), None),
225        McpRequest::Unknown { method, .. } => (Some("unknown"), Some(method.clone())),
226    }
227}
228
229/// Create a tracing span with the appropriate level.
230fn create_span(
231    level: Level,
232    method: &str,
233    request_id: &str,
234    operation_name: Option<&str>,
235    operation_target: Option<String>,
236) -> Span {
237    match level {
238        Level::TRACE => tracing::trace_span!(
239            "mcp_request",
240            method = %method,
241            request_id = %request_id,
242            operation = operation_name,
243            target = operation_target.as_deref(),
244        ),
245        Level::DEBUG => tracing::debug_span!(
246            "mcp_request",
247            method = %method,
248            request_id = %request_id,
249            operation = operation_name,
250            target = operation_target.as_deref(),
251        ),
252        Level::INFO => tracing::info_span!(
253            "mcp_request",
254            method = %method,
255            request_id = %request_id,
256            operation = operation_name,
257            target = operation_target.as_deref(),
258        ),
259        Level::WARN => tracing::warn_span!(
260            "mcp_request",
261            method = %method,
262            request_id = %request_id,
263            operation = operation_name,
264            target = operation_target.as_deref(),
265        ),
266        Level::ERROR => tracing::error_span!(
267            "mcp_request",
268            method = %method,
269            request_id = %request_id,
270            operation = operation_name,
271            target = operation_target.as_deref(),
272        ),
273    }
274}
275
276/// Log successful request completion at the configured level.
277fn log_success(level: Level, method: &str, duration_ms: f64) {
278    match level {
279        Level::TRACE => {
280            tracing::trace!(method = %method, duration_ms = duration_ms, "MCP request completed")
281        }
282        Level::DEBUG => {
283            tracing::debug!(method = %method, duration_ms = duration_ms, "MCP request completed")
284        }
285        Level::INFO => {
286            tracing::info!(method = %method, duration_ms = duration_ms, "MCP request completed")
287        }
288        Level::WARN => {
289            tracing::warn!(method = %method, duration_ms = duration_ms, "MCP request completed")
290        }
291        Level::ERROR => {
292            tracing::error!(method = %method, duration_ms = duration_ms, "MCP request completed")
293        }
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    #[test]
302    fn test_layer_creation() {
303        let layer = McpTracingLayer::new();
304        assert_eq!(layer.level, Level::INFO);
305
306        let layer = McpTracingLayer::new().level(Level::DEBUG);
307        assert_eq!(layer.level, Level::DEBUG);
308    }
309
310    #[test]
311    fn test_extract_operation_details() {
312        use crate::protocol::{CallToolParams, GetPromptParams, ReadResourceParams};
313        use serde_json::Value;
314        use std::collections::HashMap;
315
316        // Test tool call
317        let req = McpRequest::CallTool(CallToolParams {
318            name: "my_tool".to_string(),
319            arguments: Value::Null,
320            meta: None,
321        });
322        let (name, target) = extract_operation_details(&req);
323        assert_eq!(name, Some("tool"));
324        assert_eq!(target, Some("my_tool".to_string()));
325
326        // Test resource read
327        let req = McpRequest::ReadResource(ReadResourceParams {
328            uri: "file:///test.txt".to_string(),
329        });
330        let (name, target) = extract_operation_details(&req);
331        assert_eq!(name, Some("resource"));
332        assert_eq!(target, Some("file:///test.txt".to_string()));
333
334        // Test prompt get
335        let req = McpRequest::GetPrompt(GetPromptParams {
336            name: "my_prompt".to_string(),
337            arguments: HashMap::new(),
338        });
339        let (name, target) = extract_operation_details(&req);
340        assert_eq!(name, Some("prompt"));
341        assert_eq!(target, Some("my_prompt".to_string()));
342
343        // Test ping
344        let req = McpRequest::Ping;
345        let (name, target) = extract_operation_details(&req);
346        assert_eq!(name, Some("ping"));
347        assert_eq!(target, None);
348    }
349}