swiftide_core/chat_completion/
tools.rs1use std::fmt;
2
3use derive_builder::Builder;
4use serde::de::{Deserializer, Error as DeError, SeqAccess, Unexpected, Visitor};
5use serde::ser::{Error as SerError, SerializeSeq, Serializer};
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, strum_macros::EnumIs)]
10#[non_exhaustive]
11pub enum ToolOutput {
12 Text(String),
14
15 FeedbackRequired(Option<serde_json::Value>),
17
18 Fail(String),
20 Stop,
22}
23
24impl ToolOutput {
25 pub fn text(text: impl Into<String>) -> Self {
26 ToolOutput::Text(text.into())
27 }
28
29 pub fn feedback_required(feedback: Option<serde_json::Value>) -> Self {
30 ToolOutput::FeedbackRequired(feedback)
31 }
32
33 pub fn stop() -> Self {
34 ToolOutput::Stop
35 }
36
37 pub fn fail(text: impl Into<String>) -> Self {
38 ToolOutput::Fail(text.into())
39 }
40
41 pub fn content(&self) -> Option<&str> {
42 match self {
43 ToolOutput::Fail(s) | ToolOutput::Text(s) => Some(s),
44 _ => None,
45 }
46 }
47}
48
49impl<T: AsRef<str>> From<T> for ToolOutput {
50 fn from(s: T) -> Self {
51 ToolOutput::Text(s.as_ref().to_string())
52 }
53}
54
55impl std::fmt::Display for ToolOutput {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 match self {
58 ToolOutput::Text(value) => write!(f, "{value}"),
59 ToolOutput::Fail(value) => write!(f, "Tool call failed: {value}"),
60 ToolOutput::Stop => write!(f, "Stop"),
61 ToolOutput::FeedbackRequired(_) => {
62 write!(f, "Feedback required")
63 }
64 }
65 }
66}
67
68#[derive(Clone, Debug, Builder, PartialEq, Serialize, Deserialize, Eq)]
70#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
71#[builder(setter(into, strip_option))]
72pub struct ToolCall {
73 id: String,
74 name: String,
75 #[builder(default)]
76 args: Option<String>,
77}
78
79impl std::hash::Hash for ToolCall {
81 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
82 self.name.hash(state);
83 self.args.hash(state);
84 }
85}
86
87impl std::fmt::Display for ToolCall {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 write!(
90 f,
91 "{id}#{name} {args}",
92 id = self.id,
93 name = self.name,
94 args = self.args.as_deref().unwrap_or("")
95 )
96 }
97}
98
99impl ToolCall {
100 pub fn builder() -> ToolCallBuilder {
101 ToolCallBuilder::default()
102 }
103
104 pub fn id(&self) -> &str {
105 &self.id
106 }
107
108 pub fn name(&self) -> &str {
109 &self.name
110 }
111
112 pub fn args(&self) -> Option<&str> {
113 self.args.as_deref()
114 }
115
116 pub fn with_args(&mut self, args: Option<String>) {
117 self.args = args;
118 }
119}
120
121impl ToolCallBuilder {
122 pub fn maybe_args<T: Into<Option<String>>>(&mut self, args: T) -> &mut Self {
123 self.args = Some(args.into());
124 self
125 }
126
127 pub fn maybe_id<T: Into<Option<String>>>(&mut self, id: T) -> &mut Self {
128 self.id = id.into();
129 self
130 }
131
132 pub fn maybe_name<T: Into<Option<String>>>(&mut self, name: T) -> &mut Self {
133 self.name = name.into();
134 self
135 }
136}
137
138#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, Builder, Serialize, Deserialize)]
142#[builder(setter(into), derive(Debug, Serialize, Deserialize))]
143#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
144pub struct ToolSpec {
145 pub name: String,
147 pub description: String,
149
150 #[builder(default)]
151 pub parameters: Vec<ParamSpec>,
153}
154
155impl ToolSpec {
156 pub fn builder() -> ToolSpecBuilder {
157 ToolSpecBuilder::default()
158 }
159}
160
161#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, strum_macros::AsRefStr)]
162#[strum(serialize_all = "camelCase")]
163#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
164pub enum ParamType {
165 #[default]
166 String,
167 Number,
168 Boolean,
169 Array,
170 Nullable(Box<ParamType>),
171}
172
173pub enum InnerParamType {
174 String,
175 Number,
176 Boolean,
177 Array,
178}
179
180impl Serialize for ParamType {
181 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
182 where
183 S: Serializer,
184 {
185 match self {
186 ParamType::String => serializer.serialize_str("string"),
188 ParamType::Number => serializer.serialize_str("number"),
189 ParamType::Boolean => serializer.serialize_str("boolean"),
190 ParamType::Array => serializer.serialize_str("array"),
191
192 ParamType::Nullable(inner) => {
194 if let ParamType::Nullable(_) = inner.as_ref() {
196 return Err(serde::ser::Error::custom("Nested Nullable not supported"));
197 }
198 let mut seq = serializer.serialize_seq(Some(2))?;
200 seq.serialize_element(&primitive_variant_str(inner).map_err(S::Error::custom)?)?;
201 seq.serialize_element("null")?;
202 seq.end()
203 }
204 }
205 }
206}
207
208fn primitive_variant_str(pt: &ParamType) -> Result<&'static str, &'static str> {
209 match pt {
210 ParamType::String => Ok("string"),
211 ParamType::Number => Ok("number"),
212 ParamType::Boolean => Ok("boolean"),
213 ParamType::Array => Ok("array"),
214 ParamType::Nullable(_) => Err("Nested Nullable found"),
215 }
216}
217
218impl<'de> Deserialize<'de> for ParamType {
219 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
220 where
221 D: Deserializer<'de>,
222 {
223 deserializer.deserialize_any(ParamTypeVisitor)
224 }
225}
226
227struct ParamTypeVisitor;
228
229impl<'de> Visitor<'de> for ParamTypeVisitor {
230 type Value = ParamType;
231
232 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
233 write!(
234 formatter,
235 "a string (e.g. \"string\", \"number\") \
236 or a 2-element array [<type>, \"null\"]"
237 )
238 }
239
240 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
242 where
243 E: DeError,
244 {
245 match value {
246 "string" => Ok(ParamType::String),
247 "number" => Ok(ParamType::Number),
248 "boolean" => Ok(ParamType::Boolean),
249 "array" => Ok(ParamType::Array),
250 other => Err(E::unknown_variant(
251 other,
252 &["string", "number", "boolean", "array"],
253 )),
254 }
255 }
256
257 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
259 where
260 A: SeqAccess<'de>,
261 {
262 let mut items = Vec::new();
263 while let Some(item) = seq.next_element::<String>()? {
264 items.push(item);
265 }
266
267 if items.len() != 2 {
269 return Err(A::Error::invalid_length(items.len(), &"2"));
270 }
271
272 let mut first = items[0].as_str();
273 let mut second = items[1].as_str();
274
275 if first == "null" {
277 std::mem::swap(&mut first, &mut second);
278 }
279
280 if second != "null" {
282 return Err(A::Error::invalid_value(
283 Unexpected::Str(second),
284 &"expected exactly one 'null' in [<type>, 'null']",
285 ));
286 }
287
288 let inner = match first {
290 "string" => ParamType::String,
291 "number" => ParamType::Number,
292 "boolean" => ParamType::Boolean,
293 "array" => ParamType::Array,
294 other => {
295 return Err(A::Error::unknown_variant(
296 other,
297 &["string", "number", "boolean", "array", "null"],
298 ));
299 }
300 };
301
302 Ok(ParamType::Nullable(Box::new(inner)))
303 }
304}
305#[derive(Clone, Debug, Hash, Eq, PartialEq, Builder, Serialize, Deserialize)]
307#[builder(setter(into))]
308#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
309pub struct ParamSpec {
310 pub name: String,
312 pub description: String,
314 #[serde(rename = "type")]
316 #[builder(default)]
317 pub ty: ParamType,
318 #[builder(default = true)]
320 pub required: bool,
321}
322
323impl ParamSpec {
324 pub fn builder() -> ParamSpecBuilder {
325 ParamSpecBuilder::default()
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[test]
334 fn test_serialize_param_type() {
335 let param = ParamType::Nullable(Box::new(ParamType::String));
336 let serialized = serde_json::to_string(¶m).unwrap();
337 assert_eq!(serialized, r#"["string","null"]"#);
338
339 let deserialized: ParamType = serde_json::from_str(&serialized).unwrap();
340 assert_eq!(param, deserialized);
341 }
342
343 #[test]
344 fn test_deserialize_param_type() {
345 let serialized = r#"["string","null"]"#;
346 let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
347 assert_eq!(
348 deserialized,
349 ParamType::Nullable(Box::new(ParamType::String))
350 );
351
352 let serialized = r#""string""#;
353 let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
354 assert_eq!(deserialized, ParamType::String);
355 }
356}