swiftide_core/chat_completion/
tools.rs

1use std::fmt;
2
3use derive_builder::Builder;
4use serde::de::{Deserializer, Error as DeError, SeqAccess, Unexpected, Visitor};
5use serde::ser::{Error as SerError, SerializeSeq, Serializer};
6use serde::{Deserialize, Serialize};
7
8/// Output of a `ToolCall` which will be added as a message for the agent to use.
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, strum_macros::EnumIs)]
10#[non_exhaustive]
11pub enum ToolOutput {
12    /// Adds the result of the toolcall to messages
13    Text(String),
14
15    /// Indicates that the toolcall requires feedback, i.e. in a human-in-the-loop
16    FeedbackRequired(Option<serde_json::Value>),
17
18    /// Indicates that the toolcall failed, but can be handled by the llm
19    Fail(String),
20    /// Stops an agent
21    Stop,
22}
23
24impl ToolOutput {
25    pub fn text(text: impl Into<String>) -> Self {
26        ToolOutput::Text(text.into())
27    }
28
29    pub fn feedback_required(feedback: Option<serde_json::Value>) -> Self {
30        ToolOutput::FeedbackRequired(feedback)
31    }
32
33    pub fn stop() -> Self {
34        ToolOutput::Stop
35    }
36
37    pub fn fail(text: impl Into<String>) -> Self {
38        ToolOutput::Fail(text.into())
39    }
40
41    pub fn content(&self) -> Option<&str> {
42        match self {
43            ToolOutput::Fail(s) | ToolOutput::Text(s) => Some(s),
44            _ => None,
45        }
46    }
47}
48
49impl<T: AsRef<str>> From<T> for ToolOutput {
50    fn from(s: T) -> Self {
51        ToolOutput::Text(s.as_ref().to_string())
52    }
53}
54
55impl std::fmt::Display for ToolOutput {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        match self {
58            ToolOutput::Text(value) => write!(f, "{value}"),
59            ToolOutput::Fail(value) => write!(f, "Tool call failed: {value}"),
60            ToolOutput::Stop => write!(f, "Stop"),
61            ToolOutput::FeedbackRequired(_) => {
62                write!(f, "Feedback required")
63            }
64        }
65    }
66}
67
68/// A tool call that can be executed by the executor
69#[derive(Clone, Debug, Builder, PartialEq, Serialize, Deserialize, Eq)]
70#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
71#[builder(setter(into, strip_option))]
72pub struct ToolCall {
73    id: String,
74    name: String,
75    #[builder(default)]
76    args: Option<String>,
77}
78
79/// Hash is used for finding tool calls that have been retried by agents
80impl std::hash::Hash for ToolCall {
81    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
82        self.name.hash(state);
83        self.args.hash(state);
84    }
85}
86
87impl std::fmt::Display for ToolCall {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        write!(
90            f,
91            "{id}#{name} {args}",
92            id = self.id,
93            name = self.name,
94            args = self.args.as_deref().unwrap_or("")
95        )
96    }
97}
98
99impl ToolCall {
100    pub fn builder() -> ToolCallBuilder {
101        ToolCallBuilder::default()
102    }
103
104    pub fn id(&self) -> &str {
105        &self.id
106    }
107
108    pub fn name(&self) -> &str {
109        &self.name
110    }
111
112    pub fn args(&self) -> Option<&str> {
113        self.args.as_deref()
114    }
115
116    pub fn with_args(&mut self, args: Option<String>) {
117        self.args = args;
118    }
119}
120
121impl ToolCallBuilder {
122    pub fn maybe_args<T: Into<Option<String>>>(&mut self, args: T) -> &mut Self {
123        self.args = Some(args.into());
124        self
125    }
126
127    pub fn maybe_id<T: Into<Option<String>>>(&mut self, id: T) -> &mut Self {
128        self.id = id.into();
129        self
130    }
131
132    pub fn maybe_name<T: Into<Option<String>>>(&mut self, name: T) -> &mut Self {
133        self.name = name.into();
134        self
135    }
136}
137
138/// A typed tool specification intended to be usable for multiple LLMs
139///
140/// i.e. the json spec `OpenAI` uses to define their tools
141#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, Builder, Serialize, Deserialize)]
142#[builder(setter(into), derive(Debug, Serialize, Deserialize))]
143#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
144pub struct ToolSpec {
145    /// Name of the tool
146    pub name: String,
147    /// Description passed to the LLM for the tool
148    pub description: String,
149
150    #[builder(default)]
151    /// Optional parameters for the tool
152    pub parameters: Vec<ParamSpec>,
153}
154
155impl ToolSpec {
156    pub fn builder() -> ToolSpecBuilder {
157        ToolSpecBuilder::default()
158    }
159}
160
161#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, strum_macros::AsRefStr)]
162#[strum(serialize_all = "camelCase")]
163#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
164pub enum ParamType {
165    #[default]
166    String,
167    Number,
168    Boolean,
169    Array,
170    Nullable(Box<ParamType>),
171}
172
173pub enum InnerParamType {
174    String,
175    Number,
176    Boolean,
177    Array,
178}
179
180impl Serialize for ParamType {
181    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
182    where
183        S: Serializer,
184    {
185        match self {
186            // Non-nullable => single string
187            ParamType::String => serializer.serialize_str("string"),
188            ParamType::Number => serializer.serialize_str("number"),
189            ParamType::Boolean => serializer.serialize_str("boolean"),
190            ParamType::Array => serializer.serialize_str("array"),
191
192            // Nullable => an array of exactly two items, e.g. ["string", "null"]
193            ParamType::Nullable(inner) => {
194                // If you want to forbid nested nullables:
195                if let ParamType::Nullable(_) = inner.as_ref() {
196                    return Err(serde::ser::Error::custom("Nested Nullable not supported"));
197                }
198                // Otherwise, produce an array like `["string", "null"]`.
199                let mut seq = serializer.serialize_seq(Some(2))?;
200                seq.serialize_element(&primitive_variant_str(inner).map_err(S::Error::custom)?)?;
201                seq.serialize_element("null")?;
202                seq.end()
203            }
204        }
205    }
206}
207
208fn primitive_variant_str(pt: &ParamType) -> Result<&'static str, &'static str> {
209    match pt {
210        ParamType::String => Ok("string"),
211        ParamType::Number => Ok("number"),
212        ParamType::Boolean => Ok("boolean"),
213        ParamType::Array => Ok("array"),
214        ParamType::Nullable(_) => Err("Nested Nullable found"),
215    }
216}
217
218impl<'de> Deserialize<'de> for ParamType {
219    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
220    where
221        D: Deserializer<'de>,
222    {
223        deserializer.deserialize_any(ParamTypeVisitor)
224    }
225}
226
227struct ParamTypeVisitor;
228
229impl<'de> Visitor<'de> for ParamTypeVisitor {
230    type Value = ParamType;
231
232    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
233        write!(
234            formatter,
235            "a string (e.g. \"string\", \"number\") \
236             or a 2-element array [<type>, \"null\"]"
237        )
238    }
239
240    // Single strings => simple ParamType
241    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
242    where
243        E: DeError,
244    {
245        match value {
246            "string" => Ok(ParamType::String),
247            "number" => Ok(ParamType::Number),
248            "boolean" => Ok(ParamType::Boolean),
249            "array" => Ok(ParamType::Array),
250            other => Err(E::unknown_variant(
251                other,
252                &["string", "number", "boolean", "array"],
253            )),
254        }
255    }
256
257    // Arrays => expect exactly 2 items, one must be "null"
258    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
259    where
260        A: SeqAccess<'de>,
261    {
262        let mut items = Vec::new();
263        while let Some(item) = seq.next_element::<String>()? {
264            items.push(item);
265        }
266
267        // Must have exactly 2 elements
268        if items.len() != 2 {
269            return Err(A::Error::invalid_length(items.len(), &"2"));
270        }
271
272        let mut first = items[0].as_str();
273        let mut second = items[1].as_str();
274
275        // If the first is "null", swap so second is "null" and first is the real type
276        if first == "null" {
277            std::mem::swap(&mut first, &mut second);
278        }
279
280        // Now 'second' must be "null".
281        if second != "null" {
282            return Err(A::Error::invalid_value(
283                Unexpected::Str(second),
284                &"expected exactly one 'null' in [<type>, 'null']",
285            ));
286        }
287
288        // 'first' must be a known primitive
289        let inner = match first {
290            "string" => ParamType::String,
291            "number" => ParamType::Number,
292            "boolean" => ParamType::Boolean,
293            "array" => ParamType::Array,
294            other => {
295                return Err(A::Error::unknown_variant(
296                    other,
297                    &["string", "number", "boolean", "array", "null"],
298                ));
299            }
300        };
301
302        Ok(ParamType::Nullable(Box::new(inner)))
303    }
304}
305/// Parameters for tools
306#[derive(Clone, Debug, Hash, Eq, PartialEq, Builder, Serialize, Deserialize)]
307#[builder(setter(into))]
308#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
309pub struct ParamSpec {
310    /// Name of the parameter
311    pub name: String,
312    /// Description of the parameter
313    pub description: String,
314    /// Json spec type of the parameter
315    #[serde(rename = "type")]
316    #[builder(default)]
317    pub ty: ParamType,
318    /// Whether the parameter is required
319    #[builder(default = true)]
320    pub required: bool,
321}
322
323impl ParamSpec {
324    pub fn builder() -> ParamSpecBuilder {
325        ParamSpecBuilder::default()
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[test]
334    fn test_serialize_param_type() {
335        let param = ParamType::Nullable(Box::new(ParamType::String));
336        let serialized = serde_json::to_string(&param).unwrap();
337        assert_eq!(serialized, r#"["string","null"]"#);
338
339        let deserialized: ParamType = serde_json::from_str(&serialized).unwrap();
340        assert_eq!(param, deserialized);
341    }
342
343    #[test]
344    fn test_deserialize_param_type() {
345        let serialized = r#"["string","null"]"#;
346        let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
347        assert_eq!(
348            deserialized,
349            ParamType::Nullable(Box::new(ParamType::String))
350        );
351
352        let serialized = r#""string""#;
353        let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
354        assert_eq!(deserialized, ParamType::String);
355    }
356}