sgr_agent_core/
agent_tool.rs1use crate::tool::ToolDef;
4use serde::de::DeserializeOwned;
5use serde_json::Value;
6
7#[derive(Debug, Clone, Default)]
9pub struct ContextModifier {
10 pub system_injection: Option<String>,
11 pub max_tokens_override: Option<u32>,
12 pub custom_context: Vec<(String, serde_json::Value)>,
13 pub max_steps_delta: Option<i32>,
14}
15
16impl ContextModifier {
17 pub fn system(msg: impl Into<String>) -> Self {
18 Self {
19 system_injection: Some(msg.into()),
20 ..Default::default()
21 }
22 }
23
24 pub fn max_tokens(tokens: u32) -> Self {
25 Self {
26 max_tokens_override: Some(tokens),
27 ..Default::default()
28 }
29 }
30
31 pub fn custom(key: impl Into<String>, value: serde_json::Value) -> Self {
32 Self {
33 custom_context: vec![(key.into(), value)],
34 ..Default::default()
35 }
36 }
37
38 pub fn extra_steps(delta: i32) -> Self {
39 Self {
40 max_steps_delta: Some(delta),
41 ..Default::default()
42 }
43 }
44
45 pub fn is_empty(&self) -> bool {
46 self.system_injection.is_none()
47 && self.max_tokens_override.is_none()
48 && self.custom_context.is_empty()
49 && self.max_steps_delta.is_none()
50 }
51}
52
53#[derive(Debug, Clone)]
55pub struct ToolOutput {
56 pub content: String,
57 pub done: bool,
58 pub waiting: bool,
59 pub modifier: Option<ContextModifier>,
60}
61
62impl ToolOutput {
63 pub fn text(content: impl Into<String>) -> Self {
64 Self {
65 content: content.into(),
66 done: false,
67 waiting: false,
68 modifier: None,
69 }
70 }
71
72 pub fn done(content: impl Into<String>) -> Self {
73 Self {
74 content: content.into(),
75 done: true,
76 waiting: false,
77 modifier: None,
78 }
79 }
80
81 pub fn waiting(question: impl Into<String>) -> Self {
82 Self {
83 content: question.into(),
84 done: false,
85 waiting: true,
86 modifier: None,
87 }
88 }
89
90 pub fn with_modifier(mut self, modifier: ContextModifier) -> Self {
91 self.modifier = Some(modifier);
92 self
93 }
94}
95
96#[derive(Debug, thiserror::Error)]
98pub enum ToolError {
99 #[error("{0}")]
101 Execution(String),
102 #[error("invalid args: {0}")]
104 InvalidArgs(String),
105 #[error("permission denied: {0}")]
107 PermissionDenied(String),
108 #[error("not found: {0}")]
110 NotFound(String),
111 #[error("timeout: {0}")]
113 Timeout(String),
114}
115
116impl ToolError {
117 pub fn exec(err: impl std::fmt::Display) -> Self {
119 Self::Execution(err.to_string())
120 }
121}
122
123pub fn parse_args<T: DeserializeOwned>(args: &Value) -> Result<T, ToolError> {
125 serde_json::from_value(args.clone()).map_err(|e| ToolError::InvalidArgs(e.to_string()))
126}
127
128#[async_trait::async_trait]
130pub trait Tool: Send + Sync {
131 fn name(&self) -> &str;
132 fn description(&self) -> &str;
133
134 fn is_system(&self) -> bool {
135 false
136 }
137 fn is_read_only(&self) -> bool {
138 false
139 }
140
141 fn parameters_schema(&self) -> Value;
142
143 async fn execute(
144 &self,
145 args: Value,
146 ctx: &mut crate::context::AgentContext,
147 ) -> Result<ToolOutput, ToolError>;
148
149 async fn execute_readonly(
153 &self,
154 args: Value,
155 ctx: &crate::context::AgentContext,
156 ) -> Result<ToolOutput, ToolError> {
157 let mut ctx_clone = ctx.clone();
158 self.execute(args, &mut ctx_clone).await
159 }
160
161 fn to_def(&self) -> ToolDef {
162 ToolDef {
163 name: self.name().to_string(),
164 description: self.description().to_string(),
165 parameters: self.parameters_schema(),
166 }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::context::AgentContext;
174 use serde::{Deserialize, Serialize};
175
176 #[derive(Debug, Serialize, Deserialize)]
177 struct EchoArgs {
178 message: String,
179 }
180
181 struct EchoTool;
182
183 #[async_trait::async_trait]
184 impl Tool for EchoTool {
185 fn name(&self) -> &str {
186 "echo"
187 }
188 fn description(&self) -> &str {
189 "Echo a message back"
190 }
191 fn parameters_schema(&self) -> Value {
192 serde_json::json!({
193 "type": "object",
194 "properties": { "message": { "type": "string" } },
195 "required": ["message"]
196 })
197 }
198 async fn execute(
199 &self,
200 args: Value,
201 _ctx: &mut AgentContext,
202 ) -> Result<ToolOutput, ToolError> {
203 let a: EchoArgs = parse_args(&args)?;
204 Ok(ToolOutput::text(a.message))
205 }
206 }
207
208 #[test]
209 fn parse_args_valid() {
210 let args = serde_json::json!({"message": "hello"});
211 let parsed: EchoArgs = parse_args(&args).unwrap();
212 assert_eq!(parsed.message, "hello");
213 }
214
215 #[test]
216 fn parse_args_invalid() {
217 let result = parse_args::<EchoArgs>(&serde_json::json!({"wrong": 42}));
218 assert!(matches!(result.unwrap_err(), ToolError::InvalidArgs(_)));
219 }
220
221 #[tokio::test]
222 async fn tool_execute() {
223 let tool = EchoTool;
224 let mut ctx = AgentContext::new();
225 let output = tool
226 .execute(serde_json::json!({"message": "world"}), &mut ctx)
227 .await
228 .unwrap();
229 assert_eq!(output.content, "world");
230 }
231}