swiftide_core/chat_completion/
tools.rs1use derive_builder::Builder;
2use schemars::Schema;
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, strum_macros::EnumIs)]
7#[non_exhaustive]
8pub enum ToolOutput {
9 Text(String),
11
12 FeedbackRequired(Option<serde_json::Value>),
14
15 Fail(String),
17
18 Stop(Option<serde_json::Value>),
20
21 AgentFailed(Option<serde_json::Value>),
23}
24
25impl ToolOutput {
26 pub fn text(text: impl Into<String>) -> Self {
27 ToolOutput::Text(text.into())
28 }
29
30 pub fn feedback_required(feedback: Option<serde_json::Value>) -> Self {
31 ToolOutput::FeedbackRequired(feedback)
32 }
33
34 pub fn stop() -> Self {
35 ToolOutput::Stop(None)
36 }
37
38 pub fn stop_with_args(output: impl Into<serde_json::Value>) -> Self {
39 ToolOutput::Stop(Some(output.into()))
40 }
41
42 pub fn agent_failed(output: impl Into<serde_json::Value>) -> Self {
43 ToolOutput::AgentFailed(Some(output.into()))
44 }
45
46 pub fn fail(text: impl Into<String>) -> Self {
47 ToolOutput::Fail(text.into())
48 }
49
50 pub fn content(&self) -> Option<&str> {
51 match self {
52 ToolOutput::Fail(s) | ToolOutput::Text(s) => Some(s),
53 _ => None,
54 }
55 }
56
57 pub fn as_text(&self) -> Option<&str> {
59 match self {
60 ToolOutput::Text(s) => Some(s),
61 _ => None,
62 }
63 }
64
65 pub fn as_fail(&self) -> Option<&str> {
67 match self {
68 ToolOutput::Fail(s) => Some(s),
69 _ => None,
70 }
71 }
72
73 pub fn as_stop(&self) -> Option<&serde_json::Value> {
75 match self {
76 ToolOutput::Stop(args) => args.as_ref(),
77 _ => None,
78 }
79 }
80
81 pub fn as_agent_failed(&self) -> Option<&serde_json::Value> {
83 match self {
84 ToolOutput::AgentFailed(args) => args.as_ref(),
85 _ => None,
86 }
87 }
88
89 pub fn as_feedback_required(&self) -> Option<&serde_json::Value> {
91 match self {
92 ToolOutput::FeedbackRequired(args) => args.as_ref(),
93 _ => None,
94 }
95 }
96}
97
98impl<S: AsRef<str>> From<S> for ToolOutput {
99 fn from(value: S) -> Self {
100 ToolOutput::Text(value.as_ref().to_string())
101 }
102}
103impl std::fmt::Display for ToolOutput {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 match self {
106 ToolOutput::Text(value) => write!(f, "{value}"),
107 ToolOutput::Fail(value) => write!(f, "Tool call failed: {value}"),
108 ToolOutput::Stop(args) => {
109 if let Some(value) = args {
110 write!(f, "Stop {value}")
111 } else {
112 write!(f, "Stop")
113 }
114 }
115 ToolOutput::FeedbackRequired(_) => {
116 write!(f, "Feedback required")
117 }
118 ToolOutput::AgentFailed(args) => write!(
119 f,
120 "Agent failed with output: {}",
121 args.as_ref().unwrap_or_default()
122 ),
123 }
124 }
125}
126
127#[derive(Clone, Debug, Builder, PartialEq, Serialize, Deserialize, Eq)]
129#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
130#[builder(setter(into, strip_option))]
131pub struct ToolCall {
132 id: String,
133 name: String,
134 #[builder(default)]
135 args: Option<String>,
136}
137
138impl std::hash::Hash for ToolCall {
140 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
141 self.name.hash(state);
142 self.args.hash(state);
143 }
144}
145
146impl std::fmt::Display for ToolCall {
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 write!(
149 f,
150 "{id}#{name} {args}",
151 id = self.id,
152 name = self.name,
153 args = self.args.as_deref().unwrap_or("")
154 )
155 }
156}
157
158impl ToolCall {
159 pub fn builder() -> ToolCallBuilder {
160 ToolCallBuilder::default()
161 }
162
163 pub fn id(&self) -> &str {
164 &self.id
165 }
166
167 pub fn name(&self) -> &str {
168 &self.name
169 }
170
171 pub fn args(&self) -> Option<&str> {
172 self.args.as_deref()
173 }
174
175 pub fn with_args(&mut self, args: Option<String>) {
176 self.args = args;
177 }
178}
179
180impl ToolCallBuilder {
181 pub fn maybe_args<T: Into<Option<String>>>(&mut self, args: T) -> &mut Self {
182 self.args = Some(args.into());
183 self
184 }
185
186 pub fn maybe_id<T: Into<Option<String>>>(&mut self, id: T) -> &mut Self {
187 self.id = id.into();
188 self
189 }
190
191 pub fn maybe_name<T: Into<Option<String>>>(&mut self, name: T) -> &mut Self {
192 self.name = name.into();
193 self
194 }
195}
196
197#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Builder, Default)]
201#[builder(setter(into), derive(Debug, Serialize, Deserialize))]
202#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
203#[serde(deny_unknown_fields)]
204pub struct ToolSpec {
205 pub name: String,
207 pub description: String,
209
210 #[builder(default, setter(strip_option))]
211 #[serde(skip_serializing_if = "Option::is_none")]
212 pub parameters_schema: Option<Schema>,
214}
215
216impl ToolSpec {
217 pub fn builder() -> ToolSpecBuilder {
218 ToolSpecBuilder::default()
219 }
220}
221
222impl Eq for ToolSpec {}
223
224impl std::hash::Hash for ToolSpec {
225 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
226 self.name.hash(state);
227 self.description.hash(state);
228 if let Some(schema) = &self.parameters_schema
229 && let Ok(serialized) = serde_json::to_vec(schema)
230 {
231 serialized.hash(state);
232 }
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use std::collections::HashSet;
240
241 #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
242 struct ExampleArgs {
243 value: String,
244 }
245
246 #[test]
247 fn tool_spec_serializes_schema() {
248 let schema = schemars::schema_for!(ExampleArgs);
249
250 let spec = ToolSpec::builder()
251 .name("example")
252 .description("An example tool")
253 .parameters_schema(schema)
254 .build()
255 .unwrap();
256
257 let json = serde_json::to_value(&spec).unwrap();
258 assert_eq!(json.get("name").and_then(|v| v.as_str()), Some("example"));
259 assert!(json.get("parameters_schema").is_some());
260 }
261
262 #[test]
263 fn tool_spec_is_hashable() {
264 let schema = schemars::schema_for!(ExampleArgs);
265 let spec = ToolSpec::builder()
266 .name("example")
267 .description("An example tool")
268 .parameters_schema(schema)
269 .build()
270 .unwrap();
271
272 let mut set = HashSet::new();
273 set.insert(spec.clone());
274
275 assert!(set.contains(&spec));
276 }
277}