swiftide_agents/tools/
control.rs

1//! Control tools manage control flow during agent's lifecycle.
2use anyhow::Result;
3use async_trait::async_trait;
4use schemars::{Schema, schema_for};
5use std::borrow::Cow;
6use swiftide_core::{
7    AgentContext, ToolFeedback,
8    chat_completion::{Tool, ToolCall, ToolOutput, ToolSpec, errors::ToolError},
9};
10
11/// `Stop` tool is a default tool used by agents to stop
12#[derive(Clone, Debug, Default)]
13pub struct Stop {}
14
15#[async_trait]
16impl Tool for Stop {
17    async fn invoke(
18        &self,
19        _agent_context: &dyn AgentContext,
20        _tool_call: &ToolCall,
21    ) -> Result<ToolOutput, ToolError> {
22        Ok(ToolOutput::stop())
23    }
24
25    fn name(&self) -> Cow<'_, str> {
26        "stop".into()
27    }
28
29    fn tool_spec(&self) -> ToolSpec {
30        ToolSpec::builder()
31            .name("stop")
32            .description("When you have completed, or cannot complete, your task, call this")
33            .build()
34            .unwrap()
35    }
36}
37
38impl From<Stop> for Box<dyn Tool> {
39    fn from(val: Stop) -> Self {
40        Box::new(val)
41    }
42}
43
44/// `StopWithArgs` is an alternative stop tool that takes arguments
45#[derive(Clone, Debug)]
46pub struct StopWithArgs {
47    parameters_schema: Option<Schema>,
48    expects_output_field: bool,
49}
50
51impl Default for StopWithArgs {
52    fn default() -> Self {
53        Self {
54            parameters_schema: Some(schema_for!(DefaultStopWithArgsSpec)),
55            expects_output_field: true,
56        }
57    }
58}
59
60impl StopWithArgs {
61    /// Create a new `StopWithArgs` tool with a custom parameters schema.
62    ///
63    /// When providing a custom schema the full argument payload will be forwarded to the
64    /// stop output without requiring an `output` field wrapper.
65    pub fn with_parameters_schema(schema: Schema) -> Self {
66        Self {
67            parameters_schema: Some(schema),
68            expects_output_field: false,
69        }
70    }
71
72    fn parameters_schema(&self) -> Schema {
73        self.parameters_schema
74            .clone()
75            .unwrap_or_else(|| schema_for!(DefaultStopWithArgsSpec))
76    }
77}
78
79#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
80struct DefaultStopWithArgsSpec {
81    pub output: String,
82}
83
84#[async_trait]
85impl Tool for StopWithArgs {
86    async fn invoke(
87        &self,
88        _agent_context: &dyn AgentContext,
89        tool_call: &ToolCall,
90    ) -> Result<ToolOutput, ToolError> {
91        let raw_args = tool_call
92            .args()
93            .ok_or_else(|| ToolError::missing_arguments("arguments"))?;
94
95        let json: serde_json::Value = serde_json::from_str(raw_args)?;
96
97        let output = if self.expects_output_field {
98            json.get("output")
99                .cloned()
100                .ok_or_else(|| ToolError::missing_arguments("output"))?
101        } else {
102            json
103        };
104
105        Ok(ToolOutput::stop_with_args(output))
106    }
107
108    fn name(&self) -> Cow<'_, str> {
109        "stop".into()
110    }
111
112    fn tool_spec(&self) -> ToolSpec {
113        let schema = self.parameters_schema();
114
115        ToolSpec::builder()
116            .name("stop")
117            .description("When you have completed, your task, call this with your expected output")
118            .parameters_schema(schema)
119            .build()
120            .unwrap()
121    }
122}
123
124impl From<StopWithArgs> for Box<dyn Tool> {
125    fn from(val: StopWithArgs) -> Self {
126        Box::new(val)
127    }
128}
129
130#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
131struct AgentFailedArgsSpec {
132    pub reason: String,
133}
134
135/// A utility tool that can be used to let an agent decide it failed
136///
137/// This will _NOT_ have the agent return an error, instead, look at the stop reason of the agent.
138#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, Default)]
139pub struct AgentCanFail {}
140
141#[async_trait]
142impl Tool for AgentCanFail {
143    async fn invoke(
144        &self,
145        _agent_context: &dyn AgentContext,
146        tool_call: &ToolCall,
147    ) -> Result<ToolOutput, ToolError> {
148        let args: AgentFailedArgsSpec = serde_json::from_str(
149            tool_call
150                .args()
151                .ok_or(ToolError::missing_arguments("reason"))?,
152        )?;
153
154        Ok(ToolOutput::agent_failed(args.reason))
155    }
156
157    fn name(&self) -> Cow<'_, str> {
158        "task_failed".into()
159    }
160
161    fn tool_spec(&self) -> ToolSpec {
162        ToolSpec::builder()
163            .name("task_failed")
164            .description("If you cannot complete your task, or have otherwise failed, call this with your reason for failure")
165            .parameters_schema(schema_for!(AgentFailedArgsSpec))
166            .build()
167            .unwrap()
168    }
169}
170
171impl From<AgentCanFail> for Box<dyn Tool> {
172    fn from(val: AgentCanFail) -> Self {
173        Box::new(val)
174    }
175}
176
177#[derive(Clone)]
178/// Wraps a tool and requires approval before it can be used
179pub struct ApprovalRequired(pub Box<dyn Tool>);
180
181impl ApprovalRequired {
182    /// Creates a new `ApprovalRequired` tool
183    pub fn new(tool: impl Tool + 'static) -> Self {
184        Self(Box::new(tool))
185    }
186}
187
188#[async_trait]
189impl Tool for ApprovalRequired {
190    async fn invoke(
191        &self,
192        context: &dyn AgentContext,
193        tool_call: &ToolCall,
194    ) -> Result<ToolOutput, ToolError> {
195        if let Some(feedback) = context.has_received_feedback(tool_call).await {
196            match feedback {
197                ToolFeedback::Approved { .. } => return self.0.invoke(context, tool_call).await,
198                ToolFeedback::Refused { .. } => {
199                    return Ok(ToolOutput::text("This tool call was refused"));
200                }
201            }
202        }
203
204        Ok(ToolOutput::FeedbackRequired(None))
205    }
206
207    fn name(&self) -> Cow<'_, str> {
208        self.0.name()
209    }
210
211    fn tool_spec(&self) -> ToolSpec {
212        self.0.tool_spec()
213    }
214}
215
216impl From<ApprovalRequired> for Box<dyn Tool> {
217    fn from(val: ApprovalRequired) -> Self {
218        Box::new(val)
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use schemars::schema_for;
226    use serde_json::json;
227
228    fn dummy_tool_call(name: &str, args: Option<&str>) -> ToolCall {
229        let mut builder = ToolCall::builder().name(name).id("1").to_owned();
230        if let Some(args) = args {
231            builder.args(args.to_string());
232        }
233        builder.build().unwrap()
234    }
235
236    #[tokio::test]
237    async fn test_stop_tool() {
238        let stop = Stop::default();
239        let ctx = ();
240        let tool_call = dummy_tool_call("stop", None);
241        let out = stop.invoke(&ctx, &tool_call).await.unwrap();
242        assert_eq!(out, ToolOutput::stop());
243    }
244
245    #[tokio::test]
246    async fn test_stop_with_args_tool() {
247        let tool = StopWithArgs::default();
248        let ctx = ();
249        let args = r#"{"output":"expected result"}"#;
250        let tool_call = dummy_tool_call("stop", Some(args));
251        let out = tool.invoke(&ctx, &tool_call).await.unwrap();
252        assert_eq!(out, ToolOutput::stop_with_args("expected result"));
253    }
254
255    #[tokio::test]
256    async fn test_agent_can_fail_tool() {
257        let tool = AgentCanFail::default();
258        let ctx = ();
259        let args = r#"{"reason":"something went wrong"}"#;
260        let tool_call = dummy_tool_call("task_failed", Some(args));
261        let out = tool.invoke(&ctx, &tool_call).await.unwrap();
262        assert_eq!(out, ToolOutput::agent_failed("something went wrong"));
263    }
264
265    #[tokio::test]
266    async fn test_approval_required_feedback_required() {
267        let stop = Stop::default();
268        let tool = ApprovalRequired::new(stop);
269        let ctx = ();
270        let tool_call = dummy_tool_call("stop", None);
271        let out = tool.invoke(&ctx, &tool_call).await.unwrap();
272
273        // On unit; existing feedback is always present
274        assert_eq!(out, ToolOutput::Stop(None));
275    }
276
277    #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
278    struct CustomStopArgs {
279        value: i32,
280    }
281
282    #[test]
283    fn test_stop_with_args_custom_schema_in_spec() {
284        let schema = schema_for!(CustomStopArgs);
285        let tool = StopWithArgs::with_parameters_schema(schema.clone());
286        let spec = tool.tool_spec();
287        assert_eq!(spec.parameters_schema, Some(schema));
288    }
289
290    #[tokio::test]
291    async fn test_stop_with_args_custom_schema_forwards_payload() {
292        let schema = schema_for!(CustomStopArgs);
293        let tool = StopWithArgs::with_parameters_schema(schema);
294        let ctx = ();
295        let args = r#"{"value":42}"#;
296        let tool_call = dummy_tool_call("stop", Some(args));
297        let out = tool.invoke(&ctx, &tool_call).await.unwrap();
298        assert_eq!(out, ToolOutput::stop_with_args(json!({"value": 42})));
299    }
300
301    #[test]
302    fn test_stop_with_args_default_schema_matches_previous() {
303        let tool = StopWithArgs::default();
304        let spec = tool.tool_spec();
305        let expected = schema_for!(DefaultStopWithArgsSpec);
306        assert_eq!(spec.parameters_schema, Some(expected));
307    }
308}