xai_grpc_client/
tools.rs

1//! Tool calling support for Grok API.
2//!
3//! This module provides ergonomic Rust types for working with Grok's tool calling capabilities,
4//! including function calling, web search, X search, code execution, and more.
5//!
6//! # Tool Types
7//!
8//! The Grok API supports 7 different tool types:
9//! - **Function** - Client-side function calling (similar to OpenAI)
10//! - **WebSearch** - Server-side web search with domain filtering
11//! - **XSearch** - Search X (Twitter) posts with engagement filters
12//! - **CodeExecution** - Server-side Python code execution
13//! - **CollectionsSearch** - Search custom data collections
14//! - **MCP** - Model Context Protocol integration
15//! - **DocumentSearch** - Document retrieval from knowledge bases
16
17use serde_json::Value;
18use std::collections::HashMap;
19
20// Use shared proto module
21use crate::proto::{
22    self, CodeExecution as ProtoCodeExecution, CollectionsSearch as ProtoCollectionsSearch,
23    DocumentSearch as ProtoDocumentSearch, Function as ProtoFunction,
24    FunctionCall as ProtoFunctionCall, Mcp as ProtoMcp, ToolCall as ProtoToolCall, ToolCallStatus,
25    ToolCallType, ToolChoice as ProtoToolChoice, ToolMode, WebSearch as ProtoWebSearch,
26    XSearch as ProtoXSearch,
27};
28
29/// Tool that can be provided to the model for enhanced capabilities.
30///
31/// Tools allow the model to perform actions beyond text generation, such as
32/// calling functions, searching the web, or executing code.
33///
34/// # Examples
35///
36/// ```
37/// use xai_grpc_client::{Tool, FunctionTool, WebSearchTool};
38/// use serde_json::json;
39///
40/// // Function calling
41/// let weather_tool = Tool::Function(FunctionTool::new(
42///     "get_weather",
43///     "Get current weather"
44/// ).with_parameters(json!({
45///     "type": "object",
46///     "properties": {
47///         "location": {"type": "string"}
48///     }
49/// })));
50///
51/// // Web search
52/// let search_tool = Tool::WebSearch(WebSearchTool::new());
53/// ```
54#[derive(Clone, Debug)]
55pub enum Tool {
56    /// Client-side function calling (like OpenAI).
57    Function(FunctionTool),
58    /// Server-side web search with domain filters.
59    WebSearch(WebSearchTool),
60    /// Search X (Twitter) posts with engagement thresholds.
61    XSearch(XSearchTool),
62    /// Server-side Python code execution.
63    CodeExecution,
64    /// Search custom data collections.
65    CollectionsSearch(CollectionsSearchTool),
66    /// Model Context Protocol integration.
67    Mcp(McpTool),
68    /// Document retrieval from knowledge bases.
69    DocumentSearch(DocumentSearchTool),
70}
71
72impl Tool {
73    /// Convert to protobuf representation
74    pub fn to_proto(&self) -> proto::Tool {
75        let tool = match self {
76            Tool::Function(f) => proto::tool::Tool::Function(f.to_proto()),
77            Tool::WebSearch(w) => proto::tool::Tool::WebSearch(w.to_proto()),
78            Tool::XSearch(x) => proto::tool::Tool::XSearch(x.to_proto()),
79            Tool::CodeExecution => proto::tool::Tool::CodeExecution(ProtoCodeExecution {}),
80            Tool::CollectionsSearch(c) => proto::tool::Tool::CollectionsSearch(c.to_proto()),
81            Tool::Mcp(m) => proto::tool::Tool::Mcp(m.to_proto()),
82            Tool::DocumentSearch(d) => proto::tool::Tool::DocumentSearch(d.to_proto()),
83        };
84
85        proto::Tool { tool: Some(tool) }
86    }
87}
88
89/// Client-side function tool definition
90#[derive(Clone, Debug)]
91pub struct FunctionTool {
92    /// Name of the function
93    pub name: String,
94    /// Description of what the function does
95    pub description: String,
96    /// JSON Schema describing the function parameters
97    pub parameters: Value,
98}
99
100impl FunctionTool {
101    /// Create a new function tool
102    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
103        Self {
104            name: name.into(),
105            description: description.into(),
106            parameters: serde_json::json!({
107                "type": "object",
108                "properties": {},
109            }),
110        }
111    }
112
113    /// Set the parameters JSON schema
114    pub fn with_parameters(mut self, parameters: Value) -> Self {
115        self.parameters = parameters;
116        self
117    }
118
119    fn to_proto(&self) -> ProtoFunction {
120        ProtoFunction {
121            name: self.name.clone(),
122            description: self.description.clone(),
123            strict: false,
124            parameters: self.parameters.to_string(),
125        }
126    }
127}
128
129/// Web search tool configuration
130#[derive(Clone, Debug, Default)]
131pub struct WebSearchTool {
132    /// Domains to exclude from results (max 5)
133    pub excluded_domains: Vec<String>,
134    /// Domains to restrict results to (max 5)
135    pub allowed_domains: Vec<String>,
136    /// Enable image understanding in search results
137    pub enable_image_understanding: Option<bool>,
138}
139
140impl WebSearchTool {
141    /// Create a new web search tool
142    pub fn new() -> Self {
143        Self::default()
144    }
145
146    /// Exclude specific domains from search results
147    pub fn with_excluded_domains(mut self, domains: Vec<String>) -> Self {
148        self.excluded_domains = domains;
149        self
150    }
151
152    /// Restrict search to specific domains only
153    pub fn with_allowed_domains(mut self, domains: Vec<String>) -> Self {
154        self.allowed_domains = domains;
155        self
156    }
157
158    /// Enable image understanding in search results
159    pub fn with_image_understanding(mut self, enable: bool) -> Self {
160        self.enable_image_understanding = Some(enable);
161        self
162    }
163
164    fn to_proto(&self) -> ProtoWebSearch {
165        ProtoWebSearch {
166            excluded_domains: self.excluded_domains.clone(),
167            allowed_domains: self.allowed_domains.clone(),
168            enable_image_understanding: self.enable_image_understanding,
169        }
170    }
171}
172
173/// X (Twitter) search tool configuration
174#[derive(Clone, Debug, Default)]
175pub struct XSearchTool {
176    /// Start date for search results (ISO-8601)
177    pub from_date: Option<prost_types::Timestamp>,
178    /// End date for search results (ISO-8601)
179    pub to_date: Option<prost_types::Timestamp>,
180    /// Allowed X handles
181    pub allowed_x_handles: Vec<String>,
182    /// Excluded X handles
183    pub excluded_x_handles: Vec<String>,
184    /// Enable image understanding
185    pub enable_image_understanding: Option<bool>,
186    /// Enable video understanding
187    pub enable_video_understanding: Option<bool>,
188}
189
190impl XSearchTool {
191    /// Create a new X search tool
192    pub fn new() -> Self {
193        Self::default()
194    }
195
196    /// Set date range for search results
197    pub fn with_date_range(
198        mut self,
199        from: Option<prost_types::Timestamp>,
200        to: Option<prost_types::Timestamp>,
201    ) -> Self {
202        self.from_date = from;
203        self.to_date = to;
204        self
205    }
206
207    /// Set allowed X handles
208    pub fn with_allowed_handles(mut self, handles: Vec<String>) -> Self {
209        self.allowed_x_handles = handles;
210        self
211    }
212
213    /// Set excluded X handles
214    pub fn with_excluded_handles(mut self, handles: Vec<String>) -> Self {
215        self.excluded_x_handles = handles;
216        self
217    }
218
219    /// Enable media understanding
220    pub fn with_media_understanding(mut self, images: bool, videos: bool) -> Self {
221        self.enable_image_understanding = Some(images);
222        self.enable_video_understanding = Some(videos);
223        self
224    }
225
226    fn to_proto(&self) -> ProtoXSearch {
227        ProtoXSearch {
228            from_date: self.from_date,
229            to_date: self.to_date,
230            allowed_x_handles: self.allowed_x_handles.clone(),
231            excluded_x_handles: self.excluded_x_handles.clone(),
232            enable_image_understanding: self.enable_image_understanding,
233            enable_video_understanding: self.enable_video_understanding,
234        }
235    }
236}
237
238/// Collections search tool configuration
239#[derive(Clone, Debug)]
240pub struct CollectionsSearchTool {
241    /// Collection IDs to search (max 10)
242    pub collection_ids: Vec<String>,
243    /// Number of chunks to return per collection
244    pub limit: Option<i32>,
245}
246
247impl CollectionsSearchTool {
248    /// Create a new collections search tool
249    pub fn new(collection_ids: Vec<String>) -> Self {
250        Self {
251            collection_ids,
252            limit: None,
253        }
254    }
255
256    /// Set the limit of chunks to return
257    pub fn with_limit(mut self, limit: i32) -> Self {
258        self.limit = Some(limit);
259        self
260    }
261
262    fn to_proto(&self) -> ProtoCollectionsSearch {
263        ProtoCollectionsSearch {
264            collection_ids: self.collection_ids.clone(),
265            limit: self.limit,
266        }
267    }
268}
269
270/// Model Context Protocol server configuration
271#[derive(Clone, Debug)]
272pub struct McpTool {
273    /// Label for the MCP server
274    pub server_label: String,
275    /// Description of the server
276    pub server_description: String,
277    /// Server URL
278    pub server_url: String,
279    /// Allowed tool names
280    pub allowed_tool_names: Vec<String>,
281    /// Authorization token
282    pub authorization: Option<String>,
283    /// Extra headers to send
284    pub extra_headers: HashMap<String, String>,
285}
286
287impl McpTool {
288    /// Create a new MCP tool
289    pub fn new(server_url: impl Into<String>) -> Self {
290        Self {
291            server_label: String::new(),
292            server_description: String::new(),
293            server_url: server_url.into(),
294            allowed_tool_names: Vec::new(),
295            authorization: None,
296            extra_headers: HashMap::new(),
297        }
298    }
299
300    /// Set a label for the MCP server
301    pub fn with_label(mut self, label: impl Into<String>) -> Self {
302        self.server_label = label.into();
303        self
304    }
305
306    /// Set a description for the MCP server
307    pub fn with_description(mut self, description: impl Into<String>) -> Self {
308        self.server_description = description.into();
309        self
310    }
311
312    /// Set allowed tool names
313    pub fn with_allowed_tools(mut self, tools: Vec<String>) -> Self {
314        self.allowed_tool_names = tools;
315        self
316    }
317
318    /// Set authorization token
319    pub fn with_authorization(mut self, token: impl Into<String>) -> Self {
320        self.authorization = Some(token.into());
321        self
322    }
323
324    /// Add extra headers
325    pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
326        self.extra_headers = headers;
327        self
328    }
329
330    fn to_proto(&self) -> ProtoMcp {
331        ProtoMcp {
332            server_label: self.server_label.clone(),
333            server_description: self.server_description.clone(),
334            server_url: self.server_url.clone(),
335            allowed_tool_names: self.allowed_tool_names.clone(),
336            authorization: self.authorization.clone(),
337            extra_headers: self.extra_headers.clone(),
338        }
339    }
340}
341
342/// Document search tool configuration
343#[derive(Clone, Debug, Default)]
344pub struct DocumentSearchTool {
345    /// Number of files to limit search to
346    pub limit: Option<i32>,
347}
348
349impl DocumentSearchTool {
350    /// Create a new document search tool
351    pub fn new() -> Self {
352        Self::default()
353    }
354
355    /// Set the limit of files to search
356    pub fn with_limit(mut self, limit: i32) -> Self {
357        self.limit = Some(limit);
358        self
359    }
360
361    fn to_proto(&self) -> ProtoDocumentSearch {
362        ProtoDocumentSearch { limit: self.limit }
363    }
364}
365
366/// Strategy for how the model should use tools.
367///
368/// Controls whether the model can freely choose tools, must use a tool,
369/// or should call a specific function.
370#[derive(Clone, Debug)]
371pub enum ToolChoice {
372    /// Let the model decide whether to use tools.
373    Auto,
374    /// Require the model to use a tool.
375    Required,
376    /// Force the model to call a specific function.
377    Function(String),
378}
379
380impl ToolChoice {
381    /// Convert to protobuf representation
382    pub fn to_proto(&self) -> ProtoToolChoice {
383        let tool_choice = match self {
384            ToolChoice::Auto => proto::tool_choice::ToolChoice::Mode(ToolMode::Auto as i32),
385            ToolChoice::Required => proto::tool_choice::ToolChoice::Mode(ToolMode::Required as i32),
386            ToolChoice::Function(name) => {
387                proto::tool_choice::ToolChoice::FunctionName(name.clone())
388            }
389        };
390
391        ProtoToolChoice {
392            tool_choice: Some(tool_choice),
393        }
394    }
395}
396
397/// A tool call made by the model in a response.
398///
399/// Contains information about which tool was called, its status,
400/// and the function details including arguments.
401#[derive(Clone, Debug)]
402pub struct ToolCall {
403    /// Unique identifier for this tool call.
404    pub id: String,
405    /// Type of tool call (client-side or server-side).
406    pub call_type: ToolCallKind,
407    /// Status of the tool call execution.
408    pub status: ToolCallStatusKind,
409    /// Error message if the call failed.
410    pub error_message: Option<String>,
411    /// The actual function call details.
412    pub function: FunctionCall,
413}
414
415impl ToolCall {
416    /// Parse from protobuf representation
417    pub fn from_proto(proto: ProtoToolCall) -> Option<Self> {
418        let function = match proto.tool? {
419            proto::tool_call::Tool::Function(f) => FunctionCall {
420                name: f.name,
421                arguments: f.arguments,
422            },
423        };
424
425        Some(Self {
426            id: proto.id,
427            call_type: ToolCallKind::from_proto(proto.r#type),
428            status: ToolCallStatusKind::from_proto(proto.status),
429            error_message: proto.error_message,
430            function,
431        })
432    }
433
434    /// Convert to protobuf representation
435    pub fn to_proto(&self) -> ProtoToolCall {
436        ProtoToolCall {
437            id: self.id.clone(),
438            r#type: self.call_type.to_proto() as i32,
439            status: self.status.to_proto() as i32,
440            error_message: self.error_message.clone(),
441            tool: Some(proto::tool_call::Tool::Function(ProtoFunctionCall {
442                name: self.function.name.clone(),
443                arguments: self.function.arguments.clone(),
444            })),
445        }
446    }
447}
448
449/// Type of tool call
450#[derive(Clone, Debug, PartialEq, Eq)]
451pub enum ToolCallKind {
452    /// Client-side function (maps to OpenAI's function_call)
453    ClientSideTool,
454    /// Server-side web search
455    WebSearchTool,
456    /// Server-side X search
457    XSearchTool,
458    /// Server-side code execution
459    CodeExecutionTool,
460    /// Server-side collections search
461    CollectionsSearchTool,
462    /// Server-side MCP tool
463    McpTool,
464    /// Unknown or invalid type
465    Unknown,
466}
467
468impl ToolCallKind {
469    fn from_proto(value: i32) -> Self {
470        match value {
471            x if x == ToolCallType::ClientSideTool as i32 => ToolCallKind::ClientSideTool,
472            x if x == ToolCallType::WebSearchTool as i32 => ToolCallKind::WebSearchTool,
473            x if x == ToolCallType::XSearchTool as i32 => ToolCallKind::XSearchTool,
474            x if x == ToolCallType::CodeExecutionTool as i32 => ToolCallKind::CodeExecutionTool,
475            x if x == ToolCallType::CollectionsSearchTool as i32 => {
476                ToolCallKind::CollectionsSearchTool
477            }
478            x if x == ToolCallType::McpTool as i32 => ToolCallKind::McpTool,
479            _ => ToolCallKind::Unknown,
480        }
481    }
482
483    fn to_proto(&self) -> ToolCallType {
484        match self {
485            ToolCallKind::ClientSideTool => ToolCallType::ClientSideTool,
486            ToolCallKind::WebSearchTool => ToolCallType::WebSearchTool,
487            ToolCallKind::XSearchTool => ToolCallType::XSearchTool,
488            ToolCallKind::CodeExecutionTool => ToolCallType::CodeExecutionTool,
489            ToolCallKind::CollectionsSearchTool => ToolCallType::CollectionsSearchTool,
490            ToolCallKind::McpTool => ToolCallType::McpTool,
491            ToolCallKind::Unknown => ToolCallType::Invalid,
492        }
493    }
494}
495
496/// Status of a tool call
497#[derive(Clone, Debug, PartialEq, Eq)]
498pub enum ToolCallStatusKind {
499    /// Tool call is in progress
500    InProgress,
501    /// Tool call completed successfully
502    Completed,
503    /// Tool call incomplete
504    Incomplete,
505    /// Tool call failed
506    Failed,
507}
508
509impl ToolCallStatusKind {
510    fn from_proto(value: i32) -> Self {
511        match value {
512            x if x == ToolCallStatus::InProgress as i32 => ToolCallStatusKind::InProgress,
513            x if x == ToolCallStatus::Completed as i32 => ToolCallStatusKind::Completed,
514            x if x == ToolCallStatus::Incomplete as i32 => ToolCallStatusKind::Incomplete,
515            x if x == ToolCallStatus::Failed as i32 => ToolCallStatusKind::Failed,
516            _ => ToolCallStatusKind::InProgress, // default
517        }
518    }
519
520    fn to_proto(&self) -> ToolCallStatus {
521        match self {
522            ToolCallStatusKind::InProgress => ToolCallStatus::InProgress,
523            ToolCallStatusKind::Completed => ToolCallStatus::Completed,
524            ToolCallStatusKind::Incomplete => ToolCallStatus::Incomplete,
525            ToolCallStatusKind::Failed => ToolCallStatus::Failed,
526        }
527    }
528}
529
530/// Function call details
531#[derive(Clone, Debug)]
532pub struct FunctionCall {
533    /// Name of the function to call
534    pub name: String,
535    /// Arguments as JSON string
536    pub arguments: String,
537}
538
539impl FunctionCall {
540    /// Parse arguments as JSON
541    pub fn parse_arguments<T: serde::de::DeserializeOwned>(&self) -> serde_json::Result<T> {
542        serde_json::from_str(&self.arguments)
543    }
544
545    /// Get arguments as a JSON value
546    pub fn arguments_json(&self) -> serde_json::Result<Value> {
547        serde_json::from_str(&self.arguments)
548    }
549}
550
551#[cfg(test)]
552mod tests {
553    use super::*;
554    use serde_json::json;
555
556    #[test]
557    fn test_function_tool_creation() {
558        let tool = FunctionTool::new("get_weather", "Get current weather");
559        assert_eq!(tool.name, "get_weather");
560        assert_eq!(tool.description, "Get current weather");
561    }
562
563    #[test]
564    fn test_function_tool_with_parameters() {
565        let params = json!({
566            "type": "object",
567            "properties": {
568                "location": {"type": "string"}
569            }
570        });
571
572        let tool = FunctionTool::new("get_weather", "Get weather").with_parameters(params.clone());
573
574        assert_eq!(tool.parameters, params);
575    }
576
577    #[test]
578    fn test_web_search_tool() {
579        let tool = WebSearchTool::new().with_excluded_domains(vec!["spam.com".to_string()]);
580        assert_eq!(tool.excluded_domains.len(), 1);
581    }
582
583    #[test]
584    fn test_x_search_tool() {
585        let tool = XSearchTool::new().with_allowed_handles(vec!["@rustlang".to_string()]);
586        assert_eq!(tool.allowed_x_handles.len(), 1);
587    }
588
589    #[test]
590    fn test_tool_choice_auto() {
591        let choice = ToolChoice::Auto;
592        assert!(matches!(choice, ToolChoice::Auto));
593    }
594
595    #[test]
596    fn test_tool_choice_required() {
597        let choice = ToolChoice::Required;
598        assert!(matches!(choice, ToolChoice::Required));
599    }
600
601    #[test]
602    fn test_tool_choice_function() {
603        let choice = ToolChoice::Function("my_function".to_string());
604        match choice {
605            ToolChoice::Function(name) => assert_eq!(name, "my_function"),
606            _ => panic!("Expected Function variant"),
607        }
608    }
609
610    #[test]
611    fn test_function_call_parse_arguments() {
612        let call = FunctionCall {
613            name: "test_fn".to_string(),
614            arguments: r#"{"param": "value"}"#.to_string(),
615        };
616
617        let json = call.arguments_json().unwrap();
618        assert_eq!(json["param"], "value");
619    }
620
621    #[test]
622    fn test_mcp_tool() {
623        let tool = McpTool::new("https://example.com/mcp").with_label("My MCP Server");
624        assert_eq!(tool.server_url, "https://example.com/mcp");
625        assert_eq!(tool.server_label, "My MCP Server");
626    }
627
628    #[test]
629    fn test_collections_search_tool() {
630        let tool = CollectionsSearchTool::new(vec!["coll_1".to_string()]).with_limit(10);
631
632        assert_eq!(tool.collection_ids.len(), 1);
633        assert_eq!(tool.limit, Some(10));
634    }
635
636    #[test]
637    fn test_document_search_tool() {
638        let tool = DocumentSearchTool::new().with_limit(20);
639
640        assert_eq!(tool.limit, Some(20));
641    }
642}