synaptic_deep/middleware/
patch_tool_calls.rs1use async_trait::async_trait;
2use std::collections::HashSet;
3use synaptic_core::{Message, SynapticError};
4use synaptic_middleware::{AgentMiddleware, ModelRequest, ModelResponse};
5
6pub struct PatchToolCallsMiddleware;
14
15#[async_trait]
16impl AgentMiddleware for PatchToolCallsMiddleware {
17 async fn after_model(
18 &self,
19 _request: &ModelRequest,
20 response: &mut ModelResponse,
21 ) -> Result<(), SynapticError> {
22 let tool_calls = response.message.tool_calls().to_vec();
23 if tool_calls.is_empty() {
24 return Ok(());
25 }
26
27 let mut seen_ids = HashSet::new();
28 let mut patched = Vec::new();
29 let mut id_counter = 0u32;
30 let mut changed = false;
31
32 for mut tc in tool_calls {
33 if tc.name.trim().is_empty() {
35 changed = true;
36 continue;
37 }
38
39 let fixed_args = fix_json_arguments(&tc.arguments);
41 if fixed_args != tc.arguments {
42 tc.arguments = fixed_args;
43 changed = true;
44 }
45
46 if seen_ids.contains(&tc.id) || tc.id.is_empty() {
48 tc.id = format!("patched_{}", id_counter);
49 id_counter += 1;
50 changed = true;
51 }
52 seen_ids.insert(tc.id.clone());
53
54 patched.push(tc);
55 }
56
57 if changed {
58 let content = response.message.content().to_string();
59 let id = response.message.id().map(|s| s.to_string());
60 let mut new_msg = Message::ai_with_tool_calls(content, patched);
61 if let Some(id) = id {
62 new_msg = new_msg.with_id(id);
63 }
64 response.message = new_msg;
65 }
66
67 Ok(())
68 }
69}
70
71fn fix_json_arguments(args: &serde_json::Value) -> serde_json::Value {
72 if let serde_json::Value::String(s) = args {
73 let trimmed = s.trim();
74 let cleaned = if trimmed.starts_with("```") {
76 let without_start = trimmed
77 .trim_start_matches("```json")
78 .trim_start_matches("```");
79 without_start.trim_end_matches("```").trim()
80 } else {
81 trimmed
82 };
83
84 match serde_json::from_str(cleaned) {
86 Ok(v) => v,
87 Err(_) => args.clone(),
88 }
89 } else {
90 args.clone()
91 }
92}