1use std::future::Future;
32use std::pin::Pin;
33use std::task::{Context, Poll};
34use std::time::Instant;
35
36use tower::Layer;
37use tower_service::Service;
38use tracing::Level;
39
40use crate::protocol::McpRequest;
41use crate::router::{RouterRequest, RouterResponse, ToolAnnotationsMap};
42
43const JSONRPC_INVALID_PARAMS: i32 = -32602;
45
46#[derive(Debug, Clone, Copy)]
67pub struct ToolCallLoggingLayer {
68 level: Level,
69}
70
71impl Default for ToolCallLoggingLayer {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77impl ToolCallLoggingLayer {
78 pub fn new() -> Self {
80 Self { level: Level::INFO }
81 }
82
83 pub fn level(mut self, level: Level) -> Self {
87 self.level = level;
88 self
89 }
90}
91
92impl<S> Layer<S> for ToolCallLoggingLayer {
93 type Service = ToolCallLoggingService<S>;
94
95 fn layer(&self, inner: S) -> Self::Service {
96 ToolCallLoggingService {
97 inner,
98 level: self.level,
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
107pub struct ToolCallLoggingService<S> {
108 inner: S,
109 level: Level,
110}
111
112impl<S> Service<RouterRequest> for ToolCallLoggingService<S>
113where
114 S: Service<RouterRequest, Response = RouterResponse> + Clone + Send + 'static,
115 S::Error: Send,
116 S::Future: Send,
117{
118 type Response = RouterResponse;
119 type Error = S::Error;
120 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, S::Error>> + Send>>;
121
122 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
123 self.inner.poll_ready(cx)
124 }
125
126 fn call(&mut self, req: RouterRequest) -> Self::Future {
127 let tool_name = match &req.inner {
129 McpRequest::CallTool(params) => params.name.clone(),
130 _ => {
131 let fut = self.inner.call(req);
132 return Box::pin(fut);
133 }
134 };
135
136 let read_only = req
138 .extensions
139 .get::<ToolAnnotationsMap>()
140 .map(|m| m.is_read_only(&tool_name));
141 let destructive = req
142 .extensions
143 .get::<ToolAnnotationsMap>()
144 .map(|m| m.is_destructive(&tool_name));
145
146 let start = Instant::now();
147 let fut = self.inner.call(req);
148 let level = self.level;
149
150 Box::pin(async move {
151 let result = fut.await;
152 let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
153
154 if let Ok(response) = &result {
155 match &response.inner {
156 Ok(_) => {
157 log_tool_call(
158 level,
159 &tool_name,
160 duration_ms,
161 "success",
162 None,
163 read_only,
164 destructive,
165 );
166 }
167 Err(err) => {
168 let status = if err.code == JSONRPC_INVALID_PARAMS {
169 "denied"
170 } else {
171 "error"
172 };
173 log_tool_call(
174 level,
175 &tool_name,
176 duration_ms,
177 status,
178 Some((err.code, &err.message)),
179 read_only,
180 destructive,
181 );
182 }
183 }
184 }
185
186 result
187 })
188 }
189}
190
191fn log_tool_call(
193 level: Level,
194 tool: &str,
195 duration_ms: f64,
196 status: &str,
197 error: Option<(i32, &str)>,
198 read_only: Option<bool>,
199 destructive: Option<bool>,
200) {
201 match (level, error) {
202 (Level::TRACE, None) => {
203 tracing::trace!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
204 }
205 (Level::TRACE, Some((code, message))) => {
206 tracing::trace!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
207 }
208 (Level::DEBUG, None) => {
209 tracing::debug!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
210 }
211 (Level::DEBUG, Some((code, message))) => {
212 tracing::debug!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
213 }
214 (Level::INFO, None) => {
215 tracing::info!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
216 }
217 (Level::INFO, Some((code, message))) => {
218 tracing::info!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
219 }
220 (Level::WARN, None) => {
221 tracing::warn!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
222 }
223 (Level::WARN, Some((code, message))) => {
224 tracing::warn!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
225 }
226 (Level::ERROR, None) => {
227 tracing::error!(target: "mcp::tools", tool, duration_ms, status, ?read_only, ?destructive, "tool call completed")
228 }
229 (Level::ERROR, Some((code, message))) => {
230 tracing::error!(target: "mcp::tools", tool, duration_ms, status, error_code = code, error_message = message, ?read_only, ?destructive, "tool call completed")
231 }
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use crate::protocol::{CallToolParams, RequestId};
239 use crate::router::Extensions;
240
241 #[test]
242 fn test_layer_creation() {
243 let layer = ToolCallLoggingLayer::new();
244 assert_eq!(layer.level, Level::INFO);
245 }
246
247 #[test]
248 fn test_layer_with_custom_level() {
249 let layer = ToolCallLoggingLayer::new().level(Level::DEBUG);
250 assert_eq!(layer.level, Level::DEBUG);
251 }
252
253 #[test]
254 fn test_layer_default() {
255 let layer = ToolCallLoggingLayer::default();
256 assert_eq!(layer.level, Level::INFO);
257 }
258
259 #[tokio::test]
260 async fn test_non_tool_call_passthrough() {
261 let router = crate::McpRouter::new().server_info("test", "1.0.0");
262 let layer = ToolCallLoggingLayer::new();
263 let mut service = layer.layer(router);
264
265 let req = RouterRequest {
267 id: RequestId::Number(1),
268 inner: McpRequest::Ping,
269 extensions: Extensions::new(),
270 };
271
272 let result = Service::call(&mut service, req).await;
273 assert!(result.is_ok());
274 assert!(result.unwrap().inner.is_ok());
275 }
276
277 #[tokio::test]
278 async fn test_tool_call_logging() {
279 let tool = crate::ToolBuilder::new("test_tool")
280 .description("A test tool")
281 .handler(|_: serde_json::Value| async move { Ok(crate::CallToolResult::text("done")) })
282 .build();
283
284 let router = crate::McpRouter::new()
285 .server_info("test", "1.0.0")
286 .tool(tool);
287 let layer = ToolCallLoggingLayer::new();
288 let mut service = layer.layer(router);
289
290 let req = RouterRequest {
291 id: RequestId::Number(1),
292 inner: McpRequest::CallTool(CallToolParams {
293 name: "test_tool".to_string(),
294 arguments: serde_json::json!({}),
295 meta: None,
296 task: None,
297 }),
298 extensions: Extensions::new(),
299 };
300
301 let result = Service::call(&mut service, req).await;
305 assert!(result.is_ok());
306 }
307
308 #[tokio::test]
309 async fn test_tool_call_error_logging() {
310 let router = crate::McpRouter::new().server_info("test", "1.0.0");
312 let layer = ToolCallLoggingLayer::new();
313 let mut service = layer.layer(router);
314
315 let req = RouterRequest {
316 id: RequestId::Number(1),
317 inner: McpRequest::CallTool(CallToolParams {
318 name: "nonexistent".to_string(),
319 arguments: serde_json::json!({}),
320 meta: None,
321 task: None,
322 }),
323 extensions: Extensions::new(),
324 };
325
326 let result = Service::call(&mut service, req).await;
327 assert!(result.is_ok());
328 assert!(result.unwrap().inner.is_err());
330 }
331}