swiftide_core/chat_completion/
tools.rs

1use derive_builder::Builder;
2use schemars::Schema;
3use serde::{Deserialize, Serialize};
4
5/// Output of a `ToolCall` which will be added as a message for the agent to use.
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, strum_macros::EnumIs)]
7#[non_exhaustive]
8pub enum ToolOutput {
9    /// Adds the result of the toolcall to messages
10    Text(String),
11
12    /// Indicates that the toolcall requires feedback, i.e. in a human-in-the-loop
13    FeedbackRequired(Option<serde_json::Value>),
14
15    /// Indicates that the toolcall failed, but can be handled by the llm
16    Fail(String),
17
18    /// Stops an agent with an optional message
19    Stop(Option<serde_json::Value>),
20
21    /// Indicates that the agent failed and should stop
22    AgentFailed(Option<serde_json::Value>),
23}
24
25impl ToolOutput {
26    pub fn text(text: impl Into<String>) -> Self {
27        ToolOutput::Text(text.into())
28    }
29
30    pub fn feedback_required(feedback: Option<serde_json::Value>) -> Self {
31        ToolOutput::FeedbackRequired(feedback)
32    }
33
34    pub fn stop() -> Self {
35        ToolOutput::Stop(None)
36    }
37
38    pub fn stop_with_args(output: impl Into<serde_json::Value>) -> Self {
39        ToolOutput::Stop(Some(output.into()))
40    }
41
42    pub fn agent_failed(output: impl Into<serde_json::Value>) -> Self {
43        ToolOutput::AgentFailed(Some(output.into()))
44    }
45
46    pub fn fail(text: impl Into<String>) -> Self {
47        ToolOutput::Fail(text.into())
48    }
49
50    pub fn content(&self) -> Option<&str> {
51        match self {
52            ToolOutput::Fail(s) | ToolOutput::Text(s) => Some(s),
53            _ => None,
54        }
55    }
56
57    /// Get the inner text if the output is a `Text` variant.
58    pub fn as_text(&self) -> Option<&str> {
59        match self {
60            ToolOutput::Text(s) => Some(s),
61            _ => None,
62        }
63    }
64
65    /// Get the inner text if the output is a `Fail` variant.
66    pub fn as_fail(&self) -> Option<&str> {
67        match self {
68            ToolOutput::Fail(s) => Some(s),
69            _ => None,
70        }
71    }
72
73    /// Get the inner text if the output is a `Stop` variant.
74    pub fn as_stop(&self) -> Option<&serde_json::Value> {
75        match self {
76            ToolOutput::Stop(args) => args.as_ref(),
77            _ => None,
78        }
79    }
80
81    /// Get the inner text if the output is an `AgentFailed` variant.
82    pub fn as_agent_failed(&self) -> Option<&serde_json::Value> {
83        match self {
84            ToolOutput::AgentFailed(args) => args.as_ref(),
85            _ => None,
86        }
87    }
88
89    /// Get the inner feedback if the output is a `FeedbackRequired` variant.
90    pub fn as_feedback_required(&self) -> Option<&serde_json::Value> {
91        match self {
92            ToolOutput::FeedbackRequired(args) => args.as_ref(),
93            _ => None,
94        }
95    }
96}
97
98impl<S: AsRef<str>> From<S> for ToolOutput {
99    fn from(value: S) -> Self {
100        ToolOutput::Text(value.as_ref().to_string())
101    }
102}
103impl std::fmt::Display for ToolOutput {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        match self {
106            ToolOutput::Text(value) => write!(f, "{value}"),
107            ToolOutput::Fail(value) => write!(f, "Tool call failed: {value}"),
108            ToolOutput::Stop(args) => {
109                if let Some(value) = args {
110                    write!(f, "Stop {value}")
111                } else {
112                    write!(f, "Stop")
113                }
114            }
115            ToolOutput::FeedbackRequired(_) => {
116                write!(f, "Feedback required")
117            }
118            ToolOutput::AgentFailed(args) => write!(
119                f,
120                "Agent failed with output: {}",
121                args.as_ref().unwrap_or_default()
122            ),
123        }
124    }
125}
126
127/// A tool call that can be executed by the executor
128#[derive(Clone, Debug, Builder, PartialEq, Serialize, Deserialize, Eq)]
129#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
130#[builder(setter(into, strip_option))]
131pub struct ToolCall {
132    id: String,
133    name: String,
134    #[builder(default)]
135    args: Option<String>,
136}
137
138/// Hash is used for finding tool calls that have been retried by agents
139impl std::hash::Hash for ToolCall {
140    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
141        self.name.hash(state);
142        self.args.hash(state);
143    }
144}
145
146impl std::fmt::Display for ToolCall {
147    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148        write!(
149            f,
150            "{id}#{name} {args}",
151            id = self.id,
152            name = self.name,
153            args = self.args.as_deref().unwrap_or("")
154        )
155    }
156}
157
158impl ToolCall {
159    pub fn builder() -> ToolCallBuilder {
160        ToolCallBuilder::default()
161    }
162
163    pub fn id(&self) -> &str {
164        &self.id
165    }
166
167    pub fn name(&self) -> &str {
168        &self.name
169    }
170
171    pub fn args(&self) -> Option<&str> {
172        self.args.as_deref()
173    }
174
175    pub fn with_args(&mut self, args: Option<String>) {
176        self.args = args;
177    }
178}
179
180impl ToolCallBuilder {
181    pub fn maybe_args<T: Into<Option<String>>>(&mut self, args: T) -> &mut Self {
182        self.args = Some(args.into());
183        self
184    }
185
186    pub fn maybe_id<T: Into<Option<String>>>(&mut self, id: T) -> &mut Self {
187        self.id = id.into();
188        self
189    }
190
191    pub fn maybe_name<T: Into<Option<String>>>(&mut self, name: T) -> &mut Self {
192        self.name = name.into();
193        self
194    }
195}
196
197/// A typed tool specification intended to be usable for multiple LLMs
198///
199/// i.e. the json spec `OpenAI` uses to define their tools
200#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Builder, Default)]
201#[builder(setter(into), derive(Debug, Serialize, Deserialize))]
202#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
203#[serde(deny_unknown_fields)]
204pub struct ToolSpec {
205    /// Name of the tool
206    pub name: String,
207    /// Description passed to the LLM for the tool
208    pub description: String,
209
210    #[builder(default, setter(strip_option))]
211    #[serde(skip_serializing_if = "Option::is_none")]
212    /// Optional JSON schema describing the tool arguments
213    pub parameters_schema: Option<Schema>,
214}
215
216impl ToolSpec {
217    pub fn builder() -> ToolSpecBuilder {
218        ToolSpecBuilder::default()
219    }
220}
221
222impl Eq for ToolSpec {}
223
224impl std::hash::Hash for ToolSpec {
225    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
226        self.name.hash(state);
227        self.description.hash(state);
228        if let Some(schema) = &self.parameters_schema
229            && let Ok(serialized) = serde_json::to_vec(schema)
230        {
231            serialized.hash(state);
232        }
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use std::collections::HashSet;
240
241    #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
242    struct ExampleArgs {
243        value: String,
244    }
245
246    #[test]
247    fn tool_spec_serializes_schema() {
248        let schema = schemars::schema_for!(ExampleArgs);
249
250        let spec = ToolSpec::builder()
251            .name("example")
252            .description("An example tool")
253            .parameters_schema(schema)
254            .build()
255            .unwrap();
256
257        let json = serde_json::to_value(&spec).unwrap();
258        assert_eq!(json.get("name").and_then(|v| v.as_str()), Some("example"));
259        assert!(json.get("parameters_schema").is_some());
260    }
261
262    #[test]
263    fn tool_spec_is_hashable() {
264        let schema = schemars::schema_for!(ExampleArgs);
265        let spec = ToolSpec::builder()
266            .name("example")
267            .description("An example tool")
268            .parameters_schema(schema)
269            .build()
270            .unwrap();
271
272        let mut set = HashSet::new();
273        set.insert(spec.clone());
274
275        assert!(set.contains(&spec));
276    }
277}