1use crate::agent::{Agent, AgentError, Decision};
12use crate::client::LlmClient;
13use crate::registry::ToolRegistry;
14use crate::types::Message;
15
16pub struct HybridAgent<C: LlmClient> {
18 client: C,
19 system_prompt: String,
20}
21
22impl<C: LlmClient> HybridAgent<C> {
23 pub fn new(client: C, system_prompt: impl Into<String>) -> Self {
24 Self {
25 client,
26 system_prompt: system_prompt.into(),
27 }
28 }
29}
30
31fn reasoning_tool_def() -> crate::tool::ToolDef {
33 crate::tool::ToolDef {
34 name: "reasoning".to_string(),
35 description: "Analyze the situation and decide what tools to use next. Describe your reasoning, the current situation, and which tools you plan to call.".to_string(),
36 parameters: serde_json::json!({
37 "type": "object",
38 "properties": {
39 "situation": {
40 "type": "string",
41 "description": "Your assessment of the current situation"
42 },
43 "plan": {
44 "type": "array",
45 "items": { "type": "string" },
46 "description": "Step-by-step plan of what to do next"
47 },
48 "done": {
49 "type": "boolean",
50 "description": "Set to true if the task is fully complete"
51 }
52 },
53 "required": ["situation", "plan", "done"]
54 }),
55 }
56}
57
58#[async_trait::async_trait]
59impl<C: LlmClient> Agent for HybridAgent<C> {
60 async fn decide(
61 &self,
62 messages: &[Message],
63 tools: &ToolRegistry,
64 ) -> Result<Decision, AgentError> {
65 self.decide_stateful(messages, tools, None)
66 .await
67 .map(|(d, _)| d)
68 }
69
70 async fn decide_stateful(
71 &self,
72 messages: &[Message],
73 tools: &ToolRegistry,
74 previous_response_id: Option<&str>,
75 ) -> Result<(Decision, Option<String>), AgentError> {
76 let mut msgs = Vec::with_capacity(messages.len() + 1);
78 let has_system = messages
79 .iter()
80 .any(|m| m.role == crate::types::Role::System);
81 if !has_system && !self.system_prompt.is_empty() {
82 msgs.push(Message::system(&self.system_prompt));
83 }
84 msgs.extend_from_slice(messages);
85
86 let reasoning_defs = vec![reasoning_tool_def()];
88 let reasoning_calls = self.client.tools_call(&msgs, &reasoning_defs).await?;
89
90 let (situation, plan, done) = if let Some(rc) = reasoning_calls.first() {
92 let sit = rc
93 .arguments
94 .get("situation")
95 .and_then(|s| s.as_str())
96 .unwrap_or("")
97 .to_string();
98 let plan: Vec<String> = rc
99 .arguments
100 .get("plan")
101 .and_then(|p| p.as_array())
102 .map(|arr| {
103 arr.iter()
104 .filter_map(|v| v.as_str().map(String::from))
105 .collect()
106 })
107 .unwrap_or_default();
108 let done = rc
109 .arguments
110 .get("done")
111 .and_then(|d| d.as_bool())
112 .unwrap_or(false);
113 (sit, plan, done)
114 } else {
115 return Ok((
116 Decision {
117 situation: String::new(),
118 task: vec![],
119 tool_calls: vec![],
120 completed: true,
121 },
122 None,
123 ));
124 };
125
126 let mut action_msgs = msgs.clone();
128 let reasoning_context = if done {
129 format!(
130 "Reasoning: {}\nStatus: Task appears complete. Call the answer/finish tool with the final result.",
131 situation
132 )
133 } else {
134 format!("Reasoning: {}\nPlan: {}", situation, plan.join(", "))
135 };
136 action_msgs.push(Message::assistant(&reasoning_context));
137 action_msgs.push(Message::user(
138 "Now execute the next step from your plan using the available tools.",
139 ));
140
141 let context_lower = format!("{} {}", situation, plan.join(" ")).to_lowercase();
144 let filtered: Vec<_> = tools
145 .to_defs()
146 .into_iter()
147 .filter(|t| {
148 t.name == "answer"
150 || t.name == "finish_task"
151 || t.name.contains("answer")
152 || context_lower.contains(&t.name.to_lowercase())
154 || matches!(t.name.as_str(), "read" | "write" | "search")
156 })
157 .collect();
158 let defs = if filtered.is_empty() {
159 tools.to_defs()
160 } else {
161 filtered
162 };
163
164 let (tool_calls, new_response_id) = self
165 .client
166 .tools_call_stateful(&action_msgs, &defs, previous_response_id)
167 .await?;
168
169 let completed =
170 tool_calls.is_empty() || tool_calls.iter().any(|tc| tc.name == "finish_task");
171
172 Ok((
173 Decision {
174 situation,
175 task: plan,
176 tool_calls,
177 completed,
178 },
179 new_response_id,
180 ))
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use crate::agent_tool::{Tool, ToolError, ToolOutput};
188 use crate::context::AgentContext;
189 use crate::tool::ToolDef;
190 use crate::types::{SgrError, ToolCall};
191 use serde_json::Value;
192 use std::sync::Arc;
193 use std::sync::atomic::{AtomicUsize, Ordering};
194
195 struct MockHybridClient {
197 call_count: Arc<AtomicUsize>,
198 }
199
200 #[async_trait::async_trait]
201 impl LlmClient for MockHybridClient {
202 async fn structured_call(
203 &self,
204 _: &[Message],
205 _: &Value,
206 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
207 Ok((None, vec![], String::new()))
208 }
209 async fn tools_call(
210 &self,
211 _: &[Message],
212 _tools: &[ToolDef],
213 ) -> Result<Vec<ToolCall>, SgrError> {
214 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
215 if n == 0 {
216 Ok(vec![ToolCall {
218 id: "r1".into(),
219 name: "reasoning".into(),
220 arguments: serde_json::json!({
221 "situation": "Need to read a file",
222 "plan": ["read main.rs", "analyze contents"],
223 "done": false
224 }),
225 }])
226 } else {
227 Ok(vec![ToolCall {
229 id: "a1".into(),
230 name: "read_file".into(),
231 arguments: serde_json::json!({"path": "main.rs"}),
232 }])
233 }
234 }
235 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
236 Ok(String::new())
237 }
238 }
239
240 struct DummyTool;
241 #[async_trait::async_trait]
242 impl Tool for DummyTool {
243 fn name(&self) -> &str {
244 "read_file"
245 }
246 fn description(&self) -> &str {
247 "read a file"
248 }
249 fn parameters_schema(&self) -> Value {
250 serde_json::json!({"type": "object", "properties": {"path": {"type": "string"}}})
251 }
252 async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
253 Ok(ToolOutput::text("file contents"))
254 }
255 }
256
257 #[tokio::test]
258 async fn hybrid_two_phases() {
259 let client = MockHybridClient {
260 call_count: Arc::new(AtomicUsize::new(0)),
261 };
262 let agent = HybridAgent::new(client, "test agent");
263 let tools = ToolRegistry::new().register(DummyTool);
264 let msgs = vec![Message::user("read main.rs")];
265
266 let decision = agent.decide(&msgs, &tools).await.unwrap();
267 assert_eq!(decision.situation, "Need to read a file");
268 assert_eq!(decision.task.len(), 2);
269 assert_eq!(decision.tool_calls.len(), 1);
270 assert_eq!(decision.tool_calls[0].name, "read_file");
271 assert!(!decision.completed);
272 }
273
274 #[tokio::test]
275 async fn hybrid_done_still_runs_phase2() {
276 struct DoneClient {
278 call_count: Arc<AtomicUsize>,
279 }
280 #[async_trait::async_trait]
281 impl LlmClient for DoneClient {
282 async fn structured_call(
283 &self,
284 _: &[Message],
285 _: &Value,
286 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
287 Ok((None, vec![], String::new()))
288 }
289 async fn tools_call(
290 &self,
291 _: &[Message],
292 _: &[ToolDef],
293 ) -> Result<Vec<ToolCall>, SgrError> {
294 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
295 if n == 0 {
296 Ok(vec![ToolCall {
297 id: "r1".into(),
298 name: "reasoning".into(),
299 arguments: serde_json::json!({
300 "situation": "Task is already complete",
301 "plan": [],
302 "done": true
303 }),
304 }])
305 } else {
306 Ok(vec![ToolCall {
308 id: "a1".into(),
309 name: "finish_task".into(),
310 arguments: serde_json::json!({"summary": "done"}),
311 }])
312 }
313 }
314 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
315 Ok(String::new())
316 }
317 }
318
319 let agent = HybridAgent::new(
320 DoneClient {
321 call_count: Arc::new(AtomicUsize::new(0)),
322 },
323 "test",
324 );
325 let tools = ToolRegistry::new().register(DummyTool);
326 let msgs = vec![Message::user("done")];
327
328 let decision = agent.decide(&msgs, &tools).await.unwrap();
329 assert!(decision.completed);
331 assert_eq!(decision.tool_calls.len(), 1);
332 assert_eq!(decision.tool_calls[0].name, "finish_task");
333 }
334
335 #[tokio::test]
336 async fn hybrid_no_reasoning_completes() {
337 struct EmptyClient;
338 #[async_trait::async_trait]
339 impl LlmClient for EmptyClient {
340 async fn structured_call(
341 &self,
342 _: &[Message],
343 _: &Value,
344 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
345 Ok((None, vec![], String::new()))
346 }
347 async fn tools_call(
348 &self,
349 _: &[Message],
350 _: &[ToolDef],
351 ) -> Result<Vec<ToolCall>, SgrError> {
352 Ok(vec![])
353 }
354 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
355 Ok(String::new())
356 }
357 }
358
359 let agent = HybridAgent::new(EmptyClient, "test");
360 let tools = ToolRegistry::new().register(DummyTool);
361 let msgs = vec![Message::user("hello")];
362
363 let decision = agent.decide(&msgs, &tools).await.unwrap();
364 assert!(decision.completed);
365 }
366
367 #[tokio::test]
368 async fn hybrid_two_phases_independent() {
369 struct PhaseTrackingClient {
373 call_count: Arc<AtomicUsize>,
374 }
375
376 #[async_trait::async_trait]
377 impl LlmClient for PhaseTrackingClient {
378 async fn structured_call(
379 &self,
380 _: &[Message],
381 _: &Value,
382 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
383 Ok((None, vec![], String::new()))
384 }
385 async fn tools_call(
386 &self,
387 msgs: &[Message],
388 tools: &[ToolDef],
389 ) -> Result<Vec<ToolCall>, SgrError> {
390 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
391 if n == 0 {
392 assert_eq!(tools.len(), 1, "Phase 1 should only have reasoning tool");
394 assert_eq!(tools[0].name, "reasoning");
395 Ok(vec![ToolCall {
396 id: "r1".into(),
397 name: "reasoning".into(),
398 arguments: serde_json::json!({
399 "situation": "Testing phase independence",
400 "plan": ["call read_file"],
401 "done": false
402 }),
403 }])
404 } else {
405 assert!(
407 tools.len() > 1 || tools[0].name != "reasoning",
408 "Phase 2 should have the real tools, not just reasoning"
409 );
410 let last_msg = msgs.last().unwrap();
413 assert_eq!(
414 last_msg.role,
415 crate::types::Role::User,
416 "Last message in phase 2 should be the action prompt"
417 );
418 Ok(vec![ToolCall {
419 id: "a1".into(),
420 name: "read_file".into(),
421 arguments: serde_json::json!({"path": "test.rs"}),
422 }])
423 }
424 }
425 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
426 Ok(String::new())
427 }
428 }
429
430 let call_count = Arc::new(AtomicUsize::new(0));
431 let agent = HybridAgent::new(
432 PhaseTrackingClient {
433 call_count: call_count.clone(),
434 },
435 "test agent",
436 );
437 let tools = ToolRegistry::new().register(DummyTool);
438 let msgs = vec![Message::user("read test.rs")];
439
440 let decision = agent.decide(&msgs, &tools).await.unwrap();
441
442 assert_eq!(call_count.load(Ordering::SeqCst), 2);
444 assert_eq!(decision.tool_calls.len(), 1);
446 assert_eq!(decision.tool_calls[0].name, "read_file");
447 assert!(!decision.completed);
448 }
449}