1use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::str::FromStr;
8use thiserror::Error;
9
10use crate::tool::ToolCall;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum Role {
16 User,
18 Assistant,
20 System,
22 #[serde(rename = "tool")]
24 Tool,
25}
26
27#[derive(Debug, Clone, PartialEq, Eq, Error)]
28#[error("invalid message role '{role}' (expected: system|user|assistant|tool)")]
29pub struct ParseRoleError {
31 pub role: String,
33}
34
35impl Role {
36 pub fn as_str(self) -> &'static str {
38 match self {
39 Self::System => "system",
40 Self::User => "user",
41 Self::Assistant => "assistant",
42 Self::Tool => "tool",
43 }
44 }
45}
46
47impl FromStr for Role {
48 type Err = ParseRoleError;
49
50 fn from_str(s: &str) -> Result<Self, Self::Err> {
51 match s {
52 "system" => Ok(Self::System),
53 "user" => Ok(Self::User),
54 "assistant" => Ok(Self::Assistant),
55 "tool" => Ok(Self::Tool),
56 _ => Err(ParseRoleError {
57 role: s.to_string(),
58 }),
59 }
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
65#[serde(untagged)]
66pub enum MessageContent {
67 Text(String),
69 Parts(Vec<ContentPart>),
71}
72
73impl From<String> for MessageContent {
74 fn from(s: String) -> Self {
75 Self::Text(s)
76 }
77}
78
79impl From<&str> for MessageContent {
80 fn from(s: &str) -> Self {
81 Self::Text(s.to_string())
82 }
83}
84
85impl MessageContent {
86 pub fn text_len(&self) -> usize {
89 match self {
90 Self::Text(s) => s.len(),
91 Self::Parts(parts) => parts
92 .iter()
93 .map(|p| match p {
94 ContentPart::Text { text } => text.len(),
95 _ => 0,
96 })
97 .sum(),
98 }
99 }
100
101 pub fn contains_null(&self) -> bool {
103 match self {
104 Self::Text(s) => s.contains('\0'),
105 Self::Parts(parts) => parts.iter().any(|p| match p {
106 ContentPart::Text { text } => text.contains('\0'),
107 _ => false,
108 }),
109 }
110 }
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
115#[serde(tag = "type")]
116pub enum ContentPart {
117 #[serde(rename = "text")]
119 Text {
120 text: String,
122 },
123 #[serde(rename = "image_url")]
125 ImageUrl {
126 image_url: ImageUrlContent,
128 },
129 #[serde(rename = "video_url")]
131 Video {
132 url: String,
134 },
135}
136
137impl ContentPart {
138 pub fn text(text: impl Into<String>) -> Self {
140 Self::Text { text: text.into() }
141 }
142
143 pub fn image_url(url: impl Into<String>) -> Self {
145 Self::ImageUrl {
146 image_url: ImageUrlContent {
147 url: url.into(),
148 detail: None,
149 },
150 }
151 }
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
156pub struct ImageUrlContent {
157 pub url: String,
159 #[serde(skip_serializing_if = "Option::is_none")]
161 pub detail: Option<String>,
162}
163
164#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
166pub struct Message {
167 pub role: Role,
169 pub content: MessageContent,
171 #[serde(skip_serializing_if = "Option::is_none")]
173 pub name: Option<String>,
174 #[serde(skip_serializing_if = "Option::is_none")]
176 pub tool_call_id: Option<String>,
177 #[serde(skip_serializing_if = "Option::is_none")]
179 pub tool_calls: Option<Vec<ToolCall>>,
180}
181
182impl Message {
183 pub fn user(content: impl Into<String>) -> Self {
194 Self {
195 role: Role::User,
196 content: MessageContent::Text(content.into()),
197 name: None,
198 tool_call_id: None,
199 tool_calls: None,
200 }
201 }
202
203 pub fn assistant(content: impl Into<String>) -> Self {
213 Self {
214 role: Role::Assistant,
215 content: MessageContent::Text(content.into()),
216 name: None,
217 tool_call_id: None,
218 tool_calls: None,
219 }
220 }
221
222 pub fn system(content: impl Into<String>) -> Self {
232 Self {
233 role: Role::System,
234 content: MessageContent::Text(content.into()),
235 name: None,
236 tool_call_id: None,
237 tool_calls: None,
238 }
239 }
240
241 pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
252 Self {
253 role: Role::Tool,
254 content: MessageContent::Text(content.into()),
255 name: None,
256 tool_call_id: Some(tool_call_id.into()),
257 tool_calls: None,
258 }
259 }
260
261 pub fn user_parts(parts: Vec<ContentPart>) -> Self {
263 Self {
264 role: Role::User,
265 content: MessageContent::Parts(parts),
266 name: None,
267 tool_call_id: None,
268 tool_calls: None,
269 }
270 }
271
272 pub fn content_text(&self) -> &str {
278 match &self.content {
279 MessageContent::Text(s) => s.as_str(),
280 MessageContent::Parts(parts) => parts
281 .iter()
282 .find_map(|p| match p {
283 ContentPart::Text { text } => Some(text.as_str()),
284 _ => None,
285 })
286 .unwrap_or(""),
287 }
288 }
289
290 pub fn with_name(mut self, name: impl Into<String>) -> Self {
300 self.name = Some(name.into());
301 self
302 }
303
304 pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCall>) -> Self {
306 self.tool_calls = Some(tool_calls);
307 self
308 }
309}
310
311#[derive(Debug, Clone, Deserialize)]
312struct MessageInputWire {
313 role: Role,
314 content: MessageContent,
315 #[serde(default)]
316 name: Option<String>,
317 #[serde(default, alias = "toolCallId")]
318 tool_call_id: Option<String>,
319 #[serde(default)]
320 tool_calls: Option<Vec<ToolCall>>,
321}
322
323pub fn parse_messages_value(value: &Value) -> Result<Vec<Message>, String> {
325 let wire_messages: Vec<MessageInputWire> = serde_json::from_value(value.clone())
326 .map_err(|e| format!("messages must be a list of message objects: {e}"))?;
327 if wire_messages.is_empty() {
328 return Err("messages cannot be empty".to_string());
329 }
330
331 wire_messages
332 .into_iter()
333 .enumerate()
334 .map(|(idx, wire)| {
335 if wire.content.text_len() == 0 {
336 return Err(format!("message[{idx}].content cannot be empty"));
337 }
338
339 let content = wire.content;
340
341 let mut msg = match wire.role {
342 Role::System => Message {
343 role: Role::System,
344 content,
345 name: None,
346 tool_call_id: None,
347 tool_calls: None,
348 },
349 Role::User => Message {
350 role: Role::User,
351 content,
352 name: None,
353 tool_call_id: None,
354 tool_calls: None,
355 },
356 Role::Assistant => {
357 let mut m = Message {
358 role: Role::Assistant,
359 content,
360 name: None,
361 tool_call_id: None,
362 tool_calls: None,
363 };
364 if let Some(calls) = wire.tool_calls {
365 if !calls.is_empty() {
366 m = m.with_tool_calls(calls);
367 }
368 }
369 m
370 }
371 Role::Tool => {
372 let call_id = wire.tool_call_id.ok_or_else(|| {
373 format!("message[{idx}].tool_call_id is required for tool role")
374 })?;
375 Message {
376 role: Role::Tool,
377 content,
378 name: None,
379 tool_call_id: Some(call_id),
380 tool_calls: None,
381 }
382 }
383 };
384
385 if let Some(name) = wire.name {
386 if !name.is_empty() {
387 msg = msg.with_name(name);
388 }
389 }
390
391 Ok(msg)
392 })
393 .collect()
394}
395
396pub fn parse_messages_json(messages_json: &str) -> Result<Vec<Message>, String> {
398 let value: Value =
399 serde_json::from_str(messages_json).map_err(|e| format!("invalid messages json: {e}"))?;
400 parse_messages_value(&value)
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn test_message_user() {
409 let msg = Message::user("test");
410 assert_eq!(msg.role, Role::User);
411 assert_eq!(msg.content, MessageContent::Text("test".to_string()));
412 assert_eq!(msg.content_text(), "test");
413 assert_eq!(msg.name, None);
414 assert_eq!(msg.tool_call_id, None);
415 assert_eq!(msg.tool_calls, None);
416 }
417
418 #[test]
419 fn test_message_assistant() {
420 let msg = Message::assistant("response");
421 assert_eq!(msg.role, Role::Assistant);
422 assert_eq!(msg.content_text(), "response");
423 assert_eq!(msg.tool_calls, None);
424 }
425
426 #[test]
427 fn test_message_system() {
428 let msg = Message::system("instruction");
429 assert_eq!(msg.role, Role::System);
430 assert_eq!(msg.content_text(), "instruction");
431 assert_eq!(msg.tool_calls, None);
432 }
433
434 #[test]
435 fn test_message_tool() {
436 let msg = Message::tool("result", "call_123");
437 assert_eq!(msg.role, Role::Tool);
438 assert_eq!(msg.content_text(), "result");
439 assert_eq!(msg.tool_call_id, Some("call_123".to_string()));
440 assert_eq!(msg.tool_calls, None);
441 }
442
443 #[test]
444 fn test_message_with_name() {
445 let msg = Message::user("test").with_name("Alice");
446 assert_eq!(msg.name, Some("Alice".to_string()));
447 }
448
449 #[test]
450 fn test_role_serialization() {
451 let json = serde_json::to_string(&Role::User).unwrap();
452 assert_eq!(json, "\"user\"");
453
454 let json = serde_json::to_string(&Role::Assistant).unwrap();
455 assert_eq!(json, "\"assistant\"");
456
457 let json = serde_json::to_string(&Role::System).unwrap();
458 assert_eq!(json, "\"system\"");
459
460 let json = serde_json::to_string(&Role::Tool).unwrap();
461 assert_eq!(json, "\"tool\"");
462 }
463
464 #[test]
465 fn test_message_serialization() {
466 let msg = Message::user("Hello");
467 let json = serde_json::to_string(&msg).unwrap();
468 let parsed: Message = serde_json::from_str(&json).unwrap();
469 assert_eq!(msg, parsed);
470 }
471
472 #[test]
473 fn test_message_optional_fields_not_serialized() {
474 let msg = Message::user("test");
475 let json = serde_json::to_value(&msg).unwrap();
476 assert!(json.get("name").is_none());
477 assert!(json.get("tool_call_id").is_none());
478 assert!(json.get("tool_calls").is_none());
479 }
480
481 #[test]
482 fn test_message_with_name_serialized() {
483 let msg = Message::user("test").with_name("Alice");
484 let json = serde_json::to_value(&msg).unwrap();
485 assert_eq!(json.get("name").and_then(|v| v.as_str()), Some("Alice"));
486 }
487
488 #[test]
489 fn test_message_user_text() {
490 let msg = Message::user("hello");
491 assert_eq!(msg.role, Role::User);
492 assert_eq!(msg.content_text(), "hello");
493 }
494
495 #[test]
496 fn test_message_multimodal() {
497 let msg = Message::user_parts(vec![
498 ContentPart::text("what is this?"),
499 ContentPart::image_url("https://example.com/img.jpg"),
500 ]);
501 assert_eq!(msg.content_text(), "what is this?");
502 }
503
504 #[test]
505 fn test_message_content_serialization() {
506 let msg = Message::user("hello");
507 let json = serde_json::to_value(&msg).unwrap();
508 assert_eq!(json["content"], "hello");
509 let msg2 = Message::user_parts(vec![ContentPart::text("hi")]);
510 let json2 = serde_json::to_value(&msg2).unwrap();
511 assert!(json2["content"].is_array());
512 }
513
514 #[test]
515 fn test_message_content_from_string() {
516 let content: MessageContent = "hello".into();
517 assert_eq!(content, MessageContent::Text("hello".to_string()));
518
519 let content: MessageContent = String::from("world").into();
520 assert_eq!(content, MessageContent::Text("world".to_string()));
521 }
522}