sgr_agent_core/
agent_tool.rs1use crate::tool::ToolDef;
4use serde::de::DeserializeOwned;
5use serde_json::Value;
6
7#[derive(Debug, Clone, Default)]
12pub struct ContextModifier {
13 pub system_injection: Option<String>,
15 pub max_tokens_override: Option<u32>,
17 pub custom_context: Vec<(String, serde_json::Value)>,
19 pub max_steps_delta: Option<i32>,
21}
22
23impl ContextModifier {
24 pub fn system(msg: impl Into<String>) -> Self {
25 Self {
26 system_injection: Some(msg.into()),
27 ..Default::default()
28 }
29 }
30
31 pub fn max_tokens(tokens: u32) -> Self {
32 Self {
33 max_tokens_override: Some(tokens),
34 ..Default::default()
35 }
36 }
37
38 pub fn custom(key: impl Into<String>, value: serde_json::Value) -> Self {
39 Self {
40 custom_context: vec![(key.into(), value)],
41 ..Default::default()
42 }
43 }
44
45 pub fn extra_steps(delta: i32) -> Self {
46 Self {
47 max_steps_delta: Some(delta),
48 ..Default::default()
49 }
50 }
51
52 pub fn is_empty(&self) -> bool {
53 self.system_injection.is_none()
54 && self.max_tokens_override.is_none()
55 && self.custom_context.is_empty()
56 && self.max_steps_delta.is_none()
57 }
58}
59
60#[derive(Debug, Clone)]
65pub struct ToolOutput {
66 pub content: String,
67 pub done: bool,
68 pub waiting: bool,
69 pub modifier: Option<ContextModifier>,
70}
71
72impl ToolOutput {
73 pub fn text(content: impl Into<String>) -> Self {
74 Self {
75 content: content.into(),
76 done: false,
77 waiting: false,
78 modifier: None,
79 }
80 }
81
82 pub fn done(content: impl Into<String>) -> Self {
83 Self {
84 content: content.into(),
85 done: true,
86 waiting: false,
87 modifier: None,
88 }
89 }
90
91 pub fn waiting(question: impl Into<String>) -> Self {
92 Self {
93 content: question.into(),
94 done: false,
95 waiting: true,
96 modifier: None,
97 }
98 }
99
100 pub fn with_modifier(mut self, modifier: ContextModifier) -> Self {
101 self.modifier = Some(modifier);
102 self
103 }
104}
105
106#[derive(Debug, thiserror::Error)]
108pub enum ToolError {
109 #[error("{0}")]
111 Execution(String),
112 #[error("invalid args: {0}")]
114 InvalidArgs(String),
115 #[error("permission denied: {0}")]
117 PermissionDenied(String),
118 #[error("not found: {0}")]
120 NotFound(String),
121 #[error("timeout: {0}")]
123 Timeout(String),
124}
125
126impl ToolError {
127 pub fn exec(err: impl std::fmt::Display) -> Self {
129 Self::Execution(err.to_string())
130 }
131}
132
133pub fn parse_args<T: DeserializeOwned>(args: &Value) -> Result<T, ToolError> {
139 serde_json::from_value(args.clone()).map_err(|e| ToolError::InvalidArgs(e.to_string()))
140}
141
142#[async_trait::async_trait]
150pub trait Tool: Send + Sync {
151 fn name(&self) -> &str;
153 fn description(&self) -> &str;
155
156 fn is_system(&self) -> bool {
158 false
159 }
160 fn is_read_only(&self) -> bool {
162 false
163 }
164
165 fn parameters_schema(&self) -> Value;
167
168 async fn execute(
169 &self,
170 args: Value,
171 ctx: &mut crate::context::AgentContext,
172 ) -> Result<ToolOutput, ToolError>;
173
174 async fn execute_readonly(
178 &self,
179 args: Value,
180 ctx: &crate::context::AgentContext,
181 ) -> Result<ToolOutput, ToolError> {
182 let mut ctx_clone = ctx.clone();
183 self.execute(args, &mut ctx_clone).await
184 }
185
186 fn to_def(&self) -> ToolDef {
187 ToolDef {
188 name: self.name().to_string(),
189 description: self.description().to_string(),
190 parameters: self.parameters_schema(),
191 }
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use crate::context::AgentContext;
199 use serde::{Deserialize, Serialize};
200
201 #[derive(Debug, Serialize, Deserialize)]
202 struct EchoArgs {
203 message: String,
204 }
205
206 struct EchoTool;
207
208 #[async_trait::async_trait]
209 impl Tool for EchoTool {
210 fn name(&self) -> &str {
211 "echo"
212 }
213 fn description(&self) -> &str {
214 "Echo a message back"
215 }
216 fn parameters_schema(&self) -> Value {
217 serde_json::json!({
218 "type": "object",
219 "properties": { "message": { "type": "string" } },
220 "required": ["message"]
221 })
222 }
223 async fn execute(
224 &self,
225 args: Value,
226 _ctx: &mut AgentContext,
227 ) -> Result<ToolOutput, ToolError> {
228 let a: EchoArgs = parse_args(&args)?;
229 Ok(ToolOutput::text(a.message))
230 }
231 }
232
233 #[test]
234 fn parse_args_valid() {
235 let args = serde_json::json!({"message": "hello"});
236 let parsed: EchoArgs = parse_args(&args).unwrap();
237 assert_eq!(parsed.message, "hello");
238 }
239
240 #[test]
241 fn parse_args_invalid() {
242 let result = parse_args::<EchoArgs>(&serde_json::json!({"wrong": 42}));
243 assert!(matches!(result.unwrap_err(), ToolError::InvalidArgs(_)));
244 }
245
246 #[tokio::test]
247 async fn tool_execute() {
248 let tool = EchoTool;
249 let mut ctx = AgentContext::new();
250 let output = tool
251 .execute(serde_json::json!({"message": "world"}), &mut ctx)
252 .await
253 .unwrap();
254 assert_eq!(output.content, "world");
255 }
256}