1use crate::agent::{Agent, AgentError, Decision};
12use crate::client::LlmClient;
13use crate::registry::ToolRegistry;
14use crate::types::Message;
15
16pub struct HybridAgent<C: LlmClient> {
18 client: C,
19 system_prompt: String,
20}
21
22impl<C: LlmClient> HybridAgent<C> {
23 pub fn new(client: C, system_prompt: impl Into<String>) -> Self {
24 Self {
25 client,
26 system_prompt: system_prompt.into(),
27 }
28 }
29}
30
31fn reasoning_tool_def() -> crate::tool::ToolDef {
33 crate::tool::ToolDef {
34 name: "reasoning".to_string(),
35 description: "Analyze the situation and decide what tools to use next. Describe your reasoning, the current situation, and which tools you plan to call.".to_string(),
36 parameters: serde_json::json!({
37 "type": "object",
38 "properties": {
39 "situation": {
40 "type": "string",
41 "description": "Your assessment of the current situation"
42 },
43 "plan": {
44 "type": "array",
45 "items": { "type": "string" },
46 "description": "Step-by-step plan of what to do next"
47 },
48 "done": {
49 "type": "boolean",
50 "description": "Set to true if the task is fully complete"
51 }
52 },
53 "required": ["situation", "plan", "done"]
54 }),
55 }
56}
57
58#[async_trait::async_trait]
59impl<C: LlmClient> Agent for HybridAgent<C> {
60 async fn decide(
61 &self,
62 messages: &[Message],
63 tools: &ToolRegistry,
64 ) -> Result<Decision, AgentError> {
65 let mut msgs = Vec::with_capacity(messages.len() + 1);
67 let has_system = messages
68 .iter()
69 .any(|m| m.role == crate::types::Role::System);
70 if !has_system && !self.system_prompt.is_empty() {
71 msgs.push(Message::system(&self.system_prompt));
72 }
73 msgs.extend_from_slice(messages);
74
75 let reasoning_defs = vec![reasoning_tool_def()];
77 let reasoning_calls = self.client.tools_call(&msgs, &reasoning_defs).await?;
78
79 let (situation, plan, done) = if let Some(rc) = reasoning_calls.first() {
81 let sit = rc
82 .arguments
83 .get("situation")
84 .and_then(|s| s.as_str())
85 .unwrap_or("")
86 .to_string();
87 let plan: Vec<String> = rc
88 .arguments
89 .get("plan")
90 .and_then(|p| p.as_array())
91 .map(|arr| {
92 arr.iter()
93 .filter_map(|v| v.as_str().map(String::from))
94 .collect()
95 })
96 .unwrap_or_default();
97 let done = rc
98 .arguments
99 .get("done")
100 .and_then(|d| d.as_bool())
101 .unwrap_or(false);
102 (sit, plan, done)
103 } else {
104 return Ok(Decision {
106 situation: String::new(),
107 task: vec![],
108 tool_calls: vec![],
109 completed: true,
110 });
111 };
112
113 if done {
115 return Ok(Decision {
116 situation,
117 task: plan,
118 tool_calls: vec![],
119 completed: true,
120 });
121 }
122
123 let mut action_msgs = msgs.clone();
125 let reasoning_context = format!("Reasoning: {}\nPlan: {}", situation, plan.join(", "));
127 action_msgs.push(Message::assistant(&reasoning_context));
128 action_msgs.push(Message::user(
130 "Now execute the next step from your plan using the available tools.",
131 ));
132
133 let defs = tools.to_defs();
134 let tool_calls = self.client.tools_call(&action_msgs, &defs).await?;
135
136 let completed =
137 tool_calls.is_empty() || tool_calls.iter().any(|tc| tc.name == "finish_task");
138
139 Ok(Decision {
140 situation,
141 task: plan,
142 tool_calls,
143 completed,
144 })
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::agent_tool::{Tool, ToolError, ToolOutput};
152 use crate::context::AgentContext;
153 use crate::tool::ToolDef;
154 use crate::types::{SgrError, ToolCall};
155 use serde_json::Value;
156 use std::sync::Arc;
157 use std::sync::atomic::{AtomicUsize, Ordering};
158
159 struct MockHybridClient {
161 call_count: Arc<AtomicUsize>,
162 }
163
164 #[async_trait::async_trait]
165 impl LlmClient for MockHybridClient {
166 async fn structured_call(
167 &self,
168 _: &[Message],
169 _: &Value,
170 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
171 Ok((None, vec![], String::new()))
172 }
173 async fn tools_call(
174 &self,
175 _: &[Message],
176 _tools: &[ToolDef],
177 ) -> Result<Vec<ToolCall>, SgrError> {
178 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
179 if n == 0 {
180 Ok(vec![ToolCall {
182 id: "r1".into(),
183 name: "reasoning".into(),
184 arguments: serde_json::json!({
185 "situation": "Need to read a file",
186 "plan": ["read main.rs", "analyze contents"],
187 "done": false
188 }),
189 }])
190 } else {
191 Ok(vec![ToolCall {
193 id: "a1".into(),
194 name: "read_file".into(),
195 arguments: serde_json::json!({"path": "main.rs"}),
196 }])
197 }
198 }
199 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
200 Ok(String::new())
201 }
202 }
203
204 struct DummyTool;
205 #[async_trait::async_trait]
206 impl Tool for DummyTool {
207 fn name(&self) -> &str {
208 "read_file"
209 }
210 fn description(&self) -> &str {
211 "read a file"
212 }
213 fn parameters_schema(&self) -> Value {
214 serde_json::json!({"type": "object", "properties": {"path": {"type": "string"}}})
215 }
216 async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
217 Ok(ToolOutput::text("file contents"))
218 }
219 }
220
221 #[tokio::test]
222 async fn hybrid_two_phases() {
223 let client = MockHybridClient {
224 call_count: Arc::new(AtomicUsize::new(0)),
225 };
226 let agent = HybridAgent::new(client, "test agent");
227 let tools = ToolRegistry::new().register(DummyTool);
228 let msgs = vec![Message::user("read main.rs")];
229
230 let decision = agent.decide(&msgs, &tools).await.unwrap();
231 assert_eq!(decision.situation, "Need to read a file");
232 assert_eq!(decision.task.len(), 2);
233 assert_eq!(decision.tool_calls.len(), 1);
234 assert_eq!(decision.tool_calls[0].name, "read_file");
235 assert!(!decision.completed);
236 }
237
238 #[tokio::test]
239 async fn hybrid_done_in_reasoning() {
240 struct DoneClient;
241 #[async_trait::async_trait]
242 impl LlmClient for DoneClient {
243 async fn structured_call(
244 &self,
245 _: &[Message],
246 _: &Value,
247 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
248 Ok((None, vec![], String::new()))
249 }
250 async fn tools_call(
251 &self,
252 _: &[Message],
253 _: &[ToolDef],
254 ) -> Result<Vec<ToolCall>, SgrError> {
255 Ok(vec![ToolCall {
256 id: "r1".into(),
257 name: "reasoning".into(),
258 arguments: serde_json::json!({
259 "situation": "Task is already complete",
260 "plan": [],
261 "done": true
262 }),
263 }])
264 }
265 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
266 Ok(String::new())
267 }
268 }
269
270 let agent = HybridAgent::new(DoneClient, "test");
271 let tools = ToolRegistry::new().register(DummyTool);
272 let msgs = vec![Message::user("done")];
273
274 let decision = agent.decide(&msgs, &tools).await.unwrap();
275 assert!(decision.completed);
276 assert!(decision.tool_calls.is_empty());
277 }
278
279 #[tokio::test]
280 async fn hybrid_no_reasoning_completes() {
281 struct EmptyClient;
282 #[async_trait::async_trait]
283 impl LlmClient for EmptyClient {
284 async fn structured_call(
285 &self,
286 _: &[Message],
287 _: &Value,
288 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
289 Ok((None, vec![], String::new()))
290 }
291 async fn tools_call(
292 &self,
293 _: &[Message],
294 _: &[ToolDef],
295 ) -> Result<Vec<ToolCall>, SgrError> {
296 Ok(vec![])
297 }
298 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
299 Ok(String::new())
300 }
301 }
302
303 let agent = HybridAgent::new(EmptyClient, "test");
304 let tools = ToolRegistry::new().register(DummyTool);
305 let msgs = vec![Message::user("hello")];
306
307 let decision = agent.decide(&msgs, &tools).await.unwrap();
308 assert!(decision.completed);
309 }
310}