sgr_agent/agents/
tool_calling.rs1use crate::agent::{Agent, AgentError, Decision};
7use crate::client::LlmClient;
8use crate::registry::ToolRegistry;
9use crate::types::Message;
10
11pub struct ToolCallingAgent<C: LlmClient> {
13 client: C,
14 system_prompt: String,
15}
16
17impl<C: LlmClient> ToolCallingAgent<C> {
18 pub fn new(client: C, system_prompt: impl Into<String>) -> Self {
19 Self {
20 client,
21 system_prompt: system_prompt.into(),
22 }
23 }
24}
25
26#[async_trait::async_trait]
27impl<C: LlmClient> Agent for ToolCallingAgent<C> {
28 async fn decide(
29 &self,
30 messages: &[Message],
31 tools: &ToolRegistry,
32 ) -> Result<Decision, AgentError> {
33 let defs = tools.to_defs();
34
35 let mut msgs = Vec::with_capacity(messages.len() + 1);
36 let has_system = messages
37 .iter()
38 .any(|m| m.role == crate::types::Role::System);
39 if !has_system && !self.system_prompt.is_empty() {
40 msgs.push(Message::system(&self.system_prompt));
41 }
42 msgs.extend_from_slice(messages);
43
44 let tool_calls = self.client.tools_call(&msgs, &defs).await?;
45 let completed =
46 tool_calls.is_empty() || tool_calls.iter().any(|tc| tc.name == "finish_task");
47
48 Ok(Decision {
49 situation: String::new(),
50 task: vec![],
51 tool_calls,
52 completed,
53 })
54 }
55}
56
57#[cfg(test)]
58mod tests {
59 use super::*;
60 use crate::agent_tool::{ToolError, ToolOutput};
61 use crate::client::LlmClient;
62 use crate::context::AgentContext;
63 use crate::tool::ToolDef;
64 use crate::types::{SgrError, ToolCall};
65 use serde_json::Value;
66
67 struct MockFcClient {
68 calls: Vec<ToolCall>,
69 }
70
71 #[async_trait::async_trait]
72 impl LlmClient for MockFcClient {
73 async fn structured_call(
74 &self,
75 _: &[Message],
76 _: &Value,
77 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
78 Ok((None, vec![], String::new()))
79 }
80 async fn tools_call(
81 &self,
82 _: &[Message],
83 _: &[ToolDef],
84 ) -> Result<Vec<ToolCall>, SgrError> {
85 Ok(self.calls.clone())
86 }
87 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
88 Ok(String::new())
89 }
90 }
91
92 struct DummyTool;
93
94 #[async_trait::async_trait]
95 impl crate::agent_tool::Tool for DummyTool {
96 fn name(&self) -> &str {
97 "bash"
98 }
99 fn description(&self) -> &str {
100 "run command"
101 }
102 fn parameters_schema(&self) -> Value {
103 serde_json::json!({"type": "object", "properties": {"command": {"type": "string"}}})
104 }
105 async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
106 Ok(ToolOutput::text("ok"))
107 }
108 }
109
110 #[tokio::test]
111 async fn tool_calling_agent_forwards_calls() {
112 let client = MockFcClient {
113 calls: vec![ToolCall {
114 id: "1".into(),
115 name: "bash".into(),
116 arguments: serde_json::json!({"command": "ls"}),
117 }],
118 };
119 let agent = ToolCallingAgent::new(client, "test");
120 let tools = ToolRegistry::new().register(DummyTool);
121 let msgs = vec![Message::user("list files")];
122
123 let decision = agent.decide(&msgs, &tools).await.unwrap();
124 assert_eq!(decision.tool_calls.len(), 1);
125 assert_eq!(decision.tool_calls[0].name, "bash");
126 assert!(!decision.completed);
127 }
128
129 #[tokio::test]
130 async fn tool_calling_agent_no_calls_completes() {
131 let client = MockFcClient { calls: vec![] };
132 let agent = ToolCallingAgent::new(client, "test");
133 let tools = ToolRegistry::new().register(DummyTool);
134 let msgs = vec![Message::user("done")];
135
136 let decision = agent.decide(&msgs, &tools).await.unwrap();
137 assert!(decision.completed);
138 }
139}