Skip to main content

sgr_agent_core/
agent_tool.rs

1//! Tool trait — the core abstraction for agent tools.
2
3use crate::tool::ToolDef;
4use serde::de::DeserializeOwned;
5use serde_json::Value;
6
7/// Modifier that a tool can return to change agent runtime behavior.
8///
9/// Attach to `ToolOutput` via `.with_modifier()`. The agent loop applies
10/// these after tool execution — injecting system messages, adjusting token limits, etc.
11#[derive(Debug, Clone, Default)]
12pub struct ContextModifier {
13    /// Inject a message into context for the next LLM call.
14    pub system_injection: Option<String>,
15    /// Override max_tokens for subsequent calls.
16    pub max_tokens_override: Option<u32>,
17    /// Add key-value pairs to `AgentContext.custom`.
18    pub custom_context: Vec<(String, serde_json::Value)>,
19    /// Adjust remaining max_steps (positive = more steps allowed).
20    pub max_steps_delta: Option<i32>,
21}
22
23impl ContextModifier {
24    pub fn system(msg: impl Into<String>) -> Self {
25        Self {
26            system_injection: Some(msg.into()),
27            ..Default::default()
28        }
29    }
30
31    pub fn max_tokens(tokens: u32) -> Self {
32        Self {
33            max_tokens_override: Some(tokens),
34            ..Default::default()
35        }
36    }
37
38    pub fn custom(key: impl Into<String>, value: serde_json::Value) -> Self {
39        Self {
40            custom_context: vec![(key.into(), value)],
41            ..Default::default()
42        }
43    }
44
45    pub fn extra_steps(delta: i32) -> Self {
46        Self {
47            max_steps_delta: Some(delta),
48            ..Default::default()
49        }
50    }
51
52    pub fn is_empty(&self) -> bool {
53        self.system_injection.is_none()
54            && self.max_tokens_override.is_none()
55            && self.custom_context.is_empty()
56            && self.max_steps_delta.is_none()
57    }
58}
59
60/// Output from a tool execution.
61///
62/// Construct via `ToolOutput::text("result")`, `ToolOutput::done("finished")`,
63/// or `ToolOutput::waiting("question for user")`.
64#[derive(Debug, Clone)]
65pub struct ToolOutput {
66    pub content: String,
67    pub done: bool,
68    pub waiting: bool,
69    pub modifier: Option<ContextModifier>,
70}
71
72impl ToolOutput {
73    pub fn text(content: impl Into<String>) -> Self {
74        Self {
75            content: content.into(),
76            done: false,
77            waiting: false,
78            modifier: None,
79        }
80    }
81
82    pub fn done(content: impl Into<String>) -> Self {
83        Self {
84            content: content.into(),
85            done: true,
86            waiting: false,
87            modifier: None,
88        }
89    }
90
91    pub fn waiting(question: impl Into<String>) -> Self {
92        Self {
93            content: question.into(),
94            done: false,
95            waiting: true,
96            modifier: None,
97        }
98    }
99
100    pub fn with_modifier(mut self, modifier: ContextModifier) -> Self {
101        self.modifier = Some(modifier);
102        self
103    }
104}
105
106/// Errors from tool execution.
107#[derive(Debug, thiserror::Error)]
108pub enum ToolError {
109    /// Tool execution failed (I/O, network, logic error).
110    #[error("{0}")]
111    Execution(String),
112    /// Tool arguments failed to parse or validate.
113    #[error("invalid args: {0}")]
114    InvalidArgs(String),
115    /// Permission denied (sandbox, policy, auth).
116    #[error("permission denied: {0}")]
117    PermissionDenied(String),
118    /// Tool not found or not available.
119    #[error("not found: {0}")]
120    NotFound(String),
121    /// Timeout exceeded.
122    #[error("timeout: {0}")]
123    Timeout(String),
124}
125
126impl ToolError {
127    /// Create an execution error from any error type.
128    pub fn exec(err: impl std::fmt::Display) -> Self {
129        Self::Execution(err.to_string())
130    }
131}
132
133/// Parse JSON args into a typed struct. Use inside `Tool::execute`.
134///
135/// ```rust,ignore
136/// let args: MyArgs = parse_args(&args)?;
137/// ```
138pub fn parse_args<T: DeserializeOwned>(args: &Value) -> Result<T, ToolError> {
139    serde_json::from_value(args.clone()).map_err(|e| ToolError::InvalidArgs(e.to_string()))
140}
141
142/// A tool that an agent can invoke.
143///
144/// Implement this trait for each capability you want to expose to the LLM agent.
145/// Tools are registered in a `ToolRegistry` and dispatched by the agent loop.
146///
147/// Read-only tools (`is_read_only() -> true`) can execute in parallel.
148/// Write tools execute sequentially with exclusive `&mut AgentContext`.
149#[async_trait::async_trait]
150pub trait Tool: Send + Sync {
151    /// Unique tool name (used as discriminator in LLM function calling).
152    fn name(&self) -> &str;
153    /// Human-readable description shown to the LLM.
154    fn description(&self) -> &str;
155
156    /// System tools are always visible (not subject to progressive discovery).
157    fn is_system(&self) -> bool {
158        false
159    }
160    /// Read-only tools can execute in parallel via `execute_readonly`.
161    fn is_read_only(&self) -> bool {
162        false
163    }
164
165    /// JSON Schema for the tool's parameters (generated via `json_schema_for::<Args>()`).
166    fn parameters_schema(&self) -> Value;
167
168    async fn execute(
169        &self,
170        args: Value,
171        ctx: &mut crate::context::AgentContext,
172    ) -> Result<ToolOutput, ToolError>;
173
174    /// Execute without mutable context (for parallel read-only dispatch).
175    /// Default: delegates to `execute` with a cloned context. Override for true
176    /// read-only tools to avoid the clone.
177    async fn execute_readonly(
178        &self,
179        args: Value,
180        ctx: &crate::context::AgentContext,
181    ) -> Result<ToolOutput, ToolError> {
182        let mut ctx_clone = ctx.clone();
183        self.execute(args, &mut ctx_clone).await
184    }
185
186    fn to_def(&self) -> ToolDef {
187        ToolDef {
188            name: self.name().to_string(),
189            description: self.description().to_string(),
190            parameters: self.parameters_schema(),
191        }
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use crate::context::AgentContext;
199    use serde::{Deserialize, Serialize};
200
201    #[derive(Debug, Serialize, Deserialize)]
202    struct EchoArgs {
203        message: String,
204    }
205
206    struct EchoTool;
207
208    #[async_trait::async_trait]
209    impl Tool for EchoTool {
210        fn name(&self) -> &str {
211            "echo"
212        }
213        fn description(&self) -> &str {
214            "Echo a message back"
215        }
216        fn parameters_schema(&self) -> Value {
217            serde_json::json!({
218                "type": "object",
219                "properties": { "message": { "type": "string" } },
220                "required": ["message"]
221            })
222        }
223        async fn execute(
224            &self,
225            args: Value,
226            _ctx: &mut AgentContext,
227        ) -> Result<ToolOutput, ToolError> {
228            let a: EchoArgs = parse_args(&args)?;
229            Ok(ToolOutput::text(a.message))
230        }
231    }
232
233    #[test]
234    fn parse_args_valid() {
235        let args = serde_json::json!({"message": "hello"});
236        let parsed: EchoArgs = parse_args(&args).unwrap();
237        assert_eq!(parsed.message, "hello");
238    }
239
240    #[test]
241    fn parse_args_invalid() {
242        let result = parse_args::<EchoArgs>(&serde_json::json!({"wrong": 42}));
243        assert!(matches!(result.unwrap_err(), ToolError::InvalidArgs(_)));
244    }
245
246    #[tokio::test]
247    async fn tool_execute() {
248        let tool = EchoTool;
249        let mut ctx = AgentContext::new();
250        let output = tool
251            .execute(serde_json::json!({"message": "world"}), &mut ctx)
252            .await
253            .unwrap();
254        assert_eq!(output.content, "world");
255    }
256}