1use std::sync::Arc;
8
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11
12use crate::error::McpAdapterError;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct McpToolCallRequest {
21 pub tool_name: String,
23 pub server_name: String,
25 pub arguments: Value,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct McpToolCallResult {
32 pub value: Value,
34 pub is_error: bool,
36}
37
38pub type InterceptorNext<'a> = Box<
46 dyn FnOnce(
47 McpToolCallRequest,
48 ) -> std::pin::Pin<
49 Box<
50 dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>>
51 + Send
52 + 'a,
53 >,
54 > + Send
55 + 'a,
56>;
57
58pub trait ToolCallInterceptor: Send + Sync {
69 fn intercept<'a>(
74 &'a self,
75 request: McpToolCallRequest,
76 next: InterceptorNext<'a>,
77 ) -> std::pin::Pin<
78 Box<
79 dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>>
80 + Send
81 + 'a,
82 >,
83 >;
84}
85
86pub async fn run_interceptor_chain<F>(
100 interceptors: &[Arc<dyn ToolCallInterceptor>],
101 request: McpToolCallRequest,
102 inner: F,
103) -> Result<McpToolCallResult, McpAdapterError>
104where
105 F: FnOnce(
106 McpToolCallRequest,
107 ) -> std::pin::Pin<
108 Box<
109 dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>> + Send,
110 >,
111 > + Send
112 + 'static,
113{
114 if interceptors.is_empty() {
119 return inner(request).await;
120 }
121
122 run_chain_from(interceptors, 0, request, inner).await
124}
125
126fn run_chain_from<'a, F>(
128 interceptors: &'a [Arc<dyn ToolCallInterceptor>],
129 index: usize,
130 request: McpToolCallRequest,
131 inner: F,
132) -> std::pin::Pin<
133 Box<dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>> + Send + 'a>,
134>
135where
136 F: FnOnce(
137 McpToolCallRequest,
138 ) -> std::pin::Pin<
139 Box<
140 dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>> + Send,
141 >,
142 > + Send
143 + 'static,
144{
145 if index >= interceptors.len() {
146 return Box::pin(async move { inner(request).await });
147 }
148
149 let interceptor = Arc::clone(&interceptors[index]);
150 let remaining = interceptors;
151 let next_index = index + 1;
152
153 Box::pin(async move {
154 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
156 interceptor.intercept(
157 request,
158 Box::new(move |req| run_chain_from(remaining, next_index, req, inner)),
159 )
160 }));
161
162 match result {
163 Ok(fut) => fut.await,
164 Err(payload) => {
165 let msg = payload
166 .downcast_ref::<&str>()
167 .map(|s| (*s).to_owned())
168 .or_else(|| payload.downcast_ref::<String>().cloned())
169 .unwrap_or_else(|| "unknown panic".to_owned());
170 Err(McpAdapterError::InterceptorPanic { message: msg })
171 }
172 }
173 })
174}
175
176#[derive(Debug, Default)]
182pub struct LoggingInterceptor;
183
184impl ToolCallInterceptor for LoggingInterceptor {
185 fn intercept<'a>(
186 &'a self,
187 request: McpToolCallRequest,
188 next: InterceptorNext<'a>,
189 ) -> std::pin::Pin<
190 Box<
191 dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>>
192 + Send
193 + 'a,
194 >,
195 > {
196 Box::pin(async move {
197 tracing::debug!(
198 tool = %request.tool_name,
199 server = %request.server_name,
200 "MCP tool call intercepted"
201 );
202 let result = next(request).await;
203 match &result {
204 Ok(r) => tracing::debug!(is_error = r.is_error, "MCP tool call completed"),
205 Err(e) => tracing::warn!(error = %e, "MCP tool call failed in interceptor chain"),
206 }
207 result
208 })
209 }
210}
211
212#[cfg(test)]
213#[allow(clippy::unwrap_used)]
214mod tests {
215 use super::*;
216
217 struct RecordingInterceptor {
218 id: char,
219 order: Arc<tokio::sync::Mutex<Vec<char>>>,
220 }
221
222 impl ToolCallInterceptor for RecordingInterceptor {
223 fn intercept<'a>(
224 &'a self,
225 request: McpToolCallRequest,
226 next: InterceptorNext<'a>,
227 ) -> std::pin::Pin<
228 Box<
229 dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>>
230 + Send
231 + 'a,
232 >,
233 > {
234 let id = self.id;
235 let order = Arc::clone(&self.order);
236 Box::pin(async move {
237 order.lock().await.push(id);
238 let result = next(request).await;
239 order.lock().await.push(id);
240 result
241 })
242 }
243 }
244
245 struct ShortCircuitInterceptor;
246 impl ToolCallInterceptor for ShortCircuitInterceptor {
247 fn intercept<'a>(
248 &'a self,
249 _request: McpToolCallRequest,
250 _next: InterceptorNext<'a>,
251 ) -> std::pin::Pin<
252 Box<
253 dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>>
254 + Send
255 + 'a,
256 >,
257 > {
258 Box::pin(async {
259 Ok(McpToolCallResult {
260 value: serde_json::json!({"short": "circuit"}),
261 is_error: false,
262 })
263 })
264 }
265 }
266
267 fn make_inner() -> impl FnOnce(
268 McpToolCallRequest,
269 ) -> std::pin::Pin<
270 Box<dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>> + Send>,
271 > + Send
272 + 'static {
273 |_req| {
274 Box::pin(async {
275 Ok(McpToolCallResult {
276 value: serde_json::json!({"result": "ok"}),
277 is_error: false,
278 })
279 })
280 }
281 }
282
283 fn make_request() -> McpToolCallRequest {
284 McpToolCallRequest {
285 tool_name: "search".into(),
286 server_name: "s1".into(),
287 arguments: serde_json::json!({}),
288 }
289 }
290
291 #[tokio::test]
292 async fn onion_ordering_abc_to_cba() {
293 let order: Arc<tokio::sync::Mutex<Vec<char>>> =
294 Arc::new(tokio::sync::Mutex::new(Vec::new()));
295 let interceptors: Vec<Arc<dyn ToolCallInterceptor>> = vec![
296 Arc::new(RecordingInterceptor {
297 id: 'A',
298 order: Arc::clone(&order),
299 }),
300 Arc::new(RecordingInterceptor {
301 id: 'B',
302 order: Arc::clone(&order),
303 }),
304 Arc::new(RecordingInterceptor {
305 id: 'C',
306 order: Arc::clone(&order),
307 }),
308 ];
309 let result = run_interceptor_chain(&interceptors, make_request(), make_inner()).await;
310 assert!(result.is_ok());
311 let sequence = order.lock().await.clone();
312 assert_eq!(sequence, vec!['A', 'B', 'C', 'C', 'B', 'A']);
314 }
315
316 #[tokio::test]
317 async fn short_circuit_interceptor_stops_chain() {
318 let order: Arc<tokio::sync::Mutex<Vec<char>>> =
319 Arc::new(tokio::sync::Mutex::new(Vec::new()));
320 let interceptors: Vec<Arc<dyn ToolCallInterceptor>> = vec![
321 Arc::new(RecordingInterceptor {
322 id: 'A',
323 order: Arc::clone(&order),
324 }),
325 Arc::new(ShortCircuitInterceptor),
326 Arc::new(RecordingInterceptor {
327 id: 'C',
328 order: Arc::clone(&order),
329 }),
330 ];
331 let result = run_interceptor_chain(&interceptors, make_request(), make_inner()).await;
332 assert!(result.is_ok());
333 let r = result.unwrap();
334 assert_eq!(r.value["short"], "circuit");
335 let sequence = order.lock().await.clone();
337 assert_eq!(sequence, vec!['A', 'A']);
338 }
339
340 #[tokio::test]
341 async fn no_interceptors_calls_inner() {
342 let result = run_interceptor_chain(&[], make_request(), make_inner()).await;
343 assert!(result.is_ok());
344 assert_eq!(result.unwrap().value["result"], "ok");
345 }
346}