Skip to main content

sgr_agent/agents/
tool_calling.rs

1//! ToolCallingAgent — uses native function calling (Gemini FC / OpenAI tools API).
2//!
3//! Sends tool definitions directly to the LLM's native function calling endpoint.
4//! Simplest agent variant — no schema building, no parsing.
5
6use crate::agent::{Agent, AgentError, Decision};
7use crate::client::LlmClient;
8use crate::registry::ToolRegistry;
9use crate::types::Message;
10
11/// Agent that uses native function calling.
12pub struct ToolCallingAgent<C: LlmClient> {
13    client: C,
14    system_prompt: String,
15}
16
17impl<C: LlmClient> ToolCallingAgent<C> {
18    pub fn new(client: C, system_prompt: impl Into<String>) -> Self {
19        Self {
20            client,
21            system_prompt: system_prompt.into(),
22        }
23    }
24}
25
26#[async_trait::async_trait]
27impl<C: LlmClient> Agent for ToolCallingAgent<C> {
28    async fn decide(
29        &self,
30        messages: &[Message],
31        tools: &ToolRegistry,
32    ) -> Result<Decision, AgentError> {
33        let defs = tools.to_defs();
34
35        let mut msgs = Vec::with_capacity(messages.len() + 1);
36        let has_system = messages
37            .iter()
38            .any(|m| m.role == crate::types::Role::System);
39        if !has_system && !self.system_prompt.is_empty() {
40            msgs.push(Message::system(&self.system_prompt));
41        }
42        msgs.extend_from_slice(messages);
43
44        let tool_calls = self.client.tools_call(&msgs, &defs).await?;
45        let completed =
46            tool_calls.is_empty() || tool_calls.iter().any(|tc| tc.name == "finish_task");
47
48        Ok(Decision {
49            situation: String::new(),
50            task: vec![],
51            tool_calls,
52            completed,
53        })
54    }
55}
56
57#[cfg(test)]
58mod tests {
59    use super::*;
60    use crate::agent_tool::{ToolError, ToolOutput};
61    use crate::client::LlmClient;
62    use crate::context::AgentContext;
63    use crate::tool::ToolDef;
64    use crate::types::{SgrError, ToolCall};
65    use serde_json::Value;
66
67    struct MockFcClient {
68        calls: Vec<ToolCall>,
69    }
70
71    #[async_trait::async_trait]
72    impl LlmClient for MockFcClient {
73        async fn structured_call(
74            &self,
75            _: &[Message],
76            _: &Value,
77        ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
78            Ok((None, vec![], String::new()))
79        }
80        async fn tools_call(
81            &self,
82            _: &[Message],
83            _: &[ToolDef],
84        ) -> Result<Vec<ToolCall>, SgrError> {
85            Ok(self.calls.clone())
86        }
87        async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
88            Ok(String::new())
89        }
90    }
91
92    struct DummyTool;
93
94    #[async_trait::async_trait]
95    impl crate::agent_tool::Tool for DummyTool {
96        fn name(&self) -> &str {
97            "bash"
98        }
99        fn description(&self) -> &str {
100            "run command"
101        }
102        fn parameters_schema(&self) -> Value {
103            serde_json::json!({"type": "object", "properties": {"command": {"type": "string"}}})
104        }
105        async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
106            Ok(ToolOutput::text("ok"))
107        }
108    }
109
110    #[tokio::test]
111    async fn tool_calling_agent_forwards_calls() {
112        let client = MockFcClient {
113            calls: vec![ToolCall {
114                id: "1".into(),
115                name: "bash".into(),
116                arguments: serde_json::json!({"command": "ls"}),
117            }],
118        };
119        let agent = ToolCallingAgent::new(client, "test");
120        let tools = ToolRegistry::new().register(DummyTool);
121        let msgs = vec![Message::user("list files")];
122
123        let decision = agent.decide(&msgs, &tools).await.unwrap();
124        assert_eq!(decision.tool_calls.len(), 1);
125        assert_eq!(decision.tool_calls[0].name, "bash");
126        assert!(!decision.completed);
127    }
128
129    #[tokio::test]
130    async fn tool_calling_agent_no_calls_completes() {
131        let client = MockFcClient { calls: vec![] };
132        let agent = ToolCallingAgent::new(client, "test");
133        let tools = ToolRegistry::new().register(DummyTool);
134        let msgs = vec![Message::user("done")];
135
136        let decision = agent.decide(&msgs, &tools).await.unwrap();
137        assert!(decision.completed);
138    }
139}