1use crate::agent::{Agent, AgentError, Decision};
7use crate::client::LlmClient;
8use crate::registry::ToolRegistry;
9use crate::types::Message;
10use crate::union_schema;
11
12pub struct SgrAgent<C: LlmClient> {
17 client: C,
18 system_prompt: String,
19}
20
21impl<C: LlmClient> SgrAgent<C> {
22 pub fn new(client: C, system_prompt: impl Into<String>) -> Self {
23 Self {
24 client,
25 system_prompt: system_prompt.into(),
26 }
27 }
28}
29
30#[async_trait::async_trait]
31impl<C: LlmClient> Agent for SgrAgent<C> {
32 async fn decide(
33 &self,
34 messages: &[Message],
35 tools: &ToolRegistry,
36 ) -> Result<Decision, AgentError> {
37 let defs = tools.to_defs();
38 let schema = union_schema::build_action_schema(&defs);
39
40 let mut msgs = Vec::with_capacity(messages.len() + 1);
42 let has_system = messages
43 .iter()
44 .any(|m| m.role == crate::types::Role::System);
45 if !has_system && !self.system_prompt.is_empty() {
46 msgs.push(Message::system(&self.system_prompt));
47 }
48 msgs.extend_from_slice(messages);
49
50 let (output, native_calls, raw) = self.client.structured_call(&msgs, &schema).await?;
51
52 if let Some(val) = output
54 && let Ok((situation, tool_calls)) = union_schema::parse_action(&val.to_string(), &defs)
55 {
56 let completed =
57 tool_calls.is_empty() || tool_calls.iter().any(|tc| tc.name == "finish_task");
58 return Ok(Decision {
59 situation,
60 task: vec![],
61 tool_calls,
62 completed,
63 });
64 }
65
66 if !native_calls.is_empty() {
68 let completed = native_calls.iter().any(|tc| tc.name == "finish_task");
69 return Ok(Decision {
70 situation: String::new(),
71 task: vec![],
72 tool_calls: native_calls,
73 completed,
74 });
75 }
76
77 if let Ok((situation, tool_calls)) = union_schema::parse_action(&raw, &defs) {
79 let completed =
80 tool_calls.is_empty() || tool_calls.iter().any(|tc| tc.name == "finish_task");
81 return Ok(Decision {
82 situation,
83 task: vec![],
84 tool_calls,
85 completed,
86 });
87 }
88
89 Ok(Decision {
91 situation: raw,
92 task: vec![],
93 tool_calls: vec![],
94 completed: true,
95 })
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102 use crate::agent_tool::{ToolError, ToolOutput};
103 use crate::context::AgentContext;
104 use crate::tool::ToolDef;
105 use crate::types::{SgrError, ToolCall};
106 use serde_json::Value;
107
108 struct MockClient {
109 response: String,
110 }
111
112 #[async_trait::async_trait]
113 impl LlmClient for MockClient {
114 async fn structured_call(
115 &self,
116 _messages: &[Message],
117 _schema: &Value,
118 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
119 let val: Value = serde_json::from_str(&self.response).unwrap_or(Value::Null);
120 Ok((Some(val), vec![], self.response.clone()))
121 }
122 async fn tools_call(
123 &self,
124 _messages: &[Message],
125 _tools: &[ToolDef],
126 ) -> Result<Vec<ToolCall>, SgrError> {
127 Ok(vec![])
128 }
129 async fn complete(&self, _messages: &[Message]) -> Result<String, SgrError> {
130 Ok(self.response.clone())
131 }
132 }
133
134 struct DummyTool(&'static str);
135
136 #[async_trait::async_trait]
137 impl crate::agent_tool::Tool for DummyTool {
138 fn name(&self) -> &str {
139 self.0
140 }
141 fn description(&self) -> &str {
142 "dummy"
143 }
144 fn parameters_schema(&self) -> Value {
145 serde_json::json!({"type": "object", "properties": {"arg": {"type": "string"}}})
146 }
147 async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
148 Ok(ToolOutput::text("ok"))
149 }
150 }
151
152 #[tokio::test]
153 async fn sgr_agent_parses_structured_output() {
154 let client = MockClient {
155 response: r#"{"situation":"reading file","task":["read"],"actions":[{"tool_name":"read","arg":"main.rs"}]}"#.into(),
156 };
157 let agent = SgrAgent::new(client, "You are a test agent");
158 let tools = ToolRegistry::new().register(DummyTool("read"));
159 let msgs = vec![Message::user("read main.rs")];
160
161 let decision = agent.decide(&msgs, &tools).await.unwrap();
162 assert_eq!(decision.situation, "reading file");
163 assert_eq!(decision.tool_calls.len(), 1);
164 assert_eq!(decision.tool_calls[0].name, "read");
165 assert!(!decision.completed);
166 }
167
168 #[tokio::test]
169 async fn sgr_agent_empty_actions_completes() {
170 let client = MockClient {
171 response: r#"{"situation":"done","task":[],"actions":[]}"#.into(),
172 };
173 let agent = SgrAgent::new(client, "test");
174 let tools = ToolRegistry::new().register(DummyTool("read"));
175 let msgs = vec![Message::user("done")];
176
177 let decision = agent.decide(&msgs, &tools).await.unwrap();
178 assert!(decision.completed);
179 }
180}