swiftide_core/chat_completion/
tools.rs

1use std::borrow::Cow;
2
3use derive_builder::Builder;
4use schemars::Schema;
5use serde::{Deserialize, Serialize};
6
7/// Output of a `ToolCall` which will be added as a message for the agent to use.
8#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, strum_macros::EnumIs)]
9#[non_exhaustive]
10pub enum ToolOutput {
11    /// Adds the result of the toolcall to messages
12    Text(String),
13
14    /// Indicates that the toolcall requires feedback, i.e. in a human-in-the-loop
15    FeedbackRequired(Option<serde_json::Value>),
16
17    /// Indicates that the toolcall failed, but can be handled by the llm
18    Fail(String),
19
20    /// Stops an agent with an optional message
21    Stop(Option<serde_json::Value>),
22
23    /// Indicates that the agent failed and should stop
24    AgentFailed(Option<Cow<'static, str>>),
25}
26
27impl ToolOutput {
28    pub fn text(text: impl Into<String>) -> Self {
29        ToolOutput::Text(text.into())
30    }
31
32    pub fn feedback_required(feedback: Option<serde_json::Value>) -> Self {
33        ToolOutput::FeedbackRequired(feedback)
34    }
35
36    pub fn stop() -> Self {
37        ToolOutput::Stop(None)
38    }
39
40    pub fn stop_with_args(output: impl Into<serde_json::Value>) -> Self {
41        ToolOutput::Stop(Some(output.into()))
42    }
43
44    pub fn agent_failed(output: impl Into<Cow<'static, str>>) -> Self {
45        ToolOutput::AgentFailed(Some(output.into()))
46    }
47
48    pub fn fail(text: impl Into<String>) -> Self {
49        ToolOutput::Fail(text.into())
50    }
51
52    pub fn content(&self) -> Option<&str> {
53        match self {
54            ToolOutput::Fail(s) | ToolOutput::Text(s) => Some(s),
55            _ => None,
56        }
57    }
58
59    /// Get the inner text if the output is a `Text` variant.
60    pub fn as_text(&self) -> Option<&str> {
61        match self {
62            ToolOutput::Text(s) => Some(s),
63            _ => None,
64        }
65    }
66
67    /// Get the inner text if the output is a `Fail` variant.
68    pub fn as_fail(&self) -> Option<&str> {
69        match self {
70            ToolOutput::Fail(s) => Some(s),
71            _ => None,
72        }
73    }
74
75    /// Get the inner text if the output is a `Stop` variant.
76    pub fn as_stop(&self) -> Option<&serde_json::Value> {
77        match self {
78            ToolOutput::Stop(args) => args.as_ref(),
79            _ => None,
80        }
81    }
82
83    /// Get the inner text if the output is an `AgentFailed` variant.
84    pub fn as_agent_failed(&self) -> Option<&str> {
85        match self {
86            ToolOutput::AgentFailed(args) => args.as_deref(),
87            _ => None,
88        }
89    }
90
91    /// Get the inner feedback if the output is a `FeedbackRequired` variant.
92    pub fn as_feedback_required(&self) -> Option<&serde_json::Value> {
93        match self {
94            ToolOutput::FeedbackRequired(args) => args.as_ref(),
95            _ => None,
96        }
97    }
98}
99
100impl<S: AsRef<str>> From<S> for ToolOutput {
101    fn from(value: S) -> Self {
102        ToolOutput::Text(value.as_ref().to_string())
103    }
104}
105impl std::fmt::Display for ToolOutput {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        match self {
108            ToolOutput::Text(value) => write!(f, "{value}"),
109            ToolOutput::Fail(value) => write!(f, "Tool call failed: {value}"),
110            ToolOutput::Stop(args) => {
111                if let Some(value) = args {
112                    write!(f, "Stop {value}")
113                } else {
114                    write!(f, "Stop")
115                }
116            }
117            ToolOutput::FeedbackRequired(_) => {
118                write!(f, "Feedback required")
119            }
120            ToolOutput::AgentFailed(args) => write!(
121                f,
122                "Agent failed with output: {}",
123                args.as_deref().unwrap_or_default()
124            ),
125        }
126    }
127}
128
129/// A tool call that can be executed by the executor
130#[derive(Clone, Debug, Builder, PartialEq, Serialize, Deserialize, Eq)]
131#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
132#[builder(setter(into, strip_option))]
133pub struct ToolCall {
134    id: String,
135    name: String,
136    #[builder(default)]
137    args: Option<String>,
138}
139
140/// Hash is used for finding tool calls that have been retried by agents
141impl std::hash::Hash for ToolCall {
142    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
143        self.name.hash(state);
144        self.args.hash(state);
145    }
146}
147
148impl std::fmt::Display for ToolCall {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        write!(
151            f,
152            "{id}#{name} {args}",
153            id = self.id,
154            name = self.name,
155            args = self.args.as_deref().unwrap_or("")
156        )
157    }
158}
159
160impl ToolCall {
161    pub fn builder() -> ToolCallBuilder {
162        ToolCallBuilder::default()
163    }
164
165    pub fn id(&self) -> &str {
166        &self.id
167    }
168
169    pub fn name(&self) -> &str {
170        &self.name
171    }
172
173    pub fn args(&self) -> Option<&str> {
174        self.args.as_deref()
175    }
176
177    pub fn with_args(&mut self, args: Option<String>) {
178        self.args = args;
179    }
180}
181
182impl ToolCallBuilder {
183    pub fn maybe_args<T: Into<Option<String>>>(&mut self, args: T) -> &mut Self {
184        self.args = Some(args.into());
185        self
186    }
187
188    pub fn maybe_id<T: Into<Option<String>>>(&mut self, id: T) -> &mut Self {
189        self.id = id.into();
190        self
191    }
192
193    pub fn maybe_name<T: Into<Option<String>>>(&mut self, name: T) -> &mut Self {
194        self.name = name.into();
195        self
196    }
197}
198
199/// A typed tool specification intended to be usable for multiple LLMs
200///
201/// i.e. the json spec `OpenAI` uses to define their tools
202#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Builder, Default)]
203#[builder(setter(into), derive(Debug, Serialize, Deserialize))]
204#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
205#[serde(deny_unknown_fields)]
206pub struct ToolSpec {
207    /// Name of the tool
208    pub name: String,
209    /// Description passed to the LLM for the tool
210    pub description: String,
211
212    #[builder(default, setter(strip_option))]
213    #[serde(skip_serializing_if = "Option::is_none")]
214    /// Optional JSON schema describing the tool arguments
215    pub parameters_schema: Option<Schema>,
216}
217
218impl ToolSpec {
219    pub fn builder() -> ToolSpecBuilder {
220        ToolSpecBuilder::default()
221    }
222}
223
224impl Eq for ToolSpec {}
225
226impl std::hash::Hash for ToolSpec {
227    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
228        self.name.hash(state);
229        self.description.hash(state);
230        if let Some(schema) = &self.parameters_schema
231            && let Ok(serialized) = serde_json::to_vec(schema)
232        {
233            serialized.hash(state);
234        }
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use std::collections::HashSet;
242
243    #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
244    struct ExampleArgs {
245        value: String,
246    }
247
248    #[test]
249    fn tool_spec_serializes_schema() {
250        let schema = schemars::schema_for!(ExampleArgs);
251
252        let spec = ToolSpec::builder()
253            .name("example")
254            .description("An example tool")
255            .parameters_schema(schema)
256            .build()
257            .unwrap();
258
259        let json = serde_json::to_value(&spec).unwrap();
260        assert_eq!(json.get("name").and_then(|v| v.as_str()), Some("example"));
261        assert!(json.get("parameters_schema").is_some());
262    }
263
264    #[test]
265    fn tool_spec_is_hashable() {
266        let schema = schemars::schema_for!(ExampleArgs);
267        let spec = ToolSpec::builder()
268            .name("example")
269            .description("An example tool")
270            .parameters_schema(schema)
271            .build()
272            .unwrap();
273
274        let mut set = HashSet::new();
275        set.insert(spec.clone());
276
277        assert!(set.contains(&spec));
278    }
279}