swiftide_core/chat_completion/
tools.rs1use std::fmt;
2
3use derive_builder::Builder;
4use serde::de::{Deserializer, Error as DeError, SeqAccess, Unexpected, Visitor};
5use serde::ser::{Error as SerError, SerializeSeq, Serializer};
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10#[non_exhaustive]
11pub enum ToolOutput {
12 Text(String),
14
15 Fail(String),
17 Stop,
19}
20
21impl ToolOutput {
22 pub fn content(&self) -> Option<&str> {
23 match self {
24 ToolOutput::Fail(s) | ToolOutput::Text(s) => Some(s),
25 _ => None,
26 }
27 }
28}
29
30impl<T: AsRef<str>> From<T> for ToolOutput {
31 fn from(s: T) -> Self {
32 ToolOutput::Text(s.as_ref().to_string())
33 }
34}
35
36impl std::fmt::Display for ToolOutput {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match self {
39 ToolOutput::Text(value) => write!(f, "{value}"),
40 ToolOutput::Fail(value) => write!(f, "Tool call failed: {value}"),
41 ToolOutput::Stop => write!(f, "Stop"),
42 }
43 }
44}
45
46#[derive(Clone, Debug, Builder, PartialEq, Serialize, Deserialize)]
48#[builder(setter(into, strip_option))]
49pub struct ToolCall {
50 id: String,
51 name: String,
52 #[builder(default)]
53 args: Option<String>,
54}
55
56impl std::hash::Hash for &ToolCall {
58 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
59 self.name.hash(state);
60 self.args.hash(state);
61 }
62}
63
64impl std::fmt::Display for ToolCall {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 write!(
67 f,
68 "{id}#{name} {args}",
69 id = self.id,
70 name = self.name,
71 args = self.args.as_deref().unwrap_or("")
72 )
73 }
74}
75
76impl ToolCall {
77 pub fn builder() -> ToolCallBuilder {
78 ToolCallBuilder::default()
79 }
80
81 pub fn id(&self) -> &str {
82 &self.id
83 }
84
85 pub fn name(&self) -> &str {
86 &self.name
87 }
88
89 pub fn args(&self) -> Option<&str> {
90 self.args.as_deref()
91 }
92}
93
94impl ToolCallBuilder {
95 pub fn maybe_args<T: Into<Option<String>>>(&mut self, args: T) -> &mut Self {
96 self.args = Some(args.into());
97 self
98 }
99
100 pub fn maybe_id<T: Into<Option<String>>>(&mut self, id: T) -> &mut Self {
101 self.id = id.into();
102 self
103 }
104
105 pub fn maybe_name<T: Into<Option<String>>>(&mut self, name: T) -> &mut Self {
106 self.name = name.into();
107 self
108 }
109}
110
111#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, Builder)]
115#[builder(setter(into))]
116pub struct ToolSpec {
117 pub name: String,
119 pub description: String,
121
122 #[builder(default)]
123 pub parameters: Vec<ParamSpec>,
125}
126
127impl ToolSpec {
128 pub fn builder() -> ToolSpecBuilder {
129 ToolSpecBuilder::default()
130 }
131}
132
133#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, strum_macros::AsRefStr)]
134#[strum(serialize_all = "camelCase")]
135pub enum ParamType {
136 #[default]
137 String,
138 Number,
139 Boolean,
140 Array,
141 Nullable(Box<ParamType>),
142}
143
144pub enum InnerParamType {
145 String,
146 Number,
147 Boolean,
148 Array,
149}
150
151impl Serialize for ParamType {
152 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
153 where
154 S: Serializer,
155 {
156 match self {
157 ParamType::String => serializer.serialize_str("string"),
159 ParamType::Number => serializer.serialize_str("number"),
160 ParamType::Boolean => serializer.serialize_str("boolean"),
161 ParamType::Array => serializer.serialize_str("array"),
162
163 ParamType::Nullable(inner) => {
165 if let ParamType::Nullable(_) = inner.as_ref() {
167 return Err(serde::ser::Error::custom("Nested Nullable not supported"));
168 }
169 let mut seq = serializer.serialize_seq(Some(2))?;
171 seq.serialize_element(&primitive_variant_str(inner).map_err(S::Error::custom)?)?;
172 seq.serialize_element("null")?;
173 seq.end()
174 }
175 }
176 }
177}
178
179fn primitive_variant_str(pt: &ParamType) -> Result<&'static str, &'static str> {
180 match pt {
181 ParamType::String => Ok("string"),
182 ParamType::Number => Ok("number"),
183 ParamType::Boolean => Ok("boolean"),
184 ParamType::Array => Ok("array"),
185 ParamType::Nullable(_) => Err("Nested Nullable found"),
186 }
187}
188
189impl<'de> Deserialize<'de> for ParamType {
190 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
191 where
192 D: Deserializer<'de>,
193 {
194 deserializer.deserialize_any(ParamTypeVisitor)
195 }
196}
197
198struct ParamTypeVisitor;
199
200impl<'de> Visitor<'de> for ParamTypeVisitor {
201 type Value = ParamType;
202
203 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
204 write!(
205 formatter,
206 "a string (e.g. \"string\", \"number\") \
207 or a 2-element array [<type>, \"null\"]"
208 )
209 }
210
211 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
213 where
214 E: DeError,
215 {
216 match value {
217 "string" => Ok(ParamType::String),
218 "number" => Ok(ParamType::Number),
219 "boolean" => Ok(ParamType::Boolean),
220 "array" => Ok(ParamType::Array),
221 other => Err(E::unknown_variant(
222 other,
223 &["string", "number", "boolean", "array"],
224 )),
225 }
226 }
227
228 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
230 where
231 A: SeqAccess<'de>,
232 {
233 let mut items = Vec::new();
234 while let Some(item) = seq.next_element::<String>()? {
235 items.push(item);
236 }
237
238 if items.len() != 2 {
240 return Err(A::Error::invalid_length(items.len(), &"2"));
241 }
242
243 let mut first = items[0].as_str();
244 let mut second = items[1].as_str();
245
246 if first == "null" {
248 std::mem::swap(&mut first, &mut second);
249 }
250
251 if second != "null" {
253 return Err(A::Error::invalid_value(
254 Unexpected::Str(second),
255 &"expected exactly one 'null' in [<type>, 'null']",
256 ));
257 }
258
259 let inner = match first {
261 "string" => ParamType::String,
262 "number" => ParamType::Number,
263 "boolean" => ParamType::Boolean,
264 "array" => ParamType::Array,
265 other => {
266 return Err(A::Error::unknown_variant(
267 other,
268 &["string", "number", "boolean", "array", "null"],
269 ));
270 }
271 };
272
273 Ok(ParamType::Nullable(Box::new(inner)))
274 }
275}
276#[derive(Clone, Debug, Hash, Eq, PartialEq, Builder)]
278#[builder(setter(into))]
279pub struct ParamSpec {
280 pub name: String,
282 pub description: String,
284 #[builder(default)]
286 pub ty: ParamType,
287 #[builder(default = true)]
289 pub required: bool,
290}
291
292impl ParamSpec {
293 pub fn builder() -> ParamSpecBuilder {
294 ParamSpecBuilder::default()
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn test_serialize_param_type() {
304 let param = ParamType::Nullable(Box::new(ParamType::String));
305 let serialized = serde_json::to_string(¶m).unwrap();
306 assert_eq!(serialized, r#"["string","null"]"#);
307
308 let deserialized: ParamType = serde_json::from_str(&serialized).unwrap();
309 assert_eq!(param, deserialized);
310 }
311
312 #[test]
313 fn test_deserialize_param_type() {
314 let serialized = r#"["string","null"]"#;
315 let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
316 assert_eq!(
317 deserialized,
318 ParamType::Nullable(Box::new(ParamType::String))
319 );
320
321 let serialized = r#""string""#;
322 let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
323 assert_eq!(deserialized, ParamType::String);
324 }
325}