synwire_agent/middleware/
patch_tool_calls.rs1use synwire_core::BoxFuture;
4use synwire_core::agents::error::AgentError;
5use synwire_core::agents::middleware::{Middleware, MiddlewareInput, MiddlewareResult};
6
7#[derive(Debug, Default)]
13pub struct PatchToolCallsMiddleware;
14
15impl Middleware for PatchToolCallsMiddleware {
16 fn name(&self) -> &'static str {
17 "patch_tool_calls"
18 }
19
20 fn process(
21 &self,
22 mut input: MiddlewareInput,
23 ) -> BoxFuture<'_, Result<MiddlewareResult, AgentError>> {
24 Box::pin(async move {
25 let mut call_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
27 let mut result_ids: std::collections::HashSet<String> =
28 std::collections::HashSet::new();
29
30 for msg in &input.messages {
31 if let Some(calls) = msg.get("tool_calls").and_then(|v| v.as_array()) {
32 for call in calls {
33 if let Some(id) = call.get("id").and_then(|v| v.as_str()) {
34 let _ = call_ids.insert(id.to_string());
35 }
36 }
37 }
38 if let Some(id) = msg.get("tool_call_id").and_then(|v| v.as_str()) {
39 let _ = result_ids.insert(id.to_string());
40 }
41 }
42
43 let dangling: Vec<String> = call_ids.difference(&result_ids).cloned().collect();
45 if !dangling.is_empty() {
46 tracing::debug!(count = dangling.len(), "Patching dangling tool calls");
47 for id in &dangling {
49 input.messages.push(serde_json::json!({
50 "role": "tool",
51 "tool_call_id": id,
52 "content": "Tool call interrupted. Please retry if needed.",
53 }));
54 }
55 }
56
57 Ok(MiddlewareResult::Continue(input))
58 })
59 }
60}