Skip to main content

zai_rs/model/
tools.rs

1//! # Tool Definitions & Configurations
2//!
3//! Defines the tool types that can be attached to chat requests, including
4//! function calling, web search integration, retrieval tools, and the
5//! [`ThinkingType`] configuration.
6//!
7//! # Key Types
8//!
9//! - [`ThinkingType`] — Controls reasoning mode for thinking-capable models
10//! - [`FunctionTool`] — Defines a callable function with JSON-schema parameters
11//! - [`WebSearchTool`] — Enables live web search within chat
12//! - [`Retrieval`] — Enables knowledge-base retrieval
13//! - [`ToolChoice`] — Controls tool-selection behaviour (`auto`, `none`, or
14//!   specific function)
15
16use std::collections::HashMap;
17
18use serde::Serialize;
19use validator::*;
20
21use super::model_validate::validate_json_schema_value;
22use crate::tool::web_search::request::{ContentSize, SearchEngine, SearchRecencyFilter};
23
24/// Controls thinking/reasoning capabilities in AI models.
25///
26/// This structure determines whether a model should engage in step-by-step
27/// reasoning when processing requests, and whether to preserve reasoning
28/// content across turns via `clear_thinking`. Thinking mode can improve
29/// accuracy for complex tasks but may increase response time and token usage.
30///
31/// ## Fields
32///
33/// - `mode` - Whether thinking is enabled or disabled
34/// - `clear_thinking` - When `false`, preserves `reasoning_content` across
35///   turns (recommended for Coding / Agent scenarios)
36///
37/// ## Usage
38///
39/// ```rust,ignore
40/// let client = ChatCompletion::new(model, messages, api_key)
41///     .with_thinking(ThinkingType::enabled());
42///
43/// // Preserve reasoning content across turns (Coding / Agent)
44/// let client = ChatCompletion::new(model, messages, api_key)
45///     .with_thinking(ThinkingType::enabled().with_clear_thinking(false));
46/// ```
47///
48/// ## Model Compatibility
49///
50/// Thinking capabilities are available only on models that implement the
51/// `ThinkEnable` trait, such as GLM-5.1, GLM-5, GLM-4.7, and GLM-4.5 series
52/// models.
53#[derive(Debug, Clone, Serialize)]
54pub struct ThinkingType {
55    /// Whether thinking is enabled or disabled.
56    #[serde(rename = "type")]
57    pub mode: ThinkingMode,
58
59    /// Whether to clear historical `reasoning_content`.
60    ///
61    /// - `true` (default for standard API): Clears reasoning content each turn.
62    /// - `false` (recommended for Coding / Agent): Preserves reasoning content
63    ///   across turns, enabling better context for multi-step tool calls.
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub clear_thinking: Option<bool>,
66}
67
68/// Thinking mode variants.
69#[derive(Debug, Clone, Serialize)]
70#[serde(rename_all = "lowercase")]
71pub enum ThinkingMode {
72    Enabled,
73    Disabled,
74}
75
76impl ThinkingType {
77    /// Create a new thinking configuration with enabled mode.
78    pub fn enabled() -> Self {
79        Self {
80            mode: ThinkingMode::Enabled,
81            clear_thinking: None,
82        }
83    }
84
85    /// Create a new thinking configuration with disabled mode.
86    pub fn disabled() -> Self {
87        Self {
88            mode: ThinkingMode::Disabled,
89            clear_thinking: None,
90        }
91    }
92
93    /// Set whether to clear historical reasoning content.
94    ///
95    /// Use `false` for Coding / Agent scenarios where reasoning content should
96    /// be preserved across turns.
97    pub fn with_clear_thinking(mut self, clear: bool) -> Self {
98        self.clear_thinking = Some(clear);
99        self
100    }
101}
102
103/// Available tools that AI assistants can invoke during conversations.
104///
105/// This enum defines the different categories of external tools and
106/// capabilities that can be made available to AI models. Each tool type serves
107/// specific purposes and has its own configuration requirements.
108///
109/// ## Tool Categories
110///
111/// ### Function Tools
112/// Custom user-defined functions that the AI can call with structured
113/// parameters. Useful for integrating external APIs, databases, or business
114/// logic.
115///
116/// ### Retrieval Tools
117/// Access to knowledge bases, document collections, or information retrieval
118/// systems. Enables the AI to query structured knowledge sources.
119///
120/// ### Web Search Tools
121/// Internet search capabilities for accessing current information.
122/// Allows the AI to perform web searches and retrieve up-to-date information.
123///
124/// ### MCP Tools
125/// Model Context Protocol tools for standardized tool integration.
126/// Provides a standardized interface for tool communication.
127///
128/// ## Usage
129///
130/// ```rust,ignore
131/// // Function tool
132/// let function_tool = Tools::Function {
133///     function: Function::new("get_weather", "Get weather data", parameters)
134/// };
135///
136/// // Web search tool
137/// let search_tool = Tools::WebSearch {
138///     web_search: WebSearch::new(SearchEngine::SearchPro)
139///         .with_enable(true)
140///         .with_count(10)
141/// };
142/// ```
143#[derive(Debug, Clone, Serialize)]
144#[serde(tag = "type")]
145#[serde(rename_all = "snake_case")]
146pub enum Tools {
147    /// Custom function calling tool with parameters.
148    ///
149    /// Allows the AI to invoke user-defined functions with structured
150    /// arguments. Functions must be pre-defined with JSON schemas for
151    /// parameter validation.
152    Function { function: Function },
153
154    /// Knowledge retrieval system access tools.
155    ///
156    /// Provides access to knowledge bases, document collections, or other
157    /// structured information sources that the AI can query.
158    Retrieval { retrieval: Retrieval },
159
160    /// Web search capabilities for internet access.
161    ///
162    /// Enables the AI to perform web searches and access current information
163    /// from the internet. Supports various search engines and configurations.
164    WebSearch { web_search: WebSearch },
165
166    /// Model Context Protocol (MCP) tools.
167    ///
168    /// Standardized tools that follow the Model Context Protocol specification,
169    /// providing a consistent interface for tool integration and communication.
170    #[serde(rename = "mcp")]
171    MCP { mcp: MCP },
172}
173
174/// Definition of a callable function tool.
175///
176/// This structure defines a function that can be called by the assistant,
177/// including its name, description, and parameter schema.
178///
179/// # Validation
180///
181/// * `name` - Must be between 1 and 64 characters
182/// * `parameters` - Must be a valid JSON schema
183#[derive(Debug, Clone, Serialize, Validate)]
184pub struct Function {
185    /// The name of the function. Must be between 1 and 64 characters.
186    #[validate(length(min = 1, max = 64))]
187    pub name: String,
188
189    /// A description of what the function does.
190    pub description: String,
191
192    /// JSON schema describing the function's parameters.
193    /// Server expects an object; keep as Value to avoid double-encoding
194    /// strings.
195    #[serde(skip_serializing_if = "Option::is_none")]
196    #[validate(custom(function = "validate_json_schema_value"))]
197    pub parameters: Option<serde_json::Value>,
198}
199
200impl Function {
201    /// Creates a new function call definition.
202    ///
203    /// # Arguments
204    ///
205    /// * `name` - The name of the function
206    /// * `description` - A description of what the function does
207    /// * `parameters` - JSON schema string describing the function parameters
208    ///
209    /// # Returns
210    ///
211    /// A new `Function` instance.
212    ///
213    /// # Examples
214    ///
215    /// ```rust,ignore
216    /// let func = Function::new(
217    ///     "get_weather",
218    ///     "Get current weather for a location",
219    ///     r#"{"type": "object", "properties": {"location": {"type": "string"}}}"#
220    /// );
221    /// ```
222    pub fn new(
223        name: impl Into<String>,
224        description: impl Into<String>,
225        parameters: serde_json::Value,
226    ) -> Self {
227        Self {
228            name: name.into(),
229            description: description.into(),
230            parameters: Some(parameters),
231        }
232    }
233}
234
235/// Configuration for retrieval tool capabilities.
236///
237/// This structure represents a retrieval tool that can access knowledge bases
238/// or document collections. Currently a placeholder for future expansion.
239#[derive(Debug, Clone, Serialize)]
240pub struct Retrieval {
241    knowledge_id: String,
242    #[serde(skip_serializing_if = "Option::is_none")]
243    prompt_template: Option<String>,
244}
245
246impl Retrieval {
247    /// Creates a new `Retrieval` instance.
248    pub fn new(knowledge_id: impl Into<String>, prompt_template: Option<String>) -> Self {
249        Self {
250            knowledge_id: knowledge_id.into(),
251            prompt_template,
252        }
253    }
254}
255
256/// Configuration for web search tool capabilities.
257///
258/// The order in which search results are returned.
259#[derive(Debug, Clone, Serialize, PartialEq)]
260#[serde(rename_all = "snake_case")]
261pub enum ResultSequence {
262    Before,
263    After,
264}
265
266/// This structure represents a web search tool that can perform internet
267/// searches. Fields mirror the external web_search schema.
268#[derive(Debug, Clone, Serialize, Validate)]
269pub struct WebSearch {
270    /// Search engine type (required). Supported: search_std, search_pro,
271    /// search_pro_sogou, search_pro_quark.
272    pub search_engine: SearchEngine,
273
274    /// Whether to enable web search. Default is false.
275    #[serde(skip_serializing_if = "Option::is_none")]
276    pub enable: Option<bool>,
277
278    /// Force-triggered search query string.
279    #[serde(skip_serializing_if = "Option::is_none")]
280    pub search_query: Option<String>,
281
282    /// Whether to perform search intent detection. true: execute only when
283    /// intent is detected; false: skip detection and search directly.
284    #[serde(skip_serializing_if = "Option::is_none")]
285    pub search_intent: Option<bool>,
286
287    /// Number of results to return (1-50).
288    #[serde(skip_serializing_if = "Option::is_none")]
289    #[validate(range(min = 1, max = 50))]
290    pub count: Option<u32>,
291
292    /// Whitelist domain filter, e.g., "www.example.com".
293    #[serde(skip_serializing_if = "Option::is_none")]
294    pub search_domain_filter: Option<String>,
295
296    /// Time range filter.
297    #[serde(skip_serializing_if = "Option::is_none")]
298    pub search_recency_filter: Option<SearchRecencyFilter>,
299
300    /// Snippet summary size: medium or high.
301    #[serde(skip_serializing_if = "Option::is_none")]
302    pub content_size: Option<ContentSize>,
303
304    /// Return sequence for search results: before or after.
305    #[serde(skip_serializing_if = "Option::is_none")]
306    pub result_sequence: Option<ResultSequence>,
307
308    /// Whether to include detailed search source information.
309    #[serde(skip_serializing_if = "Option::is_none")]
310    pub search_result: Option<bool>,
311
312    /// Whether an answer requires search results to be returned.
313    #[serde(skip_serializing_if = "Option::is_none")]
314    pub require_search: Option<bool>,
315
316    /// Custom prompt to post-process search results.
317    #[serde(skip_serializing_if = "Option::is_none")]
318    pub search_prompt: Option<String>,
319}
320
321impl WebSearch {
322    /// Create a WebSearch config with the required search engine; other fields
323    /// are optional.
324    pub fn new(search_engine: SearchEngine) -> Self {
325        Self {
326            search_engine,
327            enable: None,
328            search_query: None,
329            search_intent: None,
330            count: None,
331            search_domain_filter: None,
332            search_recency_filter: None,
333            content_size: None,
334            result_sequence: None,
335            search_result: None,
336            require_search: None,
337            search_prompt: None,
338        }
339    }
340
341    /// Enable or disable web search.
342    pub fn with_enable(mut self, enable: bool) -> Self {
343        self.enable = Some(enable);
344        self
345    }
346    /// Set a forced search query.
347    pub fn with_search_query(mut self, query: impl Into<String>) -> Self {
348        self.search_query = Some(query.into());
349        self
350    }
351    /// Set search intent detection behavior.
352    pub fn with_search_intent(mut self, search_intent: bool) -> Self {
353        self.search_intent = Some(search_intent);
354        self
355    }
356    /// Set results count (1-50).
357    pub fn with_count(mut self, count: u32) -> Self {
358        self.count = Some(count);
359        self
360    }
361    /// Restrict to a whitelist domain.
362    pub fn with_search_domain_filter(mut self, domain: impl Into<String>) -> Self {
363        self.search_domain_filter = Some(domain.into());
364        self
365    }
366    /// Set time range filter.
367    pub fn with_search_recency_filter(mut self, filter: SearchRecencyFilter) -> Self {
368        self.search_recency_filter = Some(filter);
369        self
370    }
371    /// Set content size.
372    pub fn with_content_size(mut self, size: ContentSize) -> Self {
373        self.content_size = Some(size);
374        self
375    }
376    /// Set result sequence.
377    pub fn with_result_sequence(mut self, seq: ResultSequence) -> Self {
378        self.result_sequence = Some(seq);
379        self
380    }
381    /// Toggle returning detailed search source info.
382    pub fn with_search_result(mut self, enable: bool) -> Self {
383        self.search_result = Some(enable);
384        self
385    }
386    /// Require search results for answering.
387    pub fn with_require_search(mut self, require: bool) -> Self {
388        self.require_search = Some(require);
389        self
390    }
391    /// Set a custom prompt to post-process search results.
392    pub fn with_search_prompt(mut self, prompt: impl Into<String>) -> Self {
393        self.search_prompt = Some(prompt.into());
394        self
395    }
396}
397/// Represents the MCP connection configuration. When connecting to Zhipu's MCP
398/// server using an MCP code, fill `server_label` with that code and leave
399/// `server_url` empty.
400#[derive(Debug, Clone, Serialize, Validate)]
401pub struct MCP {
402    /// MCP server identifier (required). If connecting to Zhipu MCP via code,
403    /// put the code here.
404    #[validate(length(min = 1))]
405    pub server_label: String,
406
407    /// MCP server URL.
408    #[serde(skip_serializing_if = "Option::is_none")]
409    #[validate(url)]
410    pub server_url: Option<String>,
411
412    /// Transport type. Default: streamable-http.
413    #[serde(skip_serializing_if = "Option::is_none")]
414    pub transport_type: Option<MCPTransportType>,
415
416    /// Allowed tool names.
417    #[serde(skip_serializing_if = "Vec::is_empty")]
418    pub allowed_tools: Vec<String>,
419
420    /// Authentication headers required by the MCP server.
421    #[serde(skip_serializing_if = "Option::is_none")]
422    pub headers: Option<HashMap<String, String>>,
423}
424
425impl MCP {
426    /// Create a new MCP config with required server_label and default transport
427    /// type.
428    pub fn new(server_label: impl Into<String>) -> Self {
429        Self {
430            server_label: server_label.into(),
431            server_url: None,
432            transport_type: Some(MCPTransportType::StreamableHttp),
433            allowed_tools: Vec::new(),
434            headers: None,
435        }
436    }
437
438    /// Set the MCP server URL.
439    pub fn with_server_url(mut self, url: impl Into<String>) -> Self {
440        self.server_url = Some(url.into());
441        self
442    }
443    /// Set the MCP transport type.
444    pub fn with_transport_type(mut self, transport: MCPTransportType) -> Self {
445        self.transport_type = Some(transport);
446        self
447    }
448    /// Replace the allowed tool list.
449    pub fn with_allowed_tools(mut self, tools: impl Into<Vec<String>>) -> Self {
450        self.allowed_tools = tools.into();
451        self
452    }
453    /// Add a single allowed tool.
454    pub fn add_allowed_tool(mut self, tool: impl Into<String>) -> Self {
455        self.allowed_tools.push(tool.into());
456        self
457    }
458    /// Set authentication headers map.
459    pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
460        self.headers = Some(headers);
461        self
462    }
463    /// Add or update a single header entry.
464    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
465        let mut map = self.headers.unwrap_or_default();
466        map.insert(key.into(), value.into());
467        self.headers = Some(map);
468        self
469    }
470}
471
472/// Allowed MCP transport types.
473#[derive(Debug, Clone, Serialize, PartialEq)]
474#[serde(rename_all = "kebab-case")]
475pub enum MCPTransportType {
476    Sse,
477    StreamableHttp,
478}
479
480/// Specifies the format for the model's response.
481///
482/// This enum controls how the model should structure its output, either as
483/// plain text or as a structured JSON object.
484///
485/// # Variants
486///
487/// * `Text` - Plain text response format
488/// * `JsonObject` - Structured JSON object response format
489#[derive(Debug, Clone, Copy, Serialize)]
490#[serde(rename_all = "snake_case")]
491#[serde(tag = "type")]
492pub enum ResponseFormat {
493    /// Plain text response format.
494    Text,
495    /// Structured JSON object response format.
496    JsonObject,
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502
503    // ThinkingType tests
504    #[test]
505    fn test_thinking_type_enabled_serialization() {
506        let thinking = ThinkingType::enabled();
507        let json = serde_json::to_string(&thinking).unwrap();
508        assert!(json.contains("\"type\":\"enabled\""));
509        assert!(!json.contains("clear_thinking"));
510    }
511
512    #[test]
513    fn test_thinking_type_disabled_serialization() {
514        let thinking = ThinkingType::disabled();
515        let json = serde_json::to_string(&thinking).unwrap();
516        assert!(json.contains("\"type\":\"disabled\""));
517        assert!(!json.contains("clear_thinking"));
518    }
519
520    #[test]
521    fn test_thinking_type_with_clear_thinking_serialization() {
522        let thinking = ThinkingType::enabled().with_clear_thinking(false);
523        let json = serde_json::to_string(&thinking).unwrap();
524        assert!(json.contains("\"type\":\"enabled\""));
525        assert!(json.contains("\"clear_thinking\":false"));
526    }
527
528    #[test]
529    fn test_thinking_type_disabled_with_clear_thinking() {
530        let thinking = ThinkingType::disabled().with_clear_thinking(true);
531        let json = serde_json::to_string(&thinking).unwrap();
532        assert!(json.contains("\"type\":\"disabled\""));
533        assert!(json.contains("\"clear_thinking\":true"));
534    }
535
536    // Function tests
537    #[test]
538    fn test_function_new() {
539        let params = serde_json::json!({
540            "type": "object",
541            "properties": {
542                "name": {"type": "string"}
543            }
544        });
545        let func = Function::new("test_func", "A test function", params);
546
547        assert_eq!(func.name, "test_func");
548        assert_eq!(func.description, "A test function");
549        assert!(func.parameters.is_some());
550    }
551
552    #[test]
553    fn test_function_serialization() {
554        let params = serde_json::json!({
555            "type": "object",
556            "properties": {
557                "value": {"type": "number"}
558            }
559        });
560        let func = Function::new("test_func", "A test function", params);
561        let json = serde_json::to_string(&func).unwrap();
562
563        assert!(json.contains("\"name\":\"test_func\""));
564        assert!(json.contains("\"description\":\"A test function\""));
565        assert!(json.contains("\"properties\""));
566    }
567
568    #[test]
569    fn test_function_validation() {
570        let params = serde_json::json!({
571            "type": "object",
572            "properties": {}
573        });
574        let func = Function::new("valid_name", "Description", params.clone());
575
576        // Name length validation: 1-64 characters
577        assert!(func.validate().is_ok());
578
579        let invalid_name = Function::new("", "Description", params.clone());
580        assert!(invalid_name.validate().is_err());
581
582        let long_name = Function::new("a".repeat(65), "Description", params);
583        assert!(long_name.validate().is_err());
584    }
585
586    // Retrieval tests
587    #[test]
588    fn test_retrieval_new() {
589        let retrieval = Retrieval::new("kb_123", Some("template".to_string()));
590        assert_eq!(retrieval.knowledge_id, "kb_123");
591        assert_eq!(retrieval.prompt_template, Some("template".to_string()));
592    }
593
594    #[test]
595    fn test_retrieval_new_without_template() {
596        let retrieval = Retrieval::new("kb_456", None);
597        assert_eq!(retrieval.knowledge_id, "kb_456");
598        assert!(retrieval.prompt_template.is_none());
599    }
600
601    #[test]
602    fn test_retrieval_serialization() {
603        let retrieval = Retrieval::new("kb_789", None);
604        let json = serde_json::to_string(&retrieval).unwrap();
605        assert!(json.contains("\"knowledge_id\":\"kb_789\""));
606        // prompt_template should be omitted when None
607        assert!(!json.contains("prompt_template"));
608    }
609
610    // WebSearch tests
611    #[test]
612    fn test_web_search_new() {
613        let web_search = WebSearch::new(SearchEngine::SearchPro);
614        assert_eq!(web_search.search_engine, SearchEngine::SearchPro);
615        assert!(web_search.enable.is_none());
616    }
617
618    #[test]
619    fn test_web_search_with_enable() {
620        let web_search = WebSearch::new(SearchEngine::SearchPro).with_enable(true);
621        assert_eq!(web_search.enable, Some(true));
622    }
623
624    #[test]
625    fn test_web_search_with_search_query() {
626        let web_search = WebSearch::new(SearchEngine::SearchPro).with_search_query("test query");
627        assert_eq!(web_search.search_query, Some("test query".to_string()));
628    }
629
630    #[test]
631    fn test_web_search_with_search_intent() {
632        let web_search = WebSearch::new(SearchEngine::SearchPro).with_search_intent(true);
633        assert_eq!(web_search.search_intent, Some(true));
634    }
635
636    #[test]
637    fn test_web_search_with_count() {
638        let web_search = WebSearch::new(SearchEngine::SearchPro).with_count(10);
639        assert_eq!(web_search.count, Some(10));
640    }
641
642    #[test]
643    fn test_web_search_with_search_domain_filter() {
644        let web_search =
645            WebSearch::new(SearchEngine::SearchPro).with_search_domain_filter("example.com");
646        assert_eq!(
647            web_search.search_domain_filter,
648            Some("example.com".to_string())
649        );
650    }
651
652    #[test]
653    fn test_web_search_with_search_recency_filter() {
654        let filter = SearchRecencyFilter::OneDay;
655        let web_search =
656            WebSearch::new(SearchEngine::SearchPro).with_search_recency_filter(filter.clone());
657        assert_eq!(web_search.search_recency_filter, Some(filter));
658    }
659
660    #[test]
661    fn test_web_search_with_content_size() {
662        let size = ContentSize::Medium;
663        let web_search = WebSearch::new(SearchEngine::SearchPro).with_content_size(size.clone());
664        assert_eq!(web_search.content_size, Some(size));
665    }
666
667    #[test]
668    fn test_web_search_with_result_sequence() {
669        let seq = ResultSequence::After;
670        let web_search = WebSearch::new(SearchEngine::SearchPro).with_result_sequence(seq.clone());
671        assert_eq!(web_search.result_sequence, Some(seq));
672    }
673
674    #[test]
675    fn test_web_search_with_search_result() {
676        let web_search = WebSearch::new(SearchEngine::SearchPro).with_search_result(true);
677        assert_eq!(web_search.search_result, Some(true));
678    }
679
680    #[test]
681    fn test_web_search_with_require_search() {
682        let web_search = WebSearch::new(SearchEngine::SearchPro).with_require_search(true);
683        assert_eq!(web_search.require_search, Some(true));
684    }
685
686    #[test]
687    fn test_web_search_with_search_prompt() {
688        let web_search =
689            WebSearch::new(SearchEngine::SearchPro).with_search_prompt("custom prompt");
690        assert_eq!(web_search.search_prompt, Some("custom prompt".to_string()));
691    }
692
693    #[test]
694    fn test_web_search_serialization() {
695        let web_search = WebSearch::new(SearchEngine::SearchPro)
696            .with_enable(true)
697            .with_count(5);
698        let json = serde_json::to_string(&web_search).unwrap();
699        assert!(json.contains("\"search_engine\""));
700        assert!(json.contains("\"enable\":true"));
701        assert!(json.contains("\"count\":5"));
702    }
703
704    // MCP tests
705    #[test]
706    fn test_mcp_new() {
707        let mcp = MCP::new("server_label");
708        assert_eq!(mcp.server_label, "server_label");
709        assert_eq!(mcp.transport_type, Some(MCPTransportType::StreamableHttp));
710        assert!(mcp.allowed_tools.is_empty());
711    }
712
713    #[test]
714    fn test_mcp_with_server_url() {
715        let mcp = MCP::new("server_label").with_server_url("https://example.com");
716        assert_eq!(mcp.server_url, Some("https://example.com".to_string()));
717    }
718
719    #[test]
720    fn test_mcp_with_transport_type() {
721        let mcp = MCP::new("server_label").with_transport_type(MCPTransportType::Sse);
722        assert_eq!(mcp.transport_type, Some(MCPTransportType::Sse));
723    }
724
725    #[test]
726    fn test_mcp_with_allowed_tools() {
727        let mcp = MCP::new("server_label")
728            .with_allowed_tools(vec!["tool1".to_string(), "tool2".to_string()]);
729        assert_eq!(mcp.allowed_tools.len(), 2);
730        assert!(mcp.allowed_tools.contains(&"tool1".to_string()));
731    }
732
733    #[test]
734    fn test_mcp_add_allowed_tool() {
735        let mcp = MCP::new("server_label")
736            .add_allowed_tool("tool1")
737            .add_allowed_tool("tool2");
738        assert_eq!(mcp.allowed_tools.len(), 2);
739    }
740
741    #[test]
742    fn test_mcp_with_headers() {
743        let mut headers = HashMap::new();
744        headers.insert("Authorization".to_string(), "Bearer token".to_string());
745        let mcp = MCP::new("server_label").with_headers(headers.clone());
746        assert_eq!(mcp.headers, Some(headers));
747    }
748
749    #[test]
750    fn test_mcp_with_header() {
751        let mcp = MCP::new("server_label").with_header("Authorization", "Bearer token");
752        let headers = mcp.headers.unwrap();
753        assert_eq!(
754            headers.get("Authorization"),
755            Some(&"Bearer token".to_string())
756        );
757    }
758
759    #[test]
760    fn test_mcp_serialization() {
761        let mcp = MCP::new("server_label")
762            .with_server_url("https://example.com")
763            .with_transport_type(MCPTransportType::Sse);
764        let json = serde_json::to_string(&mcp).unwrap();
765        assert!(json.contains("\"server_label\":\"server_label\""));
766        assert!(json.contains("\"server_url\":\"https://example.com\""));
767        assert!(json.contains("\"transport_type\":\"sse\""));
768        // allowed_tools should be omitted when empty
769        assert!(!json.contains("allowed_tools"));
770    }
771
772    // MCPTransportType tests
773    #[test]
774    fn test_mcp_transport_type_sse_serialization() {
775        let transport = MCPTransportType::Sse;
776        let json = serde_json::to_string(&transport).unwrap();
777        assert!(json.contains("\"sse\""));
778    }
779
780    #[test]
781    fn test_mcp_transport_type_streamable_http_serialization() {
782        let transport = MCPTransportType::StreamableHttp;
783        let json = serde_json::to_string(&transport).unwrap();
784        assert!(json.contains("\"streamable-http\""));
785    }
786
787    // ResponseFormat tests
788    #[test]
789    fn test_response_format_text_serialization() {
790        let format = ResponseFormat::Text;
791        let json = serde_json::to_string(&format).unwrap();
792        assert!(json.contains("\"type\":\"text\""));
793    }
794
795    #[test]
796    fn test_response_format_json_object_serialization() {
797        let format = ResponseFormat::JsonObject;
798        let json = serde_json::to_string(&format).unwrap();
799        assert!(json.contains("\"type\":\"json_object\""));
800    }
801
802    // Tools enum tests
803    #[test]
804    fn test_tools_function_serialization() {
805        let func = Function::new("test_func", "test", serde_json::json!({}));
806        let tools = Tools::Function { function: func };
807        let json = serde_json::to_string(&tools).unwrap();
808        assert!(json.contains("\"type\":\"function\""));
809        assert!(json.contains("\"name\":\"test_func\""));
810    }
811
812    #[test]
813    fn test_tools_retrieval_serialization() {
814        let retrieval = Retrieval::new("kb_123", None);
815        let tools = Tools::Retrieval { retrieval };
816        let json = serde_json::to_string(&tools).unwrap();
817        assert!(json.contains("\"type\":\"retrieval\""));
818        assert!(json.contains("\"knowledge_id\":\"kb_123\""));
819    }
820
821    #[test]
822    fn test_tools_web_search_serialization() {
823        let web_search = WebSearch::new(SearchEngine::SearchPro);
824        let tools = Tools::WebSearch { web_search };
825        let json = serde_json::to_string(&tools).unwrap();
826        assert!(json.contains("\"type\":\"web_search\""));
827        assert!(json.contains("\"search_engine\""));
828    }
829
830    #[test]
831    fn test_tools_mcp_serialization() {
832        let mcp = MCP::new("server_label");
833        let tools = Tools::MCP { mcp };
834        let json = serde_json::to_string(&tools).unwrap();
835        eprintln!("JSON: {}", json);
836        assert!(json.contains("\"type\":\"mcp\""));
837        assert!(json.contains("\"server_label\":\"server_label\""));
838    }
839
840    // ResultSequence tests
841    #[test]
842    fn test_result_sequence_before_serialization() {
843        let seq = ResultSequence::Before;
844        let json = serde_json::to_string(&seq).unwrap();
845        assert!(json.contains("\"before\""));
846    }
847
848    #[test]
849    fn test_result_sequence_after_serialization() {
850        let seq = ResultSequence::After;
851        let json = serde_json::to_string(&seq).unwrap();
852        assert!(json.contains("\"after\""));
853    }
854}