swiftide_core/chat_completion/
tools.rs

1use std::fmt;
2
3use derive_builder::Builder;
4use serde::de::{Deserializer, Error as DeError, SeqAccess, Unexpected, Visitor};
5use serde::ser::{Error as SerError, SerializeSeq, Serializer};
6use serde::{Deserialize, Serialize};
7
8/// Output of a `ToolCall` which will be added as a message for the agent to use.
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10#[non_exhaustive]
11pub enum ToolOutput {
12    /// Adds the result of the toolcall to messages
13    Text(String),
14
15    /// Indicates that the toolcall failed, but can be handled by the llm
16    Fail(String),
17    /// Stops an agent
18    Stop,
19}
20
21impl ToolOutput {
22    pub fn content(&self) -> Option<&str> {
23        match self {
24            ToolOutput::Fail(s) | ToolOutput::Text(s) => Some(s),
25            _ => None,
26        }
27    }
28}
29
30impl<T: AsRef<str>> From<T> for ToolOutput {
31    fn from(s: T) -> Self {
32        ToolOutput::Text(s.as_ref().to_string())
33    }
34}
35
36impl std::fmt::Display for ToolOutput {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            ToolOutput::Text(value) => write!(f, "{value}"),
40            ToolOutput::Fail(value) => write!(f, "Tool call failed: {value}"),
41            ToolOutput::Stop => write!(f, "Stop"),
42        }
43    }
44}
45
46/// A tool call that can be executed by the executor
47#[derive(Clone, Debug, Builder, PartialEq, Serialize, Deserialize)]
48#[builder(setter(into, strip_option))]
49pub struct ToolCall {
50    id: String,
51    name: String,
52    #[builder(default)]
53    args: Option<String>,
54}
55
56/// Hash is used for finding tool calls that have been retried by agents
57impl std::hash::Hash for &ToolCall {
58    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
59        self.name.hash(state);
60        self.args.hash(state);
61    }
62}
63
64impl std::fmt::Display for ToolCall {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        write!(
67            f,
68            "{id}#{name} {args}",
69            id = self.id,
70            name = self.name,
71            args = self.args.as_deref().unwrap_or("")
72        )
73    }
74}
75
76impl ToolCall {
77    pub fn builder() -> ToolCallBuilder {
78        ToolCallBuilder::default()
79    }
80
81    pub fn id(&self) -> &str {
82        &self.id
83    }
84
85    pub fn name(&self) -> &str {
86        &self.name
87    }
88
89    pub fn args(&self) -> Option<&str> {
90        self.args.as_deref()
91    }
92}
93
94impl ToolCallBuilder {
95    pub fn maybe_args<T: Into<Option<String>>>(&mut self, args: T) -> &mut Self {
96        self.args = Some(args.into());
97        self
98    }
99
100    pub fn maybe_id<T: Into<Option<String>>>(&mut self, id: T) -> &mut Self {
101        self.id = id.into();
102        self
103    }
104
105    pub fn maybe_name<T: Into<Option<String>>>(&mut self, name: T) -> &mut Self {
106        self.name = name.into();
107        self
108    }
109}
110
111/// A typed tool specification intended to be usable for multiple LLMs
112///
113/// i.e. the json spec `OpenAI` uses to define their tools
114#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, Builder)]
115#[builder(setter(into))]
116pub struct ToolSpec {
117    /// Name of the tool
118    pub name: String,
119    /// Description passed to the LLM for the tool
120    pub description: String,
121
122    #[builder(default)]
123    /// Optional parameters for the tool
124    pub parameters: Vec<ParamSpec>,
125}
126
127impl ToolSpec {
128    pub fn builder() -> ToolSpecBuilder {
129        ToolSpecBuilder::default()
130    }
131}
132
133#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, strum_macros::AsRefStr)]
134#[strum(serialize_all = "camelCase")]
135pub enum ParamType {
136    #[default]
137    String,
138    Number,
139    Boolean,
140    Array,
141    Nullable(Box<ParamType>),
142}
143
144pub enum InnerParamType {
145    String,
146    Number,
147    Boolean,
148    Array,
149}
150
151impl Serialize for ParamType {
152    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
153    where
154        S: Serializer,
155    {
156        match self {
157            // Non-nullable => single string
158            ParamType::String => serializer.serialize_str("string"),
159            ParamType::Number => serializer.serialize_str("number"),
160            ParamType::Boolean => serializer.serialize_str("boolean"),
161            ParamType::Array => serializer.serialize_str("array"),
162
163            // Nullable => an array of exactly two items, e.g. ["string", "null"]
164            ParamType::Nullable(inner) => {
165                // If you want to forbid nested nullables:
166                if let ParamType::Nullable(_) = inner.as_ref() {
167                    return Err(serde::ser::Error::custom("Nested Nullable not supported"));
168                }
169                // Otherwise, produce an array like `["string", "null"]`.
170                let mut seq = serializer.serialize_seq(Some(2))?;
171                seq.serialize_element(&primitive_variant_str(inner).map_err(S::Error::custom)?)?;
172                seq.serialize_element("null")?;
173                seq.end()
174            }
175        }
176    }
177}
178
179fn primitive_variant_str(pt: &ParamType) -> Result<&'static str, &'static str> {
180    match pt {
181        ParamType::String => Ok("string"),
182        ParamType::Number => Ok("number"),
183        ParamType::Boolean => Ok("boolean"),
184        ParamType::Array => Ok("array"),
185        ParamType::Nullable(_) => Err("Nested Nullable found"),
186    }
187}
188
189impl<'de> Deserialize<'de> for ParamType {
190    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
191    where
192        D: Deserializer<'de>,
193    {
194        deserializer.deserialize_any(ParamTypeVisitor)
195    }
196}
197
198struct ParamTypeVisitor;
199
200impl<'de> Visitor<'de> for ParamTypeVisitor {
201    type Value = ParamType;
202
203    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
204        write!(
205            formatter,
206            "a string (e.g. \"string\", \"number\") \
207             or a 2-element array [<type>, \"null\"]"
208        )
209    }
210
211    // Single strings => simple ParamType
212    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
213    where
214        E: DeError,
215    {
216        match value {
217            "string" => Ok(ParamType::String),
218            "number" => Ok(ParamType::Number),
219            "boolean" => Ok(ParamType::Boolean),
220            "array" => Ok(ParamType::Array),
221            other => Err(E::unknown_variant(
222                other,
223                &["string", "number", "boolean", "array"],
224            )),
225        }
226    }
227
228    // Arrays => expect exactly 2 items, one must be "null"
229    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
230    where
231        A: SeqAccess<'de>,
232    {
233        let mut items = Vec::new();
234        while let Some(item) = seq.next_element::<String>()? {
235            items.push(item);
236        }
237
238        // Must have exactly 2 elements
239        if items.len() != 2 {
240            return Err(A::Error::invalid_length(items.len(), &"2"));
241        }
242
243        let mut first = items[0].as_str();
244        let mut second = items[1].as_str();
245
246        // If the first is "null", swap so second is "null" and first is the real type
247        if first == "null" {
248            std::mem::swap(&mut first, &mut second);
249        }
250
251        // Now 'second' must be "null".
252        if second != "null" {
253            return Err(A::Error::invalid_value(
254                Unexpected::Str(second),
255                &"expected exactly one 'null' in [<type>, 'null']",
256            ));
257        }
258
259        // 'first' must be a known primitive
260        let inner = match first {
261            "string" => ParamType::String,
262            "number" => ParamType::Number,
263            "boolean" => ParamType::Boolean,
264            "array" => ParamType::Array,
265            other => {
266                return Err(A::Error::unknown_variant(
267                    other,
268                    &["string", "number", "boolean", "array", "null"],
269                ));
270            }
271        };
272
273        Ok(ParamType::Nullable(Box::new(inner)))
274    }
275}
276/// Parameters for tools
277#[derive(Clone, Debug, Hash, Eq, PartialEq, Builder)]
278#[builder(setter(into))]
279pub struct ParamSpec {
280    /// Name of the parameter
281    pub name: String,
282    /// Description of the parameter
283    pub description: String,
284    /// Json spec type of the parameter
285    #[builder(default)]
286    pub ty: ParamType,
287    /// Whether the parameter is required
288    #[builder(default = true)]
289    pub required: bool,
290}
291
292impl ParamSpec {
293    pub fn builder() -> ParamSpecBuilder {
294        ParamSpecBuilder::default()
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_serialize_param_type() {
304        let param = ParamType::Nullable(Box::new(ParamType::String));
305        let serialized = serde_json::to_string(&param).unwrap();
306        assert_eq!(serialized, r#"["string","null"]"#);
307
308        let deserialized: ParamType = serde_json::from_str(&serialized).unwrap();
309        assert_eq!(param, deserialized);
310    }
311
312    #[test]
313    fn test_deserialize_param_type() {
314        let serialized = r#"["string","null"]"#;
315        let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
316        assert_eq!(
317            deserialized,
318            ParamType::Nullable(Box::new(ParamType::String))
319        );
320
321        let serialized = r#""string""#;
322        let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
323        assert_eq!(deserialized, ParamType::String);
324    }
325}