Skip to main content

tower_mcp/middleware/
tool_call_logging.rs

1//! Tool call audit logging middleware.
2//!
3//! This module provides [`ToolCallLoggingLayer`], a Tower middleware that emits
4//! structured [`tracing`] events specifically for `tools/call` requests. Unlike
5//! [`McpTracingLayer`](super::McpTracingLayer) which traces all MCP requests,
6//! this layer focuses on tool invocations and provides richer audit information.
7//!
8//! Note: [`McpRouter`](crate::McpRouter) now emits basic tool call logging
9//! (tool name, duration, status) by default at `INFO` level on the `mcp::tools`
10//! target. Use this layer when you need additional detail such as annotation
11//! hints (`read_only`, `destructive`) or a custom log level.
12//!
13//! # Logged Information
14//!
15//! For each tool call, the layer emits a single event after completion with:
16//! - **tool**: The tool name
17//! - **duration_ms**: Execution time in milliseconds
18//! - **status**: One of `"success"`, `"error"`, or `"denied"`
19//! - **error_code** / **error_message**: Present on error responses
20//!
21//! All events use tracing target `mcp::tools` for easy filtering.
22//! Non-`CallTool` requests pass through unchanged with no overhead.
23//!
24//! # Example
25//!
26//! ```rust,ignore
27//! use tower_mcp::{McpRouter, StdioTransport};
28//! use tower_mcp::middleware::ToolCallLoggingLayer;
29//!
30//! let router = McpRouter::new().server_info("my-server", "1.0.0");
31//!
32//! let mut transport = StdioTransport::new(router)
33//!     .layer(ToolCallLoggingLayer::new());
34//! ```
35
36use std::future::Future;
37use std::pin::Pin;
38use std::task::{Context, Poll};
39use std::time::Instant;
40
41use tower::Layer;
42use tower_service::Service;
43use tracing::Level;
44
45use crate::protocol::McpRequest;
46use crate::router::{RouterRequest, RouterResponse, ToolAnnotationsMap};
47
48/// JSON-RPC error code for "invalid params", which may indicate a denied tool call.
49const JSONRPC_INVALID_PARAMS: i32 = -32602;
50
51/// Tower layer that adds audit logging for tool call requests.
52///
53/// This layer intercepts `tools/call` requests and emits structured
54/// [`tracing`] events with the tool name, execution duration, and result
55/// status. Non-tool-call requests pass through with zero overhead.
56///
57/// Events are emitted with tracing target `mcp::tools`, which can be used
58/// for filtering in your tracing subscriber configuration.
59///
60/// # Example
61///
62/// ```rust,ignore
63/// use tower_mcp::{McpRouter, StdioTransport};
64/// use tower_mcp::middleware::ToolCallLoggingLayer;
65///
66/// let router = McpRouter::new().server_info("my-server", "1.0.0");
67///
68/// let mut transport = StdioTransport::new(router)
69///     .layer(ToolCallLoggingLayer::new());
70/// ```
71#[derive(Debug, Clone, Copy)]
72pub struct ToolCallLoggingLayer {
73    level: Level,
74}
75
76impl Default for ToolCallLoggingLayer {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82impl ToolCallLoggingLayer {
83    /// Create a new tool call logging layer with default settings (INFO level).
84    pub fn new() -> Self {
85        Self { level: Level::INFO }
86    }
87
88    /// Set the log level for tool call events.
89    ///
90    /// Default is `INFO`.
91    pub fn level(mut self, level: Level) -> Self {
92        self.level = level;
93        self
94    }
95}
96
97impl<S> Layer<S> for ToolCallLoggingLayer {
98    type Service = ToolCallLoggingService<S>;
99
100    fn layer(&self, inner: S) -> Self::Service {
101        ToolCallLoggingService {
102            inner,
103            level: self.level,
104        }
105    }
106}
107
108/// Tower service that logs tool call requests.
109///
110/// Created by [`ToolCallLoggingLayer`]. See the layer documentation for details.
111#[derive(Debug, Clone)]
112pub struct ToolCallLoggingService<S> {
113    inner: S,
114    level: Level,
115}
116
117impl<S> Service<RouterRequest> for ToolCallLoggingService<S>
118where
119    S: Service<RouterRequest, Response = RouterResponse> + Clone + Send + 'static,
120    S::Error: Send,
121    S::Future: Send,
122{
123    type Response = RouterResponse;
124    type Error = S::Error;
125    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, S::Error>> + Send>>;
126
127    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
128        self.inner.poll_ready(cx)
129    }
130
131    fn call(&mut self, req: RouterRequest) -> Self::Future {
132        // Only intercept CallTool requests; pass everything else through directly.
133        let tool_name = match &req.inner {
134            McpRequest::CallTool(params) => params.name.clone(),
135            _ => {
136                let fut = self.inner.call(req);
137                return Box::pin(fut);
138            }
139        };
140
141        // Extract annotation hints if the transport injected them.
142        let read_only = req
143            .extensions
144            .get::<ToolAnnotationsMap>()
145            .map(|m| m.is_read_only(&tool_name));
146        let destructive = req
147            .extensions
148            .get::<ToolAnnotationsMap>()
149            .map(|m| m.is_destructive(&tool_name));
150
151        let start = Instant::now();
152        let fut = self.inner.call(req);
153        let level = self.level;
154
155        Box::pin(async move {
156            let result = fut.await;
157            let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
158
159            if let Ok(response) = &result {
160                match &response.inner {
161                    Ok(_) => {
162                        log_tool_call(
163                            level,
164                            &tool_name,
165                            duration_ms,
166                            "success",
167                            None,
168                            read_only,
169                            destructive,
170                        );
171                    }
172                    Err(err) => {
173                        let status = if err.code == JSONRPC_INVALID_PARAMS {
174                            "denied"
175                        } else {
176                            "error"
177                        };
178                        log_tool_call(
179                            level,
180                            &tool_name,
181                            duration_ms,
182                            status,
183                            Some((err.code, &err.message)),
184                            read_only,
185                            destructive,
186                        );
187                    }
188                }
189            }
190
191            result
192        })
193    }
194}
195
196/// Emit a structured tracing event for a tool call.
197fn log_tool_call(
198    level: Level,
199    tool: &str,
200    duration_ms: f64,
201    status: &str,
202    error: Option<(i32, &str)>,
203    read_only: Option<bool>,
204    destructive: Option<bool>,
205) {
206    match (level, error) {
207        (Level::TRACE, None) => {
208            tracing::trace!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
209        }
210        (Level::TRACE, Some((code, message))) => {
211            tracing::trace!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
212        }
213        (Level::DEBUG, None) => {
214            tracing::debug!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
215        }
216        (Level::DEBUG, Some((code, message))) => {
217            tracing::debug!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
218        }
219        (Level::INFO, None) => {
220            tracing::info!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
221        }
222        (Level::INFO, Some((code, message))) => {
223            tracing::info!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
224        }
225        (Level::WARN, None) => {
226            tracing::warn!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
227        }
228        (Level::WARN, Some((code, message))) => {
229            tracing::warn!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
230        }
231        (Level::ERROR, None) => {
232            tracing::error!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
233        }
234        (Level::ERROR, Some((code, message))) => {
235            tracing::error!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
236        }
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use crate::protocol::{CallToolParams, RequestId};
244    use crate::router::Extensions;
245
246    #[test]
247    fn test_layer_creation() {
248        let layer = ToolCallLoggingLayer::new();
249        assert_eq!(layer.level, Level::INFO);
250    }
251
252    #[test]
253    fn test_layer_with_custom_level() {
254        let layer = ToolCallLoggingLayer::new().level(Level::DEBUG);
255        assert_eq!(layer.level, Level::DEBUG);
256    }
257
258    #[test]
259    fn test_layer_default() {
260        let layer = ToolCallLoggingLayer::default();
261        assert_eq!(layer.level, Level::INFO);
262    }
263
264    #[tokio::test]
265    async fn test_non_tool_call_passthrough() {
266        let router = crate::McpRouter::new().server_info("test", "1.0.0");
267        let layer = ToolCallLoggingLayer::new();
268        let mut service = layer.layer(router);
269
270        // Ping request should pass through without tool call logging
271        let req = RouterRequest {
272            id: RequestId::Number(1),
273            inner: McpRequest::Ping,
274            extensions: Extensions::new(),
275        };
276
277        let result = Service::call(&mut service, req).await;
278        assert!(result.is_ok());
279        assert!(result.unwrap().inner.is_ok());
280    }
281
282    #[tokio::test]
283    async fn test_tool_call_logging() {
284        let tool = crate::ToolBuilder::new("test_tool")
285            .description("A test tool")
286            .handler(|_: serde_json::Value| async move { Ok(crate::CallToolResult::text("done")) })
287            .build();
288
289        let router = crate::McpRouter::new()
290            .server_info("test", "1.0.0")
291            .tool(tool);
292        let layer = ToolCallLoggingLayer::new();
293        let mut service = layer.layer(router);
294
295        let req = RouterRequest {
296            id: RequestId::Number(1),
297            inner: McpRequest::CallTool(CallToolParams {
298                name: "test_tool".to_string(),
299                arguments: serde_json::json!({}),
300                meta: None,
301                task: None,
302            }),
303            extensions: Extensions::new(),
304        };
305
306        // The service should process the request and return a response.
307        // The inner result may be an error (session not initialized) but
308        // the logging layer should handle it either way.
309        let result = Service::call(&mut service, req).await;
310        assert!(result.is_ok());
311    }
312
313    #[tokio::test]
314    async fn test_tool_call_error_logging() {
315        // Call a tool that doesn't exist to get an error response
316        let router = crate::McpRouter::new().server_info("test", "1.0.0");
317        let layer = ToolCallLoggingLayer::new();
318        let mut service = layer.layer(router);
319
320        let req = RouterRequest {
321            id: RequestId::Number(1),
322            inner: McpRequest::CallTool(CallToolParams {
323                name: "nonexistent".to_string(),
324                arguments: serde_json::json!({}),
325                meta: None,
326                task: None,
327            }),
328            extensions: Extensions::new(),
329        };
330
331        let result = Service::call(&mut service, req).await;
332        assert!(result.is_ok());
333        // The response should be an error (tool not found)
334        assert!(result.unwrap().inner.is_err());
335    }
336}