1use crate::agent::{Agent, AgentError, Decision};
10use crate::client::LlmClient;
11use crate::registry::ToolRegistry;
12use crate::schema_simplifier;
13use crate::types::Message;
14use crate::union_schema;
15
16pub struct FlexibleAgent<C: LlmClient> {
18 client: C,
19 system_prompt: String,
20 max_retries: usize,
22}
23
24impl<C: LlmClient> FlexibleAgent<C> {
25 pub fn new(client: C, system_prompt: impl Into<String>, max_retries: usize) -> Self {
26 Self {
27 client,
28 system_prompt: system_prompt.into(),
29 max_retries: max_retries.max(1),
30 }
31 }
32}
33
34fn tools_prompt(tools: &ToolRegistry) -> String {
36 let mut s = String::from(
37 "## Available Tools\n\nRespond with JSON: {\"situation\": \"...\", \"task\": [...], \"actions\": [{\"tool_name\": \"...\", ...args}]}\n\n",
38 );
39 for t in tools.list() {
40 s.push_str(&schema_simplifier::simplify_tool(
41 t.name(),
42 t.description(),
43 &t.parameters_schema(),
44 ));
45 s.push_str("\n\n");
46 }
47 s
48}
49
50fn format_error_prompt(errors: &[String]) -> String {
52 let mut prompt = String::from(
53 "Your previous response(s) could not be parsed as valid JSON. Please fix and try again.\n\nErrors:\n",
54 );
55 for (i, err) in errors.iter().enumerate() {
56 prompt.push_str(&format!("{}. {}\n", i + 1, err));
57 }
58 prompt.push_str(
59 "\nRespond with ONLY valid JSON matching the schema. No markdown, no explanations.",
60 );
61 prompt
62}
63
64#[async_trait::async_trait]
65impl<C: LlmClient> Agent for FlexibleAgent<C> {
66 async fn decide(
67 &self,
68 messages: &[Message],
69 tools: &ToolRegistry,
70 ) -> Result<Decision, AgentError> {
71 let defs = tools.to_defs();
72
73 let full_system = format!("{}\n\n{}", self.system_prompt, tools_prompt(tools));
75 let mut msgs = Vec::with_capacity(messages.len() + 1);
76 let has_system = messages
77 .iter()
78 .any(|m| m.role == crate::types::Role::System);
79 if !has_system {
80 msgs.push(Message::system(&full_system));
81 }
82 msgs.extend_from_slice(messages);
83
84 let mut errors: Vec<String> = Vec::new();
85
86 for attempt in 0..self.max_retries {
87 if attempt > 0 && !errors.is_empty() {
89 msgs.push(Message::user(format_error_prompt(&errors)));
90 }
91
92 let raw = self.client.complete(&msgs).await?;
93
94 match union_schema::parse_action(&raw, &defs) {
95 Ok((situation, tool_calls)) => {
96 let completed = tool_calls.is_empty()
97 || tool_calls.iter().any(|tc| tc.name == "finish_task");
98 return Ok(Decision {
99 situation,
100 task: vec![],
101 tool_calls,
102 completed,
103 });
104 }
105 Err(e) => {
106 errors.push(e.to_string());
107 msgs.push(Message::assistant(&raw));
109 }
110 }
111 }
112
113 Ok(Decision {
115 situation: format!(
116 "Failed to parse after {} attempts. Errors: {}",
117 self.max_retries,
118 errors.join("; ")
119 ),
120 task: vec![],
121 tool_calls: vec![],
122 completed: true,
123 })
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use crate::agent_tool::{ToolError, ToolOutput};
131 use crate::client::LlmClient;
132 use crate::context::AgentContext;
133 use crate::tool::ToolDef;
134 use crate::types::{SgrError, ToolCall};
135 use serde_json::Value;
136 use std::sync::Arc;
137 use std::sync::atomic::{AtomicUsize, Ordering};
138
139 struct MockTextClient {
140 response: String,
141 }
142
143 #[async_trait::async_trait]
144 impl LlmClient for MockTextClient {
145 async fn structured_call(
146 &self,
147 _: &[Message],
148 _: &Value,
149 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
150 Ok((None, vec![], String::new()))
151 }
152 async fn tools_call(
153 &self,
154 _: &[Message],
155 _: &[ToolDef],
156 ) -> Result<Vec<ToolCall>, SgrError> {
157 Ok(vec![])
158 }
159 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
160 Ok(self.response.clone())
161 }
162 }
163
164 struct DummyTool;
165
166 #[async_trait::async_trait]
167 impl crate::agent_tool::Tool for DummyTool {
168 fn name(&self) -> &str {
169 "search"
170 }
171 fn description(&self) -> &str {
172 "search files"
173 }
174 fn parameters_schema(&self) -> Value {
175 serde_json::json!({"type": "object", "properties": {"query": {"type": "string"}}})
176 }
177 async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
178 Ok(ToolOutput::text("ok"))
179 }
180 }
181
182 #[tokio::test]
183 async fn flexible_agent_parses_json_from_text() {
184 let client = MockTextClient {
185 response: r#"Sure, let me search for that.
186```json
187{"situation": "searching", "task": ["find files"], "actions": [{"tool_name": "search", "query": "main.rs"}]}
188```"#
189 .into(),
190 };
191 let agent = FlexibleAgent::new(client, "You are a test agent", 1);
192 let tools = ToolRegistry::new().register(DummyTool);
193 let msgs = vec![Message::user("find main.rs")];
194
195 let decision = agent.decide(&msgs, &tools).await.unwrap();
196 assert_eq!(decision.tool_calls.len(), 1);
197 assert_eq!(decision.tool_calls[0].name, "search");
198 }
199
200 #[tokio::test]
201 async fn flexible_agent_plain_text_completes() {
202 let client = MockTextClient {
203 response: "I can't find any tools to use here.".into(),
204 };
205 let agent = FlexibleAgent::new(client, "test", 1);
206 let tools = ToolRegistry::new().register(DummyTool);
207 let msgs = vec![Message::user("hello")];
208
209 let decision = agent.decide(&msgs, &tools).await.unwrap();
210 assert!(decision.completed);
211 assert!(decision.tool_calls.is_empty());
212 }
213
214 #[tokio::test]
215 async fn flexible_agent_retry_succeeds() {
216 struct RetryClient {
218 call_count: Arc<AtomicUsize>,
219 }
220 #[async_trait::async_trait]
221 impl LlmClient for RetryClient {
222 async fn structured_call(
223 &self,
224 _: &[Message],
225 _: &Value,
226 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
227 Ok((None, vec![], String::new()))
228 }
229 async fn tools_call(
230 &self,
231 _: &[Message],
232 _: &[ToolDef],
233 ) -> Result<Vec<ToolCall>, SgrError> {
234 Ok(vec![])
235 }
236 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
237 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
238 if n == 0 {
239 Ok("not valid json at all".into())
240 } else {
241 Ok(
242 r#"{"situation": "found it", "task": [], "actions": [{"tool_name": "search", "query": "test"}]}"#
243 .into(),
244 )
245 }
246 }
247 }
248
249 let client = RetryClient {
250 call_count: Arc::new(AtomicUsize::new(0)),
251 };
252 let agent = FlexibleAgent::new(client, "test", 3);
253 let tools = ToolRegistry::new().register(DummyTool);
254 let msgs = vec![Message::user("search")];
255
256 let decision = agent.decide(&msgs, &tools).await.unwrap();
257 assert_eq!(decision.tool_calls.len(), 1);
258 assert_eq!(decision.situation, "found it");
259 }
260
261 #[tokio::test]
262 async fn flexible_agent_retry_exhausted() {
263 let client = MockTextClient {
264 response: "garbage output always".into(),
265 };
266 let agent = FlexibleAgent::new(client, "test", 3);
267 let tools = ToolRegistry::new().register(DummyTool);
268 let msgs = vec![Message::user("do something")];
269
270 let decision = agent.decide(&msgs, &tools).await.unwrap();
271 assert!(decision.completed);
272 assert!(decision.tool_calls.is_empty());
273 assert!(
274 decision
275 .situation
276 .contains("Failed to parse after 3 attempts")
277 );
278 }
279
280 #[test]
281 fn format_error_prompt_content() {
282 let errors = vec!["bad json".to_string(), "missing field".to_string()];
283 let prompt = format_error_prompt(&errors);
284 assert!(prompt.contains("1. bad json"));
285 assert!(prompt.contains("2. missing field"));
286 assert!(prompt.contains("valid JSON"));
287 }
288
289 #[test]
290 fn tools_prompt_uses_simplifier() {
291 let tools = ToolRegistry::new().register(DummyTool);
292 let prompt = tools_prompt(&tools);
293 assert!(prompt.contains("### search"));
294 assert!(prompt.contains("search files"));
295 }
296}