Skip to main content

synwire_agent/middleware/
patch_tool_calls.rs

1//! Patch tool calls middleware — fixes dangling tool call references.
2
3use synwire_core::BoxFuture;
4use synwire_core::agents::error::AgentError;
5use synwire_core::agents::middleware::{Middleware, MiddlewareInput, MiddlewareResult};
6
7/// Middleware that detects and patches dangling tool call messages in the
8/// conversation history.
9///
10/// A "dangling" tool call occurs when a `tool_call` message references a
11/// `tool_call_id` that has no corresponding `tool_result` message.
12#[derive(Debug, Default)]
13pub struct PatchToolCallsMiddleware;
14
15impl Middleware for PatchToolCallsMiddleware {
16    fn name(&self) -> &'static str {
17        "patch_tool_calls"
18    }
19
20    fn process(
21        &self,
22        mut input: MiddlewareInput,
23    ) -> BoxFuture<'_, Result<MiddlewareResult, AgentError>> {
24        Box::pin(async move {
25            // Collect tool_call_ids from the conversation.
26            let mut call_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
27            let mut result_ids: std::collections::HashSet<String> =
28                std::collections::HashSet::new();
29
30            for msg in &input.messages {
31                if let Some(calls) = msg.get("tool_calls").and_then(|v| v.as_array()) {
32                    for call in calls {
33                        if let Some(id) = call.get("id").and_then(|v| v.as_str()) {
34                            let _ = call_ids.insert(id.to_string());
35                        }
36                    }
37                }
38                if let Some(id) = msg.get("tool_call_id").and_then(|v| v.as_str()) {
39                    let _ = result_ids.insert(id.to_string());
40                }
41            }
42
43            // Find dangling calls (calls with no result).
44            let dangling: Vec<String> = call_ids.difference(&result_ids).cloned().collect();
45            if !dangling.is_empty() {
46                tracing::debug!(count = dangling.len(), "Patching dangling tool calls");
47                // Inject synthetic tool result messages for each dangling call.
48                for id in &dangling {
49                    input.messages.push(serde_json::json!({
50                        "role": "tool",
51                        "tool_call_id": id,
52                        "content": "Tool call interrupted. Please retry if needed.",
53                    }));
54                }
55            }
56
57            Ok(MiddlewareResult::Continue(input))
58        })
59    }
60}