1use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::str::FromStr;
8use thiserror::Error;
9
10use crate::tool::ToolCall;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum Role {
16 User,
18 Assistant,
20 System,
22 #[serde(rename = "tool")]
24 Tool,
25}
26
27#[derive(Debug, Clone, PartialEq, Eq, Error)]
28#[error("invalid message role '{role}' (expected: system|user|assistant|tool)")]
29pub struct ParseRoleError {
31 pub role: String,
33}
34
35impl Role {
36 pub fn as_str(self) -> &'static str {
38 match self {
39 Self::System => "system",
40 Self::User => "user",
41 Self::Assistant => "assistant",
42 Self::Tool => "tool",
43 }
44 }
45}
46
47impl FromStr for Role {
48 type Err = ParseRoleError;
49
50 fn from_str(s: &str) -> Result<Self, Self::Err> {
51 match s {
52 "system" => Ok(Self::System),
53 "user" => Ok(Self::User),
54 "assistant" => Ok(Self::Assistant),
55 "tool" => Ok(Self::Tool),
56 _ => Err(ParseRoleError {
57 role: s.to_string(),
58 }),
59 }
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
65#[serde(untagged)]
66pub enum MessageContent {
67 Text(String),
69 Parts(Vec<ContentPart>),
71}
72
73impl From<String> for MessageContent {
74 fn from(s: String) -> Self {
75 Self::Text(s)
76 }
77}
78
79impl From<&str> for MessageContent {
80 fn from(s: &str) -> Self {
81 Self::Text(s.to_string())
82 }
83}
84
85impl MessageContent {
86 pub fn text_len(&self) -> usize {
89 match self {
90 Self::Text(s) => s.len(),
91 Self::Parts(parts) => parts
92 .iter()
93 .map(|p| match p {
94 ContentPart::Text { text } => text.len(),
95 _ => 0,
96 })
97 .sum(),
98 }
99 }
100
101 pub fn is_empty_content(&self) -> bool {
104 match self {
105 Self::Text(s) => s.is_empty(),
106 Self::Parts(parts) => {
107 parts.is_empty()
108 || parts.iter().all(|p| match p {
109 ContentPart::Text { text } => text.is_empty(),
110 ContentPart::ImageUrl { image_url } => image_url.url.is_empty(),
111 ContentPart::Audio { input_audio } => input_audio.data.is_empty(),
112 ContentPart::Video { video } => video.data.is_empty(),
113 })
114 }
115 }
116 }
117
118 pub fn contains_null(&self) -> bool {
120 match self {
121 Self::Text(s) => s.contains('\0'),
122 Self::Parts(parts) => parts.iter().any(|p| match p {
123 ContentPart::Text { text } => text.contains('\0'),
124 ContentPart::ImageUrl { image_url } => image_url.url.contains('\0'),
125 ContentPart::Audio { input_audio } => {
126 input_audio.data.contains('\0') || input_audio.media_type.contains('\0')
127 }
128 ContentPart::Video { video } => {
129 video.data.contains('\0') || video.media_type.contains('\0')
130 }
131 }),
132 }
133 }
134}
135
136pub mod mime {
138 pub const IMAGE_PNG: &str = "image/png";
140 pub const IMAGE_JPEG: &str = "image/jpeg";
142 pub const IMAGE_WEBP: &str = "image/webp";
144 pub const IMAGE_GIF: &str = "image/gif";
146 pub const AUDIO_MP3: &str = "audio/mpeg";
148 pub const AUDIO_WAV: &str = "audio/wav";
150 pub const AUDIO_FLAC: &str = "audio/flac";
152 pub const AUDIO_OGG: &str = "audio/ogg";
154 pub const VIDEO_MP4: &str = "video/mp4";
156 pub const VIDEO_WEBM: &str = "video/webm";
158 pub const VIDEO_MOV: &str = "video/quicktime";
160 pub const VIDEO_MKV: &str = "video/x-matroska";
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
170pub struct MediaContent {
171 pub media_type: String,
173 pub data: String,
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
179#[serde(tag = "type")]
180pub enum ContentPart {
181 #[serde(rename = "text")]
183 Text {
184 text: String,
186 },
187 #[serde(rename = "image_url")]
189 ImageUrl {
190 image_url: ImageUrlContent,
192 },
193 #[serde(rename = "input_audio")]
195 Audio {
196 input_audio: MediaContent,
198 },
199 #[serde(rename = "video")]
201 Video {
202 video: MediaContent,
204 },
205}
206
207impl ContentPart {
208 pub fn text(text: impl Into<String>) -> Self {
210 Self::Text { text: text.into() }
211 }
212
213 pub fn image(media_type: impl Into<String>, data: impl Into<String>) -> Self {
218 let mt = media_type.into();
219 let d = data.into();
220 Self::ImageUrl {
221 image_url: ImageUrlContent {
222 url: format!("data:{mt};base64,{d}"),
223 detail: None,
224 },
225 }
226 }
227
228 pub fn audio(media_type: impl Into<String>, data: impl Into<String>) -> Self {
230 Self::Audio {
231 input_audio: MediaContent {
232 media_type: media_type.into(),
233 data: data.into(),
234 },
235 }
236 }
237
238 pub fn video(media_type: impl Into<String>, data: impl Into<String>) -> Self {
240 Self::Video {
241 video: MediaContent {
242 media_type: media_type.into(),
243 data: data.into(),
244 },
245 }
246 }
247
248 pub fn image_url(url: impl Into<String>) -> Self {
252 Self::ImageUrl {
253 image_url: ImageUrlContent {
254 url: url.into(),
255 detail: None,
256 },
257 }
258 }
259}
260
261#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
263pub struct ImageUrlContent {
264 pub url: String,
266 #[serde(skip_serializing_if = "Option::is_none")]
268 pub detail: Option<String>,
269}
270
271#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
273pub struct Message {
274 pub role: Role,
276 pub content: MessageContent,
278 #[serde(skip_serializing_if = "Option::is_none")]
280 pub name: Option<String>,
281 #[serde(skip_serializing_if = "Option::is_none")]
283 pub tool_call_id: Option<String>,
284 #[serde(skip_serializing_if = "Option::is_none")]
286 pub tool_calls: Option<Vec<ToolCall>>,
287}
288
289impl Message {
290 pub fn user(content: impl Into<String>) -> Self {
301 Self {
302 role: Role::User,
303 content: MessageContent::Text(content.into()),
304 name: None,
305 tool_call_id: None,
306 tool_calls: None,
307 }
308 }
309
310 pub fn assistant(content: impl Into<String>) -> Self {
320 Self {
321 role: Role::Assistant,
322 content: MessageContent::Text(content.into()),
323 name: None,
324 tool_call_id: None,
325 tool_calls: None,
326 }
327 }
328
329 pub fn system(content: impl Into<String>) -> Self {
339 Self {
340 role: Role::System,
341 content: MessageContent::Text(content.into()),
342 name: None,
343 tool_call_id: None,
344 tool_calls: None,
345 }
346 }
347
348 pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
359 Self {
360 role: Role::Tool,
361 content: MessageContent::Text(content.into()),
362 name: None,
363 tool_call_id: Some(tool_call_id.into()),
364 tool_calls: None,
365 }
366 }
367
368 pub fn user_parts(parts: Vec<ContentPart>) -> Self {
370 Self {
371 role: Role::User,
372 content: MessageContent::Parts(parts),
373 name: None,
374 tool_call_id: None,
375 tool_calls: None,
376 }
377 }
378
379 pub fn content_text(&self) -> &str {
385 match &self.content {
386 MessageContent::Text(s) => s.as_str(),
387 MessageContent::Parts(parts) => parts
388 .iter()
389 .find_map(|p| match p {
390 ContentPart::Text { text } => Some(text.as_str()),
391 _ => None,
392 })
393 .unwrap_or(""),
394 }
395 }
396
397 pub fn with_name(mut self, name: impl Into<String>) -> Self {
407 self.name = Some(name.into());
408 self
409 }
410
411 pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCall>) -> Self {
413 self.tool_calls = Some(tool_calls);
414 self
415 }
416}
417
418#[derive(Debug, Clone, Deserialize)]
419struct MessageInputWire {
420 role: Role,
421 content: MessageContent,
422 #[serde(default)]
423 name: Option<String>,
424 #[serde(default, alias = "toolCallId")]
425 tool_call_id: Option<String>,
426 #[serde(default)]
427 tool_calls: Option<Vec<ToolCall>>,
428}
429
430pub fn parse_messages_value(value: &Value) -> Result<Vec<Message>, String> {
432 let wire_messages: Vec<MessageInputWire> = serde_json::from_value(value.clone())
433 .map_err(|e| format!("messages must be a list of message objects: {e}"))?;
434 if wire_messages.is_empty() {
435 return Err("messages cannot be empty".to_string());
436 }
437
438 wire_messages
439 .into_iter()
440 .enumerate()
441 .map(|(idx, wire)| {
442 if wire.content.is_empty_content() {
443 return Err(format!("message[{idx}].content cannot be empty"));
444 }
445
446 let content = wire.content;
447
448 let mut msg = match wire.role {
449 Role::System => Message {
450 role: Role::System,
451 content,
452 name: None,
453 tool_call_id: None,
454 tool_calls: None,
455 },
456 Role::User => Message {
457 role: Role::User,
458 content,
459 name: None,
460 tool_call_id: None,
461 tool_calls: None,
462 },
463 Role::Assistant => {
464 let mut m = Message {
465 role: Role::Assistant,
466 content,
467 name: None,
468 tool_call_id: None,
469 tool_calls: None,
470 };
471 if let Some(calls) = wire.tool_calls {
472 if !calls.is_empty() {
473 m = m.with_tool_calls(calls);
474 }
475 }
476 m
477 }
478 Role::Tool => {
479 let call_id = wire.tool_call_id.ok_or_else(|| {
480 format!("message[{idx}].tool_call_id is required for tool role")
481 })?;
482 Message {
483 role: Role::Tool,
484 content,
485 name: None,
486 tool_call_id: Some(call_id),
487 tool_calls: None,
488 }
489 }
490 };
491
492 if let Some(name) = wire.name {
493 if !name.is_empty() {
494 msg = msg.with_name(name);
495 }
496 }
497
498 Ok(msg)
499 })
500 .collect()
501}
502
503pub fn parse_messages_json(messages_json: &str) -> Result<Vec<Message>, String> {
505 let value: Value =
506 serde_json::from_str(messages_json).map_err(|e| format!("invalid messages json: {e}"))?;
507 parse_messages_value(&value)
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 #[test]
515 fn test_message_user() {
516 let msg = Message::user("test");
517 assert_eq!(msg.role, Role::User);
518 assert_eq!(msg.content, MessageContent::Text("test".to_string()));
519 assert_eq!(msg.content_text(), "test");
520 assert_eq!(msg.name, None);
521 assert_eq!(msg.tool_call_id, None);
522 assert_eq!(msg.tool_calls, None);
523 }
524
525 #[test]
526 fn test_message_assistant() {
527 let msg = Message::assistant("response");
528 assert_eq!(msg.role, Role::Assistant);
529 assert_eq!(msg.content_text(), "response");
530 assert_eq!(msg.tool_calls, None);
531 }
532
533 #[test]
534 fn test_message_system() {
535 let msg = Message::system("instruction");
536 assert_eq!(msg.role, Role::System);
537 assert_eq!(msg.content_text(), "instruction");
538 assert_eq!(msg.tool_calls, None);
539 }
540
541 #[test]
542 fn test_message_tool() {
543 let msg = Message::tool("result", "call_123");
544 assert_eq!(msg.role, Role::Tool);
545 assert_eq!(msg.content_text(), "result");
546 assert_eq!(msg.tool_call_id, Some("call_123".to_string()));
547 assert_eq!(msg.tool_calls, None);
548 }
549
550 #[test]
551 fn test_message_with_name() {
552 let msg = Message::user("test").with_name("Alice");
553 assert_eq!(msg.name, Some("Alice".to_string()));
554 }
555
556 #[test]
557 fn test_role_serialization() {
558 let json = serde_json::to_string(&Role::User).unwrap();
559 assert_eq!(json, "\"user\"");
560
561 let json = serde_json::to_string(&Role::Assistant).unwrap();
562 assert_eq!(json, "\"assistant\"");
563
564 let json = serde_json::to_string(&Role::System).unwrap();
565 assert_eq!(json, "\"system\"");
566
567 let json = serde_json::to_string(&Role::Tool).unwrap();
568 assert_eq!(json, "\"tool\"");
569 }
570
571 #[test]
572 fn test_message_serialization() {
573 let msg = Message::user("Hello");
574 let json = serde_json::to_string(&msg).unwrap();
575 let parsed: Message = serde_json::from_str(&json).unwrap();
576 assert_eq!(msg, parsed);
577 }
578
579 #[test]
580 fn test_message_optional_fields_not_serialized() {
581 let msg = Message::user("test");
582 let json = serde_json::to_value(&msg).unwrap();
583 assert!(json.get("name").is_none());
584 assert!(json.get("tool_call_id").is_none());
585 assert!(json.get("tool_calls").is_none());
586 }
587
588 #[test]
589 fn test_message_with_name_serialized() {
590 let msg = Message::user("test").with_name("Alice");
591 let json = serde_json::to_value(&msg).unwrap();
592 assert_eq!(json.get("name").and_then(|v| v.as_str()), Some("Alice"));
593 }
594
595 #[test]
596 fn test_message_user_text() {
597 let msg = Message::user("hello");
598 assert_eq!(msg.role, Role::User);
599 assert_eq!(msg.content_text(), "hello");
600 }
601
602 #[test]
603 fn test_message_multimodal() {
604 let msg = Message::user_parts(vec![
605 ContentPart::text("what is this?"),
606 ContentPart::image_url("https://example.com/img.jpg"),
607 ]);
608 assert_eq!(msg.content_text(), "what is this?");
609 }
610
611 #[test]
612 fn test_content_part_image_inline_serde() {
613 let part = ContentPart::image(mime::IMAGE_PNG, "abc");
614 let v = serde_json::to_value(&part).unwrap();
615 assert_eq!(v["type"], "image_url");
616 assert!(v["image_url"]["url"]
617 .as_str()
618 .unwrap()
619 .starts_with("data:image/png;base64,"));
620 let parsed: ContentPart = serde_json::from_value(v).unwrap();
621 assert_eq!(parsed, part);
622 }
623
624 #[test]
625 fn test_content_part_audio_video_serde() {
626 let audio = ContentPart::audio(mime::AUDIO_WAV, "dGVzdA==");
627 let json = serde_json::to_string(&audio).unwrap();
628 let parsed: ContentPart = serde_json::from_str(&json).unwrap();
629 assert_eq!(parsed, audio);
630
631 let video = ContentPart::video(mime::VIDEO_MP4, "dGVzdA==");
632 let json = serde_json::to_string(&video).unwrap();
633 let parsed: ContentPart = serde_json::from_str(&json).unwrap();
634 assert_eq!(parsed, video);
635 }
636
637 #[test]
638 fn test_message_parts_image_only_not_empty() {
639 let msg = Message::user_parts(vec![ContentPart::image(mime::IMAGE_JPEG, "e30=")]);
640 assert!(!msg.content.is_empty_content());
641 }
642
643 #[test]
644 fn test_message_content_serialization() {
645 let msg = Message::user("hello");
646 let json = serde_json::to_value(&msg).unwrap();
647 assert_eq!(json["content"], "hello");
648 let msg2 = Message::user_parts(vec![ContentPart::text("hi")]);
649 let json2 = serde_json::to_value(&msg2).unwrap();
650 assert!(json2["content"].is_array());
651 }
652
653 #[test]
654 fn test_message_content_from_string() {
655 let content: MessageContent = "hello".into();
656 assert_eq!(content, MessageContent::Text("hello".to_string()));
657
658 let content: MessageContent = String::from("world").into();
659 assert_eq!(content, MessageContent::Text("world".to_string()));
660 }
661}