1use std::convert::Infallible;
38use std::future::Future;
39use std::pin::Pin;
40use std::task::{Context, Poll};
41use std::time::Instant;
42
43use tower::Layer;
44use tower_service::Service;
45use tracing::{Instrument, Level, Span};
46
47use crate::protocol::McpRequest;
48use crate::router::{RouterRequest, RouterResponse};
49
50#[derive(Debug, Clone, Copy)]
75pub struct McpTracingLayer {
76 level: Level,
77}
78
79impl Default for McpTracingLayer {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl McpTracingLayer {
86 pub fn new() -> Self {
88 Self { level: Level::INFO }
89 }
90
91 pub fn level(mut self, level: Level) -> Self {
95 self.level = level;
96 self
97 }
98}
99
100impl<S> Layer<S> for McpTracingLayer {
101 type Service = McpTracingService<S>;
102
103 fn layer(&self, inner: S) -> Self::Service {
104 McpTracingService {
105 inner,
106 level: self.level,
107 }
108 }
109}
110
111#[derive(Debug, Clone)]
115pub struct McpTracingService<S> {
116 inner: S,
117 level: Level,
118}
119
120impl<S> Service<RouterRequest> for McpTracingService<S>
121where
122 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
123 + Clone
124 + Send
125 + 'static,
126 S::Future: Send,
127{
128 type Response = RouterResponse;
129 type Error = Infallible;
130 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
131
132 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
133 self.inner.poll_ready(cx)
134 }
135
136 fn call(&mut self, req: RouterRequest) -> Self::Future {
137 let method = req.inner.method_name().to_string();
138 let request_id = format!("{:?}", req.id);
139
140 let (operation_name, operation_target) = extract_operation_details(&req.inner);
142
143 let span = create_span(
145 self.level,
146 &method,
147 &request_id,
148 operation_name,
149 operation_target,
150 );
151
152 let start = Instant::now();
153 let fut = self.inner.call(req);
154 let level = self.level;
155
156 Box::pin(
157 async move {
158 let result = fut.await;
159 let duration = start.elapsed();
160
161 match &result {
162 Ok(response) => {
163 let duration_ms = duration.as_secs_f64() * 1000.0;
164 match &response.inner {
165 Ok(_) => {
166 log_success(level, &method, duration_ms);
167 }
168 Err(err) => {
169 tracing::warn!(
170 method = %method,
171 error_code = err.code,
172 error_message = %err.message,
173 duration_ms = duration_ms,
174 "MCP request failed"
175 );
176 }
177 }
178 }
179 Err(_) => {
180 tracing::error!(method = %method, "MCP request error (infallible)");
182 }
183 }
184
185 result
186 }
187 .instrument(span),
188 )
189 }
190}
191
192pub(crate) fn extract_operation_details(
194 req: &McpRequest,
195) -> (Option<&'static str>, Option<String>) {
196 match req {
197 McpRequest::CallTool(params) => (Some("tool"), Some(params.name.clone())),
198 McpRequest::ReadResource(params) => (Some("resource"), Some(params.uri.clone())),
199 McpRequest::GetPrompt(params) => (Some("prompt"), Some(params.name.clone())),
200 McpRequest::ListTools(_) => (Some("list"), Some("tools".to_string())),
201 McpRequest::ListResources(_) => (Some("list"), Some("resources".to_string())),
202 McpRequest::ListResourceTemplates(_) => {
203 (Some("list"), Some("resource_templates".to_string()))
204 }
205 McpRequest::ListPrompts(_) => (Some("list"), Some("prompts".to_string())),
206 McpRequest::SubscribeResource(params) => (Some("subscribe"), Some(params.uri.clone())),
207 McpRequest::UnsubscribeResource(params) => (Some("unsubscribe"), Some(params.uri.clone())),
208 McpRequest::ListTasks(_) => (Some("list"), Some("tasks".to_string())),
209 McpRequest::GetTaskInfo(params) => (Some("task"), Some(params.task_id.clone())),
210 McpRequest::GetTaskResult(params) => (Some("task_result"), Some(params.task_id.clone())),
211 McpRequest::CancelTask(params) => (Some("cancel"), Some(params.task_id.clone())),
212 McpRequest::Complete(params) => {
213 let ref_type = match ¶ms.reference {
214 crate::protocol::CompletionReference::Resource { uri } => {
215 format!("resource:{}", uri)
216 }
217 crate::protocol::CompletionReference::Prompt { name } => {
218 format!("prompt:{}", name)
219 }
220 _ => "unknown".to_string(),
221 };
222 (Some("complete"), Some(ref_type))
223 }
224 McpRequest::SetLoggingLevel(params) => {
225 (Some("logging"), Some(format!("{:?}", params.level)))
226 }
227 McpRequest::Initialize(_) => (Some("init"), None),
228 McpRequest::Ping => (Some("ping"), None),
229 McpRequest::Unknown { method, .. } => (Some("unknown"), Some(method.clone())),
230 _ => (Some("unknown"), None),
231 }
232}
233
234fn create_span(
236 level: Level,
237 method: &str,
238 request_id: &str,
239 operation_name: Option<&str>,
240 operation_target: Option<String>,
241) -> Span {
242 match level {
243 Level::TRACE => tracing::trace_span!(
244 "mcp_request",
245 method = %method,
246 request_id = %request_id,
247 operation = operation_name,
248 target = operation_target.as_deref(),
249 ),
250 Level::DEBUG => tracing::debug_span!(
251 "mcp_request",
252 method = %method,
253 request_id = %request_id,
254 operation = operation_name,
255 target = operation_target.as_deref(),
256 ),
257 Level::INFO => tracing::info_span!(
258 "mcp_request",
259 method = %method,
260 request_id = %request_id,
261 operation = operation_name,
262 target = operation_target.as_deref(),
263 ),
264 Level::WARN => tracing::warn_span!(
265 "mcp_request",
266 method = %method,
267 request_id = %request_id,
268 operation = operation_name,
269 target = operation_target.as_deref(),
270 ),
271 Level::ERROR => tracing::error_span!(
272 "mcp_request",
273 method = %method,
274 request_id = %request_id,
275 operation = operation_name,
276 target = operation_target.as_deref(),
277 ),
278 }
279}
280
281fn log_success(level: Level, method: &str, duration_ms: f64) {
283 match level {
284 Level::TRACE => {
285 tracing::trace!(method = %method, duration_ms = duration_ms, "MCP request completed")
286 }
287 Level::DEBUG => {
288 tracing::debug!(method = %method, duration_ms = duration_ms, "MCP request completed")
289 }
290 Level::INFO => {
291 tracing::info!(method = %method, duration_ms = duration_ms, "MCP request completed")
292 }
293 Level::WARN => {
294 tracing::warn!(method = %method, duration_ms = duration_ms, "MCP request completed")
295 }
296 Level::ERROR => {
297 tracing::error!(method = %method, duration_ms = duration_ms, "MCP request completed")
298 }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn test_layer_creation() {
308 let layer = McpTracingLayer::new();
309 assert_eq!(layer.level, Level::INFO);
310
311 let layer = McpTracingLayer::new().level(Level::DEBUG);
312 assert_eq!(layer.level, Level::DEBUG);
313 }
314
315 #[test]
316 fn test_extract_operation_details() {
317 use crate::protocol::{CallToolParams, GetPromptParams, ReadResourceParams};
318 use serde_json::Value;
319 use std::collections::HashMap;
320
321 let req = McpRequest::CallTool(CallToolParams {
323 name: "my_tool".to_string(),
324 arguments: Value::Null,
325 meta: None,
326 task: None,
327 });
328 let (name, target) = extract_operation_details(&req);
329 assert_eq!(name, Some("tool"));
330 assert_eq!(target, Some("my_tool".to_string()));
331
332 let req = McpRequest::ReadResource(ReadResourceParams {
334 uri: "file:///test.txt".to_string(),
335 meta: None,
336 });
337 let (name, target) = extract_operation_details(&req);
338 assert_eq!(name, Some("resource"));
339 assert_eq!(target, Some("file:///test.txt".to_string()));
340
341 let req = McpRequest::GetPrompt(GetPromptParams {
343 name: "my_prompt".to_string(),
344 arguments: HashMap::new(),
345 meta: None,
346 });
347 let (name, target) = extract_operation_details(&req);
348 assert_eq!(name, Some("prompt"));
349 assert_eq!(target, Some("my_prompt".to_string()));
350
351 let req = McpRequest::Ping;
353 let (name, target) = extract_operation_details(&req);
354 assert_eq!(name, Some("ping"));
355 assert_eq!(target, None);
356 }
357}