Skip to main content

simple_agent_type/
message.rs

1//! Message types for LLM interactions.
2//!
3//! Provides role-based messages compatible with OpenAI's message format.
4
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::str::FromStr;
8use thiserror::Error;
9
10use crate::tool::ToolCall;
11
12/// Role of a message in a conversation.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum Role {
16    /// User message
17    User,
18    /// Assistant (LLM) message
19    Assistant,
20    /// System instruction message
21    System,
22    /// Tool/function call result
23    #[serde(rename = "tool")]
24    Tool,
25}
26
27#[derive(Debug, Clone, PartialEq, Eq, Error)]
28#[error("invalid message role '{role}' (expected: system|user|assistant|tool)")]
29/// Error returned when parsing an unknown message role string.
30pub struct ParseRoleError {
31    /// Original role string that failed to parse.
32    pub role: String,
33}
34
35impl Role {
36    /// Returns this role as its canonical lowercase string value.
37    pub fn as_str(self) -> &'static str {
38        match self {
39            Self::System => "system",
40            Self::User => "user",
41            Self::Assistant => "assistant",
42            Self::Tool => "tool",
43        }
44    }
45}
46
47impl FromStr for Role {
48    type Err = ParseRoleError;
49
50    fn from_str(s: &str) -> Result<Self, Self::Err> {
51        match s {
52            "system" => Ok(Self::System),
53            "user" => Ok(Self::User),
54            "assistant" => Ok(Self::Assistant),
55            "tool" => Ok(Self::Tool),
56            _ => Err(ParseRoleError {
57                role: s.to_string(),
58            }),
59        }
60    }
61}
62
63/// Content of a message — either plain text or a list of multimodal parts.
64#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
65#[serde(untagged)]
66pub enum MessageContent {
67    /// Plain text content.
68    Text(String),
69    /// Multimodal content parts (text, images, video, etc.).
70    Parts(Vec<ContentPart>),
71}
72
73impl From<String> for MessageContent {
74    fn from(s: String) -> Self {
75        Self::Text(s)
76    }
77}
78
79impl From<&str> for MessageContent {
80    fn from(s: &str) -> Self {
81        Self::Text(s.to_string())
82    }
83}
84
85impl MessageContent {
86    /// Returns the text content length for validation purposes.
87    /// For Parts, sums the text lengths of all text parts.
88    pub fn text_len(&self) -> usize {
89        match self {
90            Self::Text(s) => s.len(),
91            Self::Parts(parts) => parts
92                .iter()
93                .map(|p| match p {
94                    ContentPart::Text { text } => text.len(),
95                    _ => 0,
96                })
97                .sum(),
98        }
99    }
100
101    /// Returns true if the content contains a null byte.
102    pub fn contains_null(&self) -> bool {
103        match self {
104            Self::Text(s) => s.contains('\0'),
105            Self::Parts(parts) => parts.iter().any(|p| match p {
106                ContentPart::Text { text } => text.contains('\0'),
107                _ => false,
108            }),
109        }
110    }
111}
112
113/// A single content part in a multimodal message.
114#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
115#[serde(tag = "type")]
116pub enum ContentPart {
117    /// Text content.
118    #[serde(rename = "text")]
119    Text {
120        /// The text string.
121        text: String,
122    },
123    /// Image URL content.
124    #[serde(rename = "image_url")]
125    ImageUrl {
126        /// The image URL and optional detail level.
127        image_url: ImageUrlContent,
128    },
129    /// Video URL content.
130    #[serde(rename = "video_url")]
131    Video {
132        /// The video URL.
133        url: String,
134    },
135}
136
137impl ContentPart {
138    /// Create a text content part.
139    pub fn text(text: impl Into<String>) -> Self {
140        Self::Text { text: text.into() }
141    }
142
143    /// Create an image URL content part.
144    pub fn image_url(url: impl Into<String>) -> Self {
145        Self::ImageUrl {
146            image_url: ImageUrlContent {
147                url: url.into(),
148                detail: None,
149            },
150        }
151    }
152}
153
154/// Image URL content with optional detail level.
155#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
156pub struct ImageUrlContent {
157    /// The image URL.
158    pub url: String,
159    /// Optional detail level (e.g. "low", "high", "auto").
160    #[serde(skip_serializing_if = "Option::is_none")]
161    pub detail: Option<String>,
162}
163
164/// A message in a conversation.
165#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
166pub struct Message {
167    /// Role of the message sender
168    pub role: Role,
169    /// Content of the message
170    pub content: MessageContent,
171    /// Optional name (for multi-user conversations or tool calls)
172    #[serde(skip_serializing_if = "Option::is_none")]
173    pub name: Option<String>,
174    /// Tool call ID (for tool role messages)
175    #[serde(skip_serializing_if = "Option::is_none")]
176    pub tool_call_id: Option<String>,
177    /// Tool calls emitted by the assistant.
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub tool_calls: Option<Vec<ToolCall>>,
180}
181
182impl Message {
183    /// Create a user message.
184    ///
185    /// # Example
186    /// ```
187    /// use simple_agent_type::message::{Message, Role};
188    ///
189    /// let msg = Message::user("Hello!");
190    /// assert_eq!(msg.role, Role::User);
191    /// assert_eq!(msg.content_text(), "Hello!");
192    /// ```
193    pub fn user(content: impl Into<String>) -> Self {
194        Self {
195            role: Role::User,
196            content: MessageContent::Text(content.into()),
197            name: None,
198            tool_call_id: None,
199            tool_calls: None,
200        }
201    }
202
203    /// Create an assistant message.
204    ///
205    /// # Example
206    /// ```
207    /// use simple_agent_type::message::{Message, Role};
208    ///
209    /// let msg = Message::assistant("Hi there!");
210    /// assert_eq!(msg.role, Role::Assistant);
211    /// ```
212    pub fn assistant(content: impl Into<String>) -> Self {
213        Self {
214            role: Role::Assistant,
215            content: MessageContent::Text(content.into()),
216            name: None,
217            tool_call_id: None,
218            tool_calls: None,
219        }
220    }
221
222    /// Create a system message.
223    ///
224    /// # Example
225    /// ```
226    /// use simple_agent_type::message::{Message, Role};
227    ///
228    /// let msg = Message::system("You are a helpful assistant.");
229    /// assert_eq!(msg.role, Role::System);
230    /// ```
231    pub fn system(content: impl Into<String>) -> Self {
232        Self {
233            role: Role::System,
234            content: MessageContent::Text(content.into()),
235            name: None,
236            tool_call_id: None,
237            tool_calls: None,
238        }
239    }
240
241    /// Create a tool message.
242    ///
243    /// # Example
244    /// ```
245    /// use simple_agent_type::message::{Message, Role};
246    ///
247    /// let msg = Message::tool("result", "call_123");
248    /// assert_eq!(msg.role, Role::Tool);
249    /// assert_eq!(msg.tool_call_id, Some("call_123".to_string()));
250    /// ```
251    pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
252        Self {
253            role: Role::Tool,
254            content: MessageContent::Text(content.into()),
255            name: None,
256            tool_call_id: Some(tool_call_id.into()),
257            tool_calls: None,
258        }
259    }
260
261    /// Create a user message with multimodal content parts.
262    pub fn user_parts(parts: Vec<ContentPart>) -> Self {
263        Self {
264            role: Role::User,
265            content: MessageContent::Parts(parts),
266            name: None,
267            tool_call_id: None,
268            tool_calls: None,
269        }
270    }
271
272    /// Extract the first text string from the message content.
273    ///
274    /// For `MessageContent::Text`, returns the string directly.
275    /// For `MessageContent::Parts`, returns the text of the first `Text` part.
276    /// Returns `""` if no text is found.
277    pub fn content_text(&self) -> &str {
278        match &self.content {
279            MessageContent::Text(s) => s.as_str(),
280            MessageContent::Parts(parts) => parts
281                .iter()
282                .find_map(|p| match p {
283                    ContentPart::Text { text } => Some(text.as_str()),
284                    _ => None,
285                })
286                .unwrap_or(""),
287        }
288    }
289
290    /// Set the name field (builder pattern).
291    ///
292    /// # Example
293    /// ```
294    /// use simple_agent_type::message::Message;
295    ///
296    /// let msg = Message::user("Hello").with_name("Alice");
297    /// assert_eq!(msg.name, Some("Alice".to_string()));
298    /// ```
299    pub fn with_name(mut self, name: impl Into<String>) -> Self {
300        self.name = Some(name.into());
301        self
302    }
303
304    /// Set tool calls for assistant messages.
305    pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCall>) -> Self {
306        self.tool_calls = Some(tool_calls);
307        self
308    }
309}
310
311#[derive(Debug, Clone, Deserialize)]
312struct MessageInputWire {
313    role: Role,
314    content: MessageContent,
315    #[serde(default)]
316    name: Option<String>,
317    #[serde(default, alias = "toolCallId")]
318    tool_call_id: Option<String>,
319    #[serde(default)]
320    tool_calls: Option<Vec<ToolCall>>,
321}
322
323/// Parses a JSON value containing an array of message objects into typed messages.
324pub fn parse_messages_value(value: &Value) -> Result<Vec<Message>, String> {
325    let wire_messages: Vec<MessageInputWire> = serde_json::from_value(value.clone())
326        .map_err(|e| format!("messages must be a list of message objects: {e}"))?;
327    if wire_messages.is_empty() {
328        return Err("messages cannot be empty".to_string());
329    }
330
331    wire_messages
332        .into_iter()
333        .enumerate()
334        .map(|(idx, wire)| {
335            if wire.content.text_len() == 0 {
336                return Err(format!("message[{idx}].content cannot be empty"));
337            }
338
339            let content = wire.content;
340
341            let mut msg = match wire.role {
342                Role::System => Message {
343                    role: Role::System,
344                    content,
345                    name: None,
346                    tool_call_id: None,
347                    tool_calls: None,
348                },
349                Role::User => Message {
350                    role: Role::User,
351                    content,
352                    name: None,
353                    tool_call_id: None,
354                    tool_calls: None,
355                },
356                Role::Assistant => {
357                    let mut m = Message {
358                        role: Role::Assistant,
359                        content,
360                        name: None,
361                        tool_call_id: None,
362                        tool_calls: None,
363                    };
364                    if let Some(calls) = wire.tool_calls {
365                        if !calls.is_empty() {
366                            m = m.with_tool_calls(calls);
367                        }
368                    }
369                    m
370                }
371                Role::Tool => {
372                    let call_id = wire.tool_call_id.ok_or_else(|| {
373                        format!("message[{idx}].tool_call_id is required for tool role")
374                    })?;
375                    Message {
376                        role: Role::Tool,
377                        content,
378                        name: None,
379                        tool_call_id: Some(call_id),
380                        tool_calls: None,
381                    }
382                }
383            };
384
385            if let Some(name) = wire.name {
386                if !name.is_empty() {
387                    msg = msg.with_name(name);
388                }
389            }
390
391            Ok(msg)
392        })
393        .collect()
394}
395
396/// Parses a JSON string containing an array of message objects.
397pub fn parse_messages_json(messages_json: &str) -> Result<Vec<Message>, String> {
398    let value: Value =
399        serde_json::from_str(messages_json).map_err(|e| format!("invalid messages json: {e}"))?;
400    parse_messages_value(&value)
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    #[test]
408    fn test_message_user() {
409        let msg = Message::user("test");
410        assert_eq!(msg.role, Role::User);
411        assert_eq!(msg.content, MessageContent::Text("test".to_string()));
412        assert_eq!(msg.content_text(), "test");
413        assert_eq!(msg.name, None);
414        assert_eq!(msg.tool_call_id, None);
415        assert_eq!(msg.tool_calls, None);
416    }
417
418    #[test]
419    fn test_message_assistant() {
420        let msg = Message::assistant("response");
421        assert_eq!(msg.role, Role::Assistant);
422        assert_eq!(msg.content_text(), "response");
423        assert_eq!(msg.tool_calls, None);
424    }
425
426    #[test]
427    fn test_message_system() {
428        let msg = Message::system("instruction");
429        assert_eq!(msg.role, Role::System);
430        assert_eq!(msg.content_text(), "instruction");
431        assert_eq!(msg.tool_calls, None);
432    }
433
434    #[test]
435    fn test_message_tool() {
436        let msg = Message::tool("result", "call_123");
437        assert_eq!(msg.role, Role::Tool);
438        assert_eq!(msg.content_text(), "result");
439        assert_eq!(msg.tool_call_id, Some("call_123".to_string()));
440        assert_eq!(msg.tool_calls, None);
441    }
442
443    #[test]
444    fn test_message_with_name() {
445        let msg = Message::user("test").with_name("Alice");
446        assert_eq!(msg.name, Some("Alice".to_string()));
447    }
448
449    #[test]
450    fn test_role_serialization() {
451        let json = serde_json::to_string(&Role::User).unwrap();
452        assert_eq!(json, "\"user\"");
453
454        let json = serde_json::to_string(&Role::Assistant).unwrap();
455        assert_eq!(json, "\"assistant\"");
456
457        let json = serde_json::to_string(&Role::System).unwrap();
458        assert_eq!(json, "\"system\"");
459
460        let json = serde_json::to_string(&Role::Tool).unwrap();
461        assert_eq!(json, "\"tool\"");
462    }
463
464    #[test]
465    fn test_message_serialization() {
466        let msg = Message::user("Hello");
467        let json = serde_json::to_string(&msg).unwrap();
468        let parsed: Message = serde_json::from_str(&json).unwrap();
469        assert_eq!(msg, parsed);
470    }
471
472    #[test]
473    fn test_message_optional_fields_not_serialized() {
474        let msg = Message::user("test");
475        let json = serde_json::to_value(&msg).unwrap();
476        assert!(json.get("name").is_none());
477        assert!(json.get("tool_call_id").is_none());
478        assert!(json.get("tool_calls").is_none());
479    }
480
481    #[test]
482    fn test_message_with_name_serialized() {
483        let msg = Message::user("test").with_name("Alice");
484        let json = serde_json::to_value(&msg).unwrap();
485        assert_eq!(json.get("name").and_then(|v| v.as_str()), Some("Alice"));
486    }
487
488    #[test]
489    fn test_message_user_text() {
490        let msg = Message::user("hello");
491        assert_eq!(msg.role, Role::User);
492        assert_eq!(msg.content_text(), "hello");
493    }
494
495    #[test]
496    fn test_message_multimodal() {
497        let msg = Message::user_parts(vec![
498            ContentPart::text("what is this?"),
499            ContentPart::image_url("https://example.com/img.jpg"),
500        ]);
501        assert_eq!(msg.content_text(), "what is this?");
502    }
503
504    #[test]
505    fn test_message_content_serialization() {
506        let msg = Message::user("hello");
507        let json = serde_json::to_value(&msg).unwrap();
508        assert_eq!(json["content"], "hello");
509        let msg2 = Message::user_parts(vec![ContentPart::text("hi")]);
510        let json2 = serde_json::to_value(&msg2).unwrap();
511        assert!(json2["content"].is_array());
512    }
513
514    #[test]
515    fn test_message_content_from_string() {
516        let content: MessageContent = "hello".into();
517        assert_eq!(content, MessageContent::Text("hello".to_string()));
518
519        let content: MessageContent = String::from("world").into();
520        assert_eq!(content, MessageContent::Text("world".to_string()));
521    }
522}