Skip to main content

sgr_agent/
union_schema.rs

1//! Dynamic flat action schema builder — generates a single object JSON Schema
2//! from tool definitions at runtime. Used by SgrAgent for structured output.
3//!
4//! Instead of `anyOf` discriminated unions (which break OpenAI constrained decoding),
5//! uses a flat schema: `tool_name` as string enum + all params as nullable fields.
6
7use crate::tool::ToolDef;
8use crate::types::ToolCall;
9use serde_json::Value;
10use std::collections::BTreeMap;
11
12/// Build a flat JSON Schema from tool definitions.
13///
14/// The schema is already OpenAI strict-compatible:
15/// - All properties are in `required`
16/// - Non-universal params use `anyOf [type, null]`
17/// - `additionalProperties: false`
18///
19/// IMPORTANT: Do NOT run `ensure_strict` on this schema.
20pub fn build_action_schema(tools: &[ToolDef]) -> Value {
21    let tool_names: Vec<Value> = tools
22        .iter()
23        .map(|t| Value::String(t.name.clone()))
24        .collect();
25
26    // Collect all unique parameter names across all tools with their schemas.
27    // If multiple tools define the same param name, merge descriptions.
28    let mut all_params: BTreeMap<String, Value> = BTreeMap::new();
29    let mut param_required_by: BTreeMap<String, Vec<String>> = BTreeMap::new();
30
31    for t in tools {
32        if let Some(props) = t.parameters.get("properties").and_then(|p| p.as_object()) {
33            let required_names: Vec<String> = t
34                .parameters
35                .get("required")
36                .and_then(|r| r.as_array())
37                .map(|arr| {
38                    arr.iter()
39                        .filter_map(|v| v.as_str().map(String::from))
40                        .collect()
41                })
42                .unwrap_or_default();
43
44            for (name, schema) in props {
45                all_params
46                    .entry(name.clone())
47                    .or_insert_with(|| schema.clone());
48                if required_names.contains(name) {
49                    param_required_by
50                        .entry(name.clone())
51                        .or_default()
52                        .push(t.name.clone());
53                }
54            }
55        }
56    }
57
58    // Build properties: tool_name enum + situation + plan + all params (nullable)
59    let mut properties = serde_json::Map::new();
60
61    properties.insert(
62        "situation".into(),
63        serde_json::json!({"type": "string", "description": "Brief assessment of current state"}),
64    );
65    properties.insert(
66        "plan".into(),
67        serde_json::json!({
68            "type": "array",
69            "items": {"type": "string"},
70            "minItems": 1,
71            "maxItems": 5,
72            "description": "1-5 brief remaining steps"
73        }),
74    );
75    properties.insert(
76        "tool_name".into(),
77        serde_json::json!({"type": "string", "enum": tool_names, "description": "Tool to execute"}),
78    );
79
80    // All tool params — wrapped in anyOf [type, null] so model can set unused params to null
81    for (name, schema) in &all_params {
82        let nullable = serde_json::json!({
83            "anyOf": [schema, {"type": "null"}],
84            "description": schema.get("description").and_then(|d| d.as_str()).unwrap_or("")
85        });
86        properties.insert(name.clone(), nullable);
87    }
88
89    // All properties are required (strict mode) — nullable handles optionality
90    let required: Vec<Value> = properties
91        .keys()
92        .map(|k| Value::String(k.clone()))
93        .collect();
94
95    serde_json::json!({
96        "type": "object",
97        "properties": properties,
98        "required": required,
99        "additionalProperties": false
100    })
101}
102
103/// Parse raw LLM output into tool calls.
104/// Supports both flat format (tool_name at top level) and legacy nested format.
105pub fn parse_action(raw: &str, _tools: &[ToolDef]) -> Result<(String, Vec<ToolCall>), ParseError> {
106    let value: Value = match crate::flexible_parser::parse_flexible::<Value>(raw) {
107        Ok(r) => r.value,
108        Err(_) => serde_json::from_str::<Value>(raw).map_err(|e| ParseError(e.to_string()))?,
109    };
110
111    let situation = match value.get("situation") {
112        Some(Value::String(s)) => s.clone(),
113        _ => String::new(),
114    };
115
116    // Flat format: tool_name at top level
117    if let Some(Value::String(tool_name)) = value.get("tool_name") {
118        let mut args = serde_json::Map::new();
119        if let Value::Object(obj) = &value {
120            for (k, v) in obj {
121                match k.as_str() {
122                    "situation" | "plan" | "task" | "tool_name" => continue,
123                    _ => {
124                        // Skip null values (unused params from other tools)
125                        if !v.is_null() {
126                            args.insert(k.clone(), v.clone());
127                        }
128                    }
129                }
130            }
131        }
132        return Ok((
133            situation,
134            vec![ToolCall {
135                id: "call_0".into(),
136                name: tool_name.clone(),
137                arguments: Value::Object(args),
138            }],
139        ));
140    }
141
142    // Legacy nested format: "action" object or "actions" array
143    let actions: Vec<Value> = match value.get("action") {
144        Some(Value::Object(_)) => vec![value["action"].clone()],
145        _ => match value.get("actions") {
146            Some(Value::Array(arr)) => arr.clone(),
147            _ => Vec::new(),
148        },
149    };
150
151    let mut tool_calls: Vec<ToolCall> = Vec::new();
152    for (i, action) in actions.into_iter().enumerate() {
153        let name = match action.get("tool_name") {
154            Some(Value::String(s)) => s.clone(),
155            _ => continue,
156        };
157
158        let arguments = if let Value::Object(mut obj) = action {
159            obj.remove("tool_name");
160            // Unwrap known wrapper keys (Gemini compat)
161            if obj.len() == 1 {
162                let key = obj.keys().next().unwrap().clone();
163                if WRAPPER_KEYS.contains(&key.as_str()) && obj[&key].is_object() {
164                    obj.remove(&key).unwrap()
165                } else {
166                    Value::Object(obj)
167                }
168            } else {
169                Value::Object(obj)
170            }
171        } else {
172            action
173        };
174
175        tool_calls.push(ToolCall {
176            id: format!("call_{}", i),
177            name,
178            arguments,
179        });
180    }
181
182    Ok((situation, tool_calls))
183}
184
185/// Known wrapper keys that Gemini uses to wrap tool arguments.
186const WRAPPER_KEYS: &[&str] = &["parameters", "params", "args", "arguments"];
187
188/// Parse error for action extraction.
189#[derive(Debug, thiserror::Error)]
190#[error("{0}")]
191pub struct ParseError(pub String);
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use crate::tool::ToolDef;
197
198    fn mock_tools() -> Vec<ToolDef> {
199        vec![
200            ToolDef {
201                name: "read_file".into(),
202                description: "Read a file".into(),
203                parameters: serde_json::json!({
204                    "type": "object",
205                    "properties": {
206                        "path": { "type": "string", "description": "File path" }
207                    },
208                    "required": ["path"]
209                }),
210            },
211            ToolDef {
212                name: "bash".into(),
213                description: "Run command".into(),
214                parameters: serde_json::json!({
215                    "type": "object",
216                    "properties": {
217                        "command": { "type": "string", "description": "Shell command" }
218                    },
219                    "required": ["command"]
220                }),
221            },
222        ]
223    }
224
225    #[test]
226    fn build_schema_flat_with_enum() {
227        let schema = build_action_schema(&mock_tools());
228        let tool_name = &schema["properties"]["tool_name"];
229        let enums = tool_name["enum"].as_array().unwrap();
230        assert_eq!(enums.len(), 2);
231        assert!(enums.contains(&serde_json::json!("read_file")));
232        assert!(enums.contains(&serde_json::json!("bash")));
233        // All params nullable
234        assert!(schema["properties"]["path"]["anyOf"].is_array());
235        assert!(schema["properties"]["command"]["anyOf"].is_array());
236        assert_eq!(schema["additionalProperties"], false);
237    }
238
239    #[test]
240    fn build_schema_has_situation_and_plan() {
241        let schema = build_action_schema(&mock_tools());
242        assert!(schema["properties"]["situation"].is_object());
243        assert!(schema["properties"]["plan"].is_object());
244        assert_eq!(schema["properties"]["plan"]["maxItems"], 5);
245        let required = schema["required"].as_array().unwrap();
246        assert!(required.contains(&serde_json::json!("situation")));
247        assert!(required.contains(&serde_json::json!("tool_name")));
248        assert_eq!(schema["additionalProperties"], false);
249    }
250
251    #[test]
252    fn parse_flat_action() {
253        let raw = r#"{
254            "situation": "need to read a file",
255            "plan": ["read main.rs"],
256            "tool_name": "read_file",
257            "path": "/src/main.rs",
258            "command": null
259        }"#;
260        let (situation, calls) = parse_action(raw, &mock_tools()).unwrap();
261        assert_eq!(situation, "need to read a file");
262        assert_eq!(calls.len(), 1);
263        assert_eq!(calls[0].name, "read_file");
264        assert_eq!(calls[0].arguments["path"], "/src/main.rs");
265        // null values stripped
266        assert!(calls[0].arguments.get("command").is_none());
267    }
268
269    #[test]
270    fn parse_legacy_nested_action() {
271        let raw = r#"{
272            "situation": "need to read a file",
273            "plan": ["read main.rs"],
274            "action": {"tool_name": "read_file", "path": "/src/main.rs"}
275        }"#;
276        let (situation, calls) = parse_action(raw, &mock_tools()).unwrap();
277        assert_eq!(situation, "need to read a file");
278        assert_eq!(calls.len(), 1);
279        assert_eq!(calls[0].name, "read_file");
280    }
281
282    #[test]
283    fn parse_legacy_actions_array() {
284        let raw = r#"{
285            "situation": "multi",
286            "task": ["a", "b"],
287            "actions": [
288                {"tool_name": "read_file", "path": "/src/main.rs"},
289                {"tool_name": "bash", "command": "ls -la"}
290            ]
291        }"#;
292        let (_, calls) = parse_action(raw, &mock_tools()).unwrap();
293        assert_eq!(calls.len(), 2);
294    }
295
296    #[test]
297    fn parse_missing_action_returns_empty() {
298        let raw = r#"{"situation": "thinking"}"#;
299        let (situation, calls) = parse_action(raw, &mock_tools()).unwrap();
300        assert_eq!(situation, "thinking");
301        assert!(calls.is_empty());
302    }
303
304    #[test]
305    fn parse_markdown_wrapped() {
306        let raw = "```json\n{\"situation\": \"ok\", \"plan\": [\"do it\"], \"tool_name\": \"bash\", \"command\": \"pwd\", \"path\": null}\n```";
307        let (_, calls) = parse_action(raw, &mock_tools()).unwrap();
308        assert_eq!(calls.len(), 1);
309        assert_eq!(calls[0].name, "bash");
310        assert_eq!(calls[0].arguments["command"], "pwd");
311    }
312
313    #[test]
314    fn ensure_strict_skipped_for_pre_strict() {
315        // build_action_schema produces schemas with additionalProperties:false.
316        // OxideClient's structured_call detects this and skips ensure_strict,
317        // because ensure_strict would break anyOf-nullable fields.
318        let schema = build_action_schema(&mock_tools());
319
320        // Schema is already strict-compatible
321        assert_eq!(schema["additionalProperties"], false);
322
323        // All properties are required
324        let required = schema["required"].as_array().unwrap();
325        let props = schema["properties"].as_object().unwrap();
326        for key in props.keys() {
327            assert!(
328                required.contains(&Value::String(key.clone())),
329                "Property '{}' must be in required list",
330                key
331            );
332        }
333
334        // Tool params use anyOf [type, null] (nullable) —
335        // ensure_strict would corrupt these by wrapping again
336        let path_prop = &schema["properties"]["path"];
337        let any_of = path_prop["anyOf"].as_array().unwrap();
338        assert_eq!(any_of.len(), 2, "path should have anyOf with 2 variants");
339        let has_null = any_of
340            .iter()
341            .any(|v| v.get("type") == Some(&Value::String("null".into())));
342        assert!(has_null, "path anyOf should include null variant");
343    }
344}