sgr_agent_core/
agent_tool.rs1use crate::tool::ToolDef;
4use serde::de::DeserializeOwned;
5use serde_json::Value;
6
7#[derive(Debug, Clone, Default)]
9pub struct ContextModifier {
10 pub system_injection: Option<String>,
11 pub max_tokens_override: Option<u32>,
12 pub custom_context: Vec<(String, serde_json::Value)>,
13 pub max_steps_delta: Option<i32>,
14}
15
16impl ContextModifier {
17 pub fn system(msg: impl Into<String>) -> Self {
18 Self {
19 system_injection: Some(msg.into()),
20 ..Default::default()
21 }
22 }
23
24 pub fn max_tokens(tokens: u32) -> Self {
25 Self {
26 max_tokens_override: Some(tokens),
27 ..Default::default()
28 }
29 }
30
31 pub fn custom(key: impl Into<String>, value: serde_json::Value) -> Self {
32 Self {
33 custom_context: vec![(key.into(), value)],
34 ..Default::default()
35 }
36 }
37
38 pub fn extra_steps(delta: i32) -> Self {
39 Self {
40 max_steps_delta: Some(delta),
41 ..Default::default()
42 }
43 }
44
45 pub fn is_empty(&self) -> bool {
46 self.system_injection.is_none()
47 && self.max_tokens_override.is_none()
48 && self.custom_context.is_empty()
49 && self.max_steps_delta.is_none()
50 }
51}
52
53#[derive(Debug, Clone)]
55pub struct ToolOutput {
56 pub content: String,
57 pub done: bool,
58 pub waiting: bool,
59 pub modifier: Option<ContextModifier>,
60}
61
62impl ToolOutput {
63 pub fn text(content: impl Into<String>) -> Self {
64 Self {
65 content: content.into(),
66 done: false,
67 waiting: false,
68 modifier: None,
69 }
70 }
71
72 pub fn done(content: impl Into<String>) -> Self {
73 Self {
74 content: content.into(),
75 done: true,
76 waiting: false,
77 modifier: None,
78 }
79 }
80
81 pub fn waiting(question: impl Into<String>) -> Self {
82 Self {
83 content: question.into(),
84 done: false,
85 waiting: true,
86 modifier: None,
87 }
88 }
89
90 pub fn with_modifier(mut self, modifier: ContextModifier) -> Self {
91 self.modifier = Some(modifier);
92 self
93 }
94}
95
96#[derive(Debug, thiserror::Error)]
98pub enum ToolError {
99 #[error("{0}")]
100 Execution(String),
101 #[error("invalid args: {0}")]
102 InvalidArgs(String),
103}
104
105pub fn parse_args<T: DeserializeOwned>(args: &Value) -> Result<T, ToolError> {
107 serde_json::from_value(args.clone()).map_err(|e| ToolError::InvalidArgs(e.to_string()))
108}
109
110#[async_trait::async_trait]
112pub trait Tool: Send + Sync {
113 fn name(&self) -> &str;
114 fn description(&self) -> &str;
115
116 fn is_system(&self) -> bool {
117 false
118 }
119 fn is_read_only(&self) -> bool {
120 false
121 }
122
123 fn parameters_schema(&self) -> Value;
124
125 async fn execute(
126 &self,
127 args: Value,
128 ctx: &mut crate::context::AgentContext,
129 ) -> Result<ToolOutput, ToolError>;
130
131 async fn execute_readonly(
135 &self,
136 args: Value,
137 ctx: &crate::context::AgentContext,
138 ) -> Result<ToolOutput, ToolError> {
139 let mut ctx_clone = ctx.clone();
140 self.execute(args, &mut ctx_clone).await
141 }
142
143 fn to_def(&self) -> ToolDef {
144 ToolDef {
145 name: self.name().to_string(),
146 description: self.description().to_string(),
147 parameters: self.parameters_schema(),
148 }
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use crate::context::AgentContext;
156 use serde::{Deserialize, Serialize};
157
158 #[derive(Debug, Serialize, Deserialize)]
159 struct EchoArgs {
160 message: String,
161 }
162
163 struct EchoTool;
164
165 #[async_trait::async_trait]
166 impl Tool for EchoTool {
167 fn name(&self) -> &str {
168 "echo"
169 }
170 fn description(&self) -> &str {
171 "Echo a message back"
172 }
173 fn parameters_schema(&self) -> Value {
174 serde_json::json!({
175 "type": "object",
176 "properties": { "message": { "type": "string" } },
177 "required": ["message"]
178 })
179 }
180 async fn execute(
181 &self,
182 args: Value,
183 _ctx: &mut AgentContext,
184 ) -> Result<ToolOutput, ToolError> {
185 let a: EchoArgs = parse_args(&args)?;
186 Ok(ToolOutput::text(a.message))
187 }
188 }
189
190 #[test]
191 fn parse_args_valid() {
192 let args = serde_json::json!({"message": "hello"});
193 let parsed: EchoArgs = parse_args(&args).unwrap();
194 assert_eq!(parsed.message, "hello");
195 }
196
197 #[test]
198 fn parse_args_invalid() {
199 let result = parse_args::<EchoArgs>(&serde_json::json!({"wrong": 42}));
200 assert!(matches!(result.unwrap_err(), ToolError::InvalidArgs(_)));
201 }
202
203 #[tokio::test]
204 async fn tool_execute() {
205 let tool = EchoTool;
206 let mut ctx = AgentContext::new();
207 let output = tool
208 .execute(serde_json::json!({"message": "world"}), &mut ctx)
209 .await
210 .unwrap();
211 assert_eq!(output.content, "world");
212 }
213}