1use std::convert::Infallible;
37use std::future::Future;
38use std::pin::Pin;
39use std::task::{Context, Poll};
40use std::time::Instant;
41
42use tower::Layer;
43use tower_service::Service;
44use tracing::{Instrument, Level, Span};
45
46use crate::protocol::McpRequest;
47use crate::router::{RouterRequest, RouterResponse};
48
49#[derive(Debug, Clone, Copy)]
73pub struct McpTracingLayer {
74 level: Level,
75}
76
77impl Default for McpTracingLayer {
78 fn default() -> Self {
79 Self::new()
80 }
81}
82
83impl McpTracingLayer {
84 pub fn new() -> Self {
86 Self { level: Level::INFO }
87 }
88
89 pub fn level(mut self, level: Level) -> Self {
93 self.level = level;
94 self
95 }
96}
97
98impl<S> Layer<S> for McpTracingLayer {
99 type Service = McpTracingService<S>;
100
101 fn layer(&self, inner: S) -> Self::Service {
102 McpTracingService {
103 inner,
104 level: self.level,
105 }
106 }
107}
108
109#[derive(Debug, Clone)]
113pub struct McpTracingService<S> {
114 inner: S,
115 level: Level,
116}
117
118impl<S> Service<RouterRequest> for McpTracingService<S>
119where
120 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
121 + Clone
122 + Send
123 + 'static,
124 S::Future: Send,
125{
126 type Response = RouterResponse;
127 type Error = Infallible;
128 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
129
130 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
131 self.inner.poll_ready(cx)
132 }
133
134 fn call(&mut self, req: RouterRequest) -> Self::Future {
135 let method = req.inner.method_name().to_string();
136 let request_id = format!("{:?}", req.id);
137
138 let (operation_name, operation_target) = extract_operation_details(&req.inner);
140
141 let span = create_span(
143 self.level,
144 &method,
145 &request_id,
146 operation_name,
147 operation_target,
148 );
149
150 let start = Instant::now();
151 let fut = self.inner.call(req);
152 let level = self.level;
153
154 Box::pin(
155 async move {
156 let result = fut.await;
157 let duration = start.elapsed();
158
159 match &result {
160 Ok(response) => {
161 let duration_ms = duration.as_secs_f64() * 1000.0;
162 match &response.inner {
163 Ok(_) => {
164 log_success(level, &method, duration_ms);
165 }
166 Err(err) => {
167 tracing::warn!(
168 method = %method,
169 error_code = err.code,
170 error_message = %err.message,
171 duration_ms = duration_ms,
172 "MCP request failed"
173 );
174 }
175 }
176 }
177 Err(_) => {
178 tracing::error!(method = %method, "MCP request error (infallible)");
180 }
181 }
182
183 result
184 }
185 .instrument(span),
186 )
187 }
188}
189
190fn extract_operation_details(req: &McpRequest) -> (Option<&'static str>, Option<String>) {
192 match req {
193 McpRequest::CallTool(params) => (Some("tool"), Some(params.name.clone())),
194 McpRequest::ReadResource(params) => (Some("resource"), Some(params.uri.clone())),
195 McpRequest::GetPrompt(params) => (Some("prompt"), Some(params.name.clone())),
196 McpRequest::ListTools(_) => (Some("list"), Some("tools".to_string())),
197 McpRequest::ListResources(_) => (Some("list"), Some("resources".to_string())),
198 McpRequest::ListResourceTemplates(_) => {
199 (Some("list"), Some("resource_templates".to_string()))
200 }
201 McpRequest::ListPrompts(_) => (Some("list"), Some("prompts".to_string())),
202 McpRequest::SubscribeResource(params) => (Some("subscribe"), Some(params.uri.clone())),
203 McpRequest::UnsubscribeResource(params) => (Some("unsubscribe"), Some(params.uri.clone())),
204 McpRequest::EnqueueTask(params) => (Some("task"), Some(params.tool_name.clone())),
205 McpRequest::ListTasks(_) => (Some("list"), Some("tasks".to_string())),
206 McpRequest::GetTaskInfo(params) => (Some("task"), Some(params.task_id.clone())),
207 McpRequest::GetTaskResult(params) => (Some("task_result"), Some(params.task_id.clone())),
208 McpRequest::CancelTask(params) => (Some("cancel"), Some(params.task_id.clone())),
209 McpRequest::Complete(params) => {
210 let ref_type = match ¶ms.reference {
211 crate::protocol::CompletionReference::Resource { uri } => {
212 format!("resource:{}", uri)
213 }
214 crate::protocol::CompletionReference::Prompt { name } => {
215 format!("prompt:{}", name)
216 }
217 };
218 (Some("complete"), Some(ref_type))
219 }
220 McpRequest::SetLoggingLevel(params) => {
221 (Some("logging"), Some(format!("{:?}", params.level)))
222 }
223 McpRequest::Initialize(_) => (Some("init"), None),
224 McpRequest::Ping => (Some("ping"), None),
225 McpRequest::Unknown { method, .. } => (Some("unknown"), Some(method.clone())),
226 }
227}
228
229fn create_span(
231 level: Level,
232 method: &str,
233 request_id: &str,
234 operation_name: Option<&str>,
235 operation_target: Option<String>,
236) -> Span {
237 match level {
238 Level::TRACE => tracing::trace_span!(
239 "mcp_request",
240 method = %method,
241 request_id = %request_id,
242 operation = operation_name,
243 target = operation_target.as_deref(),
244 ),
245 Level::DEBUG => tracing::debug_span!(
246 "mcp_request",
247 method = %method,
248 request_id = %request_id,
249 operation = operation_name,
250 target = operation_target.as_deref(),
251 ),
252 Level::INFO => tracing::info_span!(
253 "mcp_request",
254 method = %method,
255 request_id = %request_id,
256 operation = operation_name,
257 target = operation_target.as_deref(),
258 ),
259 Level::WARN => tracing::warn_span!(
260 "mcp_request",
261 method = %method,
262 request_id = %request_id,
263 operation = operation_name,
264 target = operation_target.as_deref(),
265 ),
266 Level::ERROR => tracing::error_span!(
267 "mcp_request",
268 method = %method,
269 request_id = %request_id,
270 operation = operation_name,
271 target = operation_target.as_deref(),
272 ),
273 }
274}
275
276fn log_success(level: Level, method: &str, duration_ms: f64) {
278 match level {
279 Level::TRACE => {
280 tracing::trace!(method = %method, duration_ms = duration_ms, "MCP request completed")
281 }
282 Level::DEBUG => {
283 tracing::debug!(method = %method, duration_ms = duration_ms, "MCP request completed")
284 }
285 Level::INFO => {
286 tracing::info!(method = %method, duration_ms = duration_ms, "MCP request completed")
287 }
288 Level::WARN => {
289 tracing::warn!(method = %method, duration_ms = duration_ms, "MCP request completed")
290 }
291 Level::ERROR => {
292 tracing::error!(method = %method, duration_ms = duration_ms, "MCP request completed")
293 }
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[test]
302 fn test_layer_creation() {
303 let layer = McpTracingLayer::new();
304 assert_eq!(layer.level, Level::INFO);
305
306 let layer = McpTracingLayer::new().level(Level::DEBUG);
307 assert_eq!(layer.level, Level::DEBUG);
308 }
309
310 #[test]
311 fn test_extract_operation_details() {
312 use crate::protocol::{CallToolParams, GetPromptParams, ReadResourceParams};
313 use serde_json::Value;
314 use std::collections::HashMap;
315
316 let req = McpRequest::CallTool(CallToolParams {
318 name: "my_tool".to_string(),
319 arguments: Value::Null,
320 meta: None,
321 });
322 let (name, target) = extract_operation_details(&req);
323 assert_eq!(name, Some("tool"));
324 assert_eq!(target, Some("my_tool".to_string()));
325
326 let req = McpRequest::ReadResource(ReadResourceParams {
328 uri: "file:///test.txt".to_string(),
329 });
330 let (name, target) = extract_operation_details(&req);
331 assert_eq!(name, Some("resource"));
332 assert_eq!(target, Some("file:///test.txt".to_string()));
333
334 let req = McpRequest::GetPrompt(GetPromptParams {
336 name: "my_prompt".to_string(),
337 arguments: HashMap::new(),
338 });
339 let (name, target) = extract_operation_details(&req);
340 assert_eq!(name, Some("prompt"));
341 assert_eq!(target, Some("my_prompt".to_string()));
342
343 let req = McpRequest::Ping;
345 let (name, target) = extract_operation_details(&req);
346 assert_eq!(name, Some("ping"));
347 assert_eq!(target, None);
348 }
349}