1use crate::proto::IncludeOption;
8use crate::tools::{Tool, ToolChoice};
9use serde_json::Value as JsonValue;
10
11#[derive(Default, Clone, Debug)]
27pub struct CompletionOptions {
28 pub model: Option<String>,
30 pub temperature: Option<f32>,
32 pub max_tokens: Option<u32>,
34 pub top_p: Option<f32>,
36 pub frequency_penalty: Option<f32>,
37 pub presence_penalty: Option<f32>,
38 pub stop_sequences: Vec<String>,
39 pub tools: Option<Vec<Tool>>,
40 pub tool_choice: Option<ToolChoice>,
41 pub response_format: Option<ResponseFormat>,
42}
43
44impl CompletionOptions {
45 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn with_model(mut self, model: impl Into<String>) -> Self {
50 self.model = Some(model.into());
51 self
52 }
53
54 pub fn with_temperature(mut self, temp: f32) -> Self {
55 self.temperature = Some(temp);
56 self
57 }
58
59 pub fn with_max_tokens(mut self, tokens: u32) -> Self {
60 self.max_tokens = Some(tokens);
61 self
62 }
63}
64
65#[derive(Default, Clone, Debug)]
82pub struct ChatRequest {
83 messages: Vec<Message>,
84 model: Option<String>,
85 max_tokens: Option<u32>,
86 temperature: Option<f32>,
87 top_p: Option<f32>,
88 stop: Vec<String>,
89 reasoning_effort: Option<ReasoningEffort>,
90 search: Option<SearchConfig>,
91 seed: Option<i32>,
92 response_format: Option<ResponseFormat>,
93 tools: Option<Vec<Tool>>,
94 tool_choice: Option<ToolChoice>,
95 user: Option<String>,
97 logprobs: bool,
98 top_logprobs: Option<i32>,
99 frequency_penalty: Option<f32>,
100 presence_penalty: Option<f32>,
101 parallel_tool_calls: Option<bool>,
102 previous_response_id: Option<String>,
103 store_messages: bool,
104 use_encrypted_content: bool,
105 max_turns: Option<i32>,
106 include: Vec<IncludeOption>,
107}
108
109#[derive(Clone, Debug)]
114pub enum Message {
115 System(String),
117 User(MessageContent),
119 Assistant(String),
121 Tool {
130 tool_call_id: String,
133 content: String,
135 },
136}
137
138#[derive(Clone, Debug)]
140pub enum MessageContent {
141 Text(String),
143 MultiModal(Vec<ContentPart>),
145}
146
147#[derive(Clone, Debug)]
149pub enum ContentPart {
150 Text(String),
152 ImageUrl {
154 url: String,
156 detail: Option<ImageDetail>,
158 },
159 File {
161 file_id: String,
163 },
164}
165
166#[derive(Clone, Debug)]
168pub enum ImageDetail {
169 Auto,
171 Low,
173 High,
175}
176
177impl From<String> for MessageContent {
178 fn from(text: String) -> Self {
179 MessageContent::Text(text)
180 }
181}
182
183impl From<&str> for MessageContent {
184 fn from(text: &str) -> Self {
185 MessageContent::Text(text.to_string())
186 }
187}
188
189#[derive(Clone, Debug)]
194pub enum ReasoningEffort {
195 Low,
197 Medium,
199 High,
201}
202
203#[derive(Clone, Debug)]
207pub struct SearchConfig {
208 pub mode: SearchMode,
210 pub sources: Vec<SearchSource>,
212 pub max_results: Option<u32>,
214}
215
216#[derive(Clone, Debug)]
218pub enum SearchMode {
219 Off,
221 On,
223 Auto,
225}
226
227#[derive(Clone, Debug)]
229pub enum SearchSource {
230 Web,
232 X,
234 News,
236}
237
238#[derive(Clone, Debug)]
240pub enum ResponseFormat {
241 Text,
243 JsonObject,
245 JsonSchema(JsonValue),
247}
248
249impl ChatRequest {
250 pub fn new() -> Self {
251 Self::default()
252 }
253
254 pub fn user_message(mut self, content: impl Into<MessageContent>) -> Self {
255 self.messages.push(Message::User(content.into()));
256 self
257 }
258
259 pub fn system_message(mut self, content: impl Into<String>) -> Self {
260 self.messages.push(Message::System(content.into()));
261 self
262 }
263
264 pub fn assistant_message(mut self, content: impl Into<String>) -> Self {
265 self.messages.push(Message::Assistant(content.into()));
266 self
267 }
268
269 pub fn tool_result(
304 mut self,
305 tool_call_id: impl Into<String>,
306 content: impl Into<String>,
307 ) -> Self {
308 self.messages.push(Message::Tool {
309 tool_call_id: tool_call_id.into(),
310 content: content.into(),
311 });
312 self
313 }
314
315 pub fn tool_result_json(
344 self,
345 tool_call_id: impl Into<String>,
346 content: &serde_json::Value,
347 ) -> Self {
348 self.tool_result(tool_call_id, content.to_string())
349 }
350
351 pub fn user_multimodal(mut self, parts: Vec<ContentPart>) -> Self {
352 self.messages
353 .push(Message::User(MessageContent::MultiModal(parts)));
354 self
355 }
356
357 pub fn user_with_image(
358 mut self,
359 text: impl Into<String>,
360 image_url: impl Into<String>,
361 ) -> Self {
362 self.messages
363 .push(Message::User(MessageContent::MultiModal(vec![
364 ContentPart::Text(text.into()),
365 ContentPart::ImageUrl {
366 url: image_url.into(),
367 detail: None,
368 },
369 ])));
370 self
371 }
372
373 pub fn with_model(mut self, model: impl Into<String>) -> Self {
374 self.model = Some(model.into());
375 self
376 }
377
378 pub fn with_temperature(mut self, temp: f32) -> Self {
379 self.temperature = Some(temp.clamp(0.0, 2.0));
380 self
381 }
382
383 pub fn with_top_p(mut self, top_p: f32) -> Self {
384 self.top_p = Some(top_p.clamp(0.0, 1.0));
385 self
386 }
387
388 pub fn with_max_tokens(mut self, tokens: u32) -> Self {
389 self.max_tokens = Some(tokens);
390 self
391 }
392
393 pub fn with_reasoning_effort(mut self, effort: ReasoningEffort) -> Self {
394 self.reasoning_effort = Some(effort);
395 self
396 }
397
398 pub fn with_web_search(mut self) -> Self {
399 self.search = Some(SearchConfig {
400 mode: SearchMode::Auto,
401 sources: vec![SearchSource::Web],
402 max_results: Some(5),
403 });
404 self
405 }
406
407 pub fn with_json_output(mut self) -> Self {
408 self.response_format = Some(ResponseFormat::JsonObject);
409 self
410 }
411
412 pub fn with_json_schema(mut self, schema: JsonValue) -> Self {
413 self.response_format = Some(ResponseFormat::JsonSchema(schema));
414 self
415 }
416
417 pub fn with_structured_output(self, schema: JsonValue) -> Self {
419 self.with_json_schema(schema)
420 }
421
422 pub fn with_seed(mut self, seed: i32) -> Self {
423 self.seed = Some(seed);
424 self
425 }
426
427 pub fn add_stop_sequence(mut self, seq: impl Into<String>) -> Self {
428 self.stop.push(seq.into());
429 self
430 }
431
432 pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
433 self.tools = Some(tools);
434 self
435 }
436
437 pub fn add_tool(mut self, tool: Tool) -> Self {
438 if let Some(ref mut tools) = self.tools {
439 tools.push(tool);
440 } else {
441 self.tools = Some(vec![tool]);
442 }
443 self
444 }
445
446 pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
447 self.tool_choice = Some(choice);
448 self
449 }
450
451 pub fn with_user(mut self, user: impl Into<String>) -> Self {
452 self.user = Some(user.into());
453 self
454 }
455
456 pub fn with_logprobs(mut self, top_logprobs: Option<i32>) -> Self {
457 self.logprobs = true;
458 self.top_logprobs = top_logprobs;
459 self
460 }
461
462 pub fn with_frequency_penalty(mut self, penalty: f32) -> Self {
463 self.frequency_penalty = Some(penalty);
464 self
465 }
466
467 pub fn with_presence_penalty(mut self, penalty: f32) -> Self {
468 self.presence_penalty = Some(penalty);
469 self
470 }
471
472 pub fn with_parallel_tool_calls(mut self, enabled: bool) -> Self {
473 self.parallel_tool_calls = Some(enabled);
474 self
475 }
476
477 pub fn with_previous_response_id(mut self, id: impl Into<String>) -> Self {
478 self.previous_response_id = Some(id.into());
479 self
480 }
481
482 pub fn with_store_messages(mut self, store: bool) -> Self {
483 self.store_messages = store;
484 self
485 }
486
487 pub fn with_use_encrypted_content(mut self, use_encrypted: bool) -> Self {
488 self.use_encrypted_content = use_encrypted;
489 self
490 }
491
492 pub fn with_max_turns(mut self, max_turns: i32) -> Self {
498 assert!(
499 max_turns >= 1,
500 "max_turns must be at least 1, got {max_turns}"
501 );
502 self.max_turns = Some(max_turns);
503 self
504 }
505
506 pub fn add_include_option(mut self, option: IncludeOption) -> Self {
509 self.include.push(option);
510 self
511 }
512
513 pub fn with_include_options(mut self, options: Vec<IncludeOption>) -> Self {
515 self.include = options;
516 self
517 }
518
519 pub fn user_with_file(mut self, text: impl Into<String>, file_id: impl Into<String>) -> Self {
521 self.messages
522 .push(Message::User(MessageContent::MultiModal(vec![
523 ContentPart::Text(text.into()),
524 ContentPart::File {
525 file_id: file_id.into(),
526 },
527 ])));
528 self
529 }
530
531 pub fn messages(&self) -> &[Message] {
533 &self.messages
534 }
535
536 pub fn model(&self) -> Option<&str> {
537 self.model.as_deref()
538 }
539
540 pub fn max_tokens(&self) -> Option<u32> {
541 self.max_tokens
542 }
543
544 pub fn temperature(&self) -> Option<f32> {
545 self.temperature
546 }
547
548 pub fn top_p(&self) -> Option<f32> {
549 self.top_p
550 }
551
552 pub fn stop_sequences(&self) -> &[String] {
553 &self.stop
554 }
555
556 pub fn reasoning_effort(&self) -> Option<&ReasoningEffort> {
557 self.reasoning_effort.as_ref()
558 }
559
560 pub fn search_config(&self) -> Option<&SearchConfig> {
561 self.search.as_ref()
562 }
563
564 pub fn seed(&self) -> Option<i32> {
565 self.seed
566 }
567
568 pub fn response_format(&self) -> Option<&ResponseFormat> {
569 self.response_format.as_ref()
570 }
571
572 pub fn tools(&self) -> Option<&[Tool]> {
573 self.tools.as_deref()
574 }
575
576 pub fn tool_choice(&self) -> Option<&ToolChoice> {
577 self.tool_choice.as_ref()
578 }
579
580 pub fn user(&self) -> Option<&str> {
581 self.user.as_deref()
582 }
583
584 pub fn logprobs(&self) -> bool {
585 self.logprobs
586 }
587
588 pub fn top_logprobs(&self) -> Option<i32> {
589 self.top_logprobs
590 }
591
592 pub fn frequency_penalty(&self) -> Option<f32> {
593 self.frequency_penalty
594 }
595
596 pub fn presence_penalty(&self) -> Option<f32> {
597 self.presence_penalty
598 }
599
600 pub fn parallel_tool_calls(&self) -> Option<bool> {
601 self.parallel_tool_calls
602 }
603
604 pub fn previous_response_id(&self) -> Option<&str> {
605 self.previous_response_id.as_deref()
606 }
607
608 pub fn store_messages(&self) -> bool {
609 self.store_messages
610 }
611
612 pub fn use_encrypted_content(&self) -> bool {
613 self.use_encrypted_content
614 }
615
616 pub fn max_turns(&self) -> Option<i32> {
617 self.max_turns
618 }
619
620 pub fn include_options(&self) -> &[IncludeOption] {
621 &self.include
622 }
623
624 pub fn from_messages(messages: Vec<Message>) -> Self {
626 Self {
627 messages,
628 ..Default::default()
629 }
630 }
631
632 pub fn from_messages_with_options(messages: Vec<Message>, options: CompletionOptions) -> Self {
635 Self {
636 messages,
637 model: options.model,
638 temperature: options.temperature,
639 max_tokens: options.max_tokens,
640 top_p: options.top_p,
641 frequency_penalty: options.frequency_penalty,
642 presence_penalty: options.presence_penalty,
643 stop: options.stop_sequences,
644 tools: options.tools,
645 tool_choice: options.tool_choice,
646 response_format: options.response_format,
647 ..Default::default()
648 }
649 }
650}
651
652impl SearchConfig {
653 pub fn web() -> Self {
654 Self {
655 mode: SearchMode::Auto,
656 sources: vec![SearchSource::Web],
657 max_results: Some(5),
658 }
659 }
660}
661
662#[cfg(test)]
663mod tests {
664 use super::*;
665
666 #[test]
667 fn test_chat_request_builder() {
668 let request = ChatRequest::new()
669 .user_message("Hello, world!")
670 .with_model("grok-2")
671 .with_temperature(0.7)
672 .with_max_tokens(100);
673
674 assert_eq!(request.messages().len(), 1);
675 assert_eq!(request.model(), Some("grok-2"));
676 assert_eq!(request.temperature(), Some(0.7));
677 assert_eq!(request.max_tokens(), Some(100));
678 }
679
680 #[test]
681 fn test_multimodal_message() {
682 let request = ChatRequest::new().user_multimodal(vec![
683 ContentPart::Text("Describe this image".to_string()),
684 ContentPart::ImageUrl {
685 url: "https://example.com/image.jpg".to_string(),
686 detail: Some(ImageDetail::High),
687 },
688 ]);
689
690 assert_eq!(request.messages().len(), 1);
691 match &request.messages()[0] {
692 Message::User(MessageContent::MultiModal(parts)) => {
693 assert_eq!(parts.len(), 2);
694 }
695 _ => panic!("Expected multimodal user message"),
696 }
697 }
698
699 #[test]
700 fn test_from_messages() {
701 let messages = vec![
702 Message::System("You are a helpful assistant".to_string()),
703 Message::User(MessageContent::Text("Hello".to_string())),
704 ];
705
706 let request = ChatRequest::from_messages(messages);
707 assert_eq!(request.messages().len(), 2);
708 }
709
710 #[test]
711 fn test_from_messages_with_options() {
712 let messages = vec![Message::User(MessageContent::Text("Test".to_string()))];
713 let options = CompletionOptions::new()
714 .with_model("grok-2")
715 .with_temperature(0.8)
716 .with_max_tokens(200);
717
718 let request = ChatRequest::from_messages_with_options(messages, options);
719
720 assert_eq!(request.messages().len(), 1);
721 assert_eq!(request.model(), Some("grok-2"));
722 assert_eq!(request.temperature(), Some(0.8));
723 assert_eq!(request.max_tokens(), Some(200));
724 }
725
726 #[test]
727 fn test_sampling_parameters() {
728 let request = ChatRequest::new()
729 .user_message("Test")
730 .with_frequency_penalty(0.5)
731 .with_presence_penalty(0.3)
732 .with_top_p(0.9);
733
734 assert_eq!(request.frequency_penalty(), Some(0.5));
735 assert_eq!(request.presence_penalty(), Some(0.3));
736 assert_eq!(request.top_p(), Some(0.9));
737 }
738
739 #[test]
740 fn test_stop_sequences() {
741 let request = ChatRequest::new()
742 .user_message("Test")
743 .add_stop_sequence("STOP")
744 .add_stop_sequence("END");
745
746 assert_eq!(request.stop_sequences(), &["STOP", "END"]);
747 }
748
749 #[test]
750 fn test_logprobs() {
751 let request = ChatRequest::new()
752 .user_message("Test")
753 .with_logprobs(Some(5));
754
755 assert!(request.logprobs());
756 assert_eq!(request.top_logprobs(), Some(5));
757 }
758
759 #[test]
760 fn test_stored_messages() {
761 let request = ChatRequest::new()
762 .user_message("Test")
763 .with_store_messages(true)
764 .with_previous_response_id("resp_123");
765
766 assert!(request.store_messages());
767 assert_eq!(request.previous_response_id(), Some("resp_123"));
768 }
769
770 #[test]
771 fn test_search_config() {
772 let config = SearchConfig::web();
773 assert!(matches!(config.mode, SearchMode::Auto));
774 assert_eq!(config.sources.len(), 1);
775 assert_eq!(config.max_results, Some(5));
776 }
777
778 #[test]
779 fn test_reasoning_effort() {
780 let request = ChatRequest::new()
781 .user_message("Complex problem")
782 .with_reasoning_effort(ReasoningEffort::High);
783
784 assert!(matches!(
785 request.reasoning_effort(),
786 Some(ReasoningEffort::High)
787 ));
788 }
789
790 #[test]
791 fn test_json_output() {
792 let request = ChatRequest::new()
793 .user_message("Generate JSON")
794 .with_json_output();
795
796 assert!(matches!(
797 request.response_format(),
798 Some(ResponseFormat::JsonObject)
799 ));
800 }
801
802 #[test]
804 fn test_max_turns() {
805 let request = ChatRequest::new()
806 .user_message("Research this topic")
807 .with_max_turns(5);
808
809 assert_eq!(request.max_turns(), Some(5));
810 }
811
812 #[test]
813 fn test_max_turns_single_turn() {
814 let request = ChatRequest::new()
815 .user_message("Single turn")
816 .with_max_turns(1);
817
818 assert_eq!(request.max_turns(), Some(1));
819 }
820
821 #[test]
822 #[should_panic(expected = "max_turns must be at least 1")]
823 fn test_max_turns_validation_zero() {
824 ChatRequest::new().user_message("Test").with_max_turns(0);
825 }
826
827 #[test]
828 #[should_panic(expected = "max_turns must be at least 1")]
829 fn test_max_turns_validation_negative() {
830 ChatRequest::new().user_message("Test").with_max_turns(-1);
831 }
832
833 #[test]
834 fn test_include_options_single() {
835 let request = ChatRequest::new()
836 .user_message("Test")
837 .add_include_option(IncludeOption::WebSearchCallOutput);
838
839 assert_eq!(request.include_options().len(), 1);
840 }
841
842 #[test]
843 fn test_include_options_multiple() {
844 let request = ChatRequest::new()
845 .user_message("Test")
846 .add_include_option(IncludeOption::WebSearchCallOutput)
847 .add_include_option(IncludeOption::InlineCitations)
848 .add_include_option(IncludeOption::XSearchCallOutput);
849
850 assert_eq!(request.include_options().len(), 3);
851 }
852
853 #[test]
854 fn test_with_include_options() {
855 let options = vec![
856 IncludeOption::WebSearchCallOutput,
857 IncludeOption::CodeExecutionCallOutput,
858 IncludeOption::InlineCitations,
859 ];
860
861 let request = ChatRequest::new()
862 .user_message("Test")
863 .with_include_options(options);
864
865 assert_eq!(request.include_options().len(), 3);
866 }
867
868 #[test]
869 fn test_user_with_file() {
870 let request = ChatRequest::new().user_with_file("Analyze this document", "file-abc123");
871
872 assert_eq!(request.messages().len(), 1);
873 match &request.messages()[0] {
874 Message::User(MessageContent::MultiModal(parts)) => {
875 assert_eq!(parts.len(), 2);
876 match &parts[0] {
877 ContentPart::Text(text) => assert_eq!(text, "Analyze this document"),
878 _ => panic!("Expected text part"),
879 }
880 match &parts[1] {
881 ContentPart::File { file_id } => assert_eq!(file_id, "file-abc123"),
882 _ => panic!("Expected file part"),
883 }
884 }
885 _ => panic!("Expected multimodal user message"),
886 }
887 }
888
889 #[test]
890 fn test_file_content_part() {
891 let file_part = ContentPart::File {
892 file_id: "file-xyz789".to_string(),
893 };
894
895 match file_part {
896 ContentPart::File { file_id } => assert_eq!(file_id, "file-xyz789"),
897 _ => panic!("Expected file content part"),
898 }
899 }
900
901 #[test]
902 fn test_multimodal_with_file_and_image() {
903 let request = ChatRequest::new().user_multimodal(vec![
904 ContentPart::Text("Compare these".to_string()),
905 ContentPart::ImageUrl {
906 url: "https://example.com/image1.jpg".to_string(),
907 detail: Some(ImageDetail::High),
908 },
909 ContentPart::File {
910 file_id: "file-doc123".to_string(),
911 },
912 ]);
913
914 assert_eq!(request.messages().len(), 1);
915 match &request.messages()[0] {
916 Message::User(MessageContent::MultiModal(parts)) => {
917 assert_eq!(parts.len(), 3);
918 }
919 _ => panic!("Expected multimodal user message"),
920 }
921 }
922
923 #[test]
924 fn test_combined_new_features() {
925 let request = ChatRequest::new()
926 .user_message("Research and analyze")
927 .with_max_turns(10)
928 .add_include_option(IncludeOption::WebSearchCallOutput)
929 .add_include_option(IncludeOption::InlineCitations);
930
931 assert_eq!(request.max_turns(), Some(10));
932 assert_eq!(request.include_options().len(), 2);
933 }
934
935 #[test]
936 fn test_tool_result_message() {
937 let request = ChatRequest::new()
938 .user_message("What's the weather?")
939 .tool_result(
940 "call_abc123",
941 r#"{"temperature": 72, "condition": "sunny"}"#,
942 );
943
944 assert_eq!(request.messages().len(), 2);
945 match &request.messages()[1] {
946 Message::Tool {
947 tool_call_id,
948 content,
949 } => {
950 assert_eq!(tool_call_id, "call_abc123");
951 assert_eq!(content, r#"{"temperature": 72, "condition": "sunny"}"#);
952 }
953 _ => panic!("Expected tool result message"),
954 }
955 }
956
957 #[test]
958 fn test_multi_turn_tool_calling() {
959 let request = ChatRequest::new()
960 .user_message("Use the calculator to add 5 and 3")
961 .assistant_message("I'll use the calculator tool.")
962 .tool_result("call_1", r#"{"result": 8}"#)
963 .assistant_message("The sum of 5 and 3 is 8.");
964
965 assert_eq!(request.messages().len(), 4);
966
967 assert!(matches!(request.messages()[0], Message::User(_)));
969 assert!(matches!(request.messages()[1], Message::Assistant(_)));
970 assert!(matches!(request.messages()[2], Message::Tool { .. }));
971 assert!(matches!(request.messages()[3], Message::Assistant(_)));
972 }
973
974 #[test]
975 fn test_tool_result_with_from_messages() {
976 let messages = vec![
977 Message::User(MessageContent::Text("Calculate 10 * 5".to_string())),
978 Message::Tool {
979 tool_call_id: "call_xyz".to_string(),
980 content: r#"{"result": 50}"#.to_string(),
981 },
982 ];
983
984 let request = ChatRequest::from_messages(messages);
985 assert_eq!(request.messages().len(), 2);
986 }
987}