swiftide_core/chat_completion/
tools.rs

1use std::borrow::Cow;
2use std::fmt;
3
4use derive_builder::Builder;
5use serde::de::{Deserializer, Error as DeError, SeqAccess, Unexpected, Visitor};
6use serde::ser::{Error as SerError, SerializeSeq, Serializer};
7use serde::{Deserialize, Serialize};
8
9/// Output of a `ToolCall` which will be added as a message for the agent to use.
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, strum_macros::EnumIs)]
11#[non_exhaustive]
12pub enum ToolOutput {
13    /// Adds the result of the toolcall to messages
14    Text(String),
15
16    /// Indicates that the toolcall requires feedback, i.e. in a human-in-the-loop
17    FeedbackRequired(Option<serde_json::Value>),
18
19    /// Indicates that the toolcall failed, but can be handled by the llm
20    Fail(String),
21
22    /// Stops an agent with an optional message
23    Stop(Option<Cow<'static, str>>),
24
25    /// Indicates that the agent failed and should stop
26    AgentFailed(Option<Cow<'static, str>>),
27}
28
29impl ToolOutput {
30    pub fn text(text: impl Into<String>) -> Self {
31        ToolOutput::Text(text.into())
32    }
33
34    pub fn feedback_required(feedback: Option<serde_json::Value>) -> Self {
35        ToolOutput::FeedbackRequired(feedback)
36    }
37
38    pub fn stop() -> Self {
39        ToolOutput::Stop(None)
40    }
41
42    pub fn stop_with_args(output: impl Into<Cow<'static, str>>) -> Self {
43        ToolOutput::Stop(Some(output.into()))
44    }
45
46    pub fn agent_failed(output: impl Into<Cow<'static, str>>) -> Self {
47        ToolOutput::AgentFailed(Some(output.into()))
48    }
49
50    pub fn fail(text: impl Into<String>) -> Self {
51        ToolOutput::Fail(text.into())
52    }
53
54    pub fn content(&self) -> Option<&str> {
55        match self {
56            ToolOutput::Fail(s) | ToolOutput::Text(s) => Some(s),
57            _ => None,
58        }
59    }
60}
61
62impl<S: AsRef<str>> From<S> for ToolOutput {
63    fn from(value: S) -> Self {
64        ToolOutput::Text(value.as_ref().to_string())
65    }
66}
67impl std::fmt::Display for ToolOutput {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        match self {
70            ToolOutput::Text(value) => write!(f, "{value}"),
71            ToolOutput::Fail(value) => write!(f, "Tool call failed: {value}"),
72            ToolOutput::Stop(args) => write!(f, "Stop {}", args.as_deref().unwrap_or_default()),
73            ToolOutput::FeedbackRequired(_) => {
74                write!(f, "Feedback required")
75            }
76            ToolOutput::AgentFailed(args) => write!(
77                f,
78                "Agent failed with output: {}",
79                args.as_deref().unwrap_or_default()
80            ),
81        }
82    }
83}
84
85/// A tool call that can be executed by the executor
86#[derive(Clone, Debug, Builder, PartialEq, Serialize, Deserialize, Eq)]
87#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
88#[builder(setter(into, strip_option))]
89pub struct ToolCall {
90    id: String,
91    name: String,
92    #[builder(default)]
93    args: Option<String>,
94}
95
96/// Hash is used for finding tool calls that have been retried by agents
97impl std::hash::Hash for ToolCall {
98    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
99        self.name.hash(state);
100        self.args.hash(state);
101    }
102}
103
104impl std::fmt::Display for ToolCall {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        write!(
107            f,
108            "{id}#{name} {args}",
109            id = self.id,
110            name = self.name,
111            args = self.args.as_deref().unwrap_or("")
112        )
113    }
114}
115
116impl ToolCall {
117    pub fn builder() -> ToolCallBuilder {
118        ToolCallBuilder::default()
119    }
120
121    pub fn id(&self) -> &str {
122        &self.id
123    }
124
125    pub fn name(&self) -> &str {
126        &self.name
127    }
128
129    pub fn args(&self) -> Option<&str> {
130        self.args.as_deref()
131    }
132
133    pub fn with_args(&mut self, args: Option<String>) {
134        self.args = args;
135    }
136}
137
138impl ToolCallBuilder {
139    pub fn maybe_args<T: Into<Option<String>>>(&mut self, args: T) -> &mut Self {
140        self.args = Some(args.into());
141        self
142    }
143
144    pub fn maybe_id<T: Into<Option<String>>>(&mut self, id: T) -> &mut Self {
145        self.id = id.into();
146        self
147    }
148
149    pub fn maybe_name<T: Into<Option<String>>>(&mut self, name: T) -> &mut Self {
150        self.name = name.into();
151        self
152    }
153}
154
155/// A typed tool specification intended to be usable for multiple LLMs
156///
157/// i.e. the json spec `OpenAI` uses to define their tools
158#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, Builder, Serialize, Deserialize)]
159#[builder(setter(into), derive(Debug, Serialize, Deserialize))]
160#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
161pub struct ToolSpec {
162    /// Name of the tool
163    pub name: String,
164    /// Description passed to the LLM for the tool
165    pub description: String,
166
167    #[builder(default)]
168    /// Optional parameters for the tool
169    pub parameters: Vec<ParamSpec>,
170}
171
172impl ToolSpec {
173    pub fn builder() -> ToolSpecBuilder {
174        ToolSpecBuilder::default()
175    }
176}
177
178#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, strum_macros::AsRefStr)]
179#[strum(serialize_all = "camelCase")]
180#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
181pub enum ParamType {
182    #[default]
183    String,
184    Number,
185    Boolean,
186    Array,
187    Nullable(Box<ParamType>),
188}
189
190pub enum InnerParamType {
191    String,
192    Number,
193    Boolean,
194    Array,
195}
196
197impl Serialize for ParamType {
198    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
199    where
200        S: Serializer,
201    {
202        match self {
203            // Non-nullable => single string
204            ParamType::String => serializer.serialize_str("string"),
205            ParamType::Number => serializer.serialize_str("number"),
206            ParamType::Boolean => serializer.serialize_str("boolean"),
207            ParamType::Array => serializer.serialize_str("array"),
208
209            // Nullable => an array of exactly two items, e.g. ["string", "null"]
210            ParamType::Nullable(inner) => {
211                // If you want to forbid nested nullables:
212                if let ParamType::Nullable(_) = inner.as_ref() {
213                    return Err(serde::ser::Error::custom("Nested Nullable not supported"));
214                }
215                // Otherwise, produce an array like `["string", "null"]`.
216                let mut seq = serializer.serialize_seq(Some(2))?;
217                seq.serialize_element(&primitive_variant_str(inner).map_err(S::Error::custom)?)?;
218                seq.serialize_element("null")?;
219                seq.end()
220            }
221        }
222    }
223}
224
225fn primitive_variant_str(pt: &ParamType) -> Result<&'static str, &'static str> {
226    match pt {
227        ParamType::String => Ok("string"),
228        ParamType::Number => Ok("number"),
229        ParamType::Boolean => Ok("boolean"),
230        ParamType::Array => Ok("array"),
231        ParamType::Nullable(_) => Err("Nested Nullable found"),
232    }
233}
234
235impl<'de> Deserialize<'de> for ParamType {
236    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
237    where
238        D: Deserializer<'de>,
239    {
240        deserializer.deserialize_any(ParamTypeVisitor)
241    }
242}
243
244struct ParamTypeVisitor;
245
246impl<'de> Visitor<'de> for ParamTypeVisitor {
247    type Value = ParamType;
248
249    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
250        write!(
251            formatter,
252            "a string (e.g. \"string\", \"number\") \
253             or a 2-element array [<type>, \"null\"]"
254        )
255    }
256
257    // Single strings => simple ParamType
258    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
259    where
260        E: DeError,
261    {
262        match value {
263            "string" => Ok(ParamType::String),
264            "number" => Ok(ParamType::Number),
265            "boolean" => Ok(ParamType::Boolean),
266            "array" => Ok(ParamType::Array),
267            other => Err(E::unknown_variant(
268                other,
269                &["string", "number", "boolean", "array"],
270            )),
271        }
272    }
273
274    // Arrays => expect exactly 2 items, one must be "null"
275    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
276    where
277        A: SeqAccess<'de>,
278    {
279        let mut items = Vec::new();
280        while let Some(item) = seq.next_element::<String>()? {
281            items.push(item);
282        }
283
284        // Must have exactly 2 elements
285        if items.len() != 2 {
286            return Err(A::Error::invalid_length(items.len(), &"2"));
287        }
288
289        let mut first = items[0].as_str();
290        let mut second = items[1].as_str();
291
292        // If the first is "null", swap so second is "null" and first is the real type
293        if first == "null" {
294            std::mem::swap(&mut first, &mut second);
295        }
296
297        // Now 'second' must be "null".
298        if second != "null" {
299            return Err(A::Error::invalid_value(
300                Unexpected::Str(second),
301                &"expected exactly one 'null' in [<type>, 'null']",
302            ));
303        }
304
305        // 'first' must be a known primitive
306        let inner = match first {
307            "string" => ParamType::String,
308            "number" => ParamType::Number,
309            "boolean" => ParamType::Boolean,
310            "array" => ParamType::Array,
311            other => {
312                return Err(A::Error::unknown_variant(
313                    other,
314                    &["string", "number", "boolean", "array", "null"],
315                ));
316            }
317        };
318
319        Ok(ParamType::Nullable(Box::new(inner)))
320    }
321}
322/// Parameters for tools
323#[derive(Clone, Debug, Hash, Eq, PartialEq, Builder, Serialize, Deserialize)]
324#[builder(setter(into))]
325#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
326pub struct ParamSpec {
327    /// Name of the parameter
328    pub name: String,
329    /// Description of the parameter
330    pub description: String,
331    /// Json spec type of the parameter
332    #[serde(rename = "type")]
333    #[builder(default)]
334    pub ty: ParamType,
335    /// Whether the parameter is required
336    #[builder(default = true)]
337    pub required: bool,
338}
339
340impl ParamSpec {
341    pub fn builder() -> ParamSpecBuilder {
342        ParamSpecBuilder::default()
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    #[test]
351    fn test_serialize_param_type() {
352        let param = ParamType::Nullable(Box::new(ParamType::String));
353        let serialized = serde_json::to_string(&param).unwrap();
354        assert_eq!(serialized, r#"["string","null"]"#);
355
356        let deserialized: ParamType = serde_json::from_str(&serialized).unwrap();
357        assert_eq!(param, deserialized);
358    }
359
360    #[test]
361    fn test_deserialize_param_type() {
362        let serialized = r#"["string","null"]"#;
363        let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
364        assert_eq!(
365            deserialized,
366            ParamType::Nullable(Box::new(ParamType::String))
367        );
368
369        let serialized = r#""string""#;
370        let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
371        assert_eq!(deserialized, ParamType::String);
372    }
373}