swiftide_core/chat_completion/
tools.rs

1use std::borrow::Cow;
2use std::fmt;
3
4use derive_builder::Builder;
5use serde::de::{Deserializer, Error as DeError, SeqAccess, Unexpected, Visitor};
6use serde::ser::{Error as SerError, SerializeSeq, Serializer};
7use serde::{Deserialize, Serialize};
8
9/// Output of a `ToolCall` which will be added as a message for the agent to use.
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, strum_macros::EnumIs)]
11#[non_exhaustive]
12pub enum ToolOutput {
13    /// Adds the result of the toolcall to messages
14    Text(String),
15
16    /// Indicates that the toolcall requires feedback, i.e. in a human-in-the-loop
17    FeedbackRequired(Option<serde_json::Value>),
18
19    /// Indicates that the toolcall failed, but can be handled by the llm
20    Fail(String),
21
22    /// Stops an agent with an optional message
23    Stop(Option<Cow<'static, str>>),
24
25    /// Indicates that the agent failed and should stop
26    AgentFailed(Option<Cow<'static, str>>),
27}
28
29impl ToolOutput {
30    pub fn text(text: impl Into<String>) -> Self {
31        ToolOutput::Text(text.into())
32    }
33
34    pub fn feedback_required(feedback: Option<serde_json::Value>) -> Self {
35        ToolOutput::FeedbackRequired(feedback)
36    }
37
38    pub fn stop() -> Self {
39        ToolOutput::Stop(None)
40    }
41
42    pub fn stop_with_args(output: impl Into<Cow<'static, str>>) -> Self {
43        ToolOutput::Stop(Some(output.into()))
44    }
45
46    pub fn agent_failed(output: impl Into<Cow<'static, str>>) -> Self {
47        ToolOutput::AgentFailed(Some(output.into()))
48    }
49
50    pub fn fail(text: impl Into<String>) -> Self {
51        ToolOutput::Fail(text.into())
52    }
53
54    pub fn content(&self) -> Option<&str> {
55        match self {
56            ToolOutput::Fail(s) | ToolOutput::Text(s) => Some(s),
57            _ => None,
58        }
59    }
60
61    /// Get the inner text if the output is a `Text` variant.
62    pub fn as_text(&self) -> Option<&str> {
63        match self {
64            ToolOutput::Text(s) => Some(s),
65            _ => None,
66        }
67    }
68
69    /// Get the inner text if the output is a `Fail` variant.
70    pub fn as_fail(&self) -> Option<&str> {
71        match self {
72            ToolOutput::Fail(s) => Some(s),
73            _ => None,
74        }
75    }
76
77    /// Get the inner text if the output is a `Stop` variant.
78    pub fn as_stop(&self) -> Option<&str> {
79        match self {
80            ToolOutput::Stop(args) => args.as_deref(),
81            _ => None,
82        }
83    }
84
85    /// Get the inner text if the output is an `AgentFailed` variant.
86    pub fn as_agent_failed(&self) -> Option<&str> {
87        match self {
88            ToolOutput::AgentFailed(args) => args.as_deref(),
89            _ => None,
90        }
91    }
92
93    /// Get the inner feedback if the output is a `FeedbackRequired` variant.
94    pub fn as_feedback_required(&self) -> Option<&serde_json::Value> {
95        match self {
96            ToolOutput::FeedbackRequired(args) => args.as_ref(),
97            _ => None,
98        }
99    }
100}
101
102impl<S: AsRef<str>> From<S> for ToolOutput {
103    fn from(value: S) -> Self {
104        ToolOutput::Text(value.as_ref().to_string())
105    }
106}
107impl std::fmt::Display for ToolOutput {
108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109        match self {
110            ToolOutput::Text(value) => write!(f, "{value}"),
111            ToolOutput::Fail(value) => write!(f, "Tool call failed: {value}"),
112            ToolOutput::Stop(args) => write!(f, "Stop {}", args.as_deref().unwrap_or_default()),
113            ToolOutput::FeedbackRequired(_) => {
114                write!(f, "Feedback required")
115            }
116            ToolOutput::AgentFailed(args) => write!(
117                f,
118                "Agent failed with output: {}",
119                args.as_deref().unwrap_or_default()
120            ),
121        }
122    }
123}
124
125/// A tool call that can be executed by the executor
126#[derive(Clone, Debug, Builder, PartialEq, Serialize, Deserialize, Eq)]
127#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
128#[builder(setter(into, strip_option))]
129pub struct ToolCall {
130    id: String,
131    name: String,
132    #[builder(default)]
133    args: Option<String>,
134}
135
136/// Hash is used for finding tool calls that have been retried by agents
137impl std::hash::Hash for ToolCall {
138    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
139        self.name.hash(state);
140        self.args.hash(state);
141    }
142}
143
144impl std::fmt::Display for ToolCall {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        write!(
147            f,
148            "{id}#{name} {args}",
149            id = self.id,
150            name = self.name,
151            args = self.args.as_deref().unwrap_or("")
152        )
153    }
154}
155
156impl ToolCall {
157    pub fn builder() -> ToolCallBuilder {
158        ToolCallBuilder::default()
159    }
160
161    pub fn id(&self) -> &str {
162        &self.id
163    }
164
165    pub fn name(&self) -> &str {
166        &self.name
167    }
168
169    pub fn args(&self) -> Option<&str> {
170        self.args.as_deref()
171    }
172
173    pub fn with_args(&mut self, args: Option<String>) {
174        self.args = args;
175    }
176}
177
178impl ToolCallBuilder {
179    pub fn maybe_args<T: Into<Option<String>>>(&mut self, args: T) -> &mut Self {
180        self.args = Some(args.into());
181        self
182    }
183
184    pub fn maybe_id<T: Into<Option<String>>>(&mut self, id: T) -> &mut Self {
185        self.id = id.into();
186        self
187    }
188
189    pub fn maybe_name<T: Into<Option<String>>>(&mut self, name: T) -> &mut Self {
190        self.name = name.into();
191        self
192    }
193}
194
195/// A typed tool specification intended to be usable for multiple LLMs
196///
197/// i.e. the json spec `OpenAI` uses to define their tools
198#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, Builder, Serialize, Deserialize)]
199#[builder(setter(into), derive(Debug, Serialize, Deserialize))]
200#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
201pub struct ToolSpec {
202    /// Name of the tool
203    pub name: String,
204    /// Description passed to the LLM for the tool
205    pub description: String,
206
207    #[builder(default)]
208    /// Optional parameters for the tool
209    pub parameters: Vec<ParamSpec>,
210}
211
212impl ToolSpec {
213    pub fn builder() -> ToolSpecBuilder {
214        ToolSpecBuilder::default()
215    }
216}
217
218#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, strum_macros::AsRefStr)]
219#[strum(serialize_all = "camelCase")]
220#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
221pub enum ParamType {
222    #[default]
223    String,
224    Number,
225    Boolean,
226    Array,
227    Nullable(Box<ParamType>),
228}
229
230pub enum InnerParamType {
231    String,
232    Number,
233    Boolean,
234    Array,
235}
236
237impl Serialize for ParamType {
238    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
239    where
240        S: Serializer,
241    {
242        match self {
243            // Non-nullable => single string
244            ParamType::String => serializer.serialize_str("string"),
245            ParamType::Number => serializer.serialize_str("number"),
246            ParamType::Boolean => serializer.serialize_str("boolean"),
247            ParamType::Array => serializer.serialize_str("array"),
248
249            // Nullable => an array of exactly two items, e.g. ["string", "null"]
250            ParamType::Nullable(inner) => {
251                // If you want to forbid nested nullables:
252                if let ParamType::Nullable(_) = inner.as_ref() {
253                    return Err(serde::ser::Error::custom("Nested Nullable not supported"));
254                }
255                // Otherwise, produce an array like `["string", "null"]`.
256                let mut seq = serializer.serialize_seq(Some(2))?;
257                seq.serialize_element(&primitive_variant_str(inner).map_err(S::Error::custom)?)?;
258                seq.serialize_element("null")?;
259                seq.end()
260            }
261        }
262    }
263}
264
265fn primitive_variant_str(pt: &ParamType) -> Result<&'static str, &'static str> {
266    match pt {
267        ParamType::String => Ok("string"),
268        ParamType::Number => Ok("number"),
269        ParamType::Boolean => Ok("boolean"),
270        ParamType::Array => Ok("array"),
271        ParamType::Nullable(_) => Err("Nested Nullable found"),
272    }
273}
274
275impl<'de> Deserialize<'de> for ParamType {
276    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
277    where
278        D: Deserializer<'de>,
279    {
280        deserializer.deserialize_any(ParamTypeVisitor)
281    }
282}
283
284struct ParamTypeVisitor;
285
286impl<'de> Visitor<'de> for ParamTypeVisitor {
287    type Value = ParamType;
288
289    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
290        write!(
291            formatter,
292            "a string (e.g. \"string\", \"number\") \
293             or a 2-element array [<type>, \"null\"]"
294        )
295    }
296
297    // Single strings => simple ParamType
298    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
299    where
300        E: DeError,
301    {
302        match value {
303            "string" => Ok(ParamType::String),
304            "number" => Ok(ParamType::Number),
305            "boolean" => Ok(ParamType::Boolean),
306            "array" => Ok(ParamType::Array),
307            other => Err(E::unknown_variant(
308                other,
309                &["string", "number", "boolean", "array"],
310            )),
311        }
312    }
313
314    // Arrays => expect exactly 2 items, one must be "null"
315    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
316    where
317        A: SeqAccess<'de>,
318    {
319        let mut items = Vec::new();
320        while let Some(item) = seq.next_element::<String>()? {
321            items.push(item);
322        }
323
324        // Must have exactly 2 elements
325        if items.len() != 2 {
326            return Err(A::Error::invalid_length(items.len(), &"2"));
327        }
328
329        let mut first = items[0].as_str();
330        let mut second = items[1].as_str();
331
332        // If the first is "null", swap so second is "null" and first is the real type
333        if first == "null" {
334            std::mem::swap(&mut first, &mut second);
335        }
336
337        // Now 'second' must be "null".
338        if second != "null" {
339            return Err(A::Error::invalid_value(
340                Unexpected::Str(second),
341                &"expected exactly one 'null' in [<type>, 'null']",
342            ));
343        }
344
345        // 'first' must be a known primitive
346        let inner = match first {
347            "string" => ParamType::String,
348            "number" => ParamType::Number,
349            "boolean" => ParamType::Boolean,
350            "array" => ParamType::Array,
351            other => {
352                return Err(A::Error::unknown_variant(
353                    other,
354                    &["string", "number", "boolean", "array", "null"],
355                ));
356            }
357        };
358
359        Ok(ParamType::Nullable(Box::new(inner)))
360    }
361}
362/// Parameters for tools
363#[derive(Clone, Debug, Hash, Eq, PartialEq, Builder, Serialize, Deserialize)]
364#[builder(setter(into))]
365#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
366pub struct ParamSpec {
367    /// Name of the parameter
368    pub name: String,
369    /// Description of the parameter
370    pub description: String,
371    /// Json spec type of the parameter
372    #[serde(rename = "type")]
373    #[builder(default)]
374    pub ty: ParamType,
375    /// Whether the parameter is required
376    #[builder(default = true)]
377    pub required: bool,
378}
379
380impl ParamSpec {
381    pub fn builder() -> ParamSpecBuilder {
382        ParamSpecBuilder::default()
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389
390    #[test]
391    fn test_serialize_param_type() {
392        let param = ParamType::Nullable(Box::new(ParamType::String));
393        let serialized = serde_json::to_string(&param).unwrap();
394        assert_eq!(serialized, r#"["string","null"]"#);
395
396        let deserialized: ParamType = serde_json::from_str(&serialized).unwrap();
397        assert_eq!(param, deserialized);
398    }
399
400    #[test]
401    fn test_deserialize_param_type() {
402        let serialized = r#"["string","null"]"#;
403        let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
404        assert_eq!(
405            deserialized,
406            ParamType::Nullable(Box::new(ParamType::String))
407        );
408
409        let serialized = r#""string""#;
410        let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
411        assert_eq!(deserialized, ParamType::String);
412    }
413}