Skip to main content

sgr_agent/agents/
sgr.rs

1//! SgrAgent — structured output agent.
2//!
3//! Builds a discriminated union schema from the ToolRegistry, sends it via
4//! `structured_call`, parses the response into tool calls using `parse_action`.
5
6use crate::agent::{Agent, AgentError, Decision};
7use crate::client::LlmClient;
8use crate::registry::ToolRegistry;
9use crate::types::Message;
10use crate::union_schema;
11
12/// Agent that uses structured output (JSON Schema) for tool selection.
13///
14/// System prompt precedence: if `messages` already contains a `Role::System`
15/// message, the agent's `system_prompt` is NOT injected (user's takes priority).
16pub struct SgrAgent<C: LlmClient> {
17    client: C,
18    system_prompt: String,
19}
20
21impl<C: LlmClient> SgrAgent<C> {
22    pub fn new(client: C, system_prompt: impl Into<String>) -> Self {
23        Self {
24            client,
25            system_prompt: system_prompt.into(),
26        }
27    }
28}
29
30#[async_trait::async_trait]
31impl<C: LlmClient> Agent for SgrAgent<C> {
32    async fn decide(
33        &self,
34        messages: &[Message],
35        tools: &ToolRegistry,
36    ) -> Result<Decision, AgentError> {
37        let defs = tools.to_defs();
38        let schema = union_schema::build_action_schema(&defs);
39
40        // Prepend system prompt if not already in messages
41        let mut msgs = Vec::with_capacity(messages.len() + 1);
42        let has_system = messages
43            .iter()
44            .any(|m| m.role == crate::types::Role::System);
45        if !has_system && !self.system_prompt.is_empty() {
46            msgs.push(Message::system(&self.system_prompt));
47        }
48        msgs.extend_from_slice(messages);
49
50        let (output, native_calls, raw) = self.client.structured_call(&msgs, &schema).await?;
51
52        // Try to parse structured output first
53        if let Some(val) = output
54            && let Ok((situation, tool_calls)) = union_schema::parse_action(&val.to_string(), &defs)
55        {
56            let completed =
57                tool_calls.is_empty() || tool_calls.iter().any(|tc| tc.name == "finish_task");
58            return Ok(Decision {
59                situation,
60                task: vec![],
61                tool_calls,
62                completed,
63            });
64        }
65
66        // Fall back to native tool calls
67        if !native_calls.is_empty() {
68            let completed = native_calls.iter().any(|tc| tc.name == "finish_task");
69            return Ok(Decision {
70                situation: String::new(),
71                task: vec![],
72                tool_calls: native_calls,
73                completed,
74            });
75        }
76
77        // Try parsing raw text
78        if let Ok((situation, tool_calls)) = union_schema::parse_action(&raw, &defs) {
79            let completed =
80                tool_calls.is_empty() || tool_calls.iter().any(|tc| tc.name == "finish_task");
81            return Ok(Decision {
82                situation,
83                task: vec![],
84                tool_calls,
85                completed,
86            });
87        }
88
89        // No tool calls — completed
90        Ok(Decision {
91            situation: raw,
92            task: vec![],
93            tool_calls: vec![],
94            completed: true,
95        })
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use crate::agent_tool::{ToolError, ToolOutput};
103    use crate::context::AgentContext;
104    use crate::tool::ToolDef;
105    use crate::types::{SgrError, ToolCall};
106    use serde_json::Value;
107
108    struct MockClient {
109        response: String,
110    }
111
112    #[async_trait::async_trait]
113    impl LlmClient for MockClient {
114        async fn structured_call(
115            &self,
116            _messages: &[Message],
117            _schema: &Value,
118        ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
119            let val: Value = serde_json::from_str(&self.response).unwrap_or(Value::Null);
120            Ok((Some(val), vec![], self.response.clone()))
121        }
122        async fn tools_call(
123            &self,
124            _messages: &[Message],
125            _tools: &[ToolDef],
126        ) -> Result<Vec<ToolCall>, SgrError> {
127            Ok(vec![])
128        }
129        async fn complete(&self, _messages: &[Message]) -> Result<String, SgrError> {
130            Ok(self.response.clone())
131        }
132    }
133
134    struct DummyTool(&'static str);
135
136    #[async_trait::async_trait]
137    impl crate::agent_tool::Tool for DummyTool {
138        fn name(&self) -> &str {
139            self.0
140        }
141        fn description(&self) -> &str {
142            "dummy"
143        }
144        fn parameters_schema(&self) -> Value {
145            serde_json::json!({"type": "object", "properties": {"arg": {"type": "string"}}})
146        }
147        async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
148            Ok(ToolOutput::text("ok"))
149        }
150    }
151
152    #[tokio::test]
153    async fn sgr_agent_parses_structured_output() {
154        let client = MockClient {
155            response: r#"{"situation":"reading file","task":["read"],"actions":[{"tool_name":"read","arg":"main.rs"}]}"#.into(),
156        };
157        let agent = SgrAgent::new(client, "You are a test agent");
158        let tools = ToolRegistry::new().register(DummyTool("read"));
159        let msgs = vec![Message::user("read main.rs")];
160
161        let decision = agent.decide(&msgs, &tools).await.unwrap();
162        assert_eq!(decision.situation, "reading file");
163        assert_eq!(decision.tool_calls.len(), 1);
164        assert_eq!(decision.tool_calls[0].name, "read");
165        assert!(!decision.completed);
166    }
167
168    #[tokio::test]
169    async fn sgr_agent_empty_actions_completes() {
170        let client = MockClient {
171            response: r#"{"situation":"done","task":[],"actions":[]}"#.into(),
172        };
173        let agent = SgrAgent::new(client, "test");
174        let tools = ToolRegistry::new().register(DummyTool("read"));
175        let msgs = vec![Message::user("done")];
176
177        let decision = agent.decide(&msgs, &tools).await.unwrap();
178        assert!(decision.completed);
179    }
180}