Skip to main content

xai_rust/models/
message.rs

1//! Message types for chat interactions.
2
3use serde::{Deserialize, Serialize};
4
5use super::content::ContentPart;
6
7/// Role of a message in a conversation.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9#[serde(rename_all = "lowercase")]
10pub enum Role {
11    /// System message that sets the behavior of the assistant.
12    System,
13    /// User message (the human's input).
14    User,
15    /// Assistant message (the model's response).
16    Assistant,
17    /// Developer message (alias for system).
18    Developer,
19    /// Tool message containing tool call results.
20    Tool,
21}
22
23impl Role {
24    /// Check if this is a system-level role (system or developer).
25    pub fn is_system(&self) -> bool {
26        matches!(self, Role::System | Role::Developer)
27    }
28}
29
30/// Message content - either simple text or multiple content parts.
31#[derive(Debug, Clone, Serialize, Deserialize)]
32#[serde(untagged)]
33pub enum MessageContent {
34    /// Simple text content.
35    Text(String),
36    /// Multiple content parts (text, images, files).
37    Parts(Vec<ContentPart>),
38}
39
40impl MessageContent {
41    /// Create text content.
42    pub fn text(text: impl Into<String>) -> Self {
43        Self::Text(text.into())
44    }
45
46    /// Create content with multiple parts.
47    pub fn parts(parts: Vec<ContentPart>) -> Self {
48        Self::Parts(parts)
49    }
50
51    /// Get the text content if this is a simple text message.
52    pub fn as_text(&self) -> Option<&str> {
53        match self {
54            MessageContent::Text(text) => Some(text),
55            MessageContent::Parts(_) => None,
56        }
57    }
58
59    /// Get all text from this content (joining multiple text parts).
60    pub fn to_text(&self) -> String {
61        match self {
62            MessageContent::Text(text) => text.clone(),
63            MessageContent::Parts(parts) => parts
64                .iter()
65                .filter_map(|p| p.as_text())
66                .collect::<Vec<_>>()
67                .join(""),
68        }
69    }
70}
71
72impl From<String> for MessageContent {
73    fn from(text: String) -> Self {
74        Self::Text(text)
75    }
76}
77
78impl From<&str> for MessageContent {
79    fn from(text: &str) -> Self {
80        Self::Text(text.to_string())
81    }
82}
83
84impl From<Vec<ContentPart>> for MessageContent {
85    fn from(parts: Vec<ContentPart>) -> Self {
86        Self::Parts(parts)
87    }
88}
89
90/// A message in a conversation.
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct Message {
93    /// The role of the message author.
94    pub role: Role,
95    /// The content of the message.
96    pub content: MessageContent,
97    /// Optional name of the author (for multi-agent scenarios).
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub name: Option<String>,
100    /// Tool call ID (for tool messages).
101    #[serde(skip_serializing_if = "Option::is_none")]
102    pub tool_call_id: Option<String>,
103}
104
105impl Message {
106    /// Create a new message.
107    pub fn new(role: Role, content: impl Into<MessageContent>) -> Self {
108        Self {
109            role,
110            content: content.into(),
111            name: None,
112            tool_call_id: None,
113        }
114    }
115
116    /// Create a system message.
117    pub fn system(content: impl Into<String>) -> Self {
118        Self::new(Role::System, content.into())
119    }
120
121    /// Create a user message.
122    pub fn user(content: impl Into<MessageContent>) -> Self {
123        Self::new(Role::User, content)
124    }
125
126    /// Create an assistant message.
127    pub fn assistant(content: impl Into<String>) -> Self {
128        Self::new(Role::Assistant, content.into())
129    }
130
131    /// Create a tool result message.
132    pub fn tool(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
133        Self {
134            role: Role::Tool,
135            content: MessageContent::Text(content.into()),
136            name: None,
137            tool_call_id: Some(tool_call_id.into()),
138        }
139    }
140
141    /// Set the name of the message author.
142    pub fn with_name(mut self, name: impl Into<String>) -> Self {
143        self.name = Some(name.into());
144        self
145    }
146
147    /// Get the text content of this message.
148    pub fn text(&self) -> String {
149        self.content.to_text()
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    // ── Role serde roundtrip ───────────────────────────────────────────
158
159    #[test]
160    fn role_serializes_lowercase() {
161        assert_eq!(serde_json::to_string(&Role::System).unwrap(), r#""system""#);
162        assert_eq!(serde_json::to_string(&Role::User).unwrap(), r#""user""#);
163        assert_eq!(
164            serde_json::to_string(&Role::Assistant).unwrap(),
165            r#""assistant""#
166        );
167        assert_eq!(
168            serde_json::to_string(&Role::Developer).unwrap(),
169            r#""developer""#
170        );
171        assert_eq!(serde_json::to_string(&Role::Tool).unwrap(), r#""tool""#);
172    }
173
174    #[test]
175    fn role_roundtrip_all_variants() {
176        for role in [
177            Role::System,
178            Role::User,
179            Role::Assistant,
180            Role::Developer,
181            Role::Tool,
182        ] {
183            let json = serde_json::to_string(&role).unwrap();
184            let back: Role = serde_json::from_str(&json).unwrap();
185            assert_eq!(back, role);
186        }
187    }
188
189    #[test]
190    fn role_is_system() {
191        assert!(Role::System.is_system());
192        assert!(Role::Developer.is_system());
193        assert!(!Role::User.is_system());
194        assert!(!Role::Assistant.is_system());
195        assert!(!Role::Tool.is_system());
196    }
197
198    #[test]
199    fn role_rejects_unknown() {
200        let result = serde_json::from_str::<Role>(r#""moderator""#);
201        assert!(result.is_err());
202    }
203
204    // ── MessageContent serde roundtrip ─────────────────────────────────
205
206    #[test]
207    fn message_content_text_roundtrip() {
208        let content = MessageContent::Text("hello world".to_string());
209        let json = serde_json::to_value(&content).unwrap();
210        // untagged: simple string serialises as a plain JSON string
211        assert_eq!(json, serde_json::json!("hello world"));
212
213        let back: MessageContent = serde_json::from_value(json).unwrap();
214        assert_eq!(back.as_text().unwrap(), "hello world");
215    }
216
217    #[test]
218    fn message_content_parts_roundtrip() {
219        let content =
220            MessageContent::Parts(vec![ContentPart::text("part1"), ContentPart::text("part2")]);
221        let json = serde_json::to_value(&content).unwrap();
222        assert!(json.is_array());
223
224        let back: MessageContent = serde_json::from_value(json).unwrap();
225        assert_eq!(back.to_text(), "part1part2");
226    }
227
228    #[test]
229    fn message_content_as_text_returns_none_for_parts() {
230        let content = MessageContent::Parts(vec![ContentPart::text("x")]);
231        assert!(content.as_text().is_none());
232    }
233
234    #[test]
235    fn message_content_from_string() {
236        let content: MessageContent = "hello".into();
237        assert_eq!(content.as_text().unwrap(), "hello");
238    }
239
240    #[test]
241    fn message_content_from_owned_string() {
242        let content: MessageContent = String::from("hello").into();
243        assert_eq!(content.as_text().unwrap(), "hello");
244    }
245
246    #[test]
247    fn message_content_from_vec_parts() {
248        let parts = vec![ContentPart::text("a"), ContentPart::text("b")];
249        let content: MessageContent = parts.into();
250        assert_eq!(content.to_text(), "ab");
251    }
252
253    // ── Message serde roundtrip ────────────────────────────────────────
254
255    #[test]
256    fn message_system_roundtrip() {
257        let msg = Message::system("You are helpful");
258        let json = serde_json::to_value(&msg).unwrap();
259        assert_eq!(json["role"], "system");
260        assert_eq!(json["content"], "You are helpful");
261
262        let back: Message = serde_json::from_value(json).unwrap();
263        assert_eq!(back.role, Role::System);
264        assert_eq!(back.text(), "You are helpful");
265    }
266
267    #[test]
268    fn message_user_roundtrip() {
269        let msg = Message::user("What is 1+1?");
270        let json = serde_json::to_value(&msg).unwrap();
271        assert_eq!(json["role"], "user");
272
273        let back: Message = serde_json::from_value(json).unwrap();
274        assert_eq!(back.role, Role::User);
275        assert_eq!(back.text(), "What is 1+1?");
276    }
277
278    #[test]
279    fn message_assistant_roundtrip() {
280        let msg = Message::assistant("The answer is 2");
281        let json = serde_json::to_value(&msg).unwrap();
282        assert_eq!(json["role"], "assistant");
283
284        let back: Message = serde_json::from_value(json).unwrap();
285        assert_eq!(back.role, Role::Assistant);
286    }
287
288    #[test]
289    fn message_tool_roundtrip() {
290        let msg = Message::tool("call_123", r#"{"result": 42}"#);
291        let json = serde_json::to_value(&msg).unwrap();
292        assert_eq!(json["role"], "tool");
293        assert_eq!(json["tool_call_id"], "call_123");
294
295        let back: Message = serde_json::from_value(json).unwrap();
296        assert_eq!(back.role, Role::Tool);
297        assert_eq!(back.tool_call_id.as_deref(), Some("call_123"));
298    }
299
300    #[test]
301    fn message_with_name_roundtrip() {
302        let msg = Message::user("hi").with_name("alice");
303        let json = serde_json::to_value(&msg).unwrap();
304        assert_eq!(json["name"], "alice");
305
306        let back: Message = serde_json::from_value(json).unwrap();
307        assert_eq!(back.name.as_deref(), Some("alice"));
308    }
309
310    #[test]
311    fn message_skips_none_fields() {
312        let msg = Message::user("hi");
313        let json = serde_json::to_value(&msg).unwrap();
314        assert!(json.get("name").is_none());
315        assert!(json.get("tool_call_id").is_none());
316    }
317}