1use std::future::Future;
42use std::pin::Pin;
43use std::task::{Context, Poll};
44use std::time::Instant;
45
46use tower::Layer;
47use tower_service::Service;
48use tracing::Level;
49
50use crate::protocol::McpRequest;
51use crate::router::{RouterRequest, RouterResponse, ToolAnnotationsMap};
52
53#[derive(Debug, Clone, Copy)]
68pub struct AuditLayer {
69 level: Level,
70}
71
72impl Default for AuditLayer {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78impl AuditLayer {
79 pub fn new() -> Self {
81 Self { level: Level::INFO }
82 }
83
84 pub fn level(mut self, level: Level) -> Self {
88 self.level = level;
89 self
90 }
91}
92
93impl<S> Layer<S> for AuditLayer {
94 type Service = AuditService<S>;
95
96 fn layer(&self, inner: S) -> Self::Service {
97 AuditService {
98 inner,
99 level: self.level,
100 }
101 }
102}
103
104#[derive(Debug, Clone)]
108pub struct AuditService<S> {
109 inner: S,
110 level: Level,
111}
112
113struct AuditInfo {
115 method: String,
116 request_id: String,
117 tool: Option<String>,
118 resource_uri: Option<String>,
119 prompt: Option<String>,
120 read_only: Option<bool>,
121 destructive: Option<bool>,
122}
123
124impl AuditInfo {
125 fn extract(req: &RouterRequest) -> Self {
126 let method = req.inner.method_name().to_string();
127 let request_id = format!("{:?}", req.id);
128
129 let mut info = Self {
130 method,
131 request_id,
132 tool: None,
133 resource_uri: None,
134 prompt: None,
135 read_only: None,
136 destructive: None,
137 };
138
139 match &req.inner {
140 McpRequest::CallTool(params) => {
141 info.tool = Some(params.name.clone());
142
143 if let Some(annotations) = req.extensions.get::<ToolAnnotationsMap>() {
144 info.read_only = Some(annotations.is_read_only(¶ms.name));
145 info.destructive = Some(annotations.is_destructive(¶ms.name));
146 }
147 }
148 McpRequest::ReadResource(params) => {
149 info.resource_uri = Some(params.uri.clone());
150 }
151 McpRequest::GetPrompt(params) => {
152 info.prompt = Some(params.name.clone());
153 }
154 McpRequest::SubscribeResource(params) => {
155 info.resource_uri = Some(params.uri.clone());
156 }
157 McpRequest::UnsubscribeResource(params) => {
158 info.resource_uri = Some(params.uri.clone());
159 }
160 _ => {}
161 }
162
163 info
164 }
165}
166
167const JSONRPC_INVALID_PARAMS: i32 = -32602;
169
170impl<S> Service<RouterRequest> for AuditService<S>
171where
172 S: Service<RouterRequest, Response = RouterResponse> + Clone + Send + 'static,
173 S::Error: Send,
174 S::Future: Send,
175{
176 type Response = RouterResponse;
177 type Error = S::Error;
178 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, S::Error>> + Send>>;
179
180 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
181 self.inner.poll_ready(cx)
182 }
183
184 fn call(&mut self, req: RouterRequest) -> Self::Future {
185 let info = AuditInfo::extract(&req);
186 let start = Instant::now();
187 let fut = self.inner.call(req);
188 let level = self.level;
189
190 Box::pin(async move {
191 let result = fut.await;
192 let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
193
194 if let Ok(response) = &result {
195 let (status, error) = match &response.inner {
196 Ok(_) => ("success", None),
197 Err(err) => {
198 let s = if err.code == JSONRPC_INVALID_PARAMS {
199 "denied"
200 } else {
201 "error"
202 };
203 (s, Some((err.code, err.message.as_str())))
204 }
205 };
206
207 emit_audit_event(level, &info, duration_ms, status, error);
208 }
209
210 result
211 })
212 }
213}
214
215fn emit_audit_event(
220 level: Level,
221 info: &AuditInfo,
222 duration_ms: f64,
223 status: &str,
224 error: Option<(i32, &str)>,
225) {
226 let method = info.method.as_str();
227 let request_id = info.request_id.as_str();
228 let tool = info.tool.as_deref();
229 let resource_uri = info.resource_uri.as_deref();
230 let prompt = info.prompt.as_deref();
231 let read_only = info.read_only;
232 let destructive = info.destructive;
233
234 match (level, error) {
235 (Level::TRACE, None) => {
236 tracing::trace!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, ?read_only, ?destructive, "audit")
237 }
238 (Level::TRACE, Some((code, msg))) => {
239 tracing::trace!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, error_code = code, error_message = msg, ?read_only, ?destructive, "audit")
240 }
241 (Level::DEBUG, None) => {
242 tracing::debug!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, ?read_only, ?destructive, "audit")
243 }
244 (Level::DEBUG, Some((code, msg))) => {
245 tracing::debug!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, error_code = code, error_message = msg, ?read_only, ?destructive, "audit")
246 }
247 (Level::INFO, None) => {
248 tracing::info!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, ?read_only, ?destructive, "audit")
249 }
250 (Level::INFO, Some((code, msg))) => {
251 tracing::info!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, error_code = code, error_message = msg, ?read_only, ?destructive, "audit")
252 }
253 (Level::WARN, None) => {
254 tracing::warn!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, ?read_only, ?destructive, "audit")
255 }
256 (Level::WARN, Some((code, msg))) => {
257 tracing::warn!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, error_code = code, error_message = msg, ?read_only, ?destructive, "audit")
258 }
259 (Level::ERROR, None) => {
260 tracing::error!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, ?read_only, ?destructive, "audit")
261 }
262 (Level::ERROR, Some((code, msg))) => {
263 tracing::error!(target: "mcp::audit", method, request_id, ?tool, ?resource_uri, ?prompt, duration_ms, status, error_code = code, error_message = msg, ?read_only, ?destructive, "audit")
264 }
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use crate::protocol::{CallToolParams, GetPromptParams, ReadResourceParams, RequestId};
272 use crate::router::Extensions;
273 use std::collections::HashMap;
274
275 #[test]
276 fn test_layer_creation() {
277 let layer = AuditLayer::new();
278 assert_eq!(layer.level, Level::INFO);
279 }
280
281 #[test]
282 fn test_layer_with_custom_level() {
283 let layer = AuditLayer::new().level(Level::DEBUG);
284 assert_eq!(layer.level, Level::DEBUG);
285 }
286
287 #[test]
288 fn test_layer_default() {
289 let layer = AuditLayer::default();
290 assert_eq!(layer.level, Level::INFO);
291 }
292
293 #[test]
294 fn test_audit_info_tool_call() {
295 let req = RouterRequest {
296 id: RequestId::Number(1),
297 inner: McpRequest::CallTool(CallToolParams {
298 name: "my_tool".to_string(),
299 arguments: serde_json::json!({}),
300 meta: None,
301 task: None,
302 }),
303 extensions: Extensions::new(),
304 };
305
306 let info = AuditInfo::extract(&req);
307 assert_eq!(info.method, "tools/call");
308 assert_eq!(info.tool, Some("my_tool".to_string()));
309 assert!(info.resource_uri.is_none());
310 assert!(info.prompt.is_none());
311 }
312
313 #[test]
314 fn test_audit_info_resource_read() {
315 let req = RouterRequest {
316 id: RequestId::Number(2),
317 inner: McpRequest::ReadResource(ReadResourceParams {
318 uri: "file:///test.txt".to_string(),
319 meta: None,
320 }),
321 extensions: Extensions::new(),
322 };
323
324 let info = AuditInfo::extract(&req);
325 assert_eq!(info.method, "resources/read");
326 assert!(info.tool.is_none());
327 assert_eq!(info.resource_uri, Some("file:///test.txt".to_string()));
328 }
329
330 #[test]
331 fn test_audit_info_prompt_get() {
332 let req = RouterRequest {
333 id: RequestId::Number(3),
334 inner: McpRequest::GetPrompt(GetPromptParams {
335 name: "review".to_string(),
336 arguments: HashMap::new(),
337 meta: None,
338 }),
339 extensions: Extensions::new(),
340 };
341
342 let info = AuditInfo::extract(&req);
343 assert_eq!(info.method, "prompts/get");
344 assert!(info.tool.is_none());
345 assert_eq!(info.prompt, Some("review".to_string()));
346 }
347
348 #[test]
349 fn test_audit_info_ping() {
350 let req = RouterRequest {
351 id: RequestId::Number(4),
352 inner: McpRequest::Ping,
353 extensions: Extensions::new(),
354 };
355
356 let info = AuditInfo::extract(&req);
357 assert_eq!(info.method, "ping");
358 assert!(info.tool.is_none());
359 assert!(info.resource_uri.is_none());
360 assert!(info.prompt.is_none());
361 }
362
363 #[tokio::test]
364 async fn test_passthrough() {
365 let router = crate::McpRouter::new().server_info("test", "1.0.0");
366 let layer = AuditLayer::new();
367 let mut service = layer.layer(router);
368
369 let req = RouterRequest {
370 id: RequestId::Number(1),
371 inner: McpRequest::Ping,
372 extensions: Extensions::new(),
373 };
374
375 let result = Service::call(&mut service, req).await;
376 assert!(result.is_ok());
377 assert!(result.unwrap().inner.is_ok());
378 }
379
380 #[tokio::test]
381 async fn test_tool_call_audit() {
382 let tool = crate::ToolBuilder::new("test_tool")
383 .description("A test tool")
384 .handler(|_: serde_json::Value| async move { Ok(crate::CallToolResult::text("done")) })
385 .build();
386
387 let router = crate::McpRouter::new()
388 .server_info("test", "1.0.0")
389 .tool(tool);
390 let layer = AuditLayer::new();
391 let mut service = layer.layer(router);
392
393 let req = RouterRequest {
394 id: RequestId::Number(1),
395 inner: McpRequest::CallTool(CallToolParams {
396 name: "test_tool".to_string(),
397 arguments: serde_json::json!({}),
398 meta: None,
399 task: None,
400 }),
401 extensions: Extensions::new(),
402 };
403
404 let result = Service::call(&mut service, req).await;
405 assert!(result.is_ok());
406 }
407
408 #[tokio::test]
409 async fn test_error_audit() {
410 let router = crate::McpRouter::new().server_info("test", "1.0.0");
411 let layer = AuditLayer::new();
412 let mut service = layer.layer(router);
413
414 let req = RouterRequest {
415 id: RequestId::Number(1),
416 inner: McpRequest::CallTool(CallToolParams {
417 name: "nonexistent".to_string(),
418 arguments: serde_json::json!({}),
419 meta: None,
420 task: None,
421 }),
422 extensions: Extensions::new(),
423 };
424
425 let result = Service::call(&mut service, req).await;
426 assert!(result.is_ok());
427 assert!(result.unwrap().inner.is_err());
428 }
429}