1use crate::tool::ToolDef;
8use crate::types::ToolCall;
9use serde_json::Value;
10use std::collections::BTreeMap;
11
12pub fn build_action_schema(tools: &[ToolDef]) -> Value {
21 let tool_names: Vec<Value> = tools
22 .iter()
23 .map(|t| Value::String(t.name.clone()))
24 .collect();
25
26 let mut all_params: BTreeMap<String, Value> = BTreeMap::new();
29 let mut param_required_by: BTreeMap<String, Vec<String>> = BTreeMap::new();
30
31 for t in tools {
32 if let Some(props) = t.parameters.get("properties").and_then(|p| p.as_object()) {
33 let required_names: Vec<String> = t
34 .parameters
35 .get("required")
36 .and_then(|r| r.as_array())
37 .map(|arr| {
38 arr.iter()
39 .filter_map(|v| v.as_str().map(String::from))
40 .collect()
41 })
42 .unwrap_or_default();
43
44 for (name, schema) in props {
45 all_params
46 .entry(name.clone())
47 .or_insert_with(|| schema.clone());
48 if required_names.contains(name) {
49 param_required_by
50 .entry(name.clone())
51 .or_default()
52 .push(t.name.clone());
53 }
54 }
55 }
56 }
57
58 let mut properties = serde_json::Map::new();
60
61 properties.insert(
62 "situation".into(),
63 serde_json::json!({"type": "string", "description": "Brief assessment of current state"}),
64 );
65 properties.insert(
66 "plan".into(),
67 serde_json::json!({
68 "type": "array",
69 "items": {"type": "string"},
70 "minItems": 1,
71 "maxItems": 5,
72 "description": "1-5 brief remaining steps"
73 }),
74 );
75 properties.insert(
76 "tool_name".into(),
77 serde_json::json!({"type": "string", "enum": tool_names, "description": "Tool to execute"}),
78 );
79
80 for (name, schema) in &all_params {
82 let nullable = serde_json::json!({
83 "anyOf": [schema, {"type": "null"}],
84 "description": schema.get("description").and_then(|d| d.as_str()).unwrap_or("")
85 });
86 properties.insert(name.clone(), nullable);
87 }
88
89 let required: Vec<Value> = properties
91 .keys()
92 .map(|k| Value::String(k.clone()))
93 .collect();
94
95 serde_json::json!({
96 "type": "object",
97 "properties": properties,
98 "required": required,
99 "additionalProperties": false
100 })
101}
102
103pub fn parse_action(raw: &str, _tools: &[ToolDef]) -> Result<(String, Vec<ToolCall>), ParseError> {
106 let value: Value = match crate::flexible_parser::parse_flexible::<Value>(raw) {
107 Ok(r) => r.value,
108 Err(_) => serde_json::from_str::<Value>(raw).map_err(|e| ParseError(e.to_string()))?,
109 };
110
111 let situation = match value.get("situation") {
112 Some(Value::String(s)) => s.clone(),
113 _ => String::new(),
114 };
115
116 if let Some(Value::String(tool_name)) = value.get("tool_name") {
118 let mut args = serde_json::Map::new();
119 if let Value::Object(obj) = &value {
120 for (k, v) in obj {
121 match k.as_str() {
122 "situation" | "plan" | "task" | "tool_name" => continue,
123 _ => {
124 if !v.is_null() {
126 args.insert(k.clone(), v.clone());
127 }
128 }
129 }
130 }
131 }
132 return Ok((
133 situation,
134 vec![ToolCall {
135 id: "call_0".into(),
136 name: tool_name.clone(),
137 arguments: Value::Object(args),
138 }],
139 ));
140 }
141
142 let actions: Vec<Value> = match value.get("action") {
144 Some(Value::Object(_)) => vec![value["action"].clone()],
145 _ => match value.get("actions") {
146 Some(Value::Array(arr)) => arr.clone(),
147 _ => Vec::new(),
148 },
149 };
150
151 let mut tool_calls: Vec<ToolCall> = Vec::new();
152 for (i, action) in actions.into_iter().enumerate() {
153 let name = match action.get("tool_name") {
154 Some(Value::String(s)) => s.clone(),
155 _ => continue,
156 };
157
158 let arguments = if let Value::Object(mut obj) = action {
159 obj.remove("tool_name");
160 if obj.len() == 1 {
162 let key = obj.keys().next().unwrap().clone();
163 if WRAPPER_KEYS.contains(&key.as_str()) && obj[&key].is_object() {
164 obj.remove(&key).unwrap()
165 } else {
166 Value::Object(obj)
167 }
168 } else {
169 Value::Object(obj)
170 }
171 } else {
172 action
173 };
174
175 tool_calls.push(ToolCall {
176 id: format!("call_{}", i),
177 name,
178 arguments,
179 });
180 }
181
182 Ok((situation, tool_calls))
183}
184
185const WRAPPER_KEYS: &[&str] = &["parameters", "params", "args", "arguments"];
187
188#[derive(Debug, thiserror::Error)]
190#[error("{0}")]
191pub struct ParseError(pub String);
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use crate::tool::ToolDef;
197
198 fn mock_tools() -> Vec<ToolDef> {
199 vec![
200 ToolDef {
201 name: "read_file".into(),
202 description: "Read a file".into(),
203 parameters: serde_json::json!({
204 "type": "object",
205 "properties": {
206 "path": { "type": "string", "description": "File path" }
207 },
208 "required": ["path"]
209 }),
210 },
211 ToolDef {
212 name: "bash".into(),
213 description: "Run command".into(),
214 parameters: serde_json::json!({
215 "type": "object",
216 "properties": {
217 "command": { "type": "string", "description": "Shell command" }
218 },
219 "required": ["command"]
220 }),
221 },
222 ]
223 }
224
225 #[test]
226 fn build_schema_flat_with_enum() {
227 let schema = build_action_schema(&mock_tools());
228 let tool_name = &schema["properties"]["tool_name"];
229 let enums = tool_name["enum"].as_array().unwrap();
230 assert_eq!(enums.len(), 2);
231 assert!(enums.contains(&serde_json::json!("read_file")));
232 assert!(enums.contains(&serde_json::json!("bash")));
233 assert!(schema["properties"]["path"]["anyOf"].is_array());
235 assert!(schema["properties"]["command"]["anyOf"].is_array());
236 assert_eq!(schema["additionalProperties"], false);
237 }
238
239 #[test]
240 fn build_schema_has_situation_and_plan() {
241 let schema = build_action_schema(&mock_tools());
242 assert!(schema["properties"]["situation"].is_object());
243 assert!(schema["properties"]["plan"].is_object());
244 assert_eq!(schema["properties"]["plan"]["maxItems"], 5);
245 let required = schema["required"].as_array().unwrap();
246 assert!(required.contains(&serde_json::json!("situation")));
247 assert!(required.contains(&serde_json::json!("tool_name")));
248 assert_eq!(schema["additionalProperties"], false);
249 }
250
251 #[test]
252 fn parse_flat_action() {
253 let raw = r#"{
254 "situation": "need to read a file",
255 "plan": ["read main.rs"],
256 "tool_name": "read_file",
257 "path": "/src/main.rs",
258 "command": null
259 }"#;
260 let (situation, calls) = parse_action(raw, &mock_tools()).unwrap();
261 assert_eq!(situation, "need to read a file");
262 assert_eq!(calls.len(), 1);
263 assert_eq!(calls[0].name, "read_file");
264 assert_eq!(calls[0].arguments["path"], "/src/main.rs");
265 assert!(calls[0].arguments.get("command").is_none());
267 }
268
269 #[test]
270 fn parse_legacy_nested_action() {
271 let raw = r#"{
272 "situation": "need to read a file",
273 "plan": ["read main.rs"],
274 "action": {"tool_name": "read_file", "path": "/src/main.rs"}
275 }"#;
276 let (situation, calls) = parse_action(raw, &mock_tools()).unwrap();
277 assert_eq!(situation, "need to read a file");
278 assert_eq!(calls.len(), 1);
279 assert_eq!(calls[0].name, "read_file");
280 }
281
282 #[test]
283 fn parse_legacy_actions_array() {
284 let raw = r#"{
285 "situation": "multi",
286 "task": ["a", "b"],
287 "actions": [
288 {"tool_name": "read_file", "path": "/src/main.rs"},
289 {"tool_name": "bash", "command": "ls -la"}
290 ]
291 }"#;
292 let (_, calls) = parse_action(raw, &mock_tools()).unwrap();
293 assert_eq!(calls.len(), 2);
294 }
295
296 #[test]
297 fn parse_missing_action_returns_empty() {
298 let raw = r#"{"situation": "thinking"}"#;
299 let (situation, calls) = parse_action(raw, &mock_tools()).unwrap();
300 assert_eq!(situation, "thinking");
301 assert!(calls.is_empty());
302 }
303
304 #[test]
305 fn parse_markdown_wrapped() {
306 let raw = "```json\n{\"situation\": \"ok\", \"plan\": [\"do it\"], \"tool_name\": \"bash\", \"command\": \"pwd\", \"path\": null}\n```";
307 let (_, calls) = parse_action(raw, &mock_tools()).unwrap();
308 assert_eq!(calls.len(), 1);
309 assert_eq!(calls[0].name, "bash");
310 assert_eq!(calls[0].arguments["command"], "pwd");
311 }
312
313 #[test]
314 fn ensure_strict_skipped_for_pre_strict() {
315 let schema = build_action_schema(&mock_tools());
319
320 assert_eq!(schema["additionalProperties"], false);
322
323 let required = schema["required"].as_array().unwrap();
325 let props = schema["properties"].as_object().unwrap();
326 for key in props.keys() {
327 assert!(
328 required.contains(&Value::String(key.clone())),
329 "Property '{}' must be in required list",
330 key
331 );
332 }
333
334 let path_prop = &schema["properties"]["path"];
337 let any_of = path_prop["anyOf"].as_array().unwrap();
338 assert_eq!(any_of.len(), 2, "path should have anyOf with 2 variants");
339 let has_null = any_of
340 .iter()
341 .any(|v| v.get("type") == Some(&Value::String("null".into())));
342 assert!(has_null, "path anyOf should include null variant");
343 }
344}