swiftide_core/chat_completion/
tools.rs1use std::borrow::Cow;
2
3use derive_builder::Builder;
4use schemars::Schema;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, strum_macros::EnumIs)]
9#[non_exhaustive]
10pub enum ToolOutput {
11 Text(String),
13
14 FeedbackRequired(Option<serde_json::Value>),
16
17 Fail(String),
19
20 Stop(Option<serde_json::Value>),
22
23 AgentFailed(Option<Cow<'static, str>>),
25}
26
27impl ToolOutput {
28 pub fn text(text: impl Into<String>) -> Self {
29 ToolOutput::Text(text.into())
30 }
31
32 pub fn feedback_required(feedback: Option<serde_json::Value>) -> Self {
33 ToolOutput::FeedbackRequired(feedback)
34 }
35
36 pub fn stop() -> Self {
37 ToolOutput::Stop(None)
38 }
39
40 pub fn stop_with_args(output: impl Into<serde_json::Value>) -> Self {
41 ToolOutput::Stop(Some(output.into()))
42 }
43
44 pub fn agent_failed(output: impl Into<Cow<'static, str>>) -> Self {
45 ToolOutput::AgentFailed(Some(output.into()))
46 }
47
48 pub fn fail(text: impl Into<String>) -> Self {
49 ToolOutput::Fail(text.into())
50 }
51
52 pub fn content(&self) -> Option<&str> {
53 match self {
54 ToolOutput::Fail(s) | ToolOutput::Text(s) => Some(s),
55 _ => None,
56 }
57 }
58
59 pub fn as_text(&self) -> Option<&str> {
61 match self {
62 ToolOutput::Text(s) => Some(s),
63 _ => None,
64 }
65 }
66
67 pub fn as_fail(&self) -> Option<&str> {
69 match self {
70 ToolOutput::Fail(s) => Some(s),
71 _ => None,
72 }
73 }
74
75 pub fn as_stop(&self) -> Option<&serde_json::Value> {
77 match self {
78 ToolOutput::Stop(args) => args.as_ref(),
79 _ => None,
80 }
81 }
82
83 pub fn as_agent_failed(&self) -> Option<&str> {
85 match self {
86 ToolOutput::AgentFailed(args) => args.as_deref(),
87 _ => None,
88 }
89 }
90
91 pub fn as_feedback_required(&self) -> Option<&serde_json::Value> {
93 match self {
94 ToolOutput::FeedbackRequired(args) => args.as_ref(),
95 _ => None,
96 }
97 }
98}
99
100impl<S: AsRef<str>> From<S> for ToolOutput {
101 fn from(value: S) -> Self {
102 ToolOutput::Text(value.as_ref().to_string())
103 }
104}
105impl std::fmt::Display for ToolOutput {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 match self {
108 ToolOutput::Text(value) => write!(f, "{value}"),
109 ToolOutput::Fail(value) => write!(f, "Tool call failed: {value}"),
110 ToolOutput::Stop(args) => {
111 if let Some(value) = args {
112 write!(f, "Stop {value}")
113 } else {
114 write!(f, "Stop")
115 }
116 }
117 ToolOutput::FeedbackRequired(_) => {
118 write!(f, "Feedback required")
119 }
120 ToolOutput::AgentFailed(args) => write!(
121 f,
122 "Agent failed with output: {}",
123 args.as_deref().unwrap_or_default()
124 ),
125 }
126 }
127}
128
129#[derive(Clone, Debug, Builder, PartialEq, Serialize, Deserialize, Eq)]
131#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
132#[builder(setter(into, strip_option))]
133pub struct ToolCall {
134 id: String,
135 name: String,
136 #[builder(default)]
137 args: Option<String>,
138}
139
140impl std::hash::Hash for ToolCall {
142 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
143 self.name.hash(state);
144 self.args.hash(state);
145 }
146}
147
148impl std::fmt::Display for ToolCall {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 write!(
151 f,
152 "{id}#{name} {args}",
153 id = self.id,
154 name = self.name,
155 args = self.args.as_deref().unwrap_or("")
156 )
157 }
158}
159
160impl ToolCall {
161 pub fn builder() -> ToolCallBuilder {
162 ToolCallBuilder::default()
163 }
164
165 pub fn id(&self) -> &str {
166 &self.id
167 }
168
169 pub fn name(&self) -> &str {
170 &self.name
171 }
172
173 pub fn args(&self) -> Option<&str> {
174 self.args.as_deref()
175 }
176
177 pub fn with_args(&mut self, args: Option<String>) {
178 self.args = args;
179 }
180}
181
182impl ToolCallBuilder {
183 pub fn maybe_args<T: Into<Option<String>>>(&mut self, args: T) -> &mut Self {
184 self.args = Some(args.into());
185 self
186 }
187
188 pub fn maybe_id<T: Into<Option<String>>>(&mut self, id: T) -> &mut Self {
189 self.id = id.into();
190 self
191 }
192
193 pub fn maybe_name<T: Into<Option<String>>>(&mut self, name: T) -> &mut Self {
194 self.name = name.into();
195 self
196 }
197}
198
199#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Builder, Default)]
203#[builder(setter(into), derive(Debug, Serialize, Deserialize))]
204#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
205#[serde(deny_unknown_fields)]
206pub struct ToolSpec {
207 pub name: String,
209 pub description: String,
211
212 #[builder(default, setter(strip_option))]
213 #[serde(skip_serializing_if = "Option::is_none")]
214 pub parameters_schema: Option<Schema>,
216}
217
218impl ToolSpec {
219 pub fn builder() -> ToolSpecBuilder {
220 ToolSpecBuilder::default()
221 }
222}
223
224impl Eq for ToolSpec {}
225
226impl std::hash::Hash for ToolSpec {
227 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
228 self.name.hash(state);
229 self.description.hash(state);
230 if let Some(schema) = &self.parameters_schema
231 && let Ok(serialized) = serde_json::to_vec(schema)
232 {
233 serialized.hash(state);
234 }
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use std::collections::HashSet;
242
243 #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
244 struct ExampleArgs {
245 value: String,
246 }
247
248 #[test]
249 fn tool_spec_serializes_schema() {
250 let schema = schemars::schema_for!(ExampleArgs);
251
252 let spec = ToolSpec::builder()
253 .name("example")
254 .description("An example tool")
255 .parameters_schema(schema)
256 .build()
257 .unwrap();
258
259 let json = serde_json::to_value(&spec).unwrap();
260 assert_eq!(json.get("name").and_then(|v| v.as_str()), Some("example"));
261 assert!(json.get("parameters_schema").is_some());
262 }
263
264 #[test]
265 fn tool_spec_is_hashable() {
266 let schema = schemars::schema_for!(ExampleArgs);
267 let spec = ToolSpec::builder()
268 .name("example")
269 .description("An example tool")
270 .parameters_schema(schema)
271 .build()
272 .unwrap();
273
274 let mut set = HashSet::new();
275 set.insert(spec.clone());
276
277 assert!(set.contains(&spec));
278 }
279}