1use anyhow::Result;
3use async_trait::async_trait;
4use schemars::{Schema, schema_for};
5use std::borrow::Cow;
6use swiftide_core::{
7 AgentContext, ToolFeedback,
8 chat_completion::{Tool, ToolCall, ToolOutput, ToolSpec, errors::ToolError},
9};
10
11#[derive(Clone, Debug, Default)]
13pub struct Stop {}
14
15#[async_trait]
16impl Tool for Stop {
17 async fn invoke(
18 &self,
19 _agent_context: &dyn AgentContext,
20 _tool_call: &ToolCall,
21 ) -> Result<ToolOutput, ToolError> {
22 Ok(ToolOutput::stop())
23 }
24
25 fn name(&self) -> Cow<'_, str> {
26 "stop".into()
27 }
28
29 fn tool_spec(&self) -> ToolSpec {
30 ToolSpec::builder()
31 .name("stop")
32 .description("When you have completed, or cannot complete, your task, call this")
33 .build()
34 .unwrap()
35 }
36}
37
38impl From<Stop> for Box<dyn Tool> {
39 fn from(val: Stop) -> Self {
40 Box::new(val)
41 }
42}
43
44#[derive(Clone, Debug)]
46pub struct StopWithArgs {
47 parameters_schema: Option<Schema>,
48 expects_output_field: bool,
49}
50
51impl Default for StopWithArgs {
52 fn default() -> Self {
53 Self {
54 parameters_schema: Some(schema_for!(DefaultStopWithArgsSpec)),
55 expects_output_field: true,
56 }
57 }
58}
59
60impl StopWithArgs {
61 pub fn with_parameters_schema(schema: Schema) -> Self {
66 Self {
67 parameters_schema: Some(schema),
68 expects_output_field: false,
69 }
70 }
71
72 fn parameters_schema(&self) -> Schema {
73 self.parameters_schema
74 .clone()
75 .unwrap_or_else(|| schema_for!(DefaultStopWithArgsSpec))
76 }
77}
78
79#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
80struct DefaultStopWithArgsSpec {
81 pub output: String,
82}
83
84#[async_trait]
85impl Tool for StopWithArgs {
86 async fn invoke(
87 &self,
88 _agent_context: &dyn AgentContext,
89 tool_call: &ToolCall,
90 ) -> Result<ToolOutput, ToolError> {
91 let raw_args = tool_call
92 .args()
93 .ok_or_else(|| ToolError::missing_arguments("arguments"))?;
94
95 let json: serde_json::Value = serde_json::from_str(raw_args)?;
96
97 let output = if self.expects_output_field {
98 json.get("output")
99 .cloned()
100 .ok_or_else(|| ToolError::missing_arguments("output"))?
101 } else {
102 json
103 };
104
105 Ok(ToolOutput::stop_with_args(output))
106 }
107
108 fn name(&self) -> Cow<'_, str> {
109 "stop".into()
110 }
111
112 fn tool_spec(&self) -> ToolSpec {
113 let schema = self.parameters_schema();
114
115 ToolSpec::builder()
116 .name("stop")
117 .description("When you have completed, your task, call this with your expected output")
118 .parameters_schema(schema)
119 .build()
120 .unwrap()
121 }
122}
123
124impl From<StopWithArgs> for Box<dyn Tool> {
125 fn from(val: StopWithArgs) -> Self {
126 Box::new(val)
127 }
128}
129
130#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
131struct AgentFailedArgsSpec {
132 pub reason: String,
133}
134
135#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
139pub struct AgentCanFail {
140 parameters_schema: Option<Schema>,
141 expects_reason_field: bool,
142}
143
144impl Default for AgentCanFail {
145 fn default() -> Self {
146 Self {
147 parameters_schema: Some(schema_for!(AgentFailedArgsSpec)),
148 expects_reason_field: true,
149 }
150 }
151}
152
153impl AgentCanFail {
154 pub fn with_parameters_schema(schema: Schema) -> Self {
159 Self {
160 parameters_schema: Some(schema),
161 expects_reason_field: false,
162 }
163 }
164
165 fn parameters_schema(&self) -> Schema {
166 self.parameters_schema
167 .clone()
168 .unwrap_or_else(|| schema_for!(AgentFailedArgsSpec))
169 }
170}
171
172#[async_trait]
173impl Tool for AgentCanFail {
174 async fn invoke(
175 &self,
176 _agent_context: &dyn AgentContext,
177 tool_call: &ToolCall,
178 ) -> Result<ToolOutput, ToolError> {
179 let raw_args = tool_call.args().ok_or_else(|| {
180 if self.expects_reason_field {
181 ToolError::missing_arguments("reason")
182 } else {
183 ToolError::missing_arguments("arguments")
184 }
185 })?;
186
187 let reason = if self.expects_reason_field {
188 let args: AgentFailedArgsSpec = serde_json::from_str(raw_args)?;
189 args.reason
190 } else {
191 let json: serde_json::Value = serde_json::from_str(raw_args)?;
192 json.to_string()
193 };
194
195 Ok(ToolOutput::agent_failed(reason))
196 }
197
198 fn name(&self) -> Cow<'_, str> {
199 "task_failed".into()
200 }
201
202 fn tool_spec(&self) -> ToolSpec {
203 let schema = self.parameters_schema();
204
205 ToolSpec::builder()
206 .name("task_failed")
207 .description("If you cannot complete your task, or have otherwise failed, call this with your reason for failure")
208 .parameters_schema(schema)
209 .build()
210 .unwrap()
211 }
212}
213
214impl From<AgentCanFail> for Box<dyn Tool> {
215 fn from(val: AgentCanFail) -> Self {
216 Box::new(val)
217 }
218}
219
220#[derive(Clone)]
221pub struct ApprovalRequired(pub Box<dyn Tool>);
223
224impl ApprovalRequired {
225 pub fn new(tool: impl Tool + 'static) -> Self {
227 Self(Box::new(tool))
228 }
229}
230
231#[async_trait]
232impl Tool for ApprovalRequired {
233 async fn invoke(
234 &self,
235 context: &dyn AgentContext,
236 tool_call: &ToolCall,
237 ) -> Result<ToolOutput, ToolError> {
238 if let Some(feedback) = context.has_received_feedback(tool_call).await {
239 match feedback {
240 ToolFeedback::Approved { .. } => return self.0.invoke(context, tool_call).await,
241 ToolFeedback::Refused { .. } => {
242 return Ok(ToolOutput::text("This tool call was refused"));
243 }
244 }
245 }
246
247 Ok(ToolOutput::FeedbackRequired(None))
248 }
249
250 fn name(&self) -> Cow<'_, str> {
251 self.0.name()
252 }
253
254 fn tool_spec(&self) -> ToolSpec {
255 self.0.tool_spec()
256 }
257}
258
259impl From<ApprovalRequired> for Box<dyn Tool> {
260 fn from(val: ApprovalRequired) -> Self {
261 Box::new(val)
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use schemars::schema_for;
269 use serde_json::json;
270
271 fn dummy_tool_call(name: &str, args: Option<&str>) -> ToolCall {
272 let mut builder = ToolCall::builder().name(name).id("1").to_owned();
273 if let Some(args) = args {
274 builder.args(args.to_string());
275 }
276 builder.build().unwrap()
277 }
278
279 #[tokio::test]
280 async fn test_stop_tool() {
281 let stop = Stop::default();
282 let ctx = ();
283 let tool_call = dummy_tool_call("stop", None);
284 let out = stop.invoke(&ctx, &tool_call).await.unwrap();
285 assert_eq!(out, ToolOutput::stop());
286 }
287
288 #[tokio::test]
289 async fn test_stop_with_args_tool() {
290 let tool = StopWithArgs::default();
291 let ctx = ();
292 let args = r#"{"output":"expected result"}"#;
293 let tool_call = dummy_tool_call("stop", Some(args));
294 let out = tool.invoke(&ctx, &tool_call).await.unwrap();
295 assert_eq!(out, ToolOutput::stop_with_args("expected result"));
296 }
297
298 #[tokio::test]
299 async fn test_agent_can_fail_tool() {
300 let tool = AgentCanFail::default();
301 let ctx = ();
302 let args = r#"{"reason":"something went wrong"}"#;
303 let tool_call = dummy_tool_call("task_failed", Some(args));
304 let out = tool.invoke(&ctx, &tool_call).await.unwrap();
305 assert_eq!(out, ToolOutput::agent_failed("something went wrong"));
306 }
307
308 #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
309 struct CustomFailArgs {
310 code: i32,
311 message: String,
312 }
313
314 #[test]
315 fn test_agent_can_fail_custom_schema_in_spec() {
316 let schema = schema_for!(CustomFailArgs);
317 let tool = AgentCanFail::with_parameters_schema(schema.clone());
318 let spec = tool.tool_spec();
319 assert_eq!(spec.parameters_schema, Some(schema));
320 }
321
322 #[tokio::test]
323 async fn test_agent_can_fail_custom_schema_forwards_payload() {
324 let schema = schema_for!(CustomFailArgs);
325 let tool = AgentCanFail::with_parameters_schema(schema);
326 let ctx = ();
327 let args = r#"{"code":7,"message":"error"}"#;
328 let tool_call = dummy_tool_call("task_failed", Some(args));
329 let out = tool.invoke(&ctx, &tool_call).await.unwrap();
330 assert_eq!(
331 out,
332 ToolOutput::agent_failed(json!({"code":7,"message":"error"}).to_string())
333 );
334 }
335
336 #[test]
337 fn test_agent_can_fail_default_schema_matches_previous() {
338 let tool = AgentCanFail::default();
339 let spec = tool.tool_spec();
340 let expected = schema_for!(AgentFailedArgsSpec);
341 assert_eq!(spec.parameters_schema, Some(expected));
342 }
343
344 #[tokio::test]
345 async fn test_approval_required_feedback_required() {
346 let stop = Stop::default();
347 let tool = ApprovalRequired::new(stop);
348 let ctx = ();
349 let tool_call = dummy_tool_call("stop", None);
350 let out = tool.invoke(&ctx, &tool_call).await.unwrap();
351
352 assert_eq!(out, ToolOutput::Stop(None));
354 }
355
356 #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
357 struct CustomStopArgs {
358 value: i32,
359 }
360
361 #[test]
362 fn test_stop_with_args_custom_schema_in_spec() {
363 let schema = schema_for!(CustomStopArgs);
364 let tool = StopWithArgs::with_parameters_schema(schema.clone());
365 let spec = tool.tool_spec();
366 assert_eq!(spec.parameters_schema, Some(schema));
367 }
368
369 #[tokio::test]
370 async fn test_stop_with_args_custom_schema_forwards_payload() {
371 let schema = schema_for!(CustomStopArgs);
372 let tool = StopWithArgs::with_parameters_schema(schema);
373 let ctx = ();
374 let args = r#"{"value":42}"#;
375 let tool_call = dummy_tool_call("stop", Some(args));
376 let out = tool.invoke(&ctx, &tool_call).await.unwrap();
377 assert_eq!(out, ToolOutput::stop_with_args(json!({"value": 42})));
378 }
379
380 #[test]
381 fn test_stop_with_args_default_schema_matches_previous() {
382 let tool = StopWithArgs::default();
383 let spec = tool.tool_spec();
384 let expected = schema_for!(DefaultStopWithArgsSpec);
385 assert_eq!(spec.parameters_schema, Some(expected));
386 }
387}