1use std::borrow::Cow;
2use std::fmt;
3
4use derive_builder::Builder;
5use serde::de::{Deserializer, Error as DeError, SeqAccess, Unexpected, Visitor};
6use serde::ser::{Error as SerError, SerializeSeq, Serializer};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, strum_macros::EnumIs)]
11#[non_exhaustive]
12pub enum ToolOutput {
13 Text(String),
15
16 FeedbackRequired(Option<serde_json::Value>),
18
19 Fail(String),
21
22 Stop(Option<Cow<'static, str>>),
24
25 AgentFailed(Option<Cow<'static, str>>),
27}
28
29impl ToolOutput {
30 pub fn text(text: impl Into<String>) -> Self {
31 ToolOutput::Text(text.into())
32 }
33
34 pub fn feedback_required(feedback: Option<serde_json::Value>) -> Self {
35 ToolOutput::FeedbackRequired(feedback)
36 }
37
38 pub fn stop() -> Self {
39 ToolOutput::Stop(None)
40 }
41
42 pub fn stop_with_args(output: impl Into<Cow<'static, str>>) -> Self {
43 ToolOutput::Stop(Some(output.into()))
44 }
45
46 pub fn agent_failed(output: impl Into<Cow<'static, str>>) -> Self {
47 ToolOutput::AgentFailed(Some(output.into()))
48 }
49
50 pub fn fail(text: impl Into<String>) -> Self {
51 ToolOutput::Fail(text.into())
52 }
53
54 pub fn content(&self) -> Option<&str> {
55 match self {
56 ToolOutput::Fail(s) | ToolOutput::Text(s) => Some(s),
57 _ => None,
58 }
59 }
60
61 pub fn as_text(&self) -> Option<&str> {
63 match self {
64 ToolOutput::Text(s) => Some(s),
65 _ => None,
66 }
67 }
68
69 pub fn as_fail(&self) -> Option<&str> {
71 match self {
72 ToolOutput::Fail(s) => Some(s),
73 _ => None,
74 }
75 }
76
77 pub fn as_stop(&self) -> Option<&str> {
79 match self {
80 ToolOutput::Stop(args) => args.as_deref(),
81 _ => None,
82 }
83 }
84
85 pub fn as_agent_failed(&self) -> Option<&str> {
87 match self {
88 ToolOutput::AgentFailed(args) => args.as_deref(),
89 _ => None,
90 }
91 }
92
93 pub fn as_feedback_required(&self) -> Option<&serde_json::Value> {
95 match self {
96 ToolOutput::FeedbackRequired(args) => args.as_ref(),
97 _ => None,
98 }
99 }
100}
101
102impl<S: AsRef<str>> From<S> for ToolOutput {
103 fn from(value: S) -> Self {
104 ToolOutput::Text(value.as_ref().to_string())
105 }
106}
107impl std::fmt::Display for ToolOutput {
108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109 match self {
110 ToolOutput::Text(value) => write!(f, "{value}"),
111 ToolOutput::Fail(value) => write!(f, "Tool call failed: {value}"),
112 ToolOutput::Stop(args) => write!(f, "Stop {}", args.as_deref().unwrap_or_default()),
113 ToolOutput::FeedbackRequired(_) => {
114 write!(f, "Feedback required")
115 }
116 ToolOutput::AgentFailed(args) => write!(
117 f,
118 "Agent failed with output: {}",
119 args.as_deref().unwrap_or_default()
120 ),
121 }
122 }
123}
124
125#[derive(Clone, Debug, Builder, PartialEq, Serialize, Deserialize, Eq)]
127#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
128#[builder(setter(into, strip_option))]
129pub struct ToolCall {
130 id: String,
131 name: String,
132 #[builder(default)]
133 args: Option<String>,
134}
135
136impl std::hash::Hash for ToolCall {
138 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
139 self.name.hash(state);
140 self.args.hash(state);
141 }
142}
143
144impl std::fmt::Display for ToolCall {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 write!(
147 f,
148 "{id}#{name} {args}",
149 id = self.id,
150 name = self.name,
151 args = self.args.as_deref().unwrap_or("")
152 )
153 }
154}
155
156impl ToolCall {
157 pub fn builder() -> ToolCallBuilder {
158 ToolCallBuilder::default()
159 }
160
161 pub fn id(&self) -> &str {
162 &self.id
163 }
164
165 pub fn name(&self) -> &str {
166 &self.name
167 }
168
169 pub fn args(&self) -> Option<&str> {
170 self.args.as_deref()
171 }
172
173 pub fn with_args(&mut self, args: Option<String>) {
174 self.args = args;
175 }
176}
177
178impl ToolCallBuilder {
179 pub fn maybe_args<T: Into<Option<String>>>(&mut self, args: T) -> &mut Self {
180 self.args = Some(args.into());
181 self
182 }
183
184 pub fn maybe_id<T: Into<Option<String>>>(&mut self, id: T) -> &mut Self {
185 self.id = id.into();
186 self
187 }
188
189 pub fn maybe_name<T: Into<Option<String>>>(&mut self, name: T) -> &mut Self {
190 self.name = name.into();
191 self
192 }
193}
194
195#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, Builder, Serialize, Deserialize)]
199#[builder(setter(into), derive(Debug, Serialize, Deserialize))]
200#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
201pub struct ToolSpec {
202 pub name: String,
204 pub description: String,
206
207 #[builder(default)]
208 pub parameters: Vec<ParamSpec>,
210}
211
212impl ToolSpec {
213 pub fn builder() -> ToolSpecBuilder {
214 ToolSpecBuilder::default()
215 }
216}
217
218#[derive(Clone, Debug, Hash, Eq, PartialEq, Default, strum_macros::AsRefStr)]
219#[strum(serialize_all = "camelCase")]
220#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
221pub enum ParamType {
222 #[default]
223 String,
224 Number,
225 Boolean,
226 Array,
227 Nullable(Box<ParamType>),
228}
229
230pub enum InnerParamType {
231 String,
232 Number,
233 Boolean,
234 Array,
235}
236
237impl Serialize for ParamType {
238 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
239 where
240 S: Serializer,
241 {
242 match self {
243 ParamType::String => serializer.serialize_str("string"),
245 ParamType::Number => serializer.serialize_str("number"),
246 ParamType::Boolean => serializer.serialize_str("boolean"),
247 ParamType::Array => serializer.serialize_str("array"),
248
249 ParamType::Nullable(inner) => {
251 if let ParamType::Nullable(_) = inner.as_ref() {
253 return Err(serde::ser::Error::custom("Nested Nullable not supported"));
254 }
255 let mut seq = serializer.serialize_seq(Some(2))?;
257 seq.serialize_element(&primitive_variant_str(inner).map_err(S::Error::custom)?)?;
258 seq.serialize_element("null")?;
259 seq.end()
260 }
261 }
262 }
263}
264
265fn primitive_variant_str(pt: &ParamType) -> Result<&'static str, &'static str> {
266 match pt {
267 ParamType::String => Ok("string"),
268 ParamType::Number => Ok("number"),
269 ParamType::Boolean => Ok("boolean"),
270 ParamType::Array => Ok("array"),
271 ParamType::Nullable(_) => Err("Nested Nullable found"),
272 }
273}
274
275impl<'de> Deserialize<'de> for ParamType {
276 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
277 where
278 D: Deserializer<'de>,
279 {
280 deserializer.deserialize_any(ParamTypeVisitor)
281 }
282}
283
284struct ParamTypeVisitor;
285
286impl<'de> Visitor<'de> for ParamTypeVisitor {
287 type Value = ParamType;
288
289 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
290 write!(
291 formatter,
292 "a string (e.g. \"string\", \"number\") \
293 or a 2-element array [<type>, \"null\"]"
294 )
295 }
296
297 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
299 where
300 E: DeError,
301 {
302 match value {
303 "string" => Ok(ParamType::String),
304 "number" => Ok(ParamType::Number),
305 "boolean" => Ok(ParamType::Boolean),
306 "array" => Ok(ParamType::Array),
307 other => Err(E::unknown_variant(
308 other,
309 &["string", "number", "boolean", "array"],
310 )),
311 }
312 }
313
314 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
316 where
317 A: SeqAccess<'de>,
318 {
319 let mut items = Vec::new();
320 while let Some(item) = seq.next_element::<String>()? {
321 items.push(item);
322 }
323
324 if items.len() != 2 {
326 return Err(A::Error::invalid_length(items.len(), &"2"));
327 }
328
329 let mut first = items[0].as_str();
330 let mut second = items[1].as_str();
331
332 if first == "null" {
334 std::mem::swap(&mut first, &mut second);
335 }
336
337 if second != "null" {
339 return Err(A::Error::invalid_value(
340 Unexpected::Str(second),
341 &"expected exactly one 'null' in [<type>, 'null']",
342 ));
343 }
344
345 let inner = match first {
347 "string" => ParamType::String,
348 "number" => ParamType::Number,
349 "boolean" => ParamType::Boolean,
350 "array" => ParamType::Array,
351 other => {
352 return Err(A::Error::unknown_variant(
353 other,
354 &["string", "number", "boolean", "array", "null"],
355 ));
356 }
357 };
358
359 Ok(ParamType::Nullable(Box::new(inner)))
360 }
361}
362#[derive(Clone, Debug, Hash, Eq, PartialEq, Builder, Serialize, Deserialize)]
364#[builder(setter(into))]
365#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
366pub struct ParamSpec {
367 pub name: String,
369 pub description: String,
371 #[serde(rename = "type")]
373 #[builder(default)]
374 pub ty: ParamType,
375 #[builder(default = true)]
377 pub required: bool,
378}
379
380impl ParamSpec {
381 pub fn builder() -> ParamSpecBuilder {
382 ParamSpecBuilder::default()
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 #[test]
391 fn test_serialize_param_type() {
392 let param = ParamType::Nullable(Box::new(ParamType::String));
393 let serialized = serde_json::to_string(¶m).unwrap();
394 assert_eq!(serialized, r#"["string","null"]"#);
395
396 let deserialized: ParamType = serde_json::from_str(&serialized).unwrap();
397 assert_eq!(param, deserialized);
398 }
399
400 #[test]
401 fn test_deserialize_param_type() {
402 let serialized = r#"["string","null"]"#;
403 let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
404 assert_eq!(
405 deserialized,
406 ParamType::Nullable(Box::new(ParamType::String))
407 );
408
409 let serialized = r#""string""#;
410 let deserialized: ParamType = serde_json::from_str(serialized).unwrap();
411 assert_eq!(deserialized, ParamType::String);
412 }
413}