Skip to main content

simple_agent_type/
message.rs

1//! Message types for LLM interactions.
2//!
3//! Provides role-based messages compatible with OpenAI's message format.
4
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::str::FromStr;
8use thiserror::Error;
9
10use crate::tool::ToolCall;
11
12/// Role of a message in a conversation.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum Role {
16    /// User message
17    User,
18    /// Assistant (LLM) message
19    Assistant,
20    /// System instruction message
21    System,
22    /// Tool/function call result
23    #[serde(rename = "tool")]
24    Tool,
25}
26
27#[derive(Debug, Clone, PartialEq, Eq, Error)]
28#[error("invalid message role '{role}' (expected: system|user|assistant|tool)")]
29/// Error returned when parsing an unknown message role string.
30pub struct ParseRoleError {
31    /// Original role string that failed to parse.
32    pub role: String,
33}
34
35impl Role {
36    /// Returns this role as its canonical lowercase string value.
37    pub fn as_str(self) -> &'static str {
38        match self {
39            Self::System => "system",
40            Self::User => "user",
41            Self::Assistant => "assistant",
42            Self::Tool => "tool",
43        }
44    }
45}
46
47impl FromStr for Role {
48    type Err = ParseRoleError;
49
50    fn from_str(s: &str) -> Result<Self, Self::Err> {
51        match s {
52            "system" => Ok(Self::System),
53            "user" => Ok(Self::User),
54            "assistant" => Ok(Self::Assistant),
55            "tool" => Ok(Self::Tool),
56            _ => Err(ParseRoleError {
57                role: s.to_string(),
58            }),
59        }
60    }
61}
62
63/// Content of a message — either plain text or a list of multimodal parts.
64#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
65#[serde(untagged)]
66pub enum MessageContent {
67    /// Plain text content.
68    Text(String),
69    /// Multimodal content parts (text, images, video, etc.).
70    Parts(Vec<ContentPart>),
71}
72
73impl From<String> for MessageContent {
74    fn from(s: String) -> Self {
75        Self::Text(s)
76    }
77}
78
79impl From<&str> for MessageContent {
80    fn from(s: &str) -> Self {
81        Self::Text(s.to_string())
82    }
83}
84
85impl MessageContent {
86    /// Returns the text content length for validation purposes.
87    /// For Parts, sums the text lengths of all text parts.
88    pub fn text_len(&self) -> usize {
89        match self {
90            Self::Text(s) => s.len(),
91            Self::Parts(parts) => parts
92                .iter()
93                .map(|p| match p {
94                    ContentPart::Text { text } => text.len(),
95                    _ => 0,
96                })
97                .sum(),
98        }
99    }
100
101    /// Returns true if the message has no usable content (empty string, empty parts,
102    /// or every part is empty).
103    pub fn is_empty_content(&self) -> bool {
104        match self {
105            Self::Text(s) => s.is_empty(),
106            Self::Parts(parts) => {
107                parts.is_empty()
108                    || parts.iter().all(|p| match p {
109                        ContentPart::Text { text } => text.is_empty(),
110                        ContentPart::ImageUrl { image_url } => image_url.url.is_empty(),
111                        ContentPart::Audio { input_audio } => input_audio.data.is_empty(),
112                        ContentPart::Video { video } => video.data.is_empty(),
113                    })
114            }
115        }
116    }
117
118    /// Returns true if the content contains a null byte.
119    pub fn contains_null(&self) -> bool {
120        match self {
121            Self::Text(s) => s.contains('\0'),
122            Self::Parts(parts) => parts.iter().any(|p| match p {
123                ContentPart::Text { text } => text.contains('\0'),
124                ContentPart::ImageUrl { image_url } => image_url.url.contains('\0'),
125                ContentPart::Audio { input_audio } => {
126                    input_audio.data.contains('\0') || input_audio.media_type.contains('\0')
127                }
128                ContentPart::Video { video } => {
129                    video.data.contains('\0') || video.media_type.contains('\0')
130                }
131            }),
132        }
133    }
134}
135
136/// Well-known MIME type constants for multimodal content.
137pub mod mime {
138    /// PNG image.
139    pub const IMAGE_PNG: &str = "image/png";
140    /// JPEG image.
141    pub const IMAGE_JPEG: &str = "image/jpeg";
142    /// WebP image.
143    pub const IMAGE_WEBP: &str = "image/webp";
144    /// GIF image.
145    pub const IMAGE_GIF: &str = "image/gif";
146    /// MP3 audio.
147    pub const AUDIO_MP3: &str = "audio/mpeg";
148    /// WAV audio.
149    pub const AUDIO_WAV: &str = "audio/wav";
150    /// FLAC audio.
151    pub const AUDIO_FLAC: &str = "audio/flac";
152    /// OGG audio.
153    pub const AUDIO_OGG: &str = "audio/ogg";
154    /// MP4 video.
155    pub const VIDEO_MP4: &str = "video/mp4";
156    /// WebM video.
157    pub const VIDEO_WEBM: &str = "video/webm";
158    /// MOV video.
159    pub const VIDEO_MOV: &str = "video/quicktime";
160    /// MKV video.
161    pub const VIDEO_MKV: &str = "video/x-matroska";
162}
163
164/// Base64-encoded media content with its MIME type.
165///
166/// Used by [`ContentPart::Audio`] and [`ContentPart::Video`] to carry inline
167/// media data. The same struct works for both input (user uploads) and output
168/// (model-generated audio/video).
169#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
170pub struct MediaContent {
171    /// MIME type (e.g. `"image/png"`, `"audio/wav"`). See [`mime`] for constants.
172    pub media_type: String,
173    /// Base64-encoded media bytes.
174    pub data: String,
175}
176
177/// A single content part in a multimodal message.
178#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
179#[serde(tag = "type")]
180pub enum ContentPart {
181    /// Text content.
182    #[serde(rename = "text")]
183    Text {
184        /// The text string.
185        text: String,
186    },
187    /// Image content (base64 data URI for OpenAI wire compatibility).
188    #[serde(rename = "image_url")]
189    ImageUrl {
190        /// The image URL and optional detail level.
191        image_url: ImageUrlContent,
192    },
193    /// Audio content (inline base64).
194    #[serde(rename = "input_audio")]
195    Audio {
196        /// The audio media payload.
197        input_audio: MediaContent,
198    },
199    /// Video content (inline base64).
200    #[serde(rename = "video")]
201    Video {
202        /// The video media payload.
203        video: MediaContent,
204    },
205}
206
207impl ContentPart {
208    /// Create a text content part.
209    pub fn text(text: impl Into<String>) -> Self {
210        Self::Text { text: text.into() }
211    }
212
213    /// Create an image content part from base64 data.
214    ///
215    /// Wraps the data as a `data:` URI in [`ImageUrlContent`] for OpenAI wire
216    /// compatibility.
217    pub fn image(media_type: impl Into<String>, data: impl Into<String>) -> Self {
218        let mt = media_type.into();
219        let d = data.into();
220        Self::ImageUrl {
221            image_url: ImageUrlContent {
222                url: format!("data:{mt};base64,{d}"),
223                detail: None,
224            },
225        }
226    }
227
228    /// Create an audio content part from base64 data.
229    pub fn audio(media_type: impl Into<String>, data: impl Into<String>) -> Self {
230        Self::Audio {
231            input_audio: MediaContent {
232                media_type: media_type.into(),
233                data: data.into(),
234            },
235        }
236    }
237
238    /// Create a video content part from base64 data.
239    pub fn video(media_type: impl Into<String>, data: impl Into<String>) -> Self {
240        Self::Video {
241            video: MediaContent {
242                media_type: media_type.into(),
243                data: data.into(),
244            },
245        }
246    }
247
248    /// Create an image URL content part (legacy).
249    ///
250    /// Prefer [`ContentPart::image`] for inline base64 data.
251    pub fn image_url(url: impl Into<String>) -> Self {
252        Self::ImageUrl {
253            image_url: ImageUrlContent {
254                url: url.into(),
255                detail: None,
256            },
257        }
258    }
259}
260
261/// Image URL content with optional detail level.
262#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
263pub struct ImageUrlContent {
264    /// The image URL (or `data:` URI for inline base64 images).
265    pub url: String,
266    /// Optional detail level (e.g. "low", "high", "auto").
267    #[serde(skip_serializing_if = "Option::is_none")]
268    pub detail: Option<String>,
269}
270
271/// A message in a conversation.
272#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
273pub struct Message {
274    /// Role of the message sender
275    pub role: Role,
276    /// Content of the message
277    pub content: MessageContent,
278    /// Optional name (for multi-user conversations or tool calls)
279    #[serde(skip_serializing_if = "Option::is_none")]
280    pub name: Option<String>,
281    /// Tool call ID (for tool role messages)
282    #[serde(skip_serializing_if = "Option::is_none")]
283    pub tool_call_id: Option<String>,
284    /// Tool calls emitted by the assistant.
285    #[serde(skip_serializing_if = "Option::is_none")]
286    pub tool_calls: Option<Vec<ToolCall>>,
287}
288
289impl Message {
290    /// Create a user message.
291    ///
292    /// # Example
293    /// ```
294    /// use simple_agent_type::message::{Message, Role};
295    ///
296    /// let msg = Message::user("Hello!");
297    /// assert_eq!(msg.role, Role::User);
298    /// assert_eq!(msg.content_text(), "Hello!");
299    /// ```
300    pub fn user(content: impl Into<String>) -> Self {
301        Self {
302            role: Role::User,
303            content: MessageContent::Text(content.into()),
304            name: None,
305            tool_call_id: None,
306            tool_calls: None,
307        }
308    }
309
310    /// Create an assistant message.
311    ///
312    /// # Example
313    /// ```
314    /// use simple_agent_type::message::{Message, Role};
315    ///
316    /// let msg = Message::assistant("Hi there!");
317    /// assert_eq!(msg.role, Role::Assistant);
318    /// ```
319    pub fn assistant(content: impl Into<String>) -> Self {
320        Self {
321            role: Role::Assistant,
322            content: MessageContent::Text(content.into()),
323            name: None,
324            tool_call_id: None,
325            tool_calls: None,
326        }
327    }
328
329    /// Create a system message.
330    ///
331    /// # Example
332    /// ```
333    /// use simple_agent_type::message::{Message, Role};
334    ///
335    /// let msg = Message::system("You are a helpful assistant.");
336    /// assert_eq!(msg.role, Role::System);
337    /// ```
338    pub fn system(content: impl Into<String>) -> Self {
339        Self {
340            role: Role::System,
341            content: MessageContent::Text(content.into()),
342            name: None,
343            tool_call_id: None,
344            tool_calls: None,
345        }
346    }
347
348    /// Create a tool message.
349    ///
350    /// # Example
351    /// ```
352    /// use simple_agent_type::message::{Message, Role};
353    ///
354    /// let msg = Message::tool("result", "call_123");
355    /// assert_eq!(msg.role, Role::Tool);
356    /// assert_eq!(msg.tool_call_id, Some("call_123".to_string()));
357    /// ```
358    pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
359        Self {
360            role: Role::Tool,
361            content: MessageContent::Text(content.into()),
362            name: None,
363            tool_call_id: Some(tool_call_id.into()),
364            tool_calls: None,
365        }
366    }
367
368    /// Create a user message with multimodal content parts.
369    pub fn user_parts(parts: Vec<ContentPart>) -> Self {
370        Self {
371            role: Role::User,
372            content: MessageContent::Parts(parts),
373            name: None,
374            tool_call_id: None,
375            tool_calls: None,
376        }
377    }
378
379    /// Extract the first text string from the message content.
380    ///
381    /// For `MessageContent::Text`, returns the string directly.
382    /// For `MessageContent::Parts`, returns the text of the first `Text` part.
383    /// Returns `""` if no text is found.
384    pub fn content_text(&self) -> &str {
385        match &self.content {
386            MessageContent::Text(s) => s.as_str(),
387            MessageContent::Parts(parts) => parts
388                .iter()
389                .find_map(|p| match p {
390                    ContentPart::Text { text } => Some(text.as_str()),
391                    _ => None,
392                })
393                .unwrap_or(""),
394        }
395    }
396
397    /// Set the name field (builder pattern).
398    ///
399    /// # Example
400    /// ```
401    /// use simple_agent_type::message::Message;
402    ///
403    /// let msg = Message::user("Hello").with_name("Alice");
404    /// assert_eq!(msg.name, Some("Alice".to_string()));
405    /// ```
406    pub fn with_name(mut self, name: impl Into<String>) -> Self {
407        self.name = Some(name.into());
408        self
409    }
410
411    /// Set tool calls for assistant messages.
412    pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCall>) -> Self {
413        self.tool_calls = Some(tool_calls);
414        self
415    }
416}
417
418#[derive(Debug, Clone, Deserialize)]
419struct MessageInputWire {
420    role: Role,
421    content: MessageContent,
422    #[serde(default)]
423    name: Option<String>,
424    #[serde(default, alias = "toolCallId")]
425    tool_call_id: Option<String>,
426    #[serde(default)]
427    tool_calls: Option<Vec<ToolCall>>,
428}
429
430/// Parses a JSON value containing an array of message objects into typed messages.
431pub fn parse_messages_value(value: &Value) -> Result<Vec<Message>, String> {
432    let wire_messages: Vec<MessageInputWire> = serde_json::from_value(value.clone())
433        .map_err(|e| format!("messages must be a list of message objects: {e}"))?;
434    if wire_messages.is_empty() {
435        return Err("messages cannot be empty".to_string());
436    }
437
438    wire_messages
439        .into_iter()
440        .enumerate()
441        .map(|(idx, wire)| {
442            if wire.content.is_empty_content() {
443                return Err(format!("message[{idx}].content cannot be empty"));
444            }
445
446            let content = wire.content;
447
448            let mut msg = match wire.role {
449                Role::System => Message {
450                    role: Role::System,
451                    content,
452                    name: None,
453                    tool_call_id: None,
454                    tool_calls: None,
455                },
456                Role::User => Message {
457                    role: Role::User,
458                    content,
459                    name: None,
460                    tool_call_id: None,
461                    tool_calls: None,
462                },
463                Role::Assistant => {
464                    let mut m = Message {
465                        role: Role::Assistant,
466                        content,
467                        name: None,
468                        tool_call_id: None,
469                        tool_calls: None,
470                    };
471                    if let Some(calls) = wire.tool_calls {
472                        if !calls.is_empty() {
473                            m = m.with_tool_calls(calls);
474                        }
475                    }
476                    m
477                }
478                Role::Tool => {
479                    let call_id = wire.tool_call_id.ok_or_else(|| {
480                        format!("message[{idx}].tool_call_id is required for tool role")
481                    })?;
482                    Message {
483                        role: Role::Tool,
484                        content,
485                        name: None,
486                        tool_call_id: Some(call_id),
487                        tool_calls: None,
488                    }
489                }
490            };
491
492            if let Some(name) = wire.name {
493                if !name.is_empty() {
494                    msg = msg.with_name(name);
495                }
496            }
497
498            Ok(msg)
499        })
500        .collect()
501}
502
503/// Parses a JSON string containing an array of message objects.
504pub fn parse_messages_json(messages_json: &str) -> Result<Vec<Message>, String> {
505    let value: Value =
506        serde_json::from_str(messages_json).map_err(|e| format!("invalid messages json: {e}"))?;
507    parse_messages_value(&value)
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    #[test]
515    fn test_message_user() {
516        let msg = Message::user("test");
517        assert_eq!(msg.role, Role::User);
518        assert_eq!(msg.content, MessageContent::Text("test".to_string()));
519        assert_eq!(msg.content_text(), "test");
520        assert_eq!(msg.name, None);
521        assert_eq!(msg.tool_call_id, None);
522        assert_eq!(msg.tool_calls, None);
523    }
524
525    #[test]
526    fn test_message_assistant() {
527        let msg = Message::assistant("response");
528        assert_eq!(msg.role, Role::Assistant);
529        assert_eq!(msg.content_text(), "response");
530        assert_eq!(msg.tool_calls, None);
531    }
532
533    #[test]
534    fn test_message_system() {
535        let msg = Message::system("instruction");
536        assert_eq!(msg.role, Role::System);
537        assert_eq!(msg.content_text(), "instruction");
538        assert_eq!(msg.tool_calls, None);
539    }
540
541    #[test]
542    fn test_message_tool() {
543        let msg = Message::tool("result", "call_123");
544        assert_eq!(msg.role, Role::Tool);
545        assert_eq!(msg.content_text(), "result");
546        assert_eq!(msg.tool_call_id, Some("call_123".to_string()));
547        assert_eq!(msg.tool_calls, None);
548    }
549
550    #[test]
551    fn test_message_with_name() {
552        let msg = Message::user("test").with_name("Alice");
553        assert_eq!(msg.name, Some("Alice".to_string()));
554    }
555
556    #[test]
557    fn test_role_serialization() {
558        let json = serde_json::to_string(&Role::User).unwrap();
559        assert_eq!(json, "\"user\"");
560
561        let json = serde_json::to_string(&Role::Assistant).unwrap();
562        assert_eq!(json, "\"assistant\"");
563
564        let json = serde_json::to_string(&Role::System).unwrap();
565        assert_eq!(json, "\"system\"");
566
567        let json = serde_json::to_string(&Role::Tool).unwrap();
568        assert_eq!(json, "\"tool\"");
569    }
570
571    #[test]
572    fn test_message_serialization() {
573        let msg = Message::user("Hello");
574        let json = serde_json::to_string(&msg).unwrap();
575        let parsed: Message = serde_json::from_str(&json).unwrap();
576        assert_eq!(msg, parsed);
577    }
578
579    #[test]
580    fn test_message_optional_fields_not_serialized() {
581        let msg = Message::user("test");
582        let json = serde_json::to_value(&msg).unwrap();
583        assert!(json.get("name").is_none());
584        assert!(json.get("tool_call_id").is_none());
585        assert!(json.get("tool_calls").is_none());
586    }
587
588    #[test]
589    fn test_message_with_name_serialized() {
590        let msg = Message::user("test").with_name("Alice");
591        let json = serde_json::to_value(&msg).unwrap();
592        assert_eq!(json.get("name").and_then(|v| v.as_str()), Some("Alice"));
593    }
594
595    #[test]
596    fn test_message_user_text() {
597        let msg = Message::user("hello");
598        assert_eq!(msg.role, Role::User);
599        assert_eq!(msg.content_text(), "hello");
600    }
601
602    #[test]
603    fn test_message_multimodal() {
604        let msg = Message::user_parts(vec![
605            ContentPart::text("what is this?"),
606            ContentPart::image_url("https://example.com/img.jpg"),
607        ]);
608        assert_eq!(msg.content_text(), "what is this?");
609    }
610
611    #[test]
612    fn test_content_part_image_inline_serde() {
613        let part = ContentPart::image(mime::IMAGE_PNG, "abc");
614        let v = serde_json::to_value(&part).unwrap();
615        assert_eq!(v["type"], "image_url");
616        assert!(v["image_url"]["url"]
617            .as_str()
618            .unwrap()
619            .starts_with("data:image/png;base64,"));
620        let parsed: ContentPart = serde_json::from_value(v).unwrap();
621        assert_eq!(parsed, part);
622    }
623
624    #[test]
625    fn test_content_part_audio_video_serde() {
626        let audio = ContentPart::audio(mime::AUDIO_WAV, "dGVzdA==");
627        let json = serde_json::to_string(&audio).unwrap();
628        let parsed: ContentPart = serde_json::from_str(&json).unwrap();
629        assert_eq!(parsed, audio);
630
631        let video = ContentPart::video(mime::VIDEO_MP4, "dGVzdA==");
632        let json = serde_json::to_string(&video).unwrap();
633        let parsed: ContentPart = serde_json::from_str(&json).unwrap();
634        assert_eq!(parsed, video);
635    }
636
637    #[test]
638    fn test_message_parts_image_only_not_empty() {
639        let msg = Message::user_parts(vec![ContentPart::image(mime::IMAGE_JPEG, "e30=")]);
640        assert!(!msg.content.is_empty_content());
641    }
642
643    #[test]
644    fn test_message_content_serialization() {
645        let msg = Message::user("hello");
646        let json = serde_json::to_value(&msg).unwrap();
647        assert_eq!(json["content"], "hello");
648        let msg2 = Message::user_parts(vec![ContentPart::text("hi")]);
649        let json2 = serde_json::to_value(&msg2).unwrap();
650        assert!(json2["content"].is_array());
651    }
652
653    #[test]
654    fn test_message_content_from_string() {
655        let content: MessageContent = "hello".into();
656        assert_eq!(content, MessageContent::Text("hello".to_string()));
657
658        let content: MessageContent = String::from("world").into();
659        assert_eq!(content, MessageContent::Text("world".to_string()));
660    }
661}