sgr_agent_core/
agent_tool.rs1use crate::tool::ToolDef;
4use serde::de::DeserializeOwned;
5use serde_json::Value;
6
7#[derive(Debug, Clone, Default)]
9pub struct ContextModifier {
10 pub system_injection: Option<String>,
11 pub max_tokens_override: Option<u32>,
12 pub custom_context: Vec<(String, serde_json::Value)>,
13 pub max_steps_delta: Option<i32>,
14}
15
16impl ContextModifier {
17 pub fn system(msg: impl Into<String>) -> Self {
18 Self { system_injection: Some(msg.into()), ..Default::default() }
19 }
20
21 pub fn max_tokens(tokens: u32) -> Self {
22 Self { max_tokens_override: Some(tokens), ..Default::default() }
23 }
24
25 pub fn custom(key: impl Into<String>, value: serde_json::Value) -> Self {
26 Self { custom_context: vec![(key.into(), value)], ..Default::default() }
27 }
28
29 pub fn extra_steps(delta: i32) -> Self {
30 Self { max_steps_delta: Some(delta), ..Default::default() }
31 }
32
33 pub fn is_empty(&self) -> bool {
34 self.system_injection.is_none()
35 && self.max_tokens_override.is_none()
36 && self.custom_context.is_empty()
37 && self.max_steps_delta.is_none()
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct ToolOutput {
44 pub content: String,
45 pub done: bool,
46 pub waiting: bool,
47 pub modifier: Option<ContextModifier>,
48}
49
50impl ToolOutput {
51 pub fn text(content: impl Into<String>) -> Self {
52 Self { content: content.into(), done: false, waiting: false, modifier: None }
53 }
54
55 pub fn done(content: impl Into<String>) -> Self {
56 Self { content: content.into(), done: true, waiting: false, modifier: None }
57 }
58
59 pub fn waiting(question: impl Into<String>) -> Self {
60 Self { content: question.into(), done: false, waiting: true, modifier: None }
61 }
62
63 pub fn with_modifier(mut self, modifier: ContextModifier) -> Self {
64 self.modifier = Some(modifier);
65 self
66 }
67}
68
69#[derive(Debug, thiserror::Error)]
71pub enum ToolError {
72 #[error("{0}")]
73 Execution(String),
74 #[error("invalid args: {0}")]
75 InvalidArgs(String),
76}
77
78pub fn parse_args<T: DeserializeOwned>(args: &Value) -> Result<T, ToolError> {
80 serde_json::from_value(args.clone()).map_err(|e| ToolError::InvalidArgs(e.to_string()))
81}
82
83#[async_trait::async_trait]
85pub trait Tool: Send + Sync {
86 fn name(&self) -> &str;
87 fn description(&self) -> &str;
88
89 fn is_system(&self) -> bool { false }
90 fn is_read_only(&self) -> bool { false }
91
92 fn parameters_schema(&self) -> Value;
93
94 async fn execute(
95 &self,
96 args: Value,
97 ctx: &mut crate::context::AgentContext,
98 ) -> Result<ToolOutput, ToolError>;
99
100 async fn execute_readonly(
101 &self,
102 args: Value,
103 _ctx: &crate::context::AgentContext,
104 ) -> Result<ToolOutput, ToolError> {
105 let _ = args;
106 panic!("execute_readonly called on tool that doesn't implement it")
107 }
108
109 fn to_def(&self) -> ToolDef {
110 ToolDef {
111 name: self.name().to_string(),
112 description: self.description().to_string(),
113 parameters: self.parameters_schema(),
114 }
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use crate::context::AgentContext;
122 use serde::{Deserialize, Serialize};
123
124 #[derive(Debug, Serialize, Deserialize)]
125 struct EchoArgs { message: String }
126
127 struct EchoTool;
128
129 #[async_trait::async_trait]
130 impl Tool for EchoTool {
131 fn name(&self) -> &str { "echo" }
132 fn description(&self) -> &str { "Echo a message back" }
133 fn parameters_schema(&self) -> Value {
134 serde_json::json!({
135 "type": "object",
136 "properties": { "message": { "type": "string" } },
137 "required": ["message"]
138 })
139 }
140 async fn execute(&self, args: Value, _ctx: &mut AgentContext) -> Result<ToolOutput, ToolError> {
141 let a: EchoArgs = parse_args(&args)?;
142 Ok(ToolOutput::text(a.message))
143 }
144 }
145
146 #[test]
147 fn parse_args_valid() {
148 let args = serde_json::json!({"message": "hello"});
149 let parsed: EchoArgs = parse_args(&args).unwrap();
150 assert_eq!(parsed.message, "hello");
151 }
152
153 #[test]
154 fn parse_args_invalid() {
155 let result = parse_args::<EchoArgs>(&serde_json::json!({"wrong": 42}));
156 assert!(matches!(result.unwrap_err(), ToolError::InvalidArgs(_)));
157 }
158
159 #[tokio::test]
160 async fn tool_execute() {
161 let tool = EchoTool;
162 let mut ctx = AgentContext::new();
163 let output = tool.execute(serde_json::json!({"message": "world"}), &mut ctx).await.unwrap();
164 assert_eq!(output.content, "world");
165 }
166}