Skip to main content

synwire_mcp_adapters/
interceptor.rs

1//! Tool call interceptor chain (onion/middleware pattern).
2//!
3//! Interceptors wrap MCP tool call invocations in an onion-layered chain.
4//! Each interceptor can inspect, modify, short-circuit, or observe requests
5//! and responses. The ordering is `A → B → C → tool → C → B → A`.
6
7use std::sync::Arc;
8
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11
12use crate::error::McpAdapterError;
13
14// ---------------------------------------------------------------------------
15// Request and result types
16// ---------------------------------------------------------------------------
17
18/// An MCP tool call request passed through the interceptor chain.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct McpToolCallRequest {
21    /// Name of the tool (exposed name, may include prefix).
22    pub tool_name: String,
23    /// Server name that will handle the call.
24    pub server_name: String,
25    /// JSON arguments for the tool.
26    pub arguments: Value,
27}
28
29/// The result of an MCP tool call after passing through the interceptor chain.
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct McpToolCallResult {
32    /// Raw result value from the server.
33    pub value: Value,
34    /// Whether the result represents an error.
35    pub is_error: bool,
36}
37
38// ---------------------------------------------------------------------------
39// ToolCallInterceptor trait
40// ---------------------------------------------------------------------------
41
42/// Type alias for the `next` continuation in the interceptor chain.
43///
44/// Calling `next` invokes the remaining interceptors and ultimately the tool.
45pub type InterceptorNext<'a> = Box<
46    dyn FnOnce(
47            McpToolCallRequest,
48        ) -> std::pin::Pin<
49            Box<
50                dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>>
51                    + Send
52                    + 'a,
53            >,
54        > + Send
55        + 'a,
56>;
57
58/// An interceptor that wraps MCP tool calls.
59///
60/// Implement this trait to inspect, modify, or short-circuit tool call
61/// requests and responses. Implementations must be `Send + Sync`.
62///
63/// # Ordering
64///
65/// Interceptors are executed in the order they are added to the chain.
66/// If you add interceptors A, B, C, the call order is:
67/// `A.intercept → B.intercept → C.intercept → tool → C returns → B returns → A returns`
68pub trait ToolCallInterceptor: Send + Sync {
69    /// Intercept a tool call.
70    ///
71    /// Call `next(request)` to continue the chain, or return a result
72    /// directly to short-circuit.
73    fn intercept<'a>(
74        &'a self,
75        request: McpToolCallRequest,
76        next: InterceptorNext<'a>,
77    ) -> std::pin::Pin<
78        Box<
79            dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>>
80                + Send
81                + 'a,
82        >,
83    >;
84}
85
86// ---------------------------------------------------------------------------
87// Chain executor
88// ---------------------------------------------------------------------------
89
90/// Executes a list of interceptors in onion order, with a final handler.
91///
92/// Interceptors wrap the `inner` future in order; panics within any
93/// interceptor are caught and converted to [`McpAdapterError::InterceptorPanic`].
94///
95/// # Panic safety
96///
97/// Each interceptor is wrapped with `catch_unwind`. Panics in the `inner`
98/// function itself are **not** caught.
99pub async fn run_interceptor_chain<F>(
100    interceptors: &[Arc<dyn ToolCallInterceptor>],
101    request: McpToolCallRequest,
102    inner: F,
103) -> Result<McpToolCallResult, McpAdapterError>
104where
105    F: FnOnce(
106            McpToolCallRequest,
107        ) -> std::pin::Pin<
108            Box<
109                dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>> + Send,
110            >,
111        > + Send
112        + 'static,
113{
114    // Build the chain from right to left (innermost first).
115    // We represent each level as an Arc<dyn Fn(...)> to allow cloning.
116    // For a small number of interceptors this is acceptable.
117
118    if interceptors.is_empty() {
119        return inner(request).await;
120    }
121
122    // Recursively build the chain using an index.
123    run_chain_from(interceptors, 0, request, inner).await
124}
125
126/// Recursive helper that builds the interceptor chain.
127fn run_chain_from<'a, F>(
128    interceptors: &'a [Arc<dyn ToolCallInterceptor>],
129    index: usize,
130    request: McpToolCallRequest,
131    inner: F,
132) -> std::pin::Pin<
133    Box<dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>> + Send + 'a>,
134>
135where
136    F: FnOnce(
137            McpToolCallRequest,
138        ) -> std::pin::Pin<
139            Box<
140                dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>> + Send,
141            >,
142        > + Send
143        + 'static,
144{
145    if index >= interceptors.len() {
146        return Box::pin(async move { inner(request).await });
147    }
148
149    let interceptor = Arc::clone(&interceptors[index]);
150    let remaining = interceptors;
151    let next_index = index + 1;
152
153    Box::pin(async move {
154        // Catch panics from the interceptor itself
155        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
156            interceptor.intercept(
157                request,
158                Box::new(move |req| run_chain_from(remaining, next_index, req, inner)),
159            )
160        }));
161
162        match result {
163            Ok(fut) => fut.await,
164            Err(payload) => {
165                let msg = payload
166                    .downcast_ref::<&str>()
167                    .map(|s| (*s).to_owned())
168                    .or_else(|| payload.downcast_ref::<String>().cloned())
169                    .unwrap_or_else(|| "unknown panic".to_owned());
170                Err(McpAdapterError::InterceptorPanic { message: msg })
171            }
172        }
173    })
174}
175
176// ---------------------------------------------------------------------------
177// Built-in interceptors
178// ---------------------------------------------------------------------------
179
180/// An interceptor that logs tool call requests and results via `tracing`.
181#[derive(Debug, Default)]
182pub struct LoggingInterceptor;
183
184impl ToolCallInterceptor for LoggingInterceptor {
185    fn intercept<'a>(
186        &'a self,
187        request: McpToolCallRequest,
188        next: InterceptorNext<'a>,
189    ) -> std::pin::Pin<
190        Box<
191            dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>>
192                + Send
193                + 'a,
194        >,
195    > {
196        Box::pin(async move {
197            tracing::debug!(
198                tool = %request.tool_name,
199                server = %request.server_name,
200                "MCP tool call intercepted"
201            );
202            let result = next(request).await;
203            match &result {
204                Ok(r) => tracing::debug!(is_error = r.is_error, "MCP tool call completed"),
205                Err(e) => tracing::warn!(error = %e, "MCP tool call failed in interceptor chain"),
206            }
207            result
208        })
209    }
210}
211
212#[cfg(test)]
213#[allow(clippy::unwrap_used)]
214mod tests {
215    use super::*;
216
217    struct RecordingInterceptor {
218        id: char,
219        order: Arc<tokio::sync::Mutex<Vec<char>>>,
220    }
221
222    impl ToolCallInterceptor for RecordingInterceptor {
223        fn intercept<'a>(
224            &'a self,
225            request: McpToolCallRequest,
226            next: InterceptorNext<'a>,
227        ) -> std::pin::Pin<
228            Box<
229                dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>>
230                    + Send
231                    + 'a,
232            >,
233        > {
234            let id = self.id;
235            let order = Arc::clone(&self.order);
236            Box::pin(async move {
237                order.lock().await.push(id);
238                let result = next(request).await;
239                order.lock().await.push(id);
240                result
241            })
242        }
243    }
244
245    struct ShortCircuitInterceptor;
246    impl ToolCallInterceptor for ShortCircuitInterceptor {
247        fn intercept<'a>(
248            &'a self,
249            _request: McpToolCallRequest,
250            _next: InterceptorNext<'a>,
251        ) -> std::pin::Pin<
252            Box<
253                dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>>
254                    + Send
255                    + 'a,
256            >,
257        > {
258            Box::pin(async {
259                Ok(McpToolCallResult {
260                    value: serde_json::json!({"short": "circuit"}),
261                    is_error: false,
262                })
263            })
264        }
265    }
266
267    fn make_inner() -> impl FnOnce(
268        McpToolCallRequest,
269    ) -> std::pin::Pin<
270        Box<dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>> + Send>,
271    > + Send
272    + 'static {
273        |_req| {
274            Box::pin(async {
275                Ok(McpToolCallResult {
276                    value: serde_json::json!({"result": "ok"}),
277                    is_error: false,
278                })
279            })
280        }
281    }
282
283    fn make_request() -> McpToolCallRequest {
284        McpToolCallRequest {
285            tool_name: "search".into(),
286            server_name: "s1".into(),
287            arguments: serde_json::json!({}),
288        }
289    }
290
291    #[tokio::test]
292    async fn onion_ordering_abc_to_cba() {
293        let order: Arc<tokio::sync::Mutex<Vec<char>>> =
294            Arc::new(tokio::sync::Mutex::new(Vec::new()));
295        let interceptors: Vec<Arc<dyn ToolCallInterceptor>> = vec![
296            Arc::new(RecordingInterceptor {
297                id: 'A',
298                order: Arc::clone(&order),
299            }),
300            Arc::new(RecordingInterceptor {
301                id: 'B',
302                order: Arc::clone(&order),
303            }),
304            Arc::new(RecordingInterceptor {
305                id: 'C',
306                order: Arc::clone(&order),
307            }),
308        ];
309        let result = run_interceptor_chain(&interceptors, make_request(), make_inner()).await;
310        assert!(result.is_ok());
311        let sequence = order.lock().await.clone();
312        // Enter: A B C, Exit: C B A
313        assert_eq!(sequence, vec!['A', 'B', 'C', 'C', 'B', 'A']);
314    }
315
316    #[tokio::test]
317    async fn short_circuit_interceptor_stops_chain() {
318        let order: Arc<tokio::sync::Mutex<Vec<char>>> =
319            Arc::new(tokio::sync::Mutex::new(Vec::new()));
320        let interceptors: Vec<Arc<dyn ToolCallInterceptor>> = vec![
321            Arc::new(RecordingInterceptor {
322                id: 'A',
323                order: Arc::clone(&order),
324            }),
325            Arc::new(ShortCircuitInterceptor),
326            Arc::new(RecordingInterceptor {
327                id: 'C',
328                order: Arc::clone(&order),
329            }),
330        ];
331        let result = run_interceptor_chain(&interceptors, make_request(), make_inner()).await;
332        assert!(result.is_ok());
333        let r = result.unwrap();
334        assert_eq!(r.value["short"], "circuit");
335        // Only A entered (and then returned after short-circuit)
336        let sequence = order.lock().await.clone();
337        assert_eq!(sequence, vec!['A', 'A']);
338    }
339
340    #[tokio::test]
341    async fn no_interceptors_calls_inner() {
342        let result = run_interceptor_chain(&[], make_request(), make_inner()).await;
343        assert!(result.is_ok());
344        assert_eq!(result.unwrap().value["result"], "ok");
345    }
346}