1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7#[serde(tag = "type", rename_all = "snake_case")]
8pub enum Tool {
9 Function {
11 function: FunctionDefinition,
13 },
14 WebSearch {
16 #[serde(skip_serializing_if = "Option::is_none", flatten)]
18 filters: Option<WebSearchFilters>,
19 #[serde(skip_serializing_if = "Option::is_none")]
21 enable_image_understanding: Option<bool>,
22 },
23 XSearch {
25 #[serde(skip_serializing_if = "Option::is_none")]
27 allowed_x_handles: Option<Vec<String>>,
28 #[serde(skip_serializing_if = "Option::is_none")]
30 excluded_x_handles: Option<Vec<String>>,
31 #[serde(skip_serializing_if = "Option::is_none")]
33 from_date: Option<String>,
34 #[serde(skip_serializing_if = "Option::is_none")]
36 to_date: Option<String>,
37 #[serde(skip_serializing_if = "Option::is_none")]
39 enable_image_understanding: Option<bool>,
40 #[serde(skip_serializing_if = "Option::is_none")]
42 enable_video_understanding: Option<bool>,
43 },
44 CodeInterpreter {},
46 CollectionsSearch {
48 #[serde(skip_serializing_if = "Option::is_none")]
50 collection_ids: Option<Vec<String>>,
51 },
52 Mcp {
54 server: McpServer,
56 #[serde(skip_serializing_if = "Option::is_none")]
58 allowed_tools: Option<Vec<String>>,
59 },
60}
61
62impl Tool {
63 pub fn function(
65 name: impl Into<String>,
66 description: impl Into<String>,
67 parameters: serde_json::Value,
68 ) -> Self {
69 Self::Function {
70 function: FunctionDefinition {
71 name: name.into(),
72 description: Some(description.into()),
73 parameters,
74 strict: None,
75 },
76 }
77 }
78
79 pub fn web_search() -> Self {
81 Self::WebSearch {
82 filters: None,
83 enable_image_understanding: None,
84 }
85 }
86
87 pub fn web_search_filtered(filters: WebSearchFilters) -> Self {
89 Self::WebSearch {
90 filters: Some(filters),
91 enable_image_understanding: None,
92 }
93 }
94
95 pub fn x_search() -> Self {
97 Self::XSearch {
98 allowed_x_handles: None,
99 excluded_x_handles: None,
100 from_date: None,
101 to_date: None,
102 enable_image_understanding: None,
103 enable_video_understanding: None,
104 }
105 }
106
107 pub fn code_interpreter() -> Self {
109 Self::CodeInterpreter {}
110 }
111
112 pub fn collections_search(collection_ids: Vec<String>) -> Self {
114 Self::CollectionsSearch {
115 collection_ids: Some(collection_ids),
116 }
117 }
118
119 pub fn mcp(server_url: impl Into<String>) -> Self {
121 Self::Mcp {
122 server: McpServer {
123 url: server_url.into(),
124 headers: None,
125 },
126 allowed_tools: None,
127 }
128 }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct FunctionDefinition {
134 pub name: String,
136 #[serde(skip_serializing_if = "Option::is_none")]
138 pub description: Option<String>,
139 pub parameters: serde_json::Value,
141 #[serde(skip_serializing_if = "Option::is_none")]
143 pub strict: Option<bool>,
144}
145
146impl FunctionDefinition {
147 pub fn new(name: impl Into<String>, parameters: serde_json::Value) -> Self {
149 Self {
150 name: name.into(),
151 description: None,
152 parameters,
153 strict: None,
154 }
155 }
156
157 pub fn with_description(mut self, description: impl Into<String>) -> Self {
159 self.description = Some(description.into());
160 self
161 }
162
163 pub fn strict(mut self) -> Self {
165 self.strict = Some(true);
166 self
167 }
168}
169
170#[derive(Debug, Clone, Default, Serialize, Deserialize)]
172pub struct WebSearchFilters {
173 #[serde(skip_serializing_if = "Option::is_none")]
175 pub allowed_domains: Option<Vec<String>>,
176 #[serde(skip_serializing_if = "Option::is_none")]
178 pub excluded_domains: Option<Vec<String>>,
179}
180
181impl WebSearchFilters {
182 pub fn allow_domains(domains: Vec<String>) -> Self {
184 Self {
185 allowed_domains: Some(domains),
186 excluded_domains: None,
187 }
188 }
189
190 pub fn exclude_domains(domains: Vec<String>) -> Self {
192 Self {
193 allowed_domains: None,
194 excluded_domains: Some(domains),
195 }
196 }
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct McpServer {
202 pub url: String,
204 #[serde(skip_serializing_if = "Option::is_none")]
206 pub headers: Option<std::collections::HashMap<String, String>>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct McpTool {
212 pub server: McpServer,
214 #[serde(skip_serializing_if = "Option::is_none")]
216 pub allowed_tools: Option<Vec<String>>,
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct ToolCall {
222 pub id: String,
224 #[serde(rename = "type", default, skip_serializing_if = "Option::is_none")]
226 pub call_type: Option<String>,
227 #[serde(skip_serializing_if = "Option::is_none")]
229 pub function: Option<FunctionCall>,
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
234pub struct FunctionCall {
235 pub name: String,
237 pub arguments: String,
239}
240
241impl FunctionCall {
242 pub fn parse_arguments<T: serde::de::DeserializeOwned>(&self) -> Result<T, serde_json::Error> {
244 serde_json::from_str(&self.arguments)
245 }
246}
247
248#[derive(Debug, Clone, Serialize, Deserialize)]
250#[serde(untagged)]
251pub enum ToolChoice {
252 Auto(ToolChoiceAuto),
254 Specific(ToolChoiceSpecific),
256}
257
258#[derive(Debug, Clone, Serialize, Deserialize)]
260#[serde(rename_all = "lowercase")]
261pub enum ToolChoiceAuto {
262 Auto,
264 Required,
266 None,
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct ToolChoiceSpecific {
273 #[serde(rename = "type")]
275 pub tool_type: String,
276 #[serde(skip_serializing_if = "Option::is_none")]
278 pub function: Option<ToolChoiceFunction>,
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct ToolChoiceFunction {
284 pub name: String,
286}
287
288impl ToolChoice {
289 pub fn auto() -> Self {
291 Self::Auto(ToolChoiceAuto::Auto)
292 }
293
294 pub fn required() -> Self {
296 Self::Auto(ToolChoiceAuto::Required)
297 }
298
299 pub fn none() -> Self {
301 Self::Auto(ToolChoiceAuto::None)
302 }
303
304 pub fn function(name: impl Into<String>) -> Self {
306 Self::Specific(ToolChoiceSpecific {
307 tool_type: "function".to_string(),
308 function: Some(ToolChoiceFunction { name: name.into() }),
309 })
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use serde_json::json;
317
318 #[test]
321 fn tool_function_roundtrip() {
322 let t = Tool::function(
323 "get_weather",
324 "Get weather for a location",
325 json!({"type": "object", "properties": {"city": {"type": "string"}}}),
326 );
327 let json_val = serde_json::to_value(&t).unwrap();
328 assert_eq!(json_val["type"], "function");
329 assert_eq!(json_val["function"]["name"], "get_weather");
330 assert_eq!(
331 json_val["function"]["description"],
332 "Get weather for a location"
333 );
334
335 let back: Tool = serde_json::from_value(json_val).unwrap();
336 if let Tool::Function { function } = back {
337 assert_eq!(function.name, "get_weather");
338 } else {
339 panic!("Expected Function variant");
340 }
341 }
342
343 #[test]
344 fn tool_web_search_roundtrip() {
345 let t = Tool::web_search();
346 let json_val = serde_json::to_value(&t).unwrap();
347 assert_eq!(json_val["type"], "web_search");
348
349 let back: Tool = serde_json::from_value(json_val).unwrap();
350 assert!(matches!(back, Tool::WebSearch { .. }));
351 }
352
353 #[test]
354 fn tool_web_search_filtered_roundtrip() {
355 let t =
356 Tool::web_search_filtered(WebSearchFilters::allow_domains(vec!["docs.rs".to_string()]));
357 let json_val = serde_json::to_value(&t).unwrap();
358 assert_eq!(json_val["type"], "web_search");
359 assert_eq!(json_val["allowed_domains"], json!(["docs.rs"]));
360
361 let back: Tool = serde_json::from_value(json_val).unwrap();
362 if let Tool::WebSearch { filters, .. } = back {
363 let filters = filters.unwrap();
364 assert_eq!(filters.allowed_domains.unwrap(), vec!["docs.rs"]);
365 } else {
366 panic!("Expected WebSearch variant");
367 }
368 }
369
370 #[test]
371 fn tool_x_search_roundtrip() {
372 let t = Tool::x_search();
373 let json_val = serde_json::to_value(&t).unwrap();
374 assert_eq!(json_val["type"], "x_search");
375
376 let back: Tool = serde_json::from_value(json_val).unwrap();
377 assert!(matches!(back, Tool::XSearch { .. }));
378 }
379
380 #[test]
381 fn tool_code_interpreter_roundtrip() {
382 let t = Tool::code_interpreter();
383 let json_val = serde_json::to_value(&t).unwrap();
384 assert_eq!(json_val["type"], "code_interpreter");
385
386 let back: Tool = serde_json::from_value(json_val).unwrap();
387 assert!(matches!(back, Tool::CodeInterpreter {}));
388 }
389
390 #[test]
391 fn tool_collections_search_roundtrip() {
392 let t = Tool::collections_search(vec!["col-1".to_string(), "col-2".to_string()]);
393 let json_val = serde_json::to_value(&t).unwrap();
394 assert_eq!(json_val["type"], "collections_search");
395 assert_eq!(json_val["collection_ids"], json!(["col-1", "col-2"]));
396
397 let back: Tool = serde_json::from_value(json_val).unwrap();
398 if let Tool::CollectionsSearch { collection_ids } = back {
399 assert_eq!(collection_ids.unwrap(), vec!["col-1", "col-2"]);
400 } else {
401 panic!("Expected CollectionsSearch variant");
402 }
403 }
404
405 #[test]
406 fn tool_mcp_roundtrip() {
407 let t = Tool::mcp("https://mcp.example.com");
408 let json_val = serde_json::to_value(&t).unwrap();
409 assert_eq!(json_val["type"], "mcp");
410 assert_eq!(json_val["server"]["url"], "https://mcp.example.com");
411
412 let back: Tool = serde_json::from_value(json_val).unwrap();
413 if let Tool::Mcp { server, .. } = back {
414 assert_eq!(server.url, "https://mcp.example.com");
415 } else {
416 panic!("Expected Mcp variant");
417 }
418 }
419
420 #[test]
423 fn tool_call_with_call_type_roundtrip() {
424 let tc = ToolCall {
425 id: "call_abc".to_string(),
426 call_type: Some("function".to_string()),
427 function: Some(FunctionCall {
428 name: "get_weather".to_string(),
429 arguments: r#"{"city":"NYC"}"#.to_string(),
430 }),
431 };
432 let json_val = serde_json::to_value(&tc).unwrap();
433 assert_eq!(json_val["id"], "call_abc");
434 assert_eq!(json_val["type"], "function");
435 assert_eq!(json_val["function"]["name"], "get_weather");
436
437 let back: ToolCall = serde_json::from_value(json_val).unwrap();
438 assert_eq!(back.id, "call_abc");
439 assert_eq!(back.call_type.as_deref(), Some("function"));
440 assert_eq!(back.function.as_ref().unwrap().name, "get_weather");
441 }
442
443 #[test]
444 fn tool_call_without_call_type_roundtrip() {
445 let json_val = json!({
447 "id": "call_xyz",
448 "function": {
449 "name": "do_stuff",
450 "arguments": "{}"
451 }
452 });
453 let tc: ToolCall = serde_json::from_value(json_val).unwrap();
454 assert_eq!(tc.id, "call_xyz");
455 assert!(tc.call_type.is_none());
456 }
457
458 #[test]
459 fn tool_call_none_type_skipped_on_serialize() {
460 let tc = ToolCall {
461 id: "call_1".to_string(),
462 call_type: None,
463 function: None,
464 };
465 let json_val = serde_json::to_value(&tc).unwrap();
466 assert!(json_val.get("type").is_none());
468 }
469
470 #[test]
473 fn function_call_parse_arguments() {
474 let fc = FunctionCall {
475 name: "test".to_string(),
476 arguments: r#"{"x": 42}"#.to_string(),
477 };
478 let parsed: serde_json::Value = fc.parse_arguments().unwrap();
479 assert_eq!(parsed["x"], 42);
480 }
481
482 #[test]
483 fn function_call_parse_arguments_error() {
484 let fc = FunctionCall {
485 name: "test".to_string(),
486 arguments: "not json".to_string(),
487 };
488 assert!(fc.parse_arguments::<serde_json::Value>().is_err());
489 }
490
491 #[test]
494 fn tool_choice_auto_roundtrip() {
495 let choice = ToolChoice::auto();
496 let json_val = serde_json::to_value(&choice).unwrap();
497 assert_eq!(json_val, json!("auto"));
498 }
499
500 #[test]
501 fn tool_choice_required_roundtrip() {
502 let choice = ToolChoice::required();
503 let json_val = serde_json::to_value(&choice).unwrap();
504 assert_eq!(json_val, json!("required"));
505 }
506
507 #[test]
508 fn tool_choice_none_roundtrip() {
509 let choice = ToolChoice::none();
510 let json_val = serde_json::to_value(&choice).unwrap();
511 assert_eq!(json_val, json!("none"));
512 }
513
514 #[test]
515 fn tool_choice_specific_function_roundtrip() {
516 let choice = ToolChoice::function("get_weather");
517 let json_val = serde_json::to_value(&choice).unwrap();
518 assert_eq!(json_val["type"], "function");
519 assert_eq!(json_val["function"]["name"], "get_weather");
520
521 let back: ToolChoice = serde_json::from_value(json_val).unwrap();
522 if let ToolChoice::Specific(spec) = back {
523 assert_eq!(spec.tool_type, "function");
524 assert_eq!(spec.function.unwrap().name, "get_weather");
525 } else {
526 panic!("Expected Specific variant");
527 }
528 }
529
530 #[test]
533 fn function_definition_builder() {
534 let fd = FunctionDefinition::new("test", json!({}))
535 .with_description("A test function")
536 .strict();
537 assert_eq!(fd.name, "test");
538 assert_eq!(fd.description.as_deref(), Some("A test function"));
539 assert_eq!(fd.strict, Some(true));
540 }
541
542 #[test]
543 fn function_definition_roundtrip() {
544 let fd = FunctionDefinition {
545 name: "search".to_string(),
546 description: Some("Search the web".to_string()),
547 parameters: json!({"type": "object"}),
548 strict: Some(true),
549 };
550 let json_val = serde_json::to_value(&fd).unwrap();
551 let back: FunctionDefinition = serde_json::from_value(json_val).unwrap();
552 assert_eq!(back.name, "search");
553 assert_eq!(back.strict, Some(true));
554 }
555}