Skip to main content

xai_rust/
chat.rs

1//! Chat helper functions matching the Python SDK's ergonomic API.
2//!
3//! This module provides standalone functions for creating messages and content,
4//! similar to the Python SDK's `xai_sdk.chat` module.
5//!
6//! # Example
7//!
8//! ```rust
9//! use xai_rust::chat::{user, system, assistant, image, text, file};
10//!
11//! // Simple text messages
12//! let msg1 = system("You are a helpful assistant.");
13//! let msg2 = user("Hello!");
14//! let msg3 = assistant("Hi there!");
15//!
16//! // Message with image
17//! let msg4 = user((
18//!     text("What's in this image?"),
19//!     image("https://example.com/image.jpg"),
20//! ));
21//!
22//! // Message with file
23//! let msg5 = user((
24//!     text("Summarize this document"),
25//!     file("file-abc123"),
26//! ));
27//! ```
28
29use crate::models::content::{ContentPart, FileContent, ImageDetail, ImageUrlContent};
30use crate::models::message::{Message, MessageContent, Role};
31
32// ============================================================================
33// Message Constructors
34// ============================================================================
35
36/// Create a system message.
37///
38/// # Example
39///
40/// ```rust
41/// use xai_rust::chat::system;
42///
43/// let msg = system("You are a helpful assistant.");
44/// ```
45pub fn system(content: impl Into<String>) -> Message {
46    Message::new(Role::System, MessageContent::Text(content.into()))
47}
48
49/// Create a user message.
50///
51/// Accepts either a string or content parts.
52///
53/// # Examples
54///
55/// ```rust
56/// use xai_rust::chat::{user, text, image};
57///
58/// // Simple text
59/// let msg1 = user("Hello!");
60///
61/// // With image
62/// let msg2 = user((text("What's this?"), image("https://example.com/img.jpg")));
63/// ```
64pub fn user(content: impl IntoMessageContent) -> Message {
65    Message::new(Role::User, content.into_message_content())
66}
67
68/// Create an assistant message.
69///
70/// # Example
71///
72/// ```rust
73/// use xai_rust::chat::assistant;
74///
75/// let msg = assistant("I'm here to help!");
76/// ```
77pub fn assistant(content: impl Into<String>) -> Message {
78    Message::new(Role::Assistant, MessageContent::Text(content.into()))
79}
80
81/// Create a developer message (alias for system).
82///
83/// # Example
84///
85/// ```rust
86/// use xai_rust::chat::developer;
87///
88/// let msg = developer("You are a coding assistant.");
89/// ```
90pub fn developer(content: impl Into<String>) -> Message {
91    Message::new(Role::Developer, MessageContent::Text(content.into()))
92}
93
94/// Create a tool result message.
95///
96/// # Example
97///
98/// ```rust
99/// use xai_rust::chat::tool_result;
100///
101/// let msg = tool_result("call_123", r#"{"temperature": 72}"#);
102/// ```
103pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Message {
104    Message::tool(tool_call_id, content)
105}
106
107// ============================================================================
108// Content Constructors
109// ============================================================================
110
111/// Create text content.
112///
113/// # Example
114///
115/// ```rust
116/// use xai_rust::chat::text;
117///
118/// let content = text("Hello, world!");
119/// ```
120pub fn text(content: impl Into<String>) -> ContentPart {
121    ContentPart::Text {
122        text: content.into(),
123    }
124}
125
126/// Create image content from a URL.
127///
128/// # Examples
129///
130/// ```rust
131/// use xai_rust::chat::image;
132///
133/// // From URL
134/// let img = image("https://example.com/photo.jpg");
135///
136/// // From base64 data URL
137/// let img = image("data:image/jpeg;base64,/9j/4AAQ...");
138/// ```
139pub fn image(url: impl Into<String>) -> ContentPart {
140    ContentPart::ImageUrl {
141        image_url: ImageUrlContent::new(url),
142    }
143}
144
145/// Create image content with a specific detail level.
146///
147/// # Example
148///
149/// ```rust
150/// use xai_rust::chat::image_with_detail;
151/// use xai_rust::ImageDetail;
152///
153/// let img = image_with_detail("https://example.com/photo.jpg", ImageDetail::High);
154/// ```
155pub fn image_with_detail(url: impl Into<String>, detail: ImageDetail) -> ContentPart {
156    ContentPart::ImageUrl {
157        image_url: ImageUrlContent::new(url).with_detail(detail),
158    }
159}
160
161/// Create image content from base64 data.
162///
163/// # Example
164///
165/// ```rust
166/// use xai_rust::chat::image_base64;
167///
168/// let img = image_base64("/9j/4AAQ...", "image/jpeg");
169/// ```
170pub fn image_base64(data: impl Into<String>, mime_type: impl Into<String>) -> ContentPart {
171    ContentPart::ImageUrl {
172        image_url: ImageUrlContent::from_base64(&data.into(), &mime_type.into()),
173    }
174}
175
176/// Create file content from a file ID.
177///
178/// # Example
179///
180/// ```rust
181/// use xai_rust::chat::file;
182///
183/// let f = file("file-abc123");
184/// ```
185pub fn file(file_id: impl Into<String>) -> ContentPart {
186    ContentPart::File {
187        file: FileContent::new(file_id),
188    }
189}
190
191/// Create inline file content.
192///
193/// # Example
194///
195/// ```rust
196/// use xai_rust::chat::file_inline;
197///
198/// let f = file_inline("data.json", r#"{"key": "value"}"#);
199/// ```
200pub fn file_inline(filename: impl Into<String>, data: impl Into<String>) -> ContentPart {
201    ContentPart::File {
202        file: FileContent::inline(filename, data),
203    }
204}
205
206// ============================================================================
207// Trait for flexible content input
208// ============================================================================
209
210/// Trait for types that can be converted to message content.
211///
212/// This allows `user()` to accept both strings and content parts.
213pub trait IntoMessageContent {
214    /// Convert to message content.
215    fn into_message_content(self) -> MessageContent;
216}
217
218impl IntoMessageContent for String {
219    fn into_message_content(self) -> MessageContent {
220        MessageContent::Text(self)
221    }
222}
223
224impl IntoMessageContent for &str {
225    fn into_message_content(self) -> MessageContent {
226        MessageContent::Text(self.to_string())
227    }
228}
229
230impl IntoMessageContent for ContentPart {
231    fn into_message_content(self) -> MessageContent {
232        MessageContent::Parts(vec![self])
233    }
234}
235
236impl IntoMessageContent for Vec<ContentPart> {
237    fn into_message_content(self) -> MessageContent {
238        MessageContent::Parts(self)
239    }
240}
241
242// Tuple implementations for combining content parts.
243// Up to 4 content parts can be combined in a single tuple.
244// For more parts, use `Vec<ContentPart>` instead.
245impl IntoMessageContent for (ContentPart,) {
246    fn into_message_content(self) -> MessageContent {
247        MessageContent::Parts(vec![self.0])
248    }
249}
250
251impl IntoMessageContent for (ContentPart, ContentPart) {
252    fn into_message_content(self) -> MessageContent {
253        MessageContent::Parts(vec![self.0, self.1])
254    }
255}
256
257impl IntoMessageContent for (ContentPart, ContentPart, ContentPart) {
258    fn into_message_content(self) -> MessageContent {
259        MessageContent::Parts(vec![self.0, self.1, self.2])
260    }
261}
262
263impl IntoMessageContent for (ContentPart, ContentPart, ContentPart, ContentPart) {
264    fn into_message_content(self) -> MessageContent {
265        MessageContent::Parts(vec![self.0, self.1, self.2, self.3])
266    }
267}
268
269// ============================================================================
270// Tool Helpers
271// ============================================================================
272
273/// Create a function tool definition.
274///
275/// # Example
276///
277/// ```rust
278/// use xai_rust::chat::tool;
279/// use serde_json::json;
280///
281/// let weather_tool = tool(
282///     "get_weather",
283///     "Get the current weather for a location",
284///     json!({
285///         "type": "object",
286///         "properties": {
287///             "location": {"type": "string", "description": "City name"}
288///         },
289///         "required": ["location"]
290///     }),
291/// );
292/// ```
293pub fn tool(
294    name: impl Into<String>,
295    description: impl Into<String>,
296    parameters: serde_json::Value,
297) -> crate::models::tool::Tool {
298    crate::models::tool::Tool::function(name, description, parameters)
299}
300
301/// Create a function tool with strict schema enforcement.
302///
303/// # Example
304///
305/// ```rust
306/// use xai_rust::chat::tool_strict;
307/// use serde_json::json;
308///
309/// let tool = tool_strict(
310///     "get_data",
311///     "Fetch data",
312///     json!({"type": "object", "properties": {}}),
313/// );
314/// ```
315pub fn tool_strict(
316    name: impl Into<String>,
317    description: impl Into<String>,
318    parameters: serde_json::Value,
319) -> crate::models::tool::Tool {
320    crate::models::tool::Tool::Function {
321        function: crate::models::tool::FunctionDefinition {
322            name: name.into(),
323            description: Some(description.into()),
324            parameters,
325            strict: Some(true),
326        },
327    }
328}
329
330/// Create a web search tool.
331///
332/// # Example
333///
334/// ```rust
335/// use xai_rust::chat::web_search;
336///
337/// let tool = web_search();
338/// ```
339pub fn web_search() -> crate::models::tool::Tool {
340    crate::models::tool::Tool::web_search()
341}
342
343/// Create a web search tool with allowed domains.
344///
345/// # Example
346///
347/// ```rust
348/// use xai_rust::chat::web_search_allow;
349///
350/// let tool = web_search_allow(vec!["wikipedia.org".into(), "docs.rs".into()]);
351/// ```
352pub fn web_search_allow(domains: Vec<String>) -> crate::models::tool::Tool {
353    crate::models::tool::Tool::web_search_filtered(
354        crate::models::tool::WebSearchFilters::allow_domains(domains),
355    )
356}
357
358/// Create a web search tool with excluded domains.
359///
360/// # Example
361///
362/// ```rust
363/// use xai_rust::chat::web_search_exclude;
364///
365/// let tool = web_search_exclude(vec!["example.com".into()]);
366/// ```
367pub fn web_search_exclude(domains: Vec<String>) -> crate::models::tool::Tool {
368    crate::models::tool::Tool::web_search_filtered(
369        crate::models::tool::WebSearchFilters::exclude_domains(domains),
370    )
371}
372
373/// Create an X (Twitter) search tool.
374///
375/// # Example
376///
377/// ```rust
378/// use xai_rust::chat::x_search;
379///
380/// let tool = x_search();
381/// ```
382pub fn x_search() -> crate::models::tool::Tool {
383    crate::models::tool::Tool::x_search()
384}
385
386/// Create a code interpreter tool.
387///
388/// # Example
389///
390/// ```rust
391/// use xai_rust::chat::code_interpreter;
392///
393/// let tool = code_interpreter();
394/// ```
395pub fn code_interpreter() -> crate::models::tool::Tool {
396    crate::models::tool::Tool::code_interpreter()
397}
398
399/// Create a collections search tool.
400///
401/// # Example
402///
403/// ```rust
404/// use xai_rust::chat::collections_search;
405///
406/// let tool = collections_search(vec!["collection-123".into()]);
407/// ```
408pub fn collections_search(collection_ids: Vec<String>) -> crate::models::tool::Tool {
409    crate::models::tool::Tool::collections_search(collection_ids)
410}
411
412/// Create an MCP (Model Context Protocol) tool.
413///
414/// # Example
415///
416/// ```rust
417/// use xai_rust::chat::mcp;
418///
419/// let tool = mcp("https://mcp-server.example.com");
420/// ```
421pub fn mcp(server_url: impl Into<String>) -> crate::models::tool::Tool {
422    crate::models::tool::Tool::mcp(server_url)
423}
424
425// ============================================================================
426// Tool Choice Helpers
427// ============================================================================
428
429/// Specify that the model should automatically decide whether to use tools.
430pub fn tool_choice_auto() -> crate::models::tool::ToolChoice {
431    crate::models::tool::ToolChoice::auto()
432}
433
434/// Specify that the model must use at least one tool.
435pub fn tool_choice_required() -> crate::models::tool::ToolChoice {
436    crate::models::tool::ToolChoice::required()
437}
438
439/// Specify that the model should not use any tools.
440pub fn tool_choice_none() -> crate::models::tool::ToolChoice {
441    crate::models::tool::ToolChoice::none()
442}
443
444/// Force the model to use a specific function.
445///
446/// # Example
447///
448/// ```rust
449/// use xai_rust::chat::required_tool;
450///
451/// let choice = required_tool("get_weather");
452/// ```
453pub fn required_tool(function_name: impl Into<String>) -> crate::models::tool::ToolChoice {
454    crate::models::tool::ToolChoice::function(function_name)
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use crate::models::content::ContentPart;
461    use crate::models::message::{MessageContent, Role};
462    use crate::models::tool::{Tool, ToolChoice};
463
464    // ── Message constructor helpers ────────────────────────────────────
465
466    #[test]
467    fn system_creates_system_message() {
468        let msg = system("You are helpful");
469        assert_eq!(msg.role, Role::System);
470        assert_eq!(msg.text(), "You are helpful");
471        assert!(msg.name.is_none());
472        assert!(msg.tool_call_id.is_none());
473    }
474
475    #[test]
476    fn user_creates_user_message_from_str() {
477        let msg = user("Hello!");
478        assert_eq!(msg.role, Role::User);
479        assert_eq!(msg.text(), "Hello!");
480    }
481
482    #[test]
483    fn user_creates_user_message_from_string() {
484        let msg = user(String::from("Hello!"));
485        assert_eq!(msg.role, Role::User);
486        assert_eq!(msg.text(), "Hello!");
487    }
488
489    #[test]
490    fn assistant_creates_assistant_message() {
491        let msg = assistant("Hi there!");
492        assert_eq!(msg.role, Role::Assistant);
493        assert_eq!(msg.text(), "Hi there!");
494    }
495
496    #[test]
497    fn developer_creates_developer_message() {
498        let msg = developer("You are a coder");
499        assert_eq!(msg.role, Role::Developer);
500        assert_eq!(msg.text(), "You are a coder");
501    }
502
503    #[test]
504    fn tool_result_creates_tool_message() {
505        let msg = tool_result("call_123", r#"{"temp":72}"#);
506        assert_eq!(msg.role, Role::Tool);
507        assert_eq!(msg.text(), r#"{"temp":72}"#);
508        assert_eq!(msg.tool_call_id.as_deref(), Some("call_123"));
509    }
510
511    // ── Content constructor helpers ────────────────────────────────────
512
513    #[test]
514    fn text_creates_text_part() {
515        let part = text("Hello");
516        assert_eq!(part.as_text().unwrap(), "Hello");
517    }
518
519    #[test]
520    fn image_creates_image_url_part() {
521        let part = image("https://example.com/img.jpg");
522        if let ContentPart::ImageUrl { image_url } = part {
523            assert_eq!(image_url.url, "https://example.com/img.jpg");
524            assert!(image_url.detail.is_none());
525        } else {
526            panic!("Expected ImageUrl variant");
527        }
528    }
529
530    #[test]
531    fn image_with_detail_creates_image_part() {
532        let part = image_with_detail(
533            "https://example.com/img.jpg",
534            crate::models::content::ImageDetail::High,
535        );
536        if let ContentPart::ImageUrl { image_url } = part {
537            assert_eq!(
538                image_url.detail,
539                Some(crate::models::content::ImageDetail::High)
540            );
541        } else {
542            panic!("Expected ImageUrl variant");
543        }
544    }
545
546    #[test]
547    fn image_base64_creates_data_url() {
548        let part = image_base64("abc123", "image/png");
549        if let ContentPart::ImageUrl { image_url } = part {
550            assert!(image_url.url.starts_with("data:image/png;base64,"));
551            assert!(image_url.url.ends_with("abc123"));
552        } else {
553            panic!("Expected ImageUrl variant");
554        }
555    }
556
557    #[test]
558    fn file_creates_file_part() {
559        let part = file("file-abc123");
560        if let ContentPart::File { file: fc } = part {
561            assert_eq!(fc.file_id.as_deref(), Some("file-abc123"));
562        } else {
563            panic!("Expected File variant");
564        }
565    }
566
567    #[test]
568    fn file_inline_creates_inline_file_part() {
569        let part = file_inline("test.json", r#"{"x": 1}"#);
570        if let ContentPart::File { file: fc } = part {
571            assert!(fc.file_id.is_none());
572            assert_eq!(fc.filename.as_deref(), Some("test.json"));
573            assert_eq!(fc.file_data.as_deref(), Some(r#"{"x": 1}"#));
574        } else {
575            panic!("Expected File variant");
576        }
577    }
578
579    // ── IntoMessageContent trait ──────────────────────────────────────
580
581    #[test]
582    fn into_message_content_str() {
583        let content: MessageContent = "hello".into_message_content();
584        assert_eq!(content.as_text().unwrap(), "hello");
585    }
586
587    #[test]
588    fn into_message_content_string() {
589        let content: MessageContent = String::from("hello").into_message_content();
590        assert_eq!(content.as_text().unwrap(), "hello");
591    }
592
593    #[test]
594    fn into_message_content_single_part() {
595        let part = ContentPart::text("part");
596        let content = part.into_message_content();
597        if let MessageContent::Parts(parts) = content {
598            assert_eq!(parts.len(), 1);
599            assert_eq!(parts[0].as_text().unwrap(), "part");
600        } else {
601            panic!("Expected Parts variant");
602        }
603    }
604
605    #[test]
606    fn into_message_content_vec_parts() {
607        let parts = vec![ContentPart::text("a"), ContentPart::text("b")];
608        let content = parts.into_message_content();
609        if let MessageContent::Parts(p) = content {
610            assert_eq!(p.len(), 2);
611        } else {
612            panic!("Expected Parts variant");
613        }
614    }
615
616    #[test]
617    fn into_message_content_tuple_1() {
618        let content = (ContentPart::text("a"),).into_message_content();
619        if let MessageContent::Parts(p) = content {
620            assert_eq!(p.len(), 1);
621        } else {
622            panic!("Expected Parts variant");
623        }
624    }
625
626    #[test]
627    fn into_message_content_tuple_2() {
628        let content =
629            (text("What's this?"), image("https://example.com/img.jpg")).into_message_content();
630        if let MessageContent::Parts(parts) = content {
631            assert_eq!(parts.len(), 2);
632            assert_eq!(parts[0].as_text().unwrap(), "What's this?");
633            assert!(matches!(parts[1], ContentPart::ImageUrl { .. }));
634        } else {
635            panic!("Expected Parts variant");
636        }
637    }
638
639    #[test]
640    fn into_message_content_tuple_3() {
641        let content = (text("a"), text("b"), text("c")).into_message_content();
642        if let MessageContent::Parts(p) = content {
643            assert_eq!(p.len(), 3);
644        } else {
645            panic!("Expected Parts variant");
646        }
647    }
648
649    #[test]
650    fn into_message_content_tuple_4() {
651        let content = (text("a"), text("b"), text("c"), text("d")).into_message_content();
652        if let MessageContent::Parts(p) = content {
653            assert_eq!(p.len(), 4);
654        } else {
655            panic!("Expected Parts variant");
656        }
657    }
658
659    #[test]
660    fn user_accepts_tuple_content() {
661        let msg = user((text("Look at this"), image("https://example.com/photo.jpg")));
662        assert_eq!(msg.role, Role::User);
663        if let MessageContent::Parts(parts) = &msg.content {
664            assert_eq!(parts.len(), 2);
665        } else {
666            panic!("Expected Parts content");
667        }
668    }
669
670    // ── Tool helpers ──────────────────────────────────────────────────
671
672    #[test]
673    fn tool_helper_creates_function_tool() {
674        let t = tool(
675            "test_fn",
676            "A test function",
677            serde_json::json!({"type": "object"}),
678        );
679        assert!(matches!(t, Tool::Function { .. }));
680    }
681
682    #[test]
683    fn tool_strict_creates_strict_function_tool() {
684        let t = tool_strict("test_fn", "strict function", serde_json::json!({}));
685        if let Tool::Function { function } = t {
686            assert_eq!(function.strict, Some(true));
687        } else {
688            panic!("Expected Function variant");
689        }
690    }
691
692    #[test]
693    fn web_search_helper() {
694        let t = web_search();
695        assert!(matches!(t, Tool::WebSearch { .. }));
696    }
697
698    #[test]
699    fn web_search_allow_helper() {
700        let t = web_search_allow(vec!["docs.rs".to_string()]);
701        if let Tool::WebSearch { filters, .. } = t {
702            assert!(filters.is_some());
703            let f = filters.unwrap();
704            assert_eq!(f.allowed_domains.unwrap(), vec!["docs.rs"]);
705        } else {
706            panic!("Expected WebSearch variant");
707        }
708    }
709
710    #[test]
711    fn web_search_exclude_helper() {
712        let t = web_search_exclude(vec!["example.com".to_string()]);
713        if let Tool::WebSearch { filters, .. } = t {
714            let f = filters.unwrap();
715            assert_eq!(f.excluded_domains.unwrap(), vec!["example.com"]);
716        } else {
717            panic!("Expected WebSearch variant");
718        }
719    }
720
721    #[test]
722    fn x_search_helper() {
723        let t = x_search();
724        assert!(matches!(t, Tool::XSearch { .. }));
725    }
726
727    #[test]
728    fn code_interpreter_helper() {
729        let t = code_interpreter();
730        assert!(matches!(t, Tool::CodeInterpreter {}));
731    }
732
733    #[test]
734    fn collections_search_helper() {
735        let t = collections_search(vec!["col-1".to_string()]);
736        assert!(matches!(t, Tool::CollectionsSearch { .. }));
737    }
738
739    #[test]
740    fn mcp_helper() {
741        let t = mcp("https://mcp.example.com");
742        assert!(matches!(t, Tool::Mcp { .. }));
743    }
744
745    // ── Tool choice helpers ──────────────────────────────────────────
746
747    #[test]
748    fn tool_choice_auto_helper() {
749        let tc = tool_choice_auto();
750        assert!(matches!(tc, ToolChoice::Auto(_)));
751    }
752
753    #[test]
754    fn tool_choice_required_helper() {
755        let tc = tool_choice_required();
756        assert!(matches!(tc, ToolChoice::Auto(_)));
757    }
758
759    #[test]
760    fn tool_choice_none_helper() {
761        let tc = tool_choice_none();
762        assert!(matches!(tc, ToolChoice::Auto(_)));
763    }
764
765    #[test]
766    fn required_tool_helper() {
767        let tc = required_tool("get_weather");
768        if let ToolChoice::Specific(spec) = tc {
769            assert_eq!(spec.tool_type, "function");
770            assert_eq!(spec.function.unwrap().name, "get_weather");
771        } else {
772            panic!("Expected Specific variant");
773        }
774    }
775}