Skip to main content

xai_rust/models/
tool.rs

1//! Tool definitions for function calling and server-side tools.
2
3use serde::{Deserialize, Serialize};
4
5/// A tool that can be used by the model.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7#[serde(tag = "type", rename_all = "snake_case")]
8pub enum Tool {
9    /// A custom function tool.
10    Function {
11        /// The function definition.
12        function: FunctionDefinition,
13    },
14    /// Web search tool.
15    WebSearch {
16        /// Search filters.
17        #[serde(skip_serializing_if = "Option::is_none", flatten)]
18        filters: Option<WebSearchFilters>,
19        /// Enable image understanding during search.
20        #[serde(skip_serializing_if = "Option::is_none")]
21        enable_image_understanding: Option<bool>,
22    },
23    /// X (Twitter) search tool.
24    XSearch {
25        /// Allowed X handles.
26        #[serde(skip_serializing_if = "Option::is_none")]
27        allowed_x_handles: Option<Vec<String>>,
28        /// Excluded X handles.
29        #[serde(skip_serializing_if = "Option::is_none")]
30        excluded_x_handles: Option<Vec<String>>,
31        /// Start date for search (YYYY-MM-DD).
32        #[serde(skip_serializing_if = "Option::is_none")]
33        from_date: Option<String>,
34        /// End date for search (YYYY-MM-DD).
35        #[serde(skip_serializing_if = "Option::is_none")]
36        to_date: Option<String>,
37        /// Enable image understanding.
38        #[serde(skip_serializing_if = "Option::is_none")]
39        enable_image_understanding: Option<bool>,
40        /// Enable video understanding.
41        #[serde(skip_serializing_if = "Option::is_none")]
42        enable_video_understanding: Option<bool>,
43    },
44    /// Code interpreter tool.
45    CodeInterpreter {},
46    /// Collections search tool.
47    CollectionsSearch {
48        /// Collection IDs to search.
49        #[serde(skip_serializing_if = "Option::is_none")]
50        collection_ids: Option<Vec<String>>,
51    },
52    /// Remote MCP server tool.
53    Mcp {
54        /// MCP server configuration.
55        server: McpServer,
56        /// Specific tools to enable.
57        #[serde(skip_serializing_if = "Option::is_none")]
58        allowed_tools: Option<Vec<String>>,
59    },
60}
61
62impl Tool {
63    /// Create a function tool.
64    pub fn function(
65        name: impl Into<String>,
66        description: impl Into<String>,
67        parameters: serde_json::Value,
68    ) -> Self {
69        Self::Function {
70            function: FunctionDefinition {
71                name: name.into(),
72                description: Some(description.into()),
73                parameters,
74                strict: None,
75            },
76        }
77    }
78
79    /// Create a web search tool with default settings.
80    pub fn web_search() -> Self {
81        Self::WebSearch {
82            filters: None,
83            enable_image_understanding: None,
84        }
85    }
86
87    /// Create a web search tool with filters.
88    pub fn web_search_filtered(filters: WebSearchFilters) -> Self {
89        Self::WebSearch {
90            filters: Some(filters),
91            enable_image_understanding: None,
92        }
93    }
94
95    /// Create an X search tool with default settings.
96    pub fn x_search() -> Self {
97        Self::XSearch {
98            allowed_x_handles: None,
99            excluded_x_handles: None,
100            from_date: None,
101            to_date: None,
102            enable_image_understanding: None,
103            enable_video_understanding: None,
104        }
105    }
106
107    /// Create a code interpreter tool.
108    pub fn code_interpreter() -> Self {
109        Self::CodeInterpreter {}
110    }
111
112    /// Create a collections search tool.
113    pub fn collections_search(collection_ids: Vec<String>) -> Self {
114        Self::CollectionsSearch {
115            collection_ids: Some(collection_ids),
116        }
117    }
118
119    /// Create an MCP tool.
120    pub fn mcp(server_url: impl Into<String>) -> Self {
121        Self::Mcp {
122            server: McpServer {
123                url: server_url.into(),
124                headers: None,
125            },
126            allowed_tools: None,
127        }
128    }
129}
130
131/// Function definition for custom tools.
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct FunctionDefinition {
134    /// The name of the function.
135    pub name: String,
136    /// Description of what the function does.
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub description: Option<String>,
139    /// JSON Schema for the function parameters.
140    pub parameters: serde_json::Value,
141    /// Whether to strictly enforce the schema.
142    #[serde(skip_serializing_if = "Option::is_none")]
143    pub strict: Option<bool>,
144}
145
146impl FunctionDefinition {
147    /// Create a new function definition.
148    pub fn new(name: impl Into<String>, parameters: serde_json::Value) -> Self {
149        Self {
150            name: name.into(),
151            description: None,
152            parameters,
153            strict: None,
154        }
155    }
156
157    /// Set the description.
158    pub fn with_description(mut self, description: impl Into<String>) -> Self {
159        self.description = Some(description.into());
160        self
161    }
162
163    /// Enable strict mode.
164    pub fn strict(mut self) -> Self {
165        self.strict = Some(true);
166        self
167    }
168}
169
170/// Web search filters.
171#[derive(Debug, Clone, Default, Serialize, Deserialize)]
172pub struct WebSearchFilters {
173    /// Only search these domains.
174    #[serde(skip_serializing_if = "Option::is_none")]
175    pub allowed_domains: Option<Vec<String>>,
176    /// Exclude these domains from search.
177    #[serde(skip_serializing_if = "Option::is_none")]
178    pub excluded_domains: Option<Vec<String>>,
179}
180
181impl WebSearchFilters {
182    /// Create filters with allowed domains.
183    pub fn allow_domains(domains: Vec<String>) -> Self {
184        Self {
185            allowed_domains: Some(domains),
186            excluded_domains: None,
187        }
188    }
189
190    /// Create filters with excluded domains.
191    pub fn exclude_domains(domains: Vec<String>) -> Self {
192        Self {
193            allowed_domains: None,
194            excluded_domains: Some(domains),
195        }
196    }
197}
198
199/// MCP server configuration.
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct McpServer {
202    /// The URL of the MCP server.
203    pub url: String,
204    /// Custom headers for authentication.
205    #[serde(skip_serializing_if = "Option::is_none")]
206    pub headers: Option<std::collections::HashMap<String, String>>,
207}
208
209/// MCP tool configuration.
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct McpTool {
212    /// MCP server configuration.
213    pub server: McpServer,
214    /// Allowed tools from the server.
215    #[serde(skip_serializing_if = "Option::is_none")]
216    pub allowed_tools: Option<Vec<String>>,
217}
218
219/// A tool call made by the model.
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct ToolCall {
222    /// Unique ID for this tool call.
223    pub id: String,
224    /// Type of tool call (e.g., `"function"`).
225    #[serde(rename = "type", default, skip_serializing_if = "Option::is_none")]
226    pub call_type: Option<String>,
227    /// Function call details.
228    #[serde(skip_serializing_if = "Option::is_none")]
229    pub function: Option<FunctionCall>,
230}
231
232/// Function call details.
233#[derive(Debug, Clone, Serialize, Deserialize)]
234pub struct FunctionCall {
235    /// Name of the function to call.
236    pub name: String,
237    /// Arguments as a JSON string.
238    pub arguments: String,
239}
240
241impl FunctionCall {
242    /// Parse the arguments as JSON.
243    pub fn parse_arguments<T: serde::de::DeserializeOwned>(&self) -> Result<T, serde_json::Error> {
244        serde_json::from_str(&self.arguments)
245    }
246}
247
248/// Tool choice specification.
249#[derive(Debug, Clone, Serialize, Deserialize)]
250#[serde(untagged)]
251pub enum ToolChoice {
252    /// Let the model decide.
253    Auto(ToolChoiceAuto),
254    /// Force a specific tool.
255    Specific(ToolChoiceSpecific),
256}
257
258/// Auto tool choice modes.
259#[derive(Debug, Clone, Serialize, Deserialize)]
260#[serde(rename_all = "lowercase")]
261pub enum ToolChoiceAuto {
262    /// Model decides whether to use tools.
263    Auto,
264    /// Model must use at least one tool.
265    Required,
266    /// Model should not use tools.
267    None,
268}
269
270/// Specific tool choice.
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct ToolChoiceSpecific {
273    /// Type of tool.
274    #[serde(rename = "type")]
275    pub tool_type: String,
276    /// Function specification (for function tools).
277    #[serde(skip_serializing_if = "Option::is_none")]
278    pub function: Option<ToolChoiceFunction>,
279}
280
281/// Function specification for tool choice.
282#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct ToolChoiceFunction {
284    /// Name of the function to use.
285    pub name: String,
286}
287
288impl ToolChoice {
289    /// Let the model decide.
290    pub fn auto() -> Self {
291        Self::Auto(ToolChoiceAuto::Auto)
292    }
293
294    /// Require tool use.
295    pub fn required() -> Self {
296        Self::Auto(ToolChoiceAuto::Required)
297    }
298
299    /// Disable tools.
300    pub fn none() -> Self {
301        Self::Auto(ToolChoiceAuto::None)
302    }
303
304    /// Force a specific function.
305    pub fn function(name: impl Into<String>) -> Self {
306        Self::Specific(ToolChoiceSpecific {
307            tool_type: "function".to_string(),
308            function: Some(ToolChoiceFunction { name: name.into() }),
309        })
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use serde_json::json;
317
318    // ── Tool enum serde roundtrips ─────────────────────────────────────
319
320    #[test]
321    fn tool_function_roundtrip() {
322        let t = Tool::function(
323            "get_weather",
324            "Get weather for a location",
325            json!({"type": "object", "properties": {"city": {"type": "string"}}}),
326        );
327        let json_val = serde_json::to_value(&t).unwrap();
328        assert_eq!(json_val["type"], "function");
329        assert_eq!(json_val["function"]["name"], "get_weather");
330        assert_eq!(
331            json_val["function"]["description"],
332            "Get weather for a location"
333        );
334
335        let back: Tool = serde_json::from_value(json_val).unwrap();
336        if let Tool::Function { function } = back {
337            assert_eq!(function.name, "get_weather");
338        } else {
339            panic!("Expected Function variant");
340        }
341    }
342
343    #[test]
344    fn tool_web_search_roundtrip() {
345        let t = Tool::web_search();
346        let json_val = serde_json::to_value(&t).unwrap();
347        assert_eq!(json_val["type"], "web_search");
348
349        let back: Tool = serde_json::from_value(json_val).unwrap();
350        assert!(matches!(back, Tool::WebSearch { .. }));
351    }
352
353    #[test]
354    fn tool_web_search_filtered_roundtrip() {
355        let t =
356            Tool::web_search_filtered(WebSearchFilters::allow_domains(vec!["docs.rs".to_string()]));
357        let json_val = serde_json::to_value(&t).unwrap();
358        assert_eq!(json_val["type"], "web_search");
359        assert_eq!(json_val["allowed_domains"], json!(["docs.rs"]));
360
361        let back: Tool = serde_json::from_value(json_val).unwrap();
362        if let Tool::WebSearch { filters, .. } = back {
363            let filters = filters.unwrap();
364            assert_eq!(filters.allowed_domains.unwrap(), vec!["docs.rs"]);
365        } else {
366            panic!("Expected WebSearch variant");
367        }
368    }
369
370    #[test]
371    fn tool_x_search_roundtrip() {
372        let t = Tool::x_search();
373        let json_val = serde_json::to_value(&t).unwrap();
374        assert_eq!(json_val["type"], "x_search");
375
376        let back: Tool = serde_json::from_value(json_val).unwrap();
377        assert!(matches!(back, Tool::XSearch { .. }));
378    }
379
380    #[test]
381    fn tool_code_interpreter_roundtrip() {
382        let t = Tool::code_interpreter();
383        let json_val = serde_json::to_value(&t).unwrap();
384        assert_eq!(json_val["type"], "code_interpreter");
385
386        let back: Tool = serde_json::from_value(json_val).unwrap();
387        assert!(matches!(back, Tool::CodeInterpreter {}));
388    }
389
390    #[test]
391    fn tool_collections_search_roundtrip() {
392        let t = Tool::collections_search(vec!["col-1".to_string(), "col-2".to_string()]);
393        let json_val = serde_json::to_value(&t).unwrap();
394        assert_eq!(json_val["type"], "collections_search");
395        assert_eq!(json_val["collection_ids"], json!(["col-1", "col-2"]));
396
397        let back: Tool = serde_json::from_value(json_val).unwrap();
398        if let Tool::CollectionsSearch { collection_ids } = back {
399            assert_eq!(collection_ids.unwrap(), vec!["col-1", "col-2"]);
400        } else {
401            panic!("Expected CollectionsSearch variant");
402        }
403    }
404
405    #[test]
406    fn tool_mcp_roundtrip() {
407        let t = Tool::mcp("https://mcp.example.com");
408        let json_val = serde_json::to_value(&t).unwrap();
409        assert_eq!(json_val["type"], "mcp");
410        assert_eq!(json_val["server"]["url"], "https://mcp.example.com");
411
412        let back: Tool = serde_json::from_value(json_val).unwrap();
413        if let Tool::Mcp { server, .. } = back {
414            assert_eq!(server.url, "https://mcp.example.com");
415        } else {
416            panic!("Expected Mcp variant");
417        }
418    }
419
420    // ── ToolCall serde ─────────────────────────────────────────────────
421
422    #[test]
423    fn tool_call_with_call_type_roundtrip() {
424        let tc = ToolCall {
425            id: "call_abc".to_string(),
426            call_type: Some("function".to_string()),
427            function: Some(FunctionCall {
428                name: "get_weather".to_string(),
429                arguments: r#"{"city":"NYC"}"#.to_string(),
430            }),
431        };
432        let json_val = serde_json::to_value(&tc).unwrap();
433        assert_eq!(json_val["id"], "call_abc");
434        assert_eq!(json_val["type"], "function");
435        assert_eq!(json_val["function"]["name"], "get_weather");
436
437        let back: ToolCall = serde_json::from_value(json_val).unwrap();
438        assert_eq!(back.id, "call_abc");
439        assert_eq!(back.call_type.as_deref(), Some("function"));
440        assert_eq!(back.function.as_ref().unwrap().name, "get_weather");
441    }
442
443    #[test]
444    fn tool_call_without_call_type_roundtrip() {
445        // call_type is Option<String>; verify it defaults to None when absent
446        let json_val = json!({
447            "id": "call_xyz",
448            "function": {
449                "name": "do_stuff",
450                "arguments": "{}"
451            }
452        });
453        let tc: ToolCall = serde_json::from_value(json_val).unwrap();
454        assert_eq!(tc.id, "call_xyz");
455        assert!(tc.call_type.is_none());
456    }
457
458    #[test]
459    fn tool_call_none_type_skipped_on_serialize() {
460        let tc = ToolCall {
461            id: "call_1".to_string(),
462            call_type: None,
463            function: None,
464        };
465        let json_val = serde_json::to_value(&tc).unwrap();
466        // "type" key should be absent because of skip_serializing_if
467        assert!(json_val.get("type").is_none());
468    }
469
470    // ── FunctionCall parse_arguments ───────────────────────────────────
471
472    #[test]
473    fn function_call_parse_arguments() {
474        let fc = FunctionCall {
475            name: "test".to_string(),
476            arguments: r#"{"x": 42}"#.to_string(),
477        };
478        let parsed: serde_json::Value = fc.parse_arguments().unwrap();
479        assert_eq!(parsed["x"], 42);
480    }
481
482    #[test]
483    fn function_call_parse_arguments_error() {
484        let fc = FunctionCall {
485            name: "test".to_string(),
486            arguments: "not json".to_string(),
487        };
488        assert!(fc.parse_arguments::<serde_json::Value>().is_err());
489    }
490
491    // ── ToolChoice serde roundtrips ────────────────────────────────────
492
493    #[test]
494    fn tool_choice_auto_roundtrip() {
495        let choice = ToolChoice::auto();
496        let json_val = serde_json::to_value(&choice).unwrap();
497        assert_eq!(json_val, json!("auto"));
498    }
499
500    #[test]
501    fn tool_choice_required_roundtrip() {
502        let choice = ToolChoice::required();
503        let json_val = serde_json::to_value(&choice).unwrap();
504        assert_eq!(json_val, json!("required"));
505    }
506
507    #[test]
508    fn tool_choice_none_roundtrip() {
509        let choice = ToolChoice::none();
510        let json_val = serde_json::to_value(&choice).unwrap();
511        assert_eq!(json_val, json!("none"));
512    }
513
514    #[test]
515    fn tool_choice_specific_function_roundtrip() {
516        let choice = ToolChoice::function("get_weather");
517        let json_val = serde_json::to_value(&choice).unwrap();
518        assert_eq!(json_val["type"], "function");
519        assert_eq!(json_val["function"]["name"], "get_weather");
520
521        let back: ToolChoice = serde_json::from_value(json_val).unwrap();
522        if let ToolChoice::Specific(spec) = back {
523            assert_eq!(spec.tool_type, "function");
524            assert_eq!(spec.function.unwrap().name, "get_weather");
525        } else {
526            panic!("Expected Specific variant");
527        }
528    }
529
530    // ── FunctionDefinition builder ─────────────────────────────────────
531
532    #[test]
533    fn function_definition_builder() {
534        let fd = FunctionDefinition::new("test", json!({}))
535            .with_description("A test function")
536            .strict();
537        assert_eq!(fd.name, "test");
538        assert_eq!(fd.description.as_deref(), Some("A test function"));
539        assert_eq!(fd.strict, Some(true));
540    }
541
542    #[test]
543    fn function_definition_roundtrip() {
544        let fd = FunctionDefinition {
545            name: "search".to_string(),
546            description: Some("Search the web".to_string()),
547            parameters: json!({"type": "object"}),
548            strict: Some(true),
549        };
550        let json_val = serde_json::to_value(&fd).unwrap();
551        let back: FunctionDefinition = serde_json::from_value(json_val).unwrap();
552        assert_eq!(back.name, "search");
553        assert_eq!(back.strict, Some(true));
554    }
555}