swiftide_core/chat_completion/
tools.rs1use std::borrow::Cow;
2use std::fmt;
3
4use derive_builder::Builder;
5use serde::de::{Deserializer, Error as DeError, SeqAccess, Unexpected, Visitor};
6use serde::ser::{Error as SerError, SerializeSeq, Serializer};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, strum_macros::EnumIs)]
11#[non_exhaustive]
12pub enum ToolOutput {
13 Text(String),
15
16 FeedbackRequired(Option<serde_json::Value>),
18
19 Fail(String),
21
22 Stop(Option<Cow<'static, str>>),
24
25 AgentFailed(Option<Cow<'static, str>>),
27}
28
29impl ToolOutput {
30 pub fn text(text: impl Into<String>) -> Self {
31 ToolOutput::Text(text.into())
32 }
33
34 pub fn feedback_required(feedback: Option<serde_json::Value>) -> Self {
35 ToolOutput::FeedbackRequired(feedback)
36 }
37
38 pub fn stop() -> Self {
39 ToolOutput::Stop(None)
40 }
41
42 pub fn stop_with_args(output: impl Into<Cow<'static, str>>) -> Self {
43 ToolOutput::Stop(Some(output.into()))
44 }
45
46 pub fn agent_failed(output: impl Into<Cow<'static, str>>) -> Self {
47 ToolOutput::AgentFailed(Some(output.into()))
48 }
49
50 pub fn fail(text: impl Into<String>) -> Self {
51 ToolOutput::Fail(text.into())
52 }
53
54 pub fn content(&self) -> Option<&str> {
55 match self {
56 ToolOutput::Fail(s) | ToolOutput::Text(s) => Some(s),
57 _ => None,
58 }
59 }
60}
61
62impl<S: AsRef<str>> From<S> for ToolOutput {
63 fn from(value: S) -> Self {
64 ToolOutput::Text(value.as_ref().to_string())
65 }
66}
67impl std::fmt::Display for ToolOutput {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 match self {
70 ToolOutput::Text(value) => write!(f, "{value}"),
71 ToolOutput::Fail(value) => write!(f, "Tool call failed: {value}"),
72 ToolOutput::Stop(args) => write!(f, "Stop {}", args.as_deref().unwrap_or_default()),
73 ToolOutput::FeedbackRequired(_) => {
74 write!(f, "Feedback required")
75 }
76 ToolOutput::AgentFailed(args) => write!(
77 f,
78 "Agent failed with output: {}",
79 args.as_deref().unwrap_or_default()
80 ),
81 }
82 }
83}
84
85#[derive(Clone, Debug, Builder, PartialEq, Serialize, Deserialize, Eq)]
87#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
88#[builder(setter(into, strip_option))]
89pub struct ToolCall {
90 id: String,
91 name: String,
92 #[builder(default)]
93 args: Option<String>,
94}
95
96impl std::hash::Hash for ToolCall {
98 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
99 self.name.hash(state);
100 self.args.hash(state);
101 }
102}
103
104impl std::fmt::Display for ToolCall {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 write!(
107 f,
108 "{id}#{name} {args}",
109 id = self.id,
110 name = self.name,
111 args = self.args.as_deref().unwrap_or("")
112 )
113 }
114}
115
116impl ToolCall {
117 pub fn builder() -> ToolCallBuilder {
118 ToolCallBuilder::default()
119 }
120
121 pub fn id(&self) -> &str {
122 &self.id
123 }
124
125 pub fn name(&self) -> &str {
126 &self.name
127 }
128
129 pub fn args(&self) -> Option<&str> {
130 self.args.as_deref()
131 }
132
133 pub fn with_args(&mut self, args: Option<String>) {
134 self.args = args;
135 }
136}
137
138impl ToolCallBuilder {
139 pub fn maybe_args<T: Into<Option<String>>>(&mut self, args: T) -> &mut Self {
140 self.args = Some(args.into());
141 self
142 }
143
144 pub fn maybe_id<T: Into<Option<String>>>(&mut self, id: T) -> &mut Self {
145 self.id = id.into();
146 self
147 }
148
149 pub fn maybe_name<T: Into<Option<String>>>(&mut self, name: T) -> &mut Self {
150 self.name = name.into();
151 self
152 }
153}
154
155#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, Builder, Serialize, Deserialize)]
159#[builder(setter(into), derive(Debug, Serialize, Deserialize))]
160#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
161pub struct ToolSpec {
162 pub name: String,
164 pub description: String,
166
167 #[builder(default)]
168 pub parameters: Vec<ParamSpec>,
170}
171
172impl ToolSpec {
173 pub fn builder() -> ToolSpecBuilder {
174 ToolSpecBuilder::default()
175 }
176}
177
178#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, strum_macros::AsRefStr)]
179#[strum(serialize_all = "camelCase")]
180#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
181pub enum ParamType {
182 #[default]
183 String,
184 Number,
185 Boolean,
186 Array,
187 Nullable(Box<ParamType>),
188}
189
190pub enum InnerParamType {
191 String,
192 Number,
193 Boolean,
194 Array,
195}
196
197impl Serialize for ParamType {
198 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
199 where
200 S: Serializer,
201 {
202 match self {
203 ParamType::String => serializer.serialize_str("string"),
205 ParamType::Number => serializer.serialize_str("number"),
206 ParamType::Boolean => serializer.serialize_str("boolean"),
207 ParamType::Array => serializer.serialize_str("array"),
208
209 ParamType::Nullable(inner) => {
211 if let ParamType::Nullable(_) = inner.as_ref() {
213 return Err(serde::ser::Error::custom("Nested Nullable not supported"));
214 }
215 let mut seq = serializer.serialize_seq(Some(2))?;
217 seq.serialize_element(&primitive_variant_str(inner).map_err(S::Error::custom)?)?;
218 seq.serialize_element("null")?;
219 seq.end()
220 }
221 }
222 }
223}
224
225fn primitive_variant_str(pt: &ParamType) -> Result<&'static str, &'static str> {
226 match pt {
227 ParamType::String => Ok("string"),
228 ParamType::Number => Ok("number"),
229 ParamType::Boolean => Ok("boolean"),
230 ParamType::Array => Ok("array"),
231 ParamType::Nullable(_) => Err("Nested Nullable found"),
232 }
233}
234
235impl<'de> Deserialize<'de> for ParamType {
236 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
237 where
238 D: Deserializer<'de>,
239 {
240 deserializer.deserialize_any(ParamTypeVisitor)
241 }
242}
243
244struct ParamTypeVisitor;
245
246impl<'de> Visitor<'de> for ParamTypeVisitor {
247 type Value = ParamType;
248
249 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
250 write!(
251 formatter,
252 "a string (e.g. \"string\", \"number\") \
253 or a 2-element array [<type>, \"null\"]"
254 )
255 }
256
257 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
259 where
260 E: DeError,
261 {
262 match value {
263 "string" => Ok(ParamType::String),
264 "number" => Ok(ParamType::Number),
265 "boolean" => Ok(ParamType::Boolean),
266 "array" => Ok(ParamType::Array),
267 other => Err(E::unknown_variant(
268 other,
269 &["string", "number", "boolean", "array"],
270 )),
271 }
272 }
273
274 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
276 where
277 A: SeqAccess<'de>,
278 {
279 let mut items = Vec::new();
280 while let Some(item) = seq.next_element::<String>()? {
281 items.push(item);
282 }
283
284 if items.len() != 2 {
286 return Err(A::Error::invalid_length(items.len(), &"2"));
287 }
288
289 let mut first = items[0].as_str();
290 let mut second = items[1].as_str();
291
292 if first == "null" {
294 std::mem::swap(&mut first, &mut second);
295 }
296
297 if second != "null" {
299 return Err(A::Error::invalid_value(
300 Unexpected::Str(second),
301 &"expected exactly one 'null' in [<type>, 'null']",
302 ));
303 }
304
305 let inner = match first {
307 "string" => ParamType::String,
308 "number" => ParamType::Number,
309 "boolean" => ParamType::Boolean,
310 "array" => ParamType::Array,
311 other => {
312 return Err(A::Error::unknown_variant(
313 other,
314 &["string", "number", "boolean", "array", "null"],
315 ));
316 }
317 };
318
319 Ok(ParamType::Nullable(Box::new(inner)))
320 }
321}
322#[derive(Clone, Debug, Hash, Eq, PartialEq, Builder, Serialize, Deserialize)]
324#[builder(setter(into))]
325#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
326pub struct ParamSpec {
327 pub name: String,
329 pub description: String,
331 #[serde(rename = "type")]
333 #[builder(default)]
334 pub ty: ParamType,
335 #[builder(default = true)]
337 pub required: bool,
338}
339
340impl ParamSpec {
341 pub fn builder() -> ParamSpecBuilder {
342 ParamSpecBuilder::default()
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
351 fn test_serialize_param_type() {
352 let param = ParamType::Nullable(Box::new(ParamType::String));
353 let serialized = serde_json::to_string(¶m).unwrap();
354 assert_eq!(serialized, r#"["string","null"]"#);
355
356 let deserialized: ParamType = serde_json::from_str(&serialized).unwrap();
357 assert_eq!(param, deserialized);
358 }
359
360 #[test]
361 fn test_deserialize_param_type() {
362 let serialized = r#"["string","null"]"#;
363 let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
364 assert_eq!(
365 deserialized,
366 ParamType::Nullable(Box::new(ParamType::String))
367 );
368
369 let serialized = r#""string""#;
370 let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
371 assert_eq!(deserialized, ParamType::String);
372 }
373}