zai_rs/model/
tools.rs

1//! Tool definitions and configurations for the model API.
2//!
3//! This module defines the various tools that can be used by the assistant,
4//! including function calling, retrieval systems, web search, and MCP tools.
5
6use super::model_validate::validate_json_schema_value;
7use crate::tool::web_search::request::{ContentSize, SearchEngine, SearchRecencyFilter};
8use serde::Serialize;
9use std::collections::HashMap;
10use validator::*;
11
12/// Controls thinking/reasoning capabilities in AI models.
13///
14/// This enum determines whether a model should engage in step-by-step reasoning
15/// when processing requests. Thinking mode can improve accuracy for complex tasks
16/// but may increase response time and token usage.
17///
18/// ## Variants
19///
20/// - `Enabled` - Model performs explicit reasoning steps before responding
21/// - `Disabled` - Model responds directly without showing reasoning process
22///
23/// ## Usage
24///
25/// ```rust,ignore
26/// let client = ChatCompletion::new(model, messages, api_key)
27///     .with_thinking(ThinkingType::Enabled);
28/// ```
29///
30/// ## Model Compatibility
31///
32/// Thinking capabilities are available only on models that implement the
33/// `ThinkEnable` trait, such as GLM-4.5 series models.
34#[derive(Debug, Clone, Serialize)]
35#[serde(rename_all = "lowercase")]
36#[serde(tag = "type")]
37pub enum ThinkingType {
38    /// Enable thinking capabilities for enhanced reasoning.
39    ///
40    /// When enabled, the model will show its reasoning process step-by-step,
41    /// which can improve accuracy for complex logical or analytical tasks.
42    Enabled,
43
44    /// Disable thinking capabilities for direct responses.
45    ///
46    /// When disabled, the model responds directly without showing intermediate
47    /// reasoning steps, resulting in faster responses and lower token usage.
48    Disabled,
49}
50
51/// Available tools that AI assistants can invoke during conversations.
52///
53/// This enum defines the different categories of external tools and capabilities
54/// that can be made available to AI models. Each tool type serves specific purposes
55/// and has its own configuration requirements.
56///
57/// ## Tool Categories
58///
59/// ### Function Tools
60/// Custom user-defined functions that the AI can call with structured parameters.
61/// Useful for integrating external APIs, databases, or business logic.
62///
63/// ### Retrieval Tools
64/// Access to knowledge bases, document collections, or information retrieval systems.
65/// Enables the AI to query structured knowledge sources.
66///
67/// ### Web Search Tools
68/// Internet search capabilities for accessing current information.
69/// Allows the AI to perform web searches and retrieve up-to-date information.
70///
71/// ### MCP Tools
72/// Model Context Protocol tools for standardized tool integration.
73/// Provides a standardized interface for tool communication.
74///
75/// ## Usage
76///
77/// ```rust,ignore
78/// // Function tool
79/// let function_tool = Tools::Function {
80///     function: Function::new("get_weather", "Get weather data", parameters)
81/// };
82///
83/// // Web search tool
84/// let search_tool = Tools::WebSearch {
85///     web_search: WebSearch::new(SearchEngine::SearchPro)
86///         .with_enable(true)
87///         .with_count(10)
88/// };
89/// ```
90#[derive(Debug, Clone, Serialize)]
91#[serde(tag = "type")]
92#[serde(rename_all = "snake_case")]
93pub enum Tools {
94    /// Custom function calling tool with parameters.
95    ///
96    /// Allows the AI to invoke user-defined functions with structured arguments.
97    /// Functions must be pre-defined with JSON schemas for parameter validation.
98    Function { function: Function },
99
100    /// Knowledge retrieval system access tools.
101    ///
102    /// Provides access to knowledge bases, document collections, or other
103    /// structured information sources that the AI can query.
104    Retrieval { retrieval: Retrieval },
105
106    /// Web search capabilities for internet access.
107    ///
108    /// Enables the AI to perform web searches and access current information
109    /// from the internet. Supports various search engines and configurations.
110    WebSearch { web_search: WebSearch },
111
112    /// Model Context Protocol (MCP) tools.
113    ///
114    /// Standardized tools that follow the Model Context Protocol specification,
115    /// providing a consistent interface for tool integration and communication.
116    MCP { mcp: MCP },
117}
118
119/// Definition of a callable function tool.
120///
121/// This structure defines a function that can be called by the assistant,
122/// including its name, description, and parameter schema.
123///
124/// # Validation
125///
126/// * `name` - Must be between 1 and 64 characters
127/// * `parameters` - Must be a valid JSON schema
128#[derive(Debug, Clone, Serialize, Validate)]
129pub struct Function {
130    /// The name of the function. Must be between 1 and 64 characters.
131    #[validate(length(min = 1, max = 64))]
132    pub name: String,
133
134    /// A description of what the function does.
135    pub description: String,
136
137    /// JSON schema describing the function's parameters.
138    /// Server expects an object; keep as Value to avoid double-encoding strings.
139    #[serde(skip_serializing_if = "Option::is_none")]
140    #[validate(custom(function = "validate_json_schema_value"))]
141    pub parameters: Option<serde_json::Value>,
142}
143
144impl Function {
145    /// Creates a new function call definition.
146    ///
147    /// # Arguments
148    ///
149    /// * `name` - The name of the function
150    /// * `description` - A description of what the function does
151    /// * `parameters` - JSON schema string describing the function parameters
152    ///
153    /// # Returns
154    ///
155    /// A new `Function` instance.
156    ///
157    /// # Examples
158    ///
159    /// ```rust,ignore
160    /// let func = Function::new(
161    ///     "get_weather",
162    ///     "Get current weather for a location",
163    ///     r#"{"type": "object", "properties": {"location": {"type": "string"}}}"#
164    /// );
165    /// ```
166    pub fn new(
167        name: impl Into<String>,
168        description: impl Into<String>,
169        parameters: serde_json::Value,
170    ) -> Self {
171        Self {
172            name: name.into(),
173            description: description.into(),
174            parameters: Some(parameters),
175        }
176    }
177}
178
179/// Configuration for retrieval tool capabilities.
180///
181/// This structure represents a retrieval tool that can access knowledge bases
182/// or document collections. Currently a placeholder for future expansion.
183#[derive(Debug, Clone, Serialize)]
184pub struct Retrieval {
185    knowledge_id: String,
186    #[serde(skip_serializing_if = "Option::is_none")]
187    prompt_template: Option<String>,
188}
189
190impl Retrieval {
191    /// Creates a new `Retrieval` instance.
192    pub fn new(knowledge_id: impl Into<String>, prompt_template: Option<String>) -> Self {
193        Self {
194            knowledge_id: knowledge_id.into(),
195            prompt_template,
196        }
197    }
198}
199
200/// Configuration for web search tool capabilities.
201///
202/// The order in which search results are returned.
203#[derive(Debug, Clone, Serialize)]
204#[serde(rename_all = "snake_case")]
205pub enum ResultSequence {
206    Before,
207    After,
208}
209
210/// This structure represents a web search tool that can perform internet searches.
211/// Fields mirror the external web_search schema.
212#[derive(Debug, Clone, Serialize, Validate)]
213pub struct WebSearch {
214    /// Search engine type (required). Supported: search_std, search_pro, search_pro_sogou, search_pro_quark.
215    pub search_engine: SearchEngine,
216
217    /// Whether to enable web search. Default is false.
218    #[serde(skip_serializing_if = "Option::is_none")]
219    pub enable: Option<bool>,
220
221    /// Force-triggered search query string.
222    #[serde(skip_serializing_if = "Option::is_none")]
223    pub search_query: Option<String>,
224
225    /// Whether to perform search intent detection. true: execute only when intent is detected; false: skip detection and search directly.
226    #[serde(skip_serializing_if = "Option::is_none")]
227    pub search_intent: Option<bool>,
228
229    /// Number of results to return (1-50).
230    #[serde(skip_serializing_if = "Option::is_none")]
231    #[validate(range(min = 1, max = 50))]
232    pub count: Option<u32>,
233
234    /// Whitelist domain filter, e.g., "www.example.com".
235    #[serde(skip_serializing_if = "Option::is_none")]
236    pub search_domain_filter: Option<String>,
237
238    /// Time range filter.
239    #[serde(skip_serializing_if = "Option::is_none")]
240    pub search_recency_filter: Option<SearchRecencyFilter>,
241
242    /// Snippet summary size: medium or high.
243    #[serde(skip_serializing_if = "Option::is_none")]
244    pub content_size: Option<ContentSize>,
245
246    /// Return sequence for search results: before or after.
247    #[serde(skip_serializing_if = "Option::is_none")]
248    pub result_sequence: Option<ResultSequence>,
249
250    /// Whether to include detailed search source information.
251    #[serde(skip_serializing_if = "Option::is_none")]
252    pub search_result: Option<bool>,
253
254    /// Whether an answer requires search results to be returned.
255    #[serde(skip_serializing_if = "Option::is_none")]
256    pub require_search: Option<bool>,
257
258    /// Custom prompt to post-process search results.
259    #[serde(skip_serializing_if = "Option::is_none")]
260    pub search_prompt: Option<String>,
261}
262
263impl WebSearch {
264    /// Create a WebSearch config with the required search engine; other fields are optional.
265    pub fn new(search_engine: SearchEngine) -> Self {
266        Self {
267            search_engine,
268            enable: None,
269            search_query: None,
270            search_intent: None,
271            count: None,
272            search_domain_filter: None,
273            search_recency_filter: None,
274            content_size: None,
275            result_sequence: None,
276            search_result: None,
277            require_search: None,
278            search_prompt: None,
279        }
280    }
281
282    /// Enable or disable web search.
283    pub fn with_enable(mut self, enable: bool) -> Self {
284        self.enable = Some(enable);
285        self
286    }
287    /// Set a forced search query.
288    pub fn with_search_query(mut self, query: impl Into<String>) -> Self {
289        self.search_query = Some(query.into());
290        self
291    }
292    /// Set search intent detection behavior.
293    pub fn with_search_intent(mut self, search_intent: bool) -> Self {
294        self.search_intent = Some(search_intent);
295        self
296    }
297    /// Set results count (1-50).
298    pub fn with_count(mut self, count: u32) -> Self {
299        self.count = Some(count);
300        self
301    }
302    /// Restrict to a whitelist domain.
303    pub fn with_search_domain_filter(mut self, domain: impl Into<String>) -> Self {
304        self.search_domain_filter = Some(domain.into());
305        self
306    }
307    /// Set time range filter.
308    pub fn with_search_recency_filter(mut self, filter: SearchRecencyFilter) -> Self {
309        self.search_recency_filter = Some(filter);
310        self
311    }
312    /// Set content size.
313    pub fn with_content_size(mut self, size: ContentSize) -> Self {
314        self.content_size = Some(size);
315        self
316    }
317    /// Set result sequence.
318    pub fn with_result_sequence(mut self, seq: ResultSequence) -> Self {
319        self.result_sequence = Some(seq);
320        self
321    }
322    /// Toggle returning detailed search source info.
323    pub fn with_search_result(mut self, enable: bool) -> Self {
324        self.search_result = Some(enable);
325        self
326    }
327    /// Require search results for answering.
328    pub fn with_require_search(mut self, require: bool) -> Self {
329        self.require_search = Some(require);
330        self
331    }
332    /// Set a custom prompt to post-process search results.
333    pub fn with_search_prompt(mut self, prompt: impl Into<String>) -> Self {
334        self.search_prompt = Some(prompt.into());
335        self
336    }
337}
338///
339/// Represents the MCP connection configuration. When connecting to Zhipu's MCP server
340/// using an MCP code, fill `server_label` with that code and leave `server_url` empty.
341#[derive(Debug, Clone, Serialize, Validate)]
342pub struct MCP {
343    /// MCP server identifier (required). If connecting to Zhipu MCP via code, put the code here.
344    #[validate(length(min = 1))]
345    pub server_label: String,
346
347    /// MCP server URL.
348    #[serde(skip_serializing_if = "Option::is_none")]
349    #[validate(url)]
350    pub server_url: Option<String>,
351
352    /// Transport type. Default: streamable-http.
353    #[serde(skip_serializing_if = "Option::is_none")]
354    pub transport_type: Option<MCPTransportType>,
355
356    /// Allowed tool names.
357    #[serde(skip_serializing_if = "Vec::is_empty")]
358    pub allowed_tools: Vec<String>,
359
360    /// Authentication headers required by the MCP server.
361    #[serde(skip_serializing_if = "Option::is_none")]
362    pub headers: Option<HashMap<String, String>>,
363}
364
365impl MCP {
366    /// Create a new MCP config with required server_label and default transport type.
367    pub fn new(server_label: impl Into<String>) -> Self {
368        Self {
369            server_label: server_label.into(),
370            server_url: None,
371            transport_type: Some(MCPTransportType::StreamableHttp),
372            allowed_tools: Vec::new(),
373            headers: None,
374        }
375    }
376
377    /// Set the MCP server URL.
378    pub fn with_server_url(mut self, url: impl Into<String>) -> Self {
379        self.server_url = Some(url.into());
380        self
381    }
382    /// Set the MCP transport type.
383    pub fn with_transport_type(mut self, transport: MCPTransportType) -> Self {
384        self.transport_type = Some(transport);
385        self
386    }
387    /// Replace the allowed tool list.
388    pub fn with_allowed_tools(mut self, tools: impl Into<Vec<String>>) -> Self {
389        self.allowed_tools = tools.into();
390        self
391    }
392    /// Add a single allowed tool.
393    pub fn add_allowed_tool(mut self, tool: impl Into<String>) -> Self {
394        self.allowed_tools.push(tool.into());
395        self
396    }
397    /// Set authentication headers map.
398    pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
399        self.headers = Some(headers);
400        self
401    }
402    /// Add or update a single header entry.
403    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
404        let mut map = self.headers.unwrap_or_default();
405        map.insert(key.into(), value.into());
406        self.headers = Some(map);
407        self
408    }
409}
410
411/// Allowed MCP transport types.
412#[derive(Debug, Clone, Serialize)]
413#[serde(rename_all = "kebab-case")]
414pub enum MCPTransportType {
415    Sse,
416    StreamableHttp,
417}
418
419/// Specifies the format for the model's response.
420///
421/// This enum controls how the model should structure its output, either as
422/// plain text or as a structured JSON object.
423///
424/// # Variants
425///
426/// * `Text` - Plain text response format
427/// * `JsonObject` - Structured JSON object response format
428#[derive(Debug, Clone, Copy, Serialize)]
429#[serde(rename_all = "snake_case")]
430#[serde(tag = "type")]
431pub enum ResponseFormat {
432    /// Plain text response format.
433    Text,
434    /// Structured JSON object response format.
435    JsonObject,
436}