swiftide_core/chat_completion/
tools.rs1use std::fmt;
2
3use derive_builder::Builder;
4use serde::de::{Deserializer, Error as DeError, SeqAccess, Unexpected, Visitor};
5use serde::ser::{Error as SerError, SerializeSeq, Serializer};
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, strum_macros::EnumIs)]
10#[non_exhaustive]
11pub enum ToolOutput {
12 Text(String),
14
15 FeedbackRequired(Option<serde_json::Value>),
17
18 Fail(String),
20 Stop,
22}
23
24impl ToolOutput {
25 pub fn text(text: impl Into<String>) -> Self {
26 ToolOutput::Text(text.into())
27 }
28
29 pub fn feedback_required(feedback: Option<serde_json::Value>) -> Self {
30 ToolOutput::FeedbackRequired(feedback)
31 }
32
33 pub fn stop() -> Self {
34 ToolOutput::Stop
35 }
36
37 pub fn fail(text: impl Into<String>) -> Self {
38 ToolOutput::Fail(text.into())
39 }
40
41 pub fn content(&self) -> Option<&str> {
42 match self {
43 ToolOutput::Fail(s) | ToolOutput::Text(s) => Some(s),
44 _ => None,
45 }
46 }
47}
48
49impl<T: AsRef<str>> From<T> for ToolOutput {
50 fn from(s: T) -> Self {
51 ToolOutput::Text(s.as_ref().to_string())
52 }
53}
54
55impl std::fmt::Display for ToolOutput {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 match self {
58 ToolOutput::Text(value) => write!(f, "{value}"),
59 ToolOutput::Fail(value) => write!(f, "Tool call failed: {value}"),
60 ToolOutput::Stop => write!(f, "Stop"),
61 ToolOutput::FeedbackRequired(_) => {
62 write!(f, "Feedback required")
63 }
64 }
65 }
66}
67
68#[derive(Clone, Debug, Builder, PartialEq, Serialize, Deserialize, Eq)]
70#[builder(setter(into, strip_option))]
71pub struct ToolCall {
72 id: String,
73 name: String,
74 #[builder(default)]
75 args: Option<String>,
76}
77
78impl std::hash::Hash for ToolCall {
80 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
81 self.name.hash(state);
82 self.args.hash(state);
83 }
84}
85
86impl std::fmt::Display for ToolCall {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(
89 f,
90 "{id}#{name} {args}",
91 id = self.id,
92 name = self.name,
93 args = self.args.as_deref().unwrap_or("")
94 )
95 }
96}
97
98impl ToolCall {
99 pub fn builder() -> ToolCallBuilder {
100 ToolCallBuilder::default()
101 }
102
103 pub fn id(&self) -> &str {
104 &self.id
105 }
106
107 pub fn name(&self) -> &str {
108 &self.name
109 }
110
111 pub fn args(&self) -> Option<&str> {
112 self.args.as_deref()
113 }
114
115 pub fn with_args(&mut self, args: Option<String>) {
116 self.args = args;
117 }
118}
119
120impl ToolCallBuilder {
121 pub fn maybe_args<T: Into<Option<String>>>(&mut self, args: T) -> &mut Self {
122 self.args = Some(args.into());
123 self
124 }
125
126 pub fn maybe_id<T: Into<Option<String>>>(&mut self, id: T) -> &mut Self {
127 self.id = id.into();
128 self
129 }
130
131 pub fn maybe_name<T: Into<Option<String>>>(&mut self, name: T) -> &mut Self {
132 self.name = name.into();
133 self
134 }
135}
136
137#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, Builder)]
141#[builder(setter(into))]
142pub struct ToolSpec {
143 pub name: String,
145 pub description: String,
147
148 #[builder(default)]
149 pub parameters: Vec<ParamSpec>,
151}
152
153impl ToolSpec {
154 pub fn builder() -> ToolSpecBuilder {
155 ToolSpecBuilder::default()
156 }
157}
158
159#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, strum_macros::AsRefStr)]
160#[strum(serialize_all = "camelCase")]
161pub enum ParamType {
162 #[default]
163 String,
164 Number,
165 Boolean,
166 Array,
167 Nullable(Box<ParamType>),
168}
169
170pub enum InnerParamType {
171 String,
172 Number,
173 Boolean,
174 Array,
175}
176
177impl Serialize for ParamType {
178 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
179 where
180 S: Serializer,
181 {
182 match self {
183 ParamType::String => serializer.serialize_str("string"),
185 ParamType::Number => serializer.serialize_str("number"),
186 ParamType::Boolean => serializer.serialize_str("boolean"),
187 ParamType::Array => serializer.serialize_str("array"),
188
189 ParamType::Nullable(inner) => {
191 if let ParamType::Nullable(_) = inner.as_ref() {
193 return Err(serde::ser::Error::custom("Nested Nullable not supported"));
194 }
195 let mut seq = serializer.serialize_seq(Some(2))?;
197 seq.serialize_element(&primitive_variant_str(inner).map_err(S::Error::custom)?)?;
198 seq.serialize_element("null")?;
199 seq.end()
200 }
201 }
202 }
203}
204
205fn primitive_variant_str(pt: &ParamType) -> Result<&'static str, &'static str> {
206 match pt {
207 ParamType::String => Ok("string"),
208 ParamType::Number => Ok("number"),
209 ParamType::Boolean => Ok("boolean"),
210 ParamType::Array => Ok("array"),
211 ParamType::Nullable(_) => Err("Nested Nullable found"),
212 }
213}
214
215impl<'de> Deserialize<'de> for ParamType {
216 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
217 where
218 D: Deserializer<'de>,
219 {
220 deserializer.deserialize_any(ParamTypeVisitor)
221 }
222}
223
224struct ParamTypeVisitor;
225
226impl<'de> Visitor<'de> for ParamTypeVisitor {
227 type Value = ParamType;
228
229 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
230 write!(
231 formatter,
232 "a string (e.g. \"string\", \"number\") \
233 or a 2-element array [<type>, \"null\"]"
234 )
235 }
236
237 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
239 where
240 E: DeError,
241 {
242 match value {
243 "string" => Ok(ParamType::String),
244 "number" => Ok(ParamType::Number),
245 "boolean" => Ok(ParamType::Boolean),
246 "array" => Ok(ParamType::Array),
247 other => Err(E::unknown_variant(
248 other,
249 &["string", "number", "boolean", "array"],
250 )),
251 }
252 }
253
254 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
256 where
257 A: SeqAccess<'de>,
258 {
259 let mut items = Vec::new();
260 while let Some(item) = seq.next_element::<String>()? {
261 items.push(item);
262 }
263
264 if items.len() != 2 {
266 return Err(A::Error::invalid_length(items.len(), &"2"));
267 }
268
269 let mut first = items[0].as_str();
270 let mut second = items[1].as_str();
271
272 if first == "null" {
274 std::mem::swap(&mut first, &mut second);
275 }
276
277 if second != "null" {
279 return Err(A::Error::invalid_value(
280 Unexpected::Str(second),
281 &"expected exactly one 'null' in [<type>, 'null']",
282 ));
283 }
284
285 let inner = match first {
287 "string" => ParamType::String,
288 "number" => ParamType::Number,
289 "boolean" => ParamType::Boolean,
290 "array" => ParamType::Array,
291 other => {
292 return Err(A::Error::unknown_variant(
293 other,
294 &["string", "number", "boolean", "array", "null"],
295 ));
296 }
297 };
298
299 Ok(ParamType::Nullable(Box::new(inner)))
300 }
301}
302#[derive(Clone, Debug, Hash, Eq, PartialEq, Builder)]
304#[builder(setter(into))]
305pub struct ParamSpec {
306 pub name: String,
308 pub description: String,
310 #[builder(default)]
312 pub ty: ParamType,
313 #[builder(default = true)]
315 pub required: bool,
316}
317
318impl ParamSpec {
319 pub fn builder() -> ParamSpecBuilder {
320 ParamSpecBuilder::default()
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[test]
329 fn test_serialize_param_type() {
330 let param = ParamType::Nullable(Box::new(ParamType::String));
331 let serialized = serde_json::to_string(¶m).unwrap();
332 assert_eq!(serialized, r#"["string","null"]"#);
333
334 let deserialized: ParamType = serde_json::from_str(&serialized).unwrap();
335 assert_eq!(param, deserialized);
336 }
337
338 #[test]
339 fn test_deserialize_param_type() {
340 let serialized = r#"["string","null"]"#;
341 let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
342 assert_eq!(
343 deserialized,
344 ParamType::Nullable(Box::new(ParamType::String))
345 );
346
347 let serialized = r#""string""#;
348 let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
349 assert_eq!(deserialized, ParamType::String);
350 }
351}