1use std::collections::HashMap;
7
8use serde::Serialize;
9use validator::*;
10
11use super::model_validate::validate_json_schema_value;
12use crate::tool::web_search::request::{ContentSize, SearchEngine, SearchRecencyFilter};
13
14#[derive(Debug, Clone, Serialize)]
37#[serde(rename_all = "lowercase")]
38#[serde(tag = "type")]
39pub enum ThinkingType {
40 Enabled,
45
46 Disabled,
51}
52
53#[derive(Debug, Clone, Serialize)]
94#[serde(tag = "type")]
95#[serde(rename_all = "snake_case")]
96pub enum Tools {
97 Function { function: Function },
103
104 Retrieval { retrieval: Retrieval },
109
110 WebSearch { web_search: WebSearch },
115
116 #[serde(rename = "mcp")]
121 MCP { mcp: MCP },
122}
123
124#[derive(Debug, Clone, Serialize, Validate)]
134pub struct Function {
135 #[validate(length(min = 1, max = 64))]
137 pub name: String,
138
139 pub description: String,
141
142 #[serde(skip_serializing_if = "Option::is_none")]
146 #[validate(custom(function = "validate_json_schema_value"))]
147 pub parameters: Option<serde_json::Value>,
148}
149
150impl Function {
151 pub fn new(
173 name: impl Into<String>,
174 description: impl Into<String>,
175 parameters: serde_json::Value,
176 ) -> Self {
177 Self {
178 name: name.into(),
179 description: description.into(),
180 parameters: Some(parameters),
181 }
182 }
183}
184
185#[derive(Debug, Clone, Serialize)]
190pub struct Retrieval {
191 knowledge_id: String,
192 #[serde(skip_serializing_if = "Option::is_none")]
193 prompt_template: Option<String>,
194}
195
196impl Retrieval {
197 pub fn new(knowledge_id: impl Into<String>, prompt_template: Option<String>) -> Self {
199 Self {
200 knowledge_id: knowledge_id.into(),
201 prompt_template,
202 }
203 }
204}
205
206#[derive(Debug, Clone, Serialize, PartialEq)]
210#[serde(rename_all = "snake_case")]
211pub enum ResultSequence {
212 Before,
213 After,
214}
215
216#[derive(Debug, Clone, Serialize, Validate)]
219pub struct WebSearch {
220 pub search_engine: SearchEngine,
223
224 #[serde(skip_serializing_if = "Option::is_none")]
226 pub enable: Option<bool>,
227
228 #[serde(skip_serializing_if = "Option::is_none")]
230 pub search_query: Option<String>,
231
232 #[serde(skip_serializing_if = "Option::is_none")]
235 pub search_intent: Option<bool>,
236
237 #[serde(skip_serializing_if = "Option::is_none")]
239 #[validate(range(min = 1, max = 50))]
240 pub count: Option<u32>,
241
242 #[serde(skip_serializing_if = "Option::is_none")]
244 pub search_domain_filter: Option<String>,
245
246 #[serde(skip_serializing_if = "Option::is_none")]
248 pub search_recency_filter: Option<SearchRecencyFilter>,
249
250 #[serde(skip_serializing_if = "Option::is_none")]
252 pub content_size: Option<ContentSize>,
253
254 #[serde(skip_serializing_if = "Option::is_none")]
256 pub result_sequence: Option<ResultSequence>,
257
258 #[serde(skip_serializing_if = "Option::is_none")]
260 pub search_result: Option<bool>,
261
262 #[serde(skip_serializing_if = "Option::is_none")]
264 pub require_search: Option<bool>,
265
266 #[serde(skip_serializing_if = "Option::is_none")]
268 pub search_prompt: Option<String>,
269}
270
271impl WebSearch {
272 pub fn new(search_engine: SearchEngine) -> Self {
275 Self {
276 search_engine,
277 enable: None,
278 search_query: None,
279 search_intent: None,
280 count: None,
281 search_domain_filter: None,
282 search_recency_filter: None,
283 content_size: None,
284 result_sequence: None,
285 search_result: None,
286 require_search: None,
287 search_prompt: None,
288 }
289 }
290
291 pub fn with_enable(mut self, enable: bool) -> Self {
293 self.enable = Some(enable);
294 self
295 }
296 pub fn with_search_query(mut self, query: impl Into<String>) -> Self {
298 self.search_query = Some(query.into());
299 self
300 }
301 pub fn with_search_intent(mut self, search_intent: bool) -> Self {
303 self.search_intent = Some(search_intent);
304 self
305 }
306 pub fn with_count(mut self, count: u32) -> Self {
308 self.count = Some(count);
309 self
310 }
311 pub fn with_search_domain_filter(mut self, domain: impl Into<String>) -> Self {
313 self.search_domain_filter = Some(domain.into());
314 self
315 }
316 pub fn with_search_recency_filter(mut self, filter: SearchRecencyFilter) -> Self {
318 self.search_recency_filter = Some(filter);
319 self
320 }
321 pub fn with_content_size(mut self, size: ContentSize) -> Self {
323 self.content_size = Some(size);
324 self
325 }
326 pub fn with_result_sequence(mut self, seq: ResultSequence) -> Self {
328 self.result_sequence = Some(seq);
329 self
330 }
331 pub fn with_search_result(mut self, enable: bool) -> Self {
333 self.search_result = Some(enable);
334 self
335 }
336 pub fn with_require_search(mut self, require: bool) -> Self {
338 self.require_search = Some(require);
339 self
340 }
341 pub fn with_search_prompt(mut self, prompt: impl Into<String>) -> Self {
343 self.search_prompt = Some(prompt.into());
344 self
345 }
346}
347#[derive(Debug, Clone, Serialize, Validate)]
351pub struct MCP {
352 #[validate(length(min = 1))]
355 pub server_label: String,
356
357 #[serde(skip_serializing_if = "Option::is_none")]
359 #[validate(url)]
360 pub server_url: Option<String>,
361
362 #[serde(skip_serializing_if = "Option::is_none")]
364 pub transport_type: Option<MCPTransportType>,
365
366 #[serde(skip_serializing_if = "Vec::is_empty")]
368 pub allowed_tools: Vec<String>,
369
370 #[serde(skip_serializing_if = "Option::is_none")]
372 pub headers: Option<HashMap<String, String>>,
373}
374
375impl MCP {
376 pub fn new(server_label: impl Into<String>) -> Self {
379 Self {
380 server_label: server_label.into(),
381 server_url: None,
382 transport_type: Some(MCPTransportType::StreamableHttp),
383 allowed_tools: Vec::new(),
384 headers: None,
385 }
386 }
387
388 pub fn with_server_url(mut self, url: impl Into<String>) -> Self {
390 self.server_url = Some(url.into());
391 self
392 }
393 pub fn with_transport_type(mut self, transport: MCPTransportType) -> Self {
395 self.transport_type = Some(transport);
396 self
397 }
398 pub fn with_allowed_tools(mut self, tools: impl Into<Vec<String>>) -> Self {
400 self.allowed_tools = tools.into();
401 self
402 }
403 pub fn add_allowed_tool(mut self, tool: impl Into<String>) -> Self {
405 self.allowed_tools.push(tool.into());
406 self
407 }
408 pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
410 self.headers = Some(headers);
411 self
412 }
413 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
415 let mut map = self.headers.unwrap_or_default();
416 map.insert(key.into(), value.into());
417 self.headers = Some(map);
418 self
419 }
420}
421
422#[derive(Debug, Clone, Serialize, PartialEq)]
424#[serde(rename_all = "kebab-case")]
425pub enum MCPTransportType {
426 Sse,
427 StreamableHttp,
428}
429
430#[derive(Debug, Clone, Copy, Serialize)]
440#[serde(rename_all = "snake_case")]
441#[serde(tag = "type")]
442pub enum ResponseFormat {
443 Text,
445 JsonObject,
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452
453 #[test]
455 fn test_thinking_type_enabled_serialization() {
456 let thinking = ThinkingType::Enabled;
457 let json = serde_json::to_string(&thinking).unwrap();
458 assert!(json.contains("\"type\":\"enabled\""));
459 }
460
461 #[test]
462 fn test_thinking_type_disabled_serialization() {
463 let thinking = ThinkingType::Disabled;
464 let json = serde_json::to_string(&thinking).unwrap();
465 assert!(json.contains("\"type\":\"disabled\""));
466 }
467
468 #[test]
470 fn test_function_new() {
471 let params = serde_json::json!({
472 "type": "object",
473 "properties": {
474 "name": {"type": "string"}
475 }
476 });
477 let func = Function::new("test_func", "A test function", params);
478
479 assert_eq!(func.name, "test_func");
480 assert_eq!(func.description, "A test function");
481 assert!(func.parameters.is_some());
482 }
483
484 #[test]
485 fn test_function_serialization() {
486 let params = serde_json::json!({
487 "type": "object",
488 "properties": {
489 "value": {"type": "number"}
490 }
491 });
492 let func = Function::new("test_func", "A test function", params);
493 let json = serde_json::to_string(&func).unwrap();
494
495 assert!(json.contains("\"name\":\"test_func\""));
496 assert!(json.contains("\"description\":\"A test function\""));
497 assert!(json.contains("\"properties\""));
498 }
499
500 #[test]
501 fn test_function_validation() {
502 let params = serde_json::json!({
503 "type": "object",
504 "properties": {}
505 });
506 let func = Function::new("valid_name", "Description", params.clone());
507
508 assert!(func.validate().is_ok());
510
511 let invalid_name = Function::new("", "Description", params.clone());
512 assert!(invalid_name.validate().is_err());
513
514 let long_name = Function::new("a".repeat(65), "Description", params);
515 assert!(long_name.validate().is_err());
516 }
517
518 #[test]
520 fn test_retrieval_new() {
521 let retrieval = Retrieval::new("kb_123", Some("template".to_string()));
522 assert_eq!(retrieval.knowledge_id, "kb_123");
523 assert_eq!(retrieval.prompt_template, Some("template".to_string()));
524 }
525
526 #[test]
527 fn test_retrieval_new_without_template() {
528 let retrieval = Retrieval::new("kb_456", None);
529 assert_eq!(retrieval.knowledge_id, "kb_456");
530 assert!(retrieval.prompt_template.is_none());
531 }
532
533 #[test]
534 fn test_retrieval_serialization() {
535 let retrieval = Retrieval::new("kb_789", None);
536 let json = serde_json::to_string(&retrieval).unwrap();
537 assert!(json.contains("\"knowledge_id\":\"kb_789\""));
538 assert!(!json.contains("prompt_template"));
540 }
541
542 #[test]
544 fn test_web_search_new() {
545 let web_search = WebSearch::new(SearchEngine::SearchPro);
546 assert_eq!(web_search.search_engine, SearchEngine::SearchPro);
547 assert!(web_search.enable.is_none());
548 }
549
550 #[test]
551 fn test_web_search_with_enable() {
552 let web_search = WebSearch::new(SearchEngine::SearchPro).with_enable(true);
553 assert_eq!(web_search.enable, Some(true));
554 }
555
556 #[test]
557 fn test_web_search_with_search_query() {
558 let web_search = WebSearch::new(SearchEngine::SearchPro).with_search_query("test query");
559 assert_eq!(web_search.search_query, Some("test query".to_string()));
560 }
561
562 #[test]
563 fn test_web_search_with_search_intent() {
564 let web_search = WebSearch::new(SearchEngine::SearchPro).with_search_intent(true);
565 assert_eq!(web_search.search_intent, Some(true));
566 }
567
568 #[test]
569 fn test_web_search_with_count() {
570 let web_search = WebSearch::new(SearchEngine::SearchPro).with_count(10);
571 assert_eq!(web_search.count, Some(10));
572 }
573
574 #[test]
575 fn test_web_search_with_search_domain_filter() {
576 let web_search =
577 WebSearch::new(SearchEngine::SearchPro).with_search_domain_filter("example.com");
578 assert_eq!(
579 web_search.search_domain_filter,
580 Some("example.com".to_string())
581 );
582 }
583
584 #[test]
585 fn test_web_search_with_search_recency_filter() {
586 let filter = SearchRecencyFilter::OneDay;
587 let web_search =
588 WebSearch::new(SearchEngine::SearchPro).with_search_recency_filter(filter.clone());
589 assert_eq!(web_search.search_recency_filter, Some(filter));
590 }
591
592 #[test]
593 fn test_web_search_with_content_size() {
594 let size = ContentSize::Medium;
595 let web_search = WebSearch::new(SearchEngine::SearchPro).with_content_size(size.clone());
596 assert_eq!(web_search.content_size, Some(size));
597 }
598
599 #[test]
600 fn test_web_search_with_result_sequence() {
601 let seq = ResultSequence::After;
602 let web_search = WebSearch::new(SearchEngine::SearchPro).with_result_sequence(seq.clone());
603 assert_eq!(web_search.result_sequence, Some(seq));
604 }
605
606 #[test]
607 fn test_web_search_with_search_result() {
608 let web_search = WebSearch::new(SearchEngine::SearchPro).with_search_result(true);
609 assert_eq!(web_search.search_result, Some(true));
610 }
611
612 #[test]
613 fn test_web_search_with_require_search() {
614 let web_search = WebSearch::new(SearchEngine::SearchPro).with_require_search(true);
615 assert_eq!(web_search.require_search, Some(true));
616 }
617
618 #[test]
619 fn test_web_search_with_search_prompt() {
620 let web_search =
621 WebSearch::new(SearchEngine::SearchPro).with_search_prompt("custom prompt");
622 assert_eq!(web_search.search_prompt, Some("custom prompt".to_string()));
623 }
624
625 #[test]
626 fn test_web_search_serialization() {
627 let web_search = WebSearch::new(SearchEngine::SearchPro)
628 .with_enable(true)
629 .with_count(5);
630 let json = serde_json::to_string(&web_search).unwrap();
631 assert!(json.contains("\"search_engine\""));
632 assert!(json.contains("\"enable\":true"));
633 assert!(json.contains("\"count\":5"));
634 }
635
636 #[test]
638 fn test_mcp_new() {
639 let mcp = MCP::new("server_label");
640 assert_eq!(mcp.server_label, "server_label");
641 assert_eq!(mcp.transport_type, Some(MCPTransportType::StreamableHttp));
642 assert!(mcp.allowed_tools.is_empty());
643 }
644
645 #[test]
646 fn test_mcp_with_server_url() {
647 let mcp = MCP::new("server_label").with_server_url("https://example.com");
648 assert_eq!(mcp.server_url, Some("https://example.com".to_string()));
649 }
650
651 #[test]
652 fn test_mcp_with_transport_type() {
653 let mcp = MCP::new("server_label").with_transport_type(MCPTransportType::Sse);
654 assert_eq!(mcp.transport_type, Some(MCPTransportType::Sse));
655 }
656
657 #[test]
658 fn test_mcp_with_allowed_tools() {
659 let mcp = MCP::new("server_label")
660 .with_allowed_tools(vec!["tool1".to_string(), "tool2".to_string()]);
661 assert_eq!(mcp.allowed_tools.len(), 2);
662 assert!(mcp.allowed_tools.contains(&"tool1".to_string()));
663 }
664
665 #[test]
666 fn test_mcp_add_allowed_tool() {
667 let mcp = MCP::new("server_label")
668 .add_allowed_tool("tool1")
669 .add_allowed_tool("tool2");
670 assert_eq!(mcp.allowed_tools.len(), 2);
671 }
672
673 #[test]
674 fn test_mcp_with_headers() {
675 let mut headers = HashMap::new();
676 headers.insert("Authorization".to_string(), "Bearer token".to_string());
677 let mcp = MCP::new("server_label").with_headers(headers.clone());
678 assert_eq!(mcp.headers, Some(headers));
679 }
680
681 #[test]
682 fn test_mcp_with_header() {
683 let mcp = MCP::new("server_label").with_header("Authorization", "Bearer token");
684 let headers = mcp.headers.unwrap();
685 assert_eq!(
686 headers.get("Authorization"),
687 Some(&"Bearer token".to_string())
688 );
689 }
690
691 #[test]
692 fn test_mcp_serialization() {
693 let mcp = MCP::new("server_label")
694 .with_server_url("https://example.com")
695 .with_transport_type(MCPTransportType::Sse);
696 let json = serde_json::to_string(&mcp).unwrap();
697 assert!(json.contains("\"server_label\":\"server_label\""));
698 assert!(json.contains("\"server_url\":\"https://example.com\""));
699 assert!(json.contains("\"transport_type\":\"sse\""));
700 assert!(!json.contains("allowed_tools"));
702 }
703
704 #[test]
706 fn test_mcp_transport_type_sse_serialization() {
707 let transport = MCPTransportType::Sse;
708 let json = serde_json::to_string(&transport).unwrap();
709 assert!(json.contains("\"sse\""));
710 }
711
712 #[test]
713 fn test_mcp_transport_type_streamable_http_serialization() {
714 let transport = MCPTransportType::StreamableHttp;
715 let json = serde_json::to_string(&transport).unwrap();
716 assert!(json.contains("\"streamable-http\""));
717 }
718
719 #[test]
721 fn test_response_format_text_serialization() {
722 let format = ResponseFormat::Text;
723 let json = serde_json::to_string(&format).unwrap();
724 assert!(json.contains("\"type\":\"text\""));
725 }
726
727 #[test]
728 fn test_response_format_json_object_serialization() {
729 let format = ResponseFormat::JsonObject;
730 let json = serde_json::to_string(&format).unwrap();
731 assert!(json.contains("\"type\":\"json_object\""));
732 }
733
734 #[test]
736 fn test_tools_function_serialization() {
737 let func = Function::new("test_func", "test", serde_json::json!({}));
738 let tools = Tools::Function { function: func };
739 let json = serde_json::to_string(&tools).unwrap();
740 assert!(json.contains("\"type\":\"function\""));
741 assert!(json.contains("\"name\":\"test_func\""));
742 }
743
744 #[test]
745 fn test_tools_retrieval_serialization() {
746 let retrieval = Retrieval::new("kb_123", None);
747 let tools = Tools::Retrieval { retrieval };
748 let json = serde_json::to_string(&tools).unwrap();
749 assert!(json.contains("\"type\":\"retrieval\""));
750 assert!(json.contains("\"knowledge_id\":\"kb_123\""));
751 }
752
753 #[test]
754 fn test_tools_web_search_serialization() {
755 let web_search = WebSearch::new(SearchEngine::SearchPro);
756 let tools = Tools::WebSearch { web_search };
757 let json = serde_json::to_string(&tools).unwrap();
758 assert!(json.contains("\"type\":\"web_search\""));
759 assert!(json.contains("\"search_engine\""));
760 }
761
762 #[test]
763 fn test_tools_mcp_serialization() {
764 let mcp = MCP::new("server_label");
765 let tools = Tools::MCP { mcp };
766 let json = serde_json::to_string(&tools).unwrap();
767 eprintln!("JSON: {}", json);
768 assert!(json.contains("\"type\":\"mcp\""));
769 assert!(json.contains("\"server_label\":\"server_label\""));
770 }
771
772 #[test]
774 fn test_result_sequence_before_serialization() {
775 let seq = ResultSequence::Before;
776 let json = serde_json::to_string(&seq).unwrap();
777 assert!(json.contains("\"before\""));
778 }
779
780 #[test]
781 fn test_result_sequence_after_serialization() {
782 let seq = ResultSequence::After;
783 let json = serde_json::to_string(&seq).unwrap();
784 assert!(json.contains("\"after\""));
785 }
786}