zai_rs/model/tools.rs
1//! Tool definitions and configurations for the model API.
2//!
3//! This module defines the various tools that can be used by the assistant,
4//! including function calling, retrieval systems, web search, and MCP tools.
5
6use super::model_validate::validate_json_schema_value;
7use crate::tool::web_search::request::{ContentSize, SearchEngine, SearchRecencyFilter};
8use serde::Serialize;
9use std::collections::HashMap;
10use validator::*;
11
12/// Controls thinking/reasoning capabilities in AI models.
13///
14/// This enum determines whether a model should engage in step-by-step reasoning
15/// when processing requests. Thinking mode can improve accuracy for complex tasks
16/// but may increase response time and token usage.
17///
18/// ## Variants
19///
20/// - `Enabled` - Model performs explicit reasoning steps before responding
21/// - `Disabled` - Model responds directly without showing reasoning process
22///
23/// ## Usage
24///
25/// ```rust,ignore
26/// let client = ChatCompletion::new(model, messages, api_key)
27/// .with_thinking(ThinkingType::Enabled);
28/// ```
29///
30/// ## Model Compatibility
31///
32/// Thinking capabilities are available only on models that implement the
33/// `ThinkEnable` trait, such as GLM-4.5 series models.
34#[derive(Debug, Clone, Serialize)]
35#[serde(rename_all = "lowercase")]
36#[serde(tag = "type")]
37pub enum ThinkingType {
38 /// Enable thinking capabilities for enhanced reasoning.
39 ///
40 /// When enabled, the model will show its reasoning process step-by-step,
41 /// which can improve accuracy for complex logical or analytical tasks.
42 Enabled,
43
44 /// Disable thinking capabilities for direct responses.
45 ///
46 /// When disabled, the model responds directly without showing intermediate
47 /// reasoning steps, resulting in faster responses and lower token usage.
48 Disabled,
49}
50
51/// Available tools that AI assistants can invoke during conversations.
52///
53/// This enum defines the different categories of external tools and capabilities
54/// that can be made available to AI models. Each tool type serves specific purposes
55/// and has its own configuration requirements.
56///
57/// ## Tool Categories
58///
59/// ### Function Tools
60/// Custom user-defined functions that the AI can call with structured parameters.
61/// Useful for integrating external APIs, databases, or business logic.
62///
63/// ### Retrieval Tools
64/// Access to knowledge bases, document collections, or information retrieval systems.
65/// Enables the AI to query structured knowledge sources.
66///
67/// ### Web Search Tools
68/// Internet search capabilities for accessing current information.
69/// Allows the AI to perform web searches and retrieve up-to-date information.
70///
71/// ### MCP Tools
72/// Model Context Protocol tools for standardized tool integration.
73/// Provides a standardized interface for tool communication.
74///
75/// ## Usage
76///
77/// ```rust,ignore
78/// // Function tool
79/// let function_tool = Tools::Function {
80/// function: Function::new("get_weather", "Get weather data", parameters)
81/// };
82///
83/// // Web search tool
84/// let search_tool = Tools::WebSearch {
85/// web_search: WebSearch::new(SearchEngine::SearchPro)
86/// .with_enable(true)
87/// .with_count(10)
88/// };
89/// ```
90#[derive(Debug, Clone, Serialize)]
91#[serde(tag = "type")]
92#[serde(rename_all = "snake_case")]
93pub enum Tools {
94 /// Custom function calling tool with parameters.
95 ///
96 /// Allows the AI to invoke user-defined functions with structured arguments.
97 /// Functions must be pre-defined with JSON schemas for parameter validation.
98 Function { function: Function },
99
100 /// Knowledge retrieval system access tools.
101 ///
102 /// Provides access to knowledge bases, document collections, or other
103 /// structured information sources that the AI can query.
104 Retrieval { retrieval: Retrieval },
105
106 /// Web search capabilities for internet access.
107 ///
108 /// Enables the AI to perform web searches and access current information
109 /// from the internet. Supports various search engines and configurations.
110 WebSearch { web_search: WebSearch },
111
112 /// Model Context Protocol (MCP) tools.
113 ///
114 /// Standardized tools that follow the Model Context Protocol specification,
115 /// providing a consistent interface for tool integration and communication.
116 MCP { mcp: MCP },
117}
118
119/// Definition of a callable function tool.
120///
121/// This structure defines a function that can be called by the assistant,
122/// including its name, description, and parameter schema.
123///
124/// # Validation
125///
126/// * `name` - Must be between 1 and 64 characters
127/// * `parameters` - Must be a valid JSON schema
128#[derive(Debug, Clone, Serialize, Validate)]
129pub struct Function {
130 /// The name of the function. Must be between 1 and 64 characters.
131 #[validate(length(min = 1, max = 64))]
132 pub name: String,
133
134 /// A description of what the function does.
135 pub description: String,
136
137 /// JSON schema describing the function's parameters.
138 /// Server expects an object; keep as Value to avoid double-encoding strings.
139 #[serde(skip_serializing_if = "Option::is_none")]
140 #[validate(custom(function = "validate_json_schema_value"))]
141 pub parameters: Option<serde_json::Value>,
142}
143
144impl Function {
145 /// Creates a new function call definition.
146 ///
147 /// # Arguments
148 ///
149 /// * `name` - The name of the function
150 /// * `description` - A description of what the function does
151 /// * `parameters` - JSON schema string describing the function parameters
152 ///
153 /// # Returns
154 ///
155 /// A new `Function` instance.
156 ///
157 /// # Examples
158 ///
159 /// ```rust,ignore
160 /// let func = Function::new(
161 /// "get_weather",
162 /// "Get current weather for a location",
163 /// r#"{"type": "object", "properties": {"location": {"type": "string"}}}"#
164 /// );
165 /// ```
166 pub fn new(
167 name: impl Into<String>,
168 description: impl Into<String>,
169 parameters: serde_json::Value,
170 ) -> Self {
171 Self {
172 name: name.into(),
173 description: description.into(),
174 parameters: Some(parameters),
175 }
176 }
177}
178
179/// Configuration for retrieval tool capabilities.
180///
181/// This structure represents a retrieval tool that can access knowledge bases
182/// or document collections. Currently a placeholder for future expansion.
183#[derive(Debug, Clone, Serialize)]
184pub struct Retrieval {
185 knowledge_id: String,
186 #[serde(skip_serializing_if = "Option::is_none")]
187 prompt_template: Option<String>,
188}
189
190impl Retrieval {
191 /// Creates a new `Retrieval` instance.
192 pub fn new(knowledge_id: impl Into<String>, prompt_template: Option<String>) -> Self {
193 Self {
194 knowledge_id: knowledge_id.into(),
195 prompt_template,
196 }
197 }
198}
199
200/// Configuration for web search tool capabilities.
201///
202/// The order in which search results are returned.
203#[derive(Debug, Clone, Serialize)]
204#[serde(rename_all = "snake_case")]
205pub enum ResultSequence {
206 Before,
207 After,
208}
209
210/// This structure represents a web search tool that can perform internet searches.
211/// Fields mirror the external web_search schema.
212#[derive(Debug, Clone, Serialize, Validate)]
213pub struct WebSearch {
214 /// Search engine type (required). Supported: search_std, search_pro, search_pro_sogou, search_pro_quark.
215 pub search_engine: SearchEngine,
216
217 /// Whether to enable web search. Default is false.
218 #[serde(skip_serializing_if = "Option::is_none")]
219 pub enable: Option<bool>,
220
221 /// Force-triggered search query string.
222 #[serde(skip_serializing_if = "Option::is_none")]
223 pub search_query: Option<String>,
224
225 /// Whether to perform search intent detection. true: execute only when intent is detected; false: skip detection and search directly.
226 #[serde(skip_serializing_if = "Option::is_none")]
227 pub search_intent: Option<bool>,
228
229 /// Number of results to return (1-50).
230 #[serde(skip_serializing_if = "Option::is_none")]
231 #[validate(range(min = 1, max = 50))]
232 pub count: Option<u32>,
233
234 /// Whitelist domain filter, e.g., "www.example.com".
235 #[serde(skip_serializing_if = "Option::is_none")]
236 pub search_domain_filter: Option<String>,
237
238 /// Time range filter.
239 #[serde(skip_serializing_if = "Option::is_none")]
240 pub search_recency_filter: Option<SearchRecencyFilter>,
241
242 /// Snippet summary size: medium or high.
243 #[serde(skip_serializing_if = "Option::is_none")]
244 pub content_size: Option<ContentSize>,
245
246 /// Return sequence for search results: before or after.
247 #[serde(skip_serializing_if = "Option::is_none")]
248 pub result_sequence: Option<ResultSequence>,
249
250 /// Whether to include detailed search source information.
251 #[serde(skip_serializing_if = "Option::is_none")]
252 pub search_result: Option<bool>,
253
254 /// Whether an answer requires search results to be returned.
255 #[serde(skip_serializing_if = "Option::is_none")]
256 pub require_search: Option<bool>,
257
258 /// Custom prompt to post-process search results.
259 #[serde(skip_serializing_if = "Option::is_none")]
260 pub search_prompt: Option<String>,
261}
262
263impl WebSearch {
264 /// Create a WebSearch config with the required search engine; other fields are optional.
265 pub fn new(search_engine: SearchEngine) -> Self {
266 Self {
267 search_engine,
268 enable: None,
269 search_query: None,
270 search_intent: None,
271 count: None,
272 search_domain_filter: None,
273 search_recency_filter: None,
274 content_size: None,
275 result_sequence: None,
276 search_result: None,
277 require_search: None,
278 search_prompt: None,
279 }
280 }
281
282 /// Enable or disable web search.
283 pub fn with_enable(mut self, enable: bool) -> Self {
284 self.enable = Some(enable);
285 self
286 }
287 /// Set a forced search query.
288 pub fn with_search_query(mut self, query: impl Into<String>) -> Self {
289 self.search_query = Some(query.into());
290 self
291 }
292 /// Set search intent detection behavior.
293 pub fn with_search_intent(mut self, search_intent: bool) -> Self {
294 self.search_intent = Some(search_intent);
295 self
296 }
297 /// Set results count (1-50).
298 pub fn with_count(mut self, count: u32) -> Self {
299 self.count = Some(count);
300 self
301 }
302 /// Restrict to a whitelist domain.
303 pub fn with_search_domain_filter(mut self, domain: impl Into<String>) -> Self {
304 self.search_domain_filter = Some(domain.into());
305 self
306 }
307 /// Set time range filter.
308 pub fn with_search_recency_filter(mut self, filter: SearchRecencyFilter) -> Self {
309 self.search_recency_filter = Some(filter);
310 self
311 }
312 /// Set content size.
313 pub fn with_content_size(mut self, size: ContentSize) -> Self {
314 self.content_size = Some(size);
315 self
316 }
317 /// Set result sequence.
318 pub fn with_result_sequence(mut self, seq: ResultSequence) -> Self {
319 self.result_sequence = Some(seq);
320 self
321 }
322 /// Toggle returning detailed search source info.
323 pub fn with_search_result(mut self, enable: bool) -> Self {
324 self.search_result = Some(enable);
325 self
326 }
327 /// Require search results for answering.
328 pub fn with_require_search(mut self, require: bool) -> Self {
329 self.require_search = Some(require);
330 self
331 }
332 /// Set a custom prompt to post-process search results.
333 pub fn with_search_prompt(mut self, prompt: impl Into<String>) -> Self {
334 self.search_prompt = Some(prompt.into());
335 self
336 }
337}
338///
339/// Represents the MCP connection configuration. When connecting to Zhipu's MCP server
340/// using an MCP code, fill `server_label` with that code and leave `server_url` empty.
341#[derive(Debug, Clone, Serialize, Validate)]
342pub struct MCP {
343 /// MCP server identifier (required). If connecting to Zhipu MCP via code, put the code here.
344 #[validate(length(min = 1))]
345 pub server_label: String,
346
347 /// MCP server URL.
348 #[serde(skip_serializing_if = "Option::is_none")]
349 #[validate(url)]
350 pub server_url: Option<String>,
351
352 /// Transport type. Default: streamable-http.
353 #[serde(skip_serializing_if = "Option::is_none")]
354 pub transport_type: Option<MCPTransportType>,
355
356 /// Allowed tool names.
357 #[serde(skip_serializing_if = "Vec::is_empty")]
358 pub allowed_tools: Vec<String>,
359
360 /// Authentication headers required by the MCP server.
361 #[serde(skip_serializing_if = "Option::is_none")]
362 pub headers: Option<HashMap<String, String>>,
363}
364
365impl MCP {
366 /// Create a new MCP config with required server_label and default transport type.
367 pub fn new(server_label: impl Into<String>) -> Self {
368 Self {
369 server_label: server_label.into(),
370 server_url: None,
371 transport_type: Some(MCPTransportType::StreamableHttp),
372 allowed_tools: Vec::new(),
373 headers: None,
374 }
375 }
376
377 /// Set the MCP server URL.
378 pub fn with_server_url(mut self, url: impl Into<String>) -> Self {
379 self.server_url = Some(url.into());
380 self
381 }
382 /// Set the MCP transport type.
383 pub fn with_transport_type(mut self, transport: MCPTransportType) -> Self {
384 self.transport_type = Some(transport);
385 self
386 }
387 /// Replace the allowed tool list.
388 pub fn with_allowed_tools(mut self, tools: impl Into<Vec<String>>) -> Self {
389 self.allowed_tools = tools.into();
390 self
391 }
392 /// Add a single allowed tool.
393 pub fn add_allowed_tool(mut self, tool: impl Into<String>) -> Self {
394 self.allowed_tools.push(tool.into());
395 self
396 }
397 /// Set authentication headers map.
398 pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
399 self.headers = Some(headers);
400 self
401 }
402 /// Add or update a single header entry.
403 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
404 let mut map = self.headers.unwrap_or_default();
405 map.insert(key.into(), value.into());
406 self.headers = Some(map);
407 self
408 }
409}
410
411/// Allowed MCP transport types.
412#[derive(Debug, Clone, Serialize)]
413#[serde(rename_all = "kebab-case")]
414pub enum MCPTransportType {
415 Sse,
416 StreamableHttp,
417}
418
419/// Specifies the format for the model's response.
420///
421/// This enum controls how the model should structure its output, either as
422/// plain text or as a structured JSON object.
423///
424/// # Variants
425///
426/// * `Text` - Plain text response format
427/// * `JsonObject` - Structured JSON object response format
428#[derive(Debug, Clone, Copy, Serialize)]
429#[serde(rename_all = "snake_case")]
430#[serde(tag = "type")]
431pub enum ResponseFormat {
432 /// Plain text response format.
433 Text,
434 /// Structured JSON object response format.
435 JsonObject,
436}