Skip to main content

sgr_agent/
agent_tool.rs

1//! Tool trait — the core abstraction for agent tools.
2//!
3//! Implement `Tool` for each capability you want to expose to the agent.
4//! Arguments arrive as `serde_json::Value`; use `parse_args` helper for typed deserialization.
5
6use crate::tool::ToolDef;
7use serde::de::DeserializeOwned;
8use serde_json::Value;
9
10/// Output from a tool execution.
11#[derive(Debug, Clone)]
12pub struct ToolOutput {
13    /// Human-readable result content.
14    pub content: String,
15    /// If true, the agent should stop (e.g. FinishTask tool).
16    pub done: bool,
17    /// If true, the loop should pause and wait for user input.
18    /// Content contains the question to ask.
19    pub waiting: bool,
20}
21
22impl ToolOutput {
23    pub fn text(content: impl Into<String>) -> Self {
24        Self {
25            content: content.into(),
26            done: false,
27            waiting: false,
28        }
29    }
30
31    pub fn done(content: impl Into<String>) -> Self {
32        Self {
33            content: content.into(),
34            done: true,
35            waiting: false,
36        }
37    }
38
39    /// Signal that the agent needs user input before continuing.
40    /// The content is the question to present to the user.
41    pub fn waiting(question: impl Into<String>) -> Self {
42        Self {
43            content: question.into(),
44            done: false,
45            waiting: true,
46        }
47    }
48}
49
50/// Errors from tool execution.
51#[derive(Debug, thiserror::Error)]
52pub enum ToolError {
53    #[error("{0}")]
54    Execution(String),
55    #[error("invalid args: {0}")]
56    InvalidArgs(String),
57}
58
59/// Parse JSON args into a typed struct. Use inside `Tool::execute`.
60pub fn parse_args<T: DeserializeOwned>(args: &Value) -> Result<T, ToolError> {
61    serde_json::from_value(args.clone()).map_err(|e| ToolError::InvalidArgs(e.to_string()))
62}
63
64/// A tool that an agent can invoke.
65#[async_trait::async_trait]
66pub trait Tool: Send + Sync {
67    /// Unique tool name (used as discriminator in LLM output).
68    fn name(&self) -> &str;
69
70    /// Human-readable description for the LLM.
71    fn description(&self) -> &str;
72
73    /// System tools are always visible (not subject to progressive discovery).
74    fn is_system(&self) -> bool {
75        false
76    }
77
78    /// Whether this tool only reads state (no side effects).
79    /// Read-only tools can be executed in parallel.
80    fn is_read_only(&self) -> bool {
81        false
82    }
83
84    /// JSON Schema for the tool's parameters.
85    fn parameters_schema(&self) -> Value;
86
87    /// Execute the tool with JSON arguments.
88    async fn execute(
89        &self,
90        args: Value,
91        ctx: &mut super::context::AgentContext,
92    ) -> Result<ToolOutput, ToolError>;
93
94    /// Execute without mutable context access. Used for parallel execution of read-only tools.
95    /// Default implementation panics — override if is_read_only() returns true.
96    async fn execute_readonly(&self, args: Value) -> Result<ToolOutput, ToolError> {
97        let _ = args;
98        panic!("execute_readonly called on tool that doesn't implement it")
99    }
100
101    /// Convert to a `ToolDef` for LLM API submission.
102    fn to_def(&self) -> ToolDef {
103        ToolDef {
104            name: self.name().to_string(),
105            description: self.description().to_string(),
106            parameters: self.parameters_schema(),
107        }
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use crate::context::AgentContext;
115    use serde::{Deserialize, Serialize};
116
117    #[derive(Debug, Serialize, Deserialize)]
118    struct EchoArgs {
119        message: String,
120    }
121
122    struct EchoTool;
123
124    #[async_trait::async_trait]
125    impl Tool for EchoTool {
126        fn name(&self) -> &str {
127            "echo"
128        }
129        fn description(&self) -> &str {
130            "Echo a message back"
131        }
132        fn parameters_schema(&self) -> Value {
133            serde_json::json!({
134                "type": "object",
135                "properties": {
136                    "message": { "type": "string" }
137                },
138                "required": ["message"]
139            })
140        }
141        async fn execute(
142            &self,
143            args: Value,
144            _ctx: &mut AgentContext,
145        ) -> Result<ToolOutput, ToolError> {
146            let a: EchoArgs = parse_args(&args)?;
147            Ok(ToolOutput::text(a.message))
148        }
149    }
150
151    #[test]
152    fn parse_args_valid() {
153        let args = serde_json::json!({"message": "hello"});
154        let parsed: EchoArgs = parse_args(&args).unwrap();
155        assert_eq!(parsed.message, "hello");
156    }
157
158    #[test]
159    fn parse_args_invalid() {
160        let args = serde_json::json!({"wrong_field": 42});
161        let result = parse_args::<EchoArgs>(&args);
162        assert!(result.is_err());
163        assert!(matches!(result.unwrap_err(), ToolError::InvalidArgs(_)));
164    }
165
166    #[test]
167    fn tool_to_def() {
168        let tool = EchoTool;
169        let def = tool.to_def();
170        assert_eq!(def.name, "echo");
171        assert_eq!(def.description, "Echo a message back");
172        assert!(def.parameters["properties"]["message"].is_object());
173    }
174
175    #[tokio::test]
176    async fn tool_execute() {
177        let tool = EchoTool;
178        let mut ctx = AgentContext::new();
179        let args = serde_json::json!({"message": "world"});
180        let output = tool.execute(args, &mut ctx).await.unwrap();
181        assert_eq!(output.content, "world");
182        assert!(!output.done);
183    }
184
185    #[test]
186    fn tool_output_done() {
187        let out = ToolOutput::done("finished");
188        assert!(out.done);
189        assert_eq!(out.content, "finished");
190    }
191}