1use std::future::Future;
37use std::pin::Pin;
38use std::task::{Context, Poll};
39use std::time::Instant;
40
41use tower::Layer;
42use tower_service::Service;
43use tracing::Level;
44
45use crate::protocol::McpRequest;
46use crate::router::{RouterRequest, RouterResponse, ToolAnnotationsMap};
47
48const JSONRPC_INVALID_PARAMS: i32 = -32602;
50
51#[derive(Debug, Clone, Copy)]
72pub struct ToolCallLoggingLayer {
73 level: Level,
74}
75
76impl Default for ToolCallLoggingLayer {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82impl ToolCallLoggingLayer {
83 pub fn new() -> Self {
85 Self { level: Level::INFO }
86 }
87
88 pub fn level(mut self, level: Level) -> Self {
92 self.level = level;
93 self
94 }
95}
96
97impl<S> Layer<S> for ToolCallLoggingLayer {
98 type Service = ToolCallLoggingService<S>;
99
100 fn layer(&self, inner: S) -> Self::Service {
101 ToolCallLoggingService {
102 inner,
103 level: self.level,
104 }
105 }
106}
107
108#[derive(Debug, Clone)]
112pub struct ToolCallLoggingService<S> {
113 inner: S,
114 level: Level,
115}
116
117impl<S> Service<RouterRequest> for ToolCallLoggingService<S>
118where
119 S: Service<RouterRequest, Response = RouterResponse> + Clone + Send + 'static,
120 S::Error: Send,
121 S::Future: Send,
122{
123 type Response = RouterResponse;
124 type Error = S::Error;
125 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, S::Error>> + Send>>;
126
127 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
128 self.inner.poll_ready(cx)
129 }
130
131 fn call(&mut self, req: RouterRequest) -> Self::Future {
132 let tool_name = match &req.inner {
134 McpRequest::CallTool(params) => params.name.clone(),
135 _ => {
136 let fut = self.inner.call(req);
137 return Box::pin(fut);
138 }
139 };
140
141 let read_only = req
143 .extensions
144 .get::<ToolAnnotationsMap>()
145 .map(|m| m.is_read_only(&tool_name));
146 let destructive = req
147 .extensions
148 .get::<ToolAnnotationsMap>()
149 .map(|m| m.is_destructive(&tool_name));
150
151 let start = Instant::now();
152 let fut = self.inner.call(req);
153 let level = self.level;
154
155 Box::pin(async move {
156 let result = fut.await;
157 let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
158
159 if let Ok(response) = &result {
160 match &response.inner {
161 Ok(_) => {
162 log_tool_call(
163 level,
164 &tool_name,
165 duration_ms,
166 "success",
167 None,
168 read_only,
169 destructive,
170 );
171 }
172 Err(err) => {
173 let status = if err.code == JSONRPC_INVALID_PARAMS {
174 "denied"
175 } else {
176 "error"
177 };
178 log_tool_call(
179 level,
180 &tool_name,
181 duration_ms,
182 status,
183 Some((err.code, &err.message)),
184 read_only,
185 destructive,
186 );
187 }
188 }
189 }
190
191 result
192 })
193 }
194}
195
196fn log_tool_call(
198 level: Level,
199 tool: &str,
200 duration_ms: f64,
201 status: &str,
202 error: Option<(i32, &str)>,
203 read_only: Option<bool>,
204 destructive: Option<bool>,
205) {
206 match (level, error) {
207 (Level::TRACE, None) => {
208 tracing::trace!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
209 }
210 (Level::TRACE, Some((code, message))) => {
211 tracing::trace!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
212 }
213 (Level::DEBUG, None) => {
214 tracing::debug!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
215 }
216 (Level::DEBUG, Some((code, message))) => {
217 tracing::debug!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
218 }
219 (Level::INFO, None) => {
220 tracing::info!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
221 }
222 (Level::INFO, Some((code, message))) => {
223 tracing::info!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
224 }
225 (Level::WARN, None) => {
226 tracing::warn!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
227 }
228 (Level::WARN, Some((code, message))) => {
229 tracing::warn!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
230 }
231 (Level::ERROR, None) => {
232 tracing::error!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
233 }
234 (Level::ERROR, Some((code, message))) => {
235 tracing::error!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
236 }
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use crate::protocol::{CallToolParams, RequestId};
244 use crate::router::Extensions;
245
246 #[test]
247 fn test_layer_creation() {
248 let layer = ToolCallLoggingLayer::new();
249 assert_eq!(layer.level, Level::INFO);
250 }
251
252 #[test]
253 fn test_layer_with_custom_level() {
254 let layer = ToolCallLoggingLayer::new().level(Level::DEBUG);
255 assert_eq!(layer.level, Level::DEBUG);
256 }
257
258 #[test]
259 fn test_layer_default() {
260 let layer = ToolCallLoggingLayer::default();
261 assert_eq!(layer.level, Level::INFO);
262 }
263
264 #[tokio::test]
265 async fn test_non_tool_call_passthrough() {
266 let router = crate::McpRouter::new().server_info("test", "1.0.0");
267 let layer = ToolCallLoggingLayer::new();
268 let mut service = layer.layer(router);
269
270 let req = RouterRequest {
272 id: RequestId::Number(1),
273 inner: McpRequest::Ping,
274 extensions: Extensions::new(),
275 };
276
277 let result = Service::call(&mut service, req).await;
278 assert!(result.is_ok());
279 assert!(result.unwrap().inner.is_ok());
280 }
281
282 #[tokio::test]
283 async fn test_tool_call_logging() {
284 let tool = crate::ToolBuilder::new("test_tool")
285 .description("A test tool")
286 .handler(|_: serde_json::Value| async move { Ok(crate::CallToolResult::text("done")) })
287 .build();
288
289 let router = crate::McpRouter::new()
290 .server_info("test", "1.0.0")
291 .tool(tool);
292 let layer = ToolCallLoggingLayer::new();
293 let mut service = layer.layer(router);
294
295 let req = RouterRequest {
296 id: RequestId::Number(1),
297 inner: McpRequest::CallTool(CallToolParams {
298 name: "test_tool".to_string(),
299 arguments: serde_json::json!({}),
300 meta: None,
301 task: None,
302 }),
303 extensions: Extensions::new(),
304 };
305
306 let result = Service::call(&mut service, req).await;
310 assert!(result.is_ok());
311 }
312
313 #[tokio::test]
314 async fn test_tool_call_error_logging() {
315 let router = crate::McpRouter::new().server_info("test", "1.0.0");
317 let layer = ToolCallLoggingLayer::new();
318 let mut service = layer.layer(router);
319
320 let req = RouterRequest {
321 id: RequestId::Number(1),
322 inner: McpRequest::CallTool(CallToolParams {
323 name: "nonexistent".to_string(),
324 arguments: serde_json::json!({}),
325 meta: None,
326 task: None,
327 }),
328 extensions: Extensions::new(),
329 };
330
331 let result = Service::call(&mut service, req).await;
332 assert!(result.is_ok());
333 assert!(result.unwrap().inner.is_err());
335 }
336}