Skip to main content

tower_mcp/middleware/
audit.rs

1//! MCP audit logging middleware.
2//!
3//! This module provides [`AuditLayer`], a Tower middleware that emits structured
4//! audit events for all MCP requests using the [`tracing`] crate. Unlike
5//! [`McpTracingLayer`](super::McpTracingLayer) which focuses on request tracing
6//! with spans, or [`ToolCallLoggingLayer`](super::ToolCallLoggingLayer) which
7//! focuses only on tool calls, `AuditLayer` emits a single, well-structured
8//! audit event per request designed for compliance and security monitoring.
9//!
10//! # Audit Event Fields
11//!
12//! Every audit event includes:
13//! - **method**: The MCP method (e.g., `tools/call`, `resources/read`)
14//! - **request_id**: The JSON-RPC request ID
15//! - **duration_ms**: Request processing time in milliseconds
16//! - **status**: `"success"`, `"error"`, or `"denied"`
17//!
18//! Operation-specific fields:
19//! - **tool**: Tool name (for `tools/call`)
20//! - **resource_uri**: Resource URI (for `resources/read`)
21//! - **prompt**: Prompt name (for `prompts/get`)
22//! - **error_code** / **error_message**: Present on error responses
23//! - **read_only** / **destructive**: Tool annotation hints (when available)
24//!
25//! All events use tracing target `mcp::audit` for easy filtering and routing
26//! to dedicated audit log sinks.
27//!
28//! # Example
29//!
30//! ```rust,ignore
31//! use tower_mcp::middleware::AuditLayer;
32//!
33//! // Route audit events to a file via tracing-subscriber
34//! let transport = StdioTransport::new(router)
35//!     .layer(AuditLayer::new());
36//!
37//! // Filter audit events in subscriber config:
38//! // RUST_LOG="mcp::audit=info"
39//! ```
40
41use std::future::Future;
42use std::pin::Pin;
43use std::task::{Context, Poll};
44use std::time::Instant;
45
46use tower::Layer;
47use tower_service::Service;
48use tracing::Level;
49
50use crate::protocol::McpRequest;
51use crate::router::{RouterRequest, RouterResponse, ToolAnnotationsMap};
52
53/// Tower layer that adds structured audit logging to all MCP requests.
54///
55/// This layer emits one structured [`tracing`] event per request at completion,
56/// using target `mcp::audit`. Events contain the method, operation details,
57/// duration, status, and tool annotation hints when available.
58///
59/// # Example
60///
61/// ```rust,ignore
62/// use tower_mcp::middleware::AuditLayer;
63///
64/// let transport = HttpTransport::new(router)
65///     .layer(AuditLayer::new());
66/// ```
67#[derive(Debug, Clone, Copy)]
68pub struct AuditLayer {
69    level: Level,
70}
71
72impl Default for AuditLayer {
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78impl AuditLayer {
79    /// Create a new audit layer with default settings (INFO level).
80    pub fn new() -> Self {
81        Self { level: Level::INFO }
82    }
83
84    /// Set the log level for audit events.
85    ///
86    /// Default is `INFO`.
87    pub fn level(mut self, level: Level) -> Self {
88        self.level = level;
89        self
90    }
91}
92
93impl<S> Layer<S> for AuditLayer {
94    type Service = AuditService<S>;
95
96    fn layer(&self, inner: S) -> Self::Service {
97        AuditService {
98            inner,
99            level: self.level,
100        }
101    }
102}
103
104/// Tower service that emits structured audit events for MCP requests.
105///
106/// Created by [`AuditLayer`]. See the layer documentation for details.
107#[derive(Debug, Clone)]
108pub struct AuditService<S> {
109    inner: S,
110    level: Level,
111}
112
113/// Audit-relevant details extracted from a request before forwarding.
114struct AuditInfo {
115    method: String,
116    request_id: String,
117    tool: Option<String>,
118    resource_uri: Option<String>,
119    prompt: Option<String>,
120    read_only: Option<bool>,
121    destructive: Option<bool>,
122}
123
124impl AuditInfo {
125    fn extract(req: &RouterRequest) -> Self {
126        let method = req.inner.method_name().to_string();
127        let request_id = format!("{:?}", req.id);
128
129        let mut info = Self {
130            method,
131            request_id,
132            tool: None,
133            resource_uri: None,
134            prompt: None,
135            read_only: None,
136            destructive: None,
137        };
138
139        match &req.inner {
140            McpRequest::CallTool(params) => {
141                info.tool = Some(params.name.clone());
142
143                if let Some(annotations) = req.extensions.get::<ToolAnnotationsMap>() {
144                    info.read_only = Some(annotations.is_read_only(&params.name));
145                    info.destructive = Some(annotations.is_destructive(&params.name));
146                }
147            }
148            McpRequest::ReadResource(params) => {
149                info.resource_uri = Some(params.uri.clone());
150            }
151            McpRequest::GetPrompt(params) => {
152                info.prompt = Some(params.name.clone());
153            }
154            McpRequest::SubscribeResource(params) => {
155                info.resource_uri = Some(params.uri.clone());
156            }
157            McpRequest::UnsubscribeResource(params) => {
158                info.resource_uri = Some(params.uri.clone());
159            }
160            _ => {}
161        }
162
163        info
164    }
165}
166
167/// JSON-RPC error code for "invalid params", which may indicate a denied request.
168const JSONRPC_INVALID_PARAMS: i32 = -32602;
169
170impl<S> Service<RouterRequest> for AuditService<S>
171where
172    S: Service<RouterRequest, Response = RouterResponse> + Clone + Send + 'static,
173    S::Error: Send,
174    S::Future: Send,
175{
176    type Response = RouterResponse;
177    type Error = S::Error;
178    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, S::Error>> + Send>>;
179
180    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
181        self.inner.poll_ready(cx)
182    }
183
184    fn call(&mut self, req: RouterRequest) -> Self::Future {
185        let info = AuditInfo::extract(&req);
186        let start = Instant::now();
187        let fut = self.inner.call(req);
188        let level = self.level;
189
190        Box::pin(async move {
191            let result = fut.await;
192            let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
193
194            if let Ok(response) = &result {
195                let (status, error) = match &response.inner {
196                    Ok(_) => ("success", None),
197                    Err(err) => {
198                        let s = if err.code == JSONRPC_INVALID_PARAMS {
199                            "denied"
200                        } else {
201                            "error"
202                        };
203                        (s, Some((err.code, err.message.as_str())))
204                    }
205                };
206
207                emit_audit_event(level, &info, duration_ms, status, error);
208            }
209
210            result
211        })
212    }
213}
214
215/// Emit a structured audit event at the configured level.
216///
217/// Uses tracing target `mcp::audit` so events can be routed independently
218/// from other application logs.
219fn emit_audit_event(
220    level: Level,
221    info: &AuditInfo,
222    duration_ms: f64,
223    status: &str,
224    error: Option<(i32, &str)>,
225) {
226    let method = info.method.as_str();
227    let request_id = info.request_id.as_str();
228    let tool = info.tool.as_deref();
229    let resource_uri = info.resource_uri.as_deref();
230    let prompt = info.prompt.as_deref();
231    let read_only = info.read_only;
232    let destructive = info.destructive;
233
234    match (level, error) {
235        (Level::TRACE, None) => {
236            tracing::trace!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, ?read_only, ?destructive, "audit")
237        }
238        (Level::TRACE, Some((code, msg))) => {
239            tracing::trace!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, error_code = code, error_message = msg, ?read_only, ?destructive, "audit")
240        }
241        (Level::DEBUG, None) => {
242            tracing::debug!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, ?read_only, ?destructive, "audit")
243        }
244        (Level::DEBUG, Some((code, msg))) => {
245            tracing::debug!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, error_code = code, error_message = msg, ?read_only, ?destructive, "audit")
246        }
247        (Level::INFO, None) => {
248            tracing::info!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, ?read_only, ?destructive, "audit")
249        }
250        (Level::INFO, Some((code, msg))) => {
251            tracing::info!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, error_code = code, error_message = msg, ?read_only, ?destructive, "audit")
252        }
253        (Level::WARN, None) => {
254            tracing::warn!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, ?read_only, ?destructive, "audit")
255        }
256        (Level::WARN, Some((code, msg))) => {
257            tracing::warn!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, error_code = code, error_message = msg, ?read_only, ?destructive, "audit")
258        }
259        (Level::ERROR, None) => {
260            tracing::error!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, ?read_only, ?destructive, "audit")
261        }
262        (Level::ERROR, Some((code, msg))) => {
263            tracing::error!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, error_code = code, error_message = msg, ?read_only, ?destructive, "audit")
264        }
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use crate::protocol::{CallToolParams, GetPromptParams, ReadResourceParams, RequestId};
272    use crate::router::Extensions;
273    use std::collections::HashMap;
274
275    #[test]
276    fn test_layer_creation() {
277        let layer = AuditLayer::new();
278        assert_eq!(layer.level, Level::INFO);
279    }
280
281    #[test]
282    fn test_layer_with_custom_level() {
283        let layer = AuditLayer::new().level(Level::DEBUG);
284        assert_eq!(layer.level, Level::DEBUG);
285    }
286
287    #[test]
288    fn test_layer_default() {
289        let layer = AuditLayer::default();
290        assert_eq!(layer.level, Level::INFO);
291    }
292
293    #[test]
294    fn test_audit_info_tool_call() {
295        let req = RouterRequest {
296            id: RequestId::Number(1),
297            inner: McpRequest::CallTool(CallToolParams {
298                name: "my_tool".to_string(),
299                arguments: serde_json::json!({}),
300                meta: None,
301                task: None,
302            }),
303            extensions: Extensions::new(),
304        };
305
306        let info = AuditInfo::extract(&req);
307        assert_eq!(info.method, "tools/call");
308        assert_eq!(info.tool, Some("my_tool".to_string()));
309        assert!(info.resource_uri.is_none());
310        assert!(info.prompt.is_none());
311    }
312
313    #[test]
314    fn test_audit_info_resource_read() {
315        let req = RouterRequest {
316            id: RequestId::Number(2),
317            inner: McpRequest::ReadResource(ReadResourceParams {
318                uri: "file:///test.txt".to_string(),
319                meta: None,
320            }),
321            extensions: Extensions::new(),
322        };
323
324        let info = AuditInfo::extract(&req);
325        assert_eq!(info.method, "resources/read");
326        assert!(info.tool.is_none());
327        assert_eq!(info.resource_uri, Some("file:///test.txt".to_string()));
328    }
329
330    #[test]
331    fn test_audit_info_prompt_get() {
332        let req = RouterRequest {
333            id: RequestId::Number(3),
334            inner: McpRequest::GetPrompt(GetPromptParams {
335                name: "review".to_string(),
336                arguments: HashMap::new(),
337                meta: None,
338            }),
339            extensions: Extensions::new(),
340        };
341
342        let info = AuditInfo::extract(&req);
343        assert_eq!(info.method, "prompts/get");
344        assert!(info.tool.is_none());
345        assert_eq!(info.prompt, Some("review".to_string()));
346    }
347
348    #[test]
349    fn test_audit_info_ping() {
350        let req = RouterRequest {
351            id: RequestId::Number(4),
352            inner: McpRequest::Ping,
353            extensions: Extensions::new(),
354        };
355
356        let info = AuditInfo::extract(&req);
357        assert_eq!(info.method, "ping");
358        assert!(info.tool.is_none());
359        assert!(info.resource_uri.is_none());
360        assert!(info.prompt.is_none());
361    }
362
363    #[tokio::test]
364    async fn test_passthrough() {
365        let router = crate::McpRouter::new().server_info("test", "1.0.0");
366        let layer = AuditLayer::new();
367        let mut service = layer.layer(router);
368
369        let req = RouterRequest {
370            id: RequestId::Number(1),
371            inner: McpRequest::Ping,
372            extensions: Extensions::new(),
373        };
374
375        let result = Service::call(&mut service, req).await;
376        assert!(result.is_ok());
377        assert!(result.unwrap().inner.is_ok());
378    }
379
380    #[tokio::test]
381    async fn test_tool_call_audit() {
382        let tool = crate::ToolBuilder::new("test_tool")
383            .description("A test tool")
384            .handler(|_: serde_json::Value| async move { Ok(crate::CallToolResult::text("done")) })
385            .build();
386
387        let router = crate::McpRouter::new()
388            .server_info("test", "1.0.0")
389            .tool(tool);
390        let layer = AuditLayer::new();
391        let mut service = layer.layer(router);
392
393        let req = RouterRequest {
394            id: RequestId::Number(1),
395            inner: McpRequest::CallTool(CallToolParams {
396                name: "test_tool".to_string(),
397                arguments: serde_json::json!({}),
398                meta: None,
399                task: None,
400            }),
401            extensions: Extensions::new(),
402        };
403
404        let result = Service::call(&mut service, req).await;
405        assert!(result.is_ok());
406    }
407
408    #[tokio::test]
409    async fn test_error_audit() {
410        let router = crate::McpRouter::new().server_info("test", "1.0.0");
411        let layer = AuditLayer::new();
412        let mut service = layer.layer(router);
413
414        let req = RouterRequest {
415            id: RequestId::Number(1),
416            inner: McpRequest::CallTool(CallToolParams {
417                name: "nonexistent".to_string(),
418                arguments: serde_json::json!({}),
419                meta: None,
420                task: None,
421            }),
422            extensions: Extensions::new(),
423        };
424
425        let result = Service::call(&mut service, req).await;
426        assert!(result.is_ok());
427        assert!(result.unwrap().inner.is_err());
428    }
429}