Skip to main content

synaptic_deep/middleware/
patch_tool_calls.rs

1use async_trait::async_trait;
2use std::collections::HashSet;
3use synaptic_core::{Message, SynapticError};
4use synaptic_middleware::{AgentMiddleware, ModelRequest, ModelResponse};
5
6/// Middleware that fixes malformed tool calls in model responses.
7///
8/// Patches applied:
9/// - Strip markdown code fences from JSON arguments
10/// - Attempt to parse string arguments as JSON
11/// - Deduplicate tool call IDs
12/// - Remove tool calls with empty names
13pub struct PatchToolCallsMiddleware;
14
15#[async_trait]
16impl AgentMiddleware for PatchToolCallsMiddleware {
17    async fn after_model(
18        &self,
19        _request: &ModelRequest,
20        response: &mut ModelResponse,
21    ) -> Result<(), SynapticError> {
22        let tool_calls = response.message.tool_calls().to_vec();
23        if tool_calls.is_empty() {
24            return Ok(());
25        }
26
27        let mut seen_ids = HashSet::new();
28        let mut patched = Vec::new();
29        let mut id_counter = 0u32;
30        let mut changed = false;
31
32        for mut tc in tool_calls {
33            // Skip empty names
34            if tc.name.trim().is_empty() {
35                changed = true;
36                continue;
37            }
38
39            // Fix JSON arguments
40            let fixed_args = fix_json_arguments(&tc.arguments);
41            if fixed_args != tc.arguments {
42                tc.arguments = fixed_args;
43                changed = true;
44            }
45
46            // Deduplicate IDs
47            if seen_ids.contains(&tc.id) || tc.id.is_empty() {
48                tc.id = format!("patched_{}", id_counter);
49                id_counter += 1;
50                changed = true;
51            }
52            seen_ids.insert(tc.id.clone());
53
54            patched.push(tc);
55        }
56
57        if changed {
58            let content = response.message.content().to_string();
59            let id = response.message.id().map(|s| s.to_string());
60            let mut new_msg = Message::ai_with_tool_calls(content, patched);
61            if let Some(id) = id {
62                new_msg = new_msg.with_id(id);
63            }
64            response.message = new_msg;
65        }
66
67        Ok(())
68    }
69}
70
71fn fix_json_arguments(args: &serde_json::Value) -> serde_json::Value {
72    if let serde_json::Value::String(s) = args {
73        let trimmed = s.trim();
74        // Strip markdown code fences
75        let cleaned = if trimmed.starts_with("```") {
76            let without_start = trimmed
77                .trim_start_matches("```json")
78                .trim_start_matches("```");
79            without_start.trim_end_matches("```").trim()
80        } else {
81            trimmed
82        };
83
84        // Try to parse as JSON
85        match serde_json::from_str(cleaned) {
86            Ok(v) => v,
87            Err(_) => args.clone(),
88        }
89    } else {
90        args.clone()
91    }
92}