swiftide_agents/tools/
control.rs1use anyhow::Result;
3use async_trait::async_trait;
4use schemars::{Schema, schema_for};
5use std::borrow::Cow;
6use swiftide_core::{
7 AgentContext, ToolFeedback,
8 chat_completion::{Tool, ToolCall, ToolOutput, ToolSpec, errors::ToolError},
9};
10
11#[derive(Clone, Debug, Default)]
13pub struct Stop {}
14
15#[async_trait]
16impl Tool for Stop {
17 async fn invoke(
18 &self,
19 _agent_context: &dyn AgentContext,
20 _tool_call: &ToolCall,
21 ) -> Result<ToolOutput, ToolError> {
22 Ok(ToolOutput::stop())
23 }
24
25 fn name(&self) -> Cow<'_, str> {
26 "stop".into()
27 }
28
29 fn tool_spec(&self) -> ToolSpec {
30 ToolSpec::builder()
31 .name("stop")
32 .description("When you have completed, or cannot complete, your task, call this")
33 .build()
34 .unwrap()
35 }
36}
37
38impl From<Stop> for Box<dyn Tool> {
39 fn from(val: Stop) -> Self {
40 Box::new(val)
41 }
42}
43
44#[derive(Clone, Debug)]
46pub struct StopWithArgs {
47 parameters_schema: Option<Schema>,
48 expects_output_field: bool,
49}
50
51impl Default for StopWithArgs {
52 fn default() -> Self {
53 Self {
54 parameters_schema: Some(schema_for!(DefaultStopWithArgsSpec)),
55 expects_output_field: true,
56 }
57 }
58}
59
60impl StopWithArgs {
61 pub fn with_parameters_schema(schema: Schema) -> Self {
66 Self {
67 parameters_schema: Some(schema),
68 expects_output_field: false,
69 }
70 }
71
72 fn parameters_schema(&self) -> Schema {
73 self.parameters_schema
74 .clone()
75 .unwrap_or_else(|| schema_for!(DefaultStopWithArgsSpec))
76 }
77}
78
79#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
80struct DefaultStopWithArgsSpec {
81 pub output: String,
82}
83
84#[async_trait]
85impl Tool for StopWithArgs {
86 async fn invoke(
87 &self,
88 _agent_context: &dyn AgentContext,
89 tool_call: &ToolCall,
90 ) -> Result<ToolOutput, ToolError> {
91 let raw_args = tool_call
92 .args()
93 .ok_or_else(|| ToolError::missing_arguments("arguments"))?;
94
95 let json: serde_json::Value = serde_json::from_str(raw_args)?;
96
97 let output = if self.expects_output_field {
98 json.get("output")
99 .cloned()
100 .ok_or_else(|| ToolError::missing_arguments("output"))?
101 } else {
102 json
103 };
104
105 Ok(ToolOutput::stop_with_args(output))
106 }
107
108 fn name(&self) -> Cow<'_, str> {
109 "stop".into()
110 }
111
112 fn tool_spec(&self) -> ToolSpec {
113 let schema = self.parameters_schema();
114
115 ToolSpec::builder()
116 .name("stop")
117 .description("When you have completed, your task, call this with your expected output")
118 .parameters_schema(schema)
119 .build()
120 .unwrap()
121 }
122}
123
124impl From<StopWithArgs> for Box<dyn Tool> {
125 fn from(val: StopWithArgs) -> Self {
126 Box::new(val)
127 }
128}
129
130#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
131struct AgentFailedArgsSpec {
132 pub reason: String,
133}
134
135#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, Default)]
139pub struct AgentCanFail {}
140
141#[async_trait]
142impl Tool for AgentCanFail {
143 async fn invoke(
144 &self,
145 _agent_context: &dyn AgentContext,
146 tool_call: &ToolCall,
147 ) -> Result<ToolOutput, ToolError> {
148 let args: AgentFailedArgsSpec = serde_json::from_str(
149 tool_call
150 .args()
151 .ok_or(ToolError::missing_arguments("reason"))?,
152 )?;
153
154 Ok(ToolOutput::agent_failed(args.reason))
155 }
156
157 fn name(&self) -> Cow<'_, str> {
158 "task_failed".into()
159 }
160
161 fn tool_spec(&self) -> ToolSpec {
162 ToolSpec::builder()
163 .name("task_failed")
164 .description("If you cannot complete your task, or have otherwise failed, call this with your reason for failure")
165 .parameters_schema(schema_for!(AgentFailedArgsSpec))
166 .build()
167 .unwrap()
168 }
169}
170
171impl From<AgentCanFail> for Box<dyn Tool> {
172 fn from(val: AgentCanFail) -> Self {
173 Box::new(val)
174 }
175}
176
177#[derive(Clone)]
178pub struct ApprovalRequired(pub Box<dyn Tool>);
180
181impl ApprovalRequired {
182 pub fn new(tool: impl Tool + 'static) -> Self {
184 Self(Box::new(tool))
185 }
186}
187
188#[async_trait]
189impl Tool for ApprovalRequired {
190 async fn invoke(
191 &self,
192 context: &dyn AgentContext,
193 tool_call: &ToolCall,
194 ) -> Result<ToolOutput, ToolError> {
195 if let Some(feedback) = context.has_received_feedback(tool_call).await {
196 match feedback {
197 ToolFeedback::Approved { .. } => return self.0.invoke(context, tool_call).await,
198 ToolFeedback::Refused { .. } => {
199 return Ok(ToolOutput::text("This tool call was refused"));
200 }
201 }
202 }
203
204 Ok(ToolOutput::FeedbackRequired(None))
205 }
206
207 fn name(&self) -> Cow<'_, str> {
208 self.0.name()
209 }
210
211 fn tool_spec(&self) -> ToolSpec {
212 self.0.tool_spec()
213 }
214}
215
216impl From<ApprovalRequired> for Box<dyn Tool> {
217 fn from(val: ApprovalRequired) -> Self {
218 Box::new(val)
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use schemars::schema_for;
226 use serde_json::json;
227
228 fn dummy_tool_call(name: &str, args: Option<&str>) -> ToolCall {
229 let mut builder = ToolCall::builder().name(name).id("1").to_owned();
230 if let Some(args) = args {
231 builder.args(args.to_string());
232 }
233 builder.build().unwrap()
234 }
235
236 #[tokio::test]
237 async fn test_stop_tool() {
238 let stop = Stop::default();
239 let ctx = ();
240 let tool_call = dummy_tool_call("stop", None);
241 let out = stop.invoke(&ctx, &tool_call).await.unwrap();
242 assert_eq!(out, ToolOutput::stop());
243 }
244
245 #[tokio::test]
246 async fn test_stop_with_args_tool() {
247 let tool = StopWithArgs::default();
248 let ctx = ();
249 let args = r#"{"output":"expected result"}"#;
250 let tool_call = dummy_tool_call("stop", Some(args));
251 let out = tool.invoke(&ctx, &tool_call).await.unwrap();
252 assert_eq!(out, ToolOutput::stop_with_args("expected result"));
253 }
254
255 #[tokio::test]
256 async fn test_agent_can_fail_tool() {
257 let tool = AgentCanFail::default();
258 let ctx = ();
259 let args = r#"{"reason":"something went wrong"}"#;
260 let tool_call = dummy_tool_call("task_failed", Some(args));
261 let out = tool.invoke(&ctx, &tool_call).await.unwrap();
262 assert_eq!(out, ToolOutput::agent_failed("something went wrong"));
263 }
264
265 #[tokio::test]
266 async fn test_approval_required_feedback_required() {
267 let stop = Stop::default();
268 let tool = ApprovalRequired::new(stop);
269 let ctx = ();
270 let tool_call = dummy_tool_call("stop", None);
271 let out = tool.invoke(&ctx, &tool_call).await.unwrap();
272
273 assert_eq!(out, ToolOutput::Stop(None));
275 }
276
277 #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
278 struct CustomStopArgs {
279 value: i32,
280 }
281
282 #[test]
283 fn test_stop_with_args_custom_schema_in_spec() {
284 let schema = schema_for!(CustomStopArgs);
285 let tool = StopWithArgs::with_parameters_schema(schema.clone());
286 let spec = tool.tool_spec();
287 assert_eq!(spec.parameters_schema, Some(schema));
288 }
289
290 #[tokio::test]
291 async fn test_stop_with_args_custom_schema_forwards_payload() {
292 let schema = schema_for!(CustomStopArgs);
293 let tool = StopWithArgs::with_parameters_schema(schema);
294 let ctx = ();
295 let args = r#"{"value":42}"#;
296 let tool_call = dummy_tool_call("stop", Some(args));
297 let out = tool.invoke(&ctx, &tool_call).await.unwrap();
298 assert_eq!(out, ToolOutput::stop_with_args(json!({"value": 42})));
299 }
300
301 #[test]
302 fn test_stop_with_args_default_schema_matches_previous() {
303 let tool = StopWithArgs::default();
304 let spec = tool.tool_spec();
305 let expected = schema_for!(DefaultStopWithArgsSpec);
306 assert_eq!(spec.parameters_schema, Some(expected));
307 }
308}