Skip to main content

tower_mcp/middleware/
tool_call_logging.rs

1//! Tool call audit logging middleware.
2//!
3//! This module provides [`ToolCallLoggingLayer`], a Tower middleware that emits
4//! structured [`tracing`] events specifically for `tools/call` requests. Unlike
5//! [`McpTracingLayer`](super::McpTracingLayer) which traces all MCP requests,
6//! this layer focuses on tool invocations and provides richer audit information.
7//!
8//! # Logged Information
9//!
10//! For each tool call, the layer emits a single event after completion with:
11//! - **tool**: The tool name
12//! - **duration_ms**: Execution time in milliseconds
13//! - **status**: One of `"success"`, `"error"`, or `"denied"`
14//! - **error_code** / **error_message**: Present on error responses
15//!
16//! All events use tracing target `mcp::tools` for easy filtering.
17//! Non-`CallTool` requests pass through unchanged with no overhead.
18//!
19//! # Example
20//!
21//! ```rust,ignore
22//! use tower_mcp::{McpRouter, StdioTransport};
23//! use tower_mcp::middleware::ToolCallLoggingLayer;
24//!
25//! let router = McpRouter::new().server_info("my-server", "1.0.0");
26//!
27//! let mut transport = StdioTransport::new(router)
28//!     .layer(ToolCallLoggingLayer::new());
29//! ```
30
31use std::future::Future;
32use std::pin::Pin;
33use std::task::{Context, Poll};
34use std::time::Instant;
35
36use tower::Layer;
37use tower_service::Service;
38use tracing::Level;
39
40use crate::protocol::McpRequest;
41use crate::router::{RouterRequest, RouterResponse, ToolAnnotationsMap};
42
43/// JSON-RPC error code for "invalid params", which may indicate a denied tool call.
44const JSONRPC_INVALID_PARAMS: i32 = -32602;
45
46/// Tower layer that adds audit logging for tool call requests.
47///
48/// This layer intercepts `tools/call` requests and emits structured
49/// [`tracing`] events with the tool name, execution duration, and result
50/// status. Non-tool-call requests pass through with zero overhead.
51///
52/// Events are emitted with tracing target `mcp::tools`, which can be used
53/// for filtering in your tracing subscriber configuration.
54///
55/// # Example
56///
57/// ```rust,ignore
58/// use tower_mcp::{McpRouter, StdioTransport};
59/// use tower_mcp::middleware::ToolCallLoggingLayer;
60///
61/// let router = McpRouter::new().server_info("my-server", "1.0.0");
62///
63/// let mut transport = StdioTransport::new(router)
64///     .layer(ToolCallLoggingLayer::new());
65/// ```
66#[derive(Debug, Clone, Copy)]
67pub struct ToolCallLoggingLayer {
68    level: Level,
69}
70
71impl Default for ToolCallLoggingLayer {
72    fn default() -> Self {
73        Self::new()
74    }
75}
76
77impl ToolCallLoggingLayer {
78    /// Create a new tool call logging layer with default settings (INFO level).
79    pub fn new() -> Self {
80        Self { level: Level::INFO }
81    }
82
83    /// Set the log level for tool call events.
84    ///
85    /// Default is `INFO`.
86    pub fn level(mut self, level: Level) -> Self {
87        self.level = level;
88        self
89    }
90}
91
92impl<S> Layer<S> for ToolCallLoggingLayer {
93    type Service = ToolCallLoggingService<S>;
94
95    fn layer(&self, inner: S) -> Self::Service {
96        ToolCallLoggingService {
97            inner,
98            level: self.level,
99        }
100    }
101}
102
103/// Tower service that logs tool call requests.
104///
105/// Created by [`ToolCallLoggingLayer`]. See the layer documentation for details.
106#[derive(Debug, Clone)]
107pub struct ToolCallLoggingService<S> {
108    inner: S,
109    level: Level,
110}
111
112impl<S> Service<RouterRequest> for ToolCallLoggingService<S>
113where
114    S: Service<RouterRequest, Response = RouterResponse> + Clone + Send + 'static,
115    S::Error: Send,
116    S::Future: Send,
117{
118    type Response = RouterResponse;
119    type Error = S::Error;
120    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, S::Error>> + Send>>;
121
122    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
123        self.inner.poll_ready(cx)
124    }
125
126    fn call(&mut self, req: RouterRequest) -> Self::Future {
127        // Only intercept CallTool requests; pass everything else through directly.
128        let tool_name = match &req.inner {
129            McpRequest::CallTool(params) => params.name.clone(),
130            _ => {
131                let fut = self.inner.call(req);
132                return Box::pin(fut);
133            }
134        };
135
136        // Extract annotation hints if the transport injected them.
137        let read_only = req
138            .extensions
139            .get::<ToolAnnotationsMap>()
140            .map(|m| m.is_read_only(&tool_name));
141        let destructive = req
142            .extensions
143            .get::<ToolAnnotationsMap>()
144            .map(|m| m.is_destructive(&tool_name));
145
146        let start = Instant::now();
147        let fut = self.inner.call(req);
148        let level = self.level;
149
150        Box::pin(async move {
151            let result = fut.await;
152            let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
153
154            if let Ok(response) = &result {
155                match &response.inner {
156                    Ok(_) => {
157                        log_tool_call(
158                            level,
159                            &tool_name,
160                            duration_ms,
161                            "success",
162                            None,
163                            read_only,
164                            destructive,
165                        );
166                    }
167                    Err(err) => {
168                        let status = if err.code == JSONRPC_INVALID_PARAMS {
169                            "denied"
170                        } else {
171                            "error"
172                        };
173                        log_tool_call(
174                            level,
175                            &tool_name,
176                            duration_ms,
177                            status,
178                            Some((err.code, &err.message)),
179                            read_only,
180                            destructive,
181                        );
182                    }
183                }
184            }
185
186            result
187        })
188    }
189}
190
191/// Emit a structured tracing event for a tool call.
192fn log_tool_call(
193    level: Level,
194    tool: &str,
195    duration_ms: f64,
196    status: &str,
197    error: Option<(i32, &str)>,
198    read_only: Option<bool>,
199    destructive: Option<bool>,
200) {
201    match (level, error) {
202        (Level::TRACE, None) => {
203            tracing::trace!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
204        }
205        (Level::TRACE, Some((code, message))) => {
206            tracing::trace!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
207        }
208        (Level::DEBUG, None) => {
209            tracing::debug!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
210        }
211        (Level::DEBUG, Some((code, message))) => {
212            tracing::debug!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
213        }
214        (Level::INFO, None) => {
215            tracing::info!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
216        }
217        (Level::INFO, Some((code, message))) => {
218            tracing::info!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
219        }
220        (Level::WARN, None) => {
221            tracing::warn!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
222        }
223        (Level::WARN, Some((code, message))) => {
224            tracing::warn!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
225        }
226        (Level::ERROR, None) => {
227            tracing::error!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
228        }
229        (Level::ERROR, Some((code, message))) => {
230            tracing::error!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
231        }
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use crate::protocol::{CallToolParams, RequestId};
239    use crate::router::Extensions;
240
241    #[test]
242    fn test_layer_creation() {
243        let layer = ToolCallLoggingLayer::new();
244        assert_eq!(layer.level, Level::INFO);
245    }
246
247    #[test]
248    fn test_layer_with_custom_level() {
249        let layer = ToolCallLoggingLayer::new().level(Level::DEBUG);
250        assert_eq!(layer.level, Level::DEBUG);
251    }
252
253    #[test]
254    fn test_layer_default() {
255        let layer = ToolCallLoggingLayer::default();
256        assert_eq!(layer.level, Level::INFO);
257    }
258
259    #[tokio::test]
260    async fn test_non_tool_call_passthrough() {
261        let router = crate::McpRouter::new().server_info("test", "1.0.0");
262        let layer = ToolCallLoggingLayer::new();
263        let mut service = layer.layer(router);
264
265        // Ping request should pass through without tool call logging
266        let req = RouterRequest {
267            id: RequestId::Number(1),
268            inner: McpRequest::Ping,
269            extensions: Extensions::new(),
270        };
271
272        let result = Service::call(&mut service, req).await;
273        assert!(result.is_ok());
274        assert!(result.unwrap().inner.is_ok());
275    }
276
277    #[tokio::test]
278    async fn test_tool_call_logging() {
279        let tool = crate::ToolBuilder::new("test_tool")
280            .description("A test tool")
281            .handler(|_: serde_json::Value| async move { Ok(crate::CallToolResult::text("done")) })
282            .build();
283
284        let router = crate::McpRouter::new()
285            .server_info("test", "1.0.0")
286            .tool(tool);
287        let layer = ToolCallLoggingLayer::new();
288        let mut service = layer.layer(router);
289
290        let req = RouterRequest {
291            id: RequestId::Number(1),
292            inner: McpRequest::CallTool(CallToolParams {
293                name: "test_tool".to_string(),
294                arguments: serde_json::json!({}),
295                meta: None,
296                task: None,
297            }),
298            extensions: Extensions::new(),
299        };
300
301        // The service should process the request and return a response.
302        // The inner result may be an error (session not initialized) but
303        // the logging layer should handle it either way.
304        let result = Service::call(&mut service, req).await;
305        assert!(result.is_ok());
306    }
307
308    #[tokio::test]
309    async fn test_tool_call_error_logging() {
310        // Call a tool that doesn't exist to get an error response
311        let router = crate::McpRouter::new().server_info("test", "1.0.0");
312        let layer = ToolCallLoggingLayer::new();
313        let mut service = layer.layer(router);
314
315        let req = RouterRequest {
316            id: RequestId::Number(1),
317            inner: McpRequest::CallTool(CallToolParams {
318                name: "nonexistent".to_string(),
319                arguments: serde_json::json!({}),
320                meta: None,
321                task: None,
322            }),
323            extensions: Extensions::new(),
324        };
325
326        let result = Service::call(&mut service, req).await;
327        assert!(result.is_ok());
328        // The response should be an error (tool not found)
329        assert!(result.unwrap().inner.is_err());
330    }
331}