1use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::str::FromStr;
8use thiserror::Error;
9
10use crate::tool::ToolCall;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum Role {
16 User,
18 Assistant,
20 System,
22 #[serde(rename = "tool")]
24 Tool,
25}
26
27#[derive(Debug, Clone, PartialEq, Eq, Error)]
28#[error("invalid message role '{role}' (expected: system|user|assistant|tool)")]
29pub struct ParseRoleError {
31 pub role: String,
33}
34
35impl Role {
36 pub fn as_str(self) -> &'static str {
38 match self {
39 Self::System => "system",
40 Self::User => "user",
41 Self::Assistant => "assistant",
42 Self::Tool => "tool",
43 }
44 }
45}
46
47impl FromStr for Role {
48 type Err = ParseRoleError;
49
50 fn from_str(s: &str) -> Result<Self, Self::Err> {
51 match s {
52 "system" => Ok(Self::System),
53 "user" => Ok(Self::User),
54 "assistant" => Ok(Self::Assistant),
55 "tool" => Ok(Self::Tool),
56 _ => Err(ParseRoleError {
57 role: s.to_string(),
58 }),
59 }
60 }
61}
62
63#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
65pub struct Message {
66 pub role: Role,
68 pub content: String,
70 #[serde(skip_serializing_if = "Option::is_none")]
72 pub name: Option<String>,
73 #[serde(skip_serializing_if = "Option::is_none")]
75 pub tool_call_id: Option<String>,
76 #[serde(skip_serializing_if = "Option::is_none")]
78 pub tool_calls: Option<Vec<ToolCall>>,
79}
80
81impl Message {
82 pub fn user(content: impl Into<String>) -> Self {
93 Self {
94 role: Role::User,
95 content: content.into(),
96 name: None,
97 tool_call_id: None,
98 tool_calls: None,
99 }
100 }
101
102 pub fn assistant(content: impl Into<String>) -> Self {
112 Self {
113 role: Role::Assistant,
114 content: content.into(),
115 name: None,
116 tool_call_id: None,
117 tool_calls: None,
118 }
119 }
120
121 pub fn system(content: impl Into<String>) -> Self {
131 Self {
132 role: Role::System,
133 content: content.into(),
134 name: None,
135 tool_call_id: None,
136 tool_calls: None,
137 }
138 }
139
140 pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
151 Self {
152 role: Role::Tool,
153 content: content.into(),
154 name: None,
155 tool_call_id: Some(tool_call_id.into()),
156 tool_calls: None,
157 }
158 }
159
160 pub fn with_name(mut self, name: impl Into<String>) -> Self {
170 self.name = Some(name.into());
171 self
172 }
173
174 pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCall>) -> Self {
176 self.tool_calls = Some(tool_calls);
177 self
178 }
179}
180
181#[derive(Debug, Clone, Deserialize)]
182struct MessageInputWire {
183 role: Role,
184 content: String,
185 #[serde(default)]
186 name: Option<String>,
187 #[serde(default, alias = "toolCallId")]
188 tool_call_id: Option<String>,
189 #[serde(default)]
190 tool_calls: Option<Vec<ToolCall>>,
191}
192
193pub fn parse_messages_value(value: &Value) -> Result<Vec<Message>, String> {
195 let wire_messages: Vec<MessageInputWire> = serde_json::from_value(value.clone())
196 .map_err(|e| format!("messages must be a list of message objects: {e}"))?;
197 if wire_messages.is_empty() {
198 return Err("messages cannot be empty".to_string());
199 }
200
201 wire_messages
202 .into_iter()
203 .enumerate()
204 .map(|(idx, wire)| {
205 if wire.content.is_empty() {
206 return Err(format!("message[{idx}].content cannot be empty"));
207 }
208
209 let mut msg = match wire.role {
210 Role::System => Message::system(wire.content),
211 Role::User => Message::user(wire.content),
212 Role::Assistant => {
213 let mut m = Message::assistant(wire.content);
214 if let Some(calls) = wire.tool_calls {
215 if !calls.is_empty() {
216 m = m.with_tool_calls(calls);
217 }
218 }
219 m
220 }
221 Role::Tool => {
222 let call_id = wire.tool_call_id.ok_or_else(|| {
223 format!("message[{idx}].tool_call_id is required for tool role")
224 })?;
225 Message::tool(wire.content, call_id)
226 }
227 };
228
229 if let Some(name) = wire.name {
230 if !name.is_empty() {
231 msg = msg.with_name(name);
232 }
233 }
234
235 Ok(msg)
236 })
237 .collect()
238}
239
240pub fn parse_messages_json(messages_json: &str) -> Result<Vec<Message>, String> {
242 let value: Value =
243 serde_json::from_str(messages_json).map_err(|e| format!("invalid messages json: {e}"))?;
244 parse_messages_value(&value)
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn test_message_user() {
253 let msg = Message::user("test");
254 assert_eq!(msg.role, Role::User);
255 assert_eq!(msg.content, "test");
256 assert_eq!(msg.name, None);
257 assert_eq!(msg.tool_call_id, None);
258 assert_eq!(msg.tool_calls, None);
259 }
260
261 #[test]
262 fn test_message_assistant() {
263 let msg = Message::assistant("response");
264 assert_eq!(msg.role, Role::Assistant);
265 assert_eq!(msg.content, "response");
266 assert_eq!(msg.tool_calls, None);
267 }
268
269 #[test]
270 fn test_message_system() {
271 let msg = Message::system("instruction");
272 assert_eq!(msg.role, Role::System);
273 assert_eq!(msg.content, "instruction");
274 assert_eq!(msg.tool_calls, None);
275 }
276
277 #[test]
278 fn test_message_tool() {
279 let msg = Message::tool("result", "call_123");
280 assert_eq!(msg.role, Role::Tool);
281 assert_eq!(msg.content, "result");
282 assert_eq!(msg.tool_call_id, Some("call_123".to_string()));
283 assert_eq!(msg.tool_calls, None);
284 }
285
286 #[test]
287 fn test_message_with_name() {
288 let msg = Message::user("test").with_name("Alice");
289 assert_eq!(msg.name, Some("Alice".to_string()));
290 }
291
292 #[test]
293 fn test_role_serialization() {
294 let json = serde_json::to_string(&Role::User).unwrap();
295 assert_eq!(json, "\"user\"");
296
297 let json = serde_json::to_string(&Role::Assistant).unwrap();
298 assert_eq!(json, "\"assistant\"");
299
300 let json = serde_json::to_string(&Role::System).unwrap();
301 assert_eq!(json, "\"system\"");
302
303 let json = serde_json::to_string(&Role::Tool).unwrap();
304 assert_eq!(json, "\"tool\"");
305 }
306
307 #[test]
308 fn test_message_serialization() {
309 let msg = Message::user("Hello");
310 let json = serde_json::to_string(&msg).unwrap();
311 let parsed: Message = serde_json::from_str(&json).unwrap();
312 assert_eq!(msg, parsed);
313 }
314
315 #[test]
316 fn test_message_optional_fields_not_serialized() {
317 let msg = Message::user("test");
318 let json = serde_json::to_value(&msg).unwrap();
319 assert!(json.get("name").is_none());
320 assert!(json.get("tool_call_id").is_none());
321 assert!(json.get("tool_calls").is_none());
322 }
323
324 #[test]
325 fn test_message_with_name_serialized() {
326 let msg = Message::user("test").with_name("Alice");
327 let json = serde_json::to_value(&msg).unwrap();
328 assert_eq!(json.get("name").and_then(|v| v.as_str()), Some("Alice"));
329 }
330}