swiftide_core/chat_completion/
tools.rs

1use std::fmt;
2
3use derive_builder::Builder;
4use serde::de::{Deserializer, Error as DeError, SeqAccess, Unexpected, Visitor};
5use serde::ser::{Error as SerError, SerializeSeq, Serializer};
6use serde::{Deserialize, Serialize};
7
8/// Output of a `ToolCall` which will be added as a message for the agent to use.
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, strum_macros::EnumIs)]
10#[non_exhaustive]
11pub enum ToolOutput {
12    /// Adds the result of the toolcall to messages
13    Text(String),
14
15    /// Indicates that the toolcall requires feedback, i.e. in a human-in-the-loop
16    FeedbackRequired(Option<serde_json::Value>),
17
18    /// Indicates that the toolcall failed, but can be handled by the llm
19    Fail(String),
20    /// Stops an agent
21    Stop,
22}
23
24impl ToolOutput {
25    pub fn text(text: impl Into<String>) -> Self {
26        ToolOutput::Text(text.into())
27    }
28
29    pub fn feedback_required(feedback: Option<serde_json::Value>) -> Self {
30        ToolOutput::FeedbackRequired(feedback)
31    }
32
33    pub fn stop() -> Self {
34        ToolOutput::Stop
35    }
36
37    pub fn fail(text: impl Into<String>) -> Self {
38        ToolOutput::Fail(text.into())
39    }
40
41    pub fn content(&self) -> Option<&str> {
42        match self {
43            ToolOutput::Fail(s) | ToolOutput::Text(s) => Some(s),
44            _ => None,
45        }
46    }
47}
48
49impl<T: AsRef<str>> From<T> for ToolOutput {
50    fn from(s: T) -> Self {
51        ToolOutput::Text(s.as_ref().to_string())
52    }
53}
54
55impl std::fmt::Display for ToolOutput {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        match self {
58            ToolOutput::Text(value) => write!(f, "{value}"),
59            ToolOutput::Fail(value) => write!(f, "Tool call failed: {value}"),
60            ToolOutput::Stop => write!(f, "Stop"),
61            ToolOutput::FeedbackRequired(_) => {
62                write!(f, "Feedback required")
63            }
64        }
65    }
66}
67
68/// A tool call that can be executed by the executor
69#[derive(Clone, Debug, Builder, PartialEq, Serialize, Deserialize, Eq)]
70#[builder(setter(into, strip_option))]
71pub struct ToolCall {
72    id: String,
73    name: String,
74    #[builder(default)]
75    args: Option<String>,
76}
77
78/// Hash is used for finding tool calls that have been retried by agents
79impl std::hash::Hash for ToolCall {
80    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
81        self.name.hash(state);
82        self.args.hash(state);
83    }
84}
85
86impl std::fmt::Display for ToolCall {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        write!(
89            f,
90            "{id}#{name} {args}",
91            id = self.id,
92            name = self.name,
93            args = self.args.as_deref().unwrap_or("")
94        )
95    }
96}
97
98impl ToolCall {
99    pub fn builder() -> ToolCallBuilder {
100        ToolCallBuilder::default()
101    }
102
103    pub fn id(&self) -> &str {
104        &self.id
105    }
106
107    pub fn name(&self) -> &str {
108        &self.name
109    }
110
111    pub fn args(&self) -> Option<&str> {
112        self.args.as_deref()
113    }
114
115    pub fn with_args(&mut self, args: Option<String>) {
116        self.args = args;
117    }
118}
119
120impl ToolCallBuilder {
121    pub fn maybe_args<T: Into<Option<String>>>(&mut self, args: T) -> &mut Self {
122        self.args = Some(args.into());
123        self
124    }
125
126    pub fn maybe_id<T: Into<Option<String>>>(&mut self, id: T) -> &mut Self {
127        self.id = id.into();
128        self
129    }
130
131    pub fn maybe_name<T: Into<Option<String>>>(&mut self, name: T) -> &mut Self {
132        self.name = name.into();
133        self
134    }
135}
136
137/// A typed tool specification intended to be usable for multiple LLMs
138///
139/// i.e. the json spec `OpenAI` uses to define their tools
140#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, Builder)]
141#[builder(setter(into))]
142pub struct ToolSpec {
143    /// Name of the tool
144    pub name: String,
145    /// Description passed to the LLM for the tool
146    pub description: String,
147
148    #[builder(default)]
149    /// Optional parameters for the tool
150    pub parameters: Vec<ParamSpec>,
151}
152
153impl ToolSpec {
154    pub fn builder() -> ToolSpecBuilder {
155        ToolSpecBuilder::default()
156    }
157}
158
159#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, strum_macros::AsRefStr)]
160#[strum(serialize_all = "camelCase")]
161pub enum ParamType {
162    #[default]
163    String,
164    Number,
165    Boolean,
166    Array,
167    Nullable(Box<ParamType>),
168}
169
170pub enum InnerParamType {
171    String,
172    Number,
173    Boolean,
174    Array,
175}
176
177impl Serialize for ParamType {
178    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
179    where
180        S: Serializer,
181    {
182        match self {
183            // Non-nullable => single string
184            ParamType::String => serializer.serialize_str("string"),
185            ParamType::Number => serializer.serialize_str("number"),
186            ParamType::Boolean => serializer.serialize_str("boolean"),
187            ParamType::Array => serializer.serialize_str("array"),
188
189            // Nullable => an array of exactly two items, e.g. ["string", "null"]
190            ParamType::Nullable(inner) => {
191                // If you want to forbid nested nullables:
192                if let ParamType::Nullable(_) = inner.as_ref() {
193                    return Err(serde::ser::Error::custom("Nested Nullable not supported"));
194                }
195                // Otherwise, produce an array like `["string", "null"]`.
196                let mut seq = serializer.serialize_seq(Some(2))?;
197                seq.serialize_element(&primitive_variant_str(inner).map_err(S::Error::custom)?)?;
198                seq.serialize_element("null")?;
199                seq.end()
200            }
201        }
202    }
203}
204
205fn primitive_variant_str(pt: &ParamType) -> Result<&'static str, &'static str> {
206    match pt {
207        ParamType::String => Ok("string"),
208        ParamType::Number => Ok("number"),
209        ParamType::Boolean => Ok("boolean"),
210        ParamType::Array => Ok("array"),
211        ParamType::Nullable(_) => Err("Nested Nullable found"),
212    }
213}
214
215impl<'de> Deserialize<'de> for ParamType {
216    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
217    where
218        D: Deserializer<'de>,
219    {
220        deserializer.deserialize_any(ParamTypeVisitor)
221    }
222}
223
224struct ParamTypeVisitor;
225
226impl<'de> Visitor<'de> for ParamTypeVisitor {
227    type Value = ParamType;
228
229    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
230        write!(
231            formatter,
232            "a string (e.g. \"string\", \"number\") \
233             or a 2-element array [<type>, \"null\"]"
234        )
235    }
236
237    // Single strings => simple ParamType
238    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
239    where
240        E: DeError,
241    {
242        match value {
243            "string" => Ok(ParamType::String),
244            "number" => Ok(ParamType::Number),
245            "boolean" => Ok(ParamType::Boolean),
246            "array" => Ok(ParamType::Array),
247            other => Err(E::unknown_variant(
248                other,
249                &["string", "number", "boolean", "array"],
250            )),
251        }
252    }
253
254    // Arrays => expect exactly 2 items, one must be "null"
255    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
256    where
257        A: SeqAccess<'de>,
258    {
259        let mut items = Vec::new();
260        while let Some(item) = seq.next_element::<String>()? {
261            items.push(item);
262        }
263
264        // Must have exactly 2 elements
265        if items.len() != 2 {
266            return Err(A::Error::invalid_length(items.len(), &"2"));
267        }
268
269        let mut first = items[0].as_str();
270        let mut second = items[1].as_str();
271
272        // If the first is "null", swap so second is "null" and first is the real type
273        if first == "null" {
274            std::mem::swap(&mut first, &mut second);
275        }
276
277        // Now 'second' must be "null".
278        if second != "null" {
279            return Err(A::Error::invalid_value(
280                Unexpected::Str(second),
281                &"expected exactly one 'null' in [<type>, 'null']",
282            ));
283        }
284
285        // 'first' must be a known primitive
286        let inner = match first {
287            "string" => ParamType::String,
288            "number" => ParamType::Number,
289            "boolean" => ParamType::Boolean,
290            "array" => ParamType::Array,
291            other => {
292                return Err(A::Error::unknown_variant(
293                    other,
294                    &["string", "number", "boolean", "array", "null"],
295                ));
296            }
297        };
298
299        Ok(ParamType::Nullable(Box::new(inner)))
300    }
301}
302/// Parameters for tools
303#[derive(Clone, Debug, Hash, Eq, PartialEq, Builder)]
304#[builder(setter(into))]
305pub struct ParamSpec {
306    /// Name of the parameter
307    pub name: String,
308    /// Description of the parameter
309    pub description: String,
310    /// Json spec type of the parameter
311    #[builder(default)]
312    pub ty: ParamType,
313    /// Whether the parameter is required
314    #[builder(default = true)]
315    pub required: bool,
316}
317
318impl ParamSpec {
319    pub fn builder() -> ParamSpecBuilder {
320        ParamSpecBuilder::default()
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327
328    #[test]
329    fn test_serialize_param_type() {
330        let param = ParamType::Nullable(Box::new(ParamType::String));
331        let serialized = serde_json::to_string(&param).unwrap();
332        assert_eq!(serialized, r#"["string","null"]"#);
333
334        let deserialized: ParamType = serde_json::from_str(&serialized).unwrap();
335        assert_eq!(param, deserialized);
336    }
337
338    #[test]
339    fn test_deserialize_param_type() {
340        let serialized = r#"["string","null"]"#;
341        let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
342        assert_eq!(
343            deserialized,
344            ParamType::Nullable(Box::new(ParamType::String))
345        );
346
347        let serialized = r#""string""#;
348        let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
349        assert_eq!(deserialized, ParamType::String);
350    }
351}