swiftide_agents/tools/
control.rs

1//! Control tools manage control flow during agent's lifecycle.
2use anyhow::Result;
3use async_trait::async_trait;
4use std::borrow::Cow;
5use swiftide_core::{
6    AgentContext, ToolFeedback,
7    chat_completion::{
8        ParamSpec, ParamType, Tool, ToolCall, ToolOutput, ToolSpec, errors::ToolError,
9    },
10};
11
12/// `Stop` tool is a default tool used by agents to stop
13#[derive(Clone, Debug, Default)]
14pub struct Stop {}
15
16#[async_trait]
17impl Tool for Stop {
18    async fn invoke(
19        &self,
20        _agent_context: &dyn AgentContext,
21        _tool_call: &ToolCall,
22    ) -> Result<ToolOutput, ToolError> {
23        Ok(ToolOutput::stop())
24    }
25
26    fn name(&self) -> Cow<'_, str> {
27        "stop".into()
28    }
29
30    fn tool_spec(&self) -> ToolSpec {
31        ToolSpec::builder()
32            .name("stop")
33            .description("When you have completed, or cannot complete, your task, call this")
34            .build()
35            .unwrap()
36    }
37}
38
39impl From<Stop> for Box<dyn Tool> {
40    fn from(val: Stop) -> Self {
41        Box::new(val)
42    }
43}
44
45/// `StopWithArgs` is an alternative stop tool that takes arguments
46#[derive(Clone, Debug, Default)]
47pub struct StopWithArgs {}
48
49#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
50struct StopWithArgsSpec {
51    pub output: String,
52}
53
54#[async_trait]
55impl Tool for StopWithArgs {
56    async fn invoke(
57        &self,
58        _agent_context: &dyn AgentContext,
59        tool_call: &ToolCall,
60    ) -> Result<ToolOutput, ToolError> {
61        let args: StopWithArgsSpec = serde_json::from_str(
62            tool_call
63                .args()
64                .ok_or(ToolError::missing_arguments("output"))?,
65        )?;
66
67        Ok(ToolOutput::stop_with_args(args.output))
68    }
69
70    fn name(&self) -> Cow<'_, str> {
71        "stop".into()
72    }
73
74    fn tool_spec(&self) -> ToolSpec {
75        ToolSpec::builder()
76            .name("stop")
77            .description("When you have completed, your task, call this with your expected output")
78            .parameters(vec![
79                ParamSpec::builder()
80                    .name("output")
81                    .description("The expected output of the task")
82                    .ty(ParamType::String)
83                    .required(true)
84                    .build()
85                    .unwrap(),
86            ])
87            .build()
88            .unwrap()
89    }
90}
91
92impl From<StopWithArgs> for Box<dyn Tool> {
93    fn from(val: StopWithArgs) -> Self {
94        Box::new(val)
95    }
96}
97
98#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
99struct AgentFailedArgsSpec {
100    pub reason: String,
101}
102
103/// A utility tool that can be used to let an agent decide it failed
104///
105/// This will _NOT_ have the agent return an error, instead, look at the stop reason of the agent.
106#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
107pub struct AgentCanFail {}
108
109#[async_trait]
110impl Tool for AgentCanFail {
111    async fn invoke(
112        &self,
113        _agent_context: &dyn AgentContext,
114        tool_call: &ToolCall,
115    ) -> Result<ToolOutput, ToolError> {
116        let args: StopWithArgsSpec = serde_json::from_str(
117            tool_call
118                .args()
119                .ok_or(ToolError::missing_arguments("reason"))?,
120        )?;
121
122        Ok(ToolOutput::agent_failed(args.output))
123    }
124
125    fn name(&self) -> Cow<'_, str> {
126        "task_failed".into()
127    }
128
129    fn tool_spec(&self) -> ToolSpec {
130        ToolSpec::builder()
131            .name("stop")
132            .description("If you cannot complete your task, or have otherwise failed, call this with your reason for failure")
133            .parameters(vec![
134                ParamSpec::builder()
135                    .name("reason")
136                    .description("The reason for failure")
137                    .ty(ParamType::String)
138                    .required(true)
139                    .build()
140                    .unwrap(),
141            ])
142            .build()
143            .unwrap()
144    }
145}
146
147impl From<AgentCanFail> for Box<dyn Tool> {
148    fn from(val: AgentCanFail) -> Self {
149        Box::new(val)
150    }
151}
152
153#[derive(Clone)]
154/// Wraps a tool and requires approval before it can be used
155pub struct ApprovalRequired(pub Box<dyn Tool>);
156
157impl ApprovalRequired {
158    /// Creates a new `ApprovalRequired` tool
159    pub fn new(tool: impl Tool + 'static) -> Self {
160        Self(Box::new(tool))
161    }
162}
163
164#[async_trait]
165impl Tool for ApprovalRequired {
166    async fn invoke(
167        &self,
168        context: &dyn AgentContext,
169        tool_call: &ToolCall,
170    ) -> Result<ToolOutput, ToolError> {
171        if let Some(feedback) = context.has_received_feedback(tool_call).await {
172            match feedback {
173                ToolFeedback::Approved { .. } => return self.0.invoke(context, tool_call).await,
174                ToolFeedback::Refused { .. } => {
175                    return Ok(ToolOutput::text("This tool call was refused"));
176                }
177            }
178        }
179
180        Ok(ToolOutput::FeedbackRequired(None))
181    }
182
183    fn name(&self) -> Cow<'_, str> {
184        self.0.name()
185    }
186
187    fn tool_spec(&self) -> ToolSpec {
188        self.0.tool_spec()
189    }
190}
191
192impl From<ApprovalRequired> for Box<dyn Tool> {
193    fn from(val: ApprovalRequired) -> Self {
194        Box::new(val)
195    }
196}