Skip to main content

simple_agent_type/
message.rs

1//! Message types for LLM interactions.
2//!
3//! Provides role-based messages compatible with OpenAI's message format.
4
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::str::FromStr;
8use thiserror::Error;
9
10use crate::tool::ToolCall;
11
12/// Role of a message in a conversation.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum Role {
16    /// User message
17    User,
18    /// Assistant (LLM) message
19    Assistant,
20    /// System instruction message
21    System,
22    /// Tool/function call result
23    #[serde(rename = "tool")]
24    Tool,
25}
26
27#[derive(Debug, Clone, PartialEq, Eq, Error)]
28#[error("invalid message role '{role}' (expected: system|user|assistant|tool)")]
29/// Error returned when parsing an unknown message role string.
30pub struct ParseRoleError {
31    /// Original role string that failed to parse.
32    pub role: String,
33}
34
35impl Role {
36    /// Returns this role as its canonical lowercase string value.
37    pub fn as_str(self) -> &'static str {
38        match self {
39            Self::System => "system",
40            Self::User => "user",
41            Self::Assistant => "assistant",
42            Self::Tool => "tool",
43        }
44    }
45}
46
47impl FromStr for Role {
48    type Err = ParseRoleError;
49
50    fn from_str(s: &str) -> Result<Self, Self::Err> {
51        match s {
52            "system" => Ok(Self::System),
53            "user" => Ok(Self::User),
54            "assistant" => Ok(Self::Assistant),
55            "tool" => Ok(Self::Tool),
56            _ => Err(ParseRoleError {
57                role: s.to_string(),
58            }),
59        }
60    }
61}
62
63/// A message in a conversation.
64#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
65pub struct Message {
66    /// Role of the message sender
67    pub role: Role,
68    /// Content of the message
69    pub content: String,
70    /// Optional name (for multi-user conversations or tool calls)
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub name: Option<String>,
73    /// Tool call ID (for tool role messages)
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub tool_call_id: Option<String>,
76    /// Tool calls emitted by the assistant.
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub tool_calls: Option<Vec<ToolCall>>,
79}
80
81impl Message {
82    /// Create a user message.
83    ///
84    /// # Example
85    /// ```
86    /// use simple_agent_type::message::{Message, Role};
87    ///
88    /// let msg = Message::user("Hello!");
89    /// assert_eq!(msg.role, Role::User);
90    /// assert_eq!(msg.content, "Hello!");
91    /// ```
92    pub fn user(content: impl Into<String>) -> Self {
93        Self {
94            role: Role::User,
95            content: content.into(),
96            name: None,
97            tool_call_id: None,
98            tool_calls: None,
99        }
100    }
101
102    /// Create an assistant message.
103    ///
104    /// # Example
105    /// ```
106    /// use simple_agent_type::message::{Message, Role};
107    ///
108    /// let msg = Message::assistant("Hi there!");
109    /// assert_eq!(msg.role, Role::Assistant);
110    /// ```
111    pub fn assistant(content: impl Into<String>) -> Self {
112        Self {
113            role: Role::Assistant,
114            content: content.into(),
115            name: None,
116            tool_call_id: None,
117            tool_calls: None,
118        }
119    }
120
121    /// Create a system message.
122    ///
123    /// # Example
124    /// ```
125    /// use simple_agent_type::message::{Message, Role};
126    ///
127    /// let msg = Message::system("You are a helpful assistant.");
128    /// assert_eq!(msg.role, Role::System);
129    /// ```
130    pub fn system(content: impl Into<String>) -> Self {
131        Self {
132            role: Role::System,
133            content: content.into(),
134            name: None,
135            tool_call_id: None,
136            tool_calls: None,
137        }
138    }
139
140    /// Create a tool message.
141    ///
142    /// # Example
143    /// ```
144    /// use simple_agent_type::message::{Message, Role};
145    ///
146    /// let msg = Message::tool("result", "call_123");
147    /// assert_eq!(msg.role, Role::Tool);
148    /// assert_eq!(msg.tool_call_id, Some("call_123".to_string()));
149    /// ```
150    pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
151        Self {
152            role: Role::Tool,
153            content: content.into(),
154            name: None,
155            tool_call_id: Some(tool_call_id.into()),
156            tool_calls: None,
157        }
158    }
159
160    /// Set the name field (builder pattern).
161    ///
162    /// # Example
163    /// ```
164    /// use simple_agent_type::message::Message;
165    ///
166    /// let msg = Message::user("Hello").with_name("Alice");
167    /// assert_eq!(msg.name, Some("Alice".to_string()));
168    /// ```
169    pub fn with_name(mut self, name: impl Into<String>) -> Self {
170        self.name = Some(name.into());
171        self
172    }
173
174    /// Set tool calls for assistant messages.
175    pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCall>) -> Self {
176        self.tool_calls = Some(tool_calls);
177        self
178    }
179}
180
181#[derive(Debug, Clone, Deserialize)]
182struct MessageInputWire {
183    role: Role,
184    content: String,
185    #[serde(default)]
186    name: Option<String>,
187    #[serde(default, alias = "toolCallId")]
188    tool_call_id: Option<String>,
189    #[serde(default)]
190    tool_calls: Option<Vec<ToolCall>>,
191}
192
193/// Parses a JSON value containing an array of message objects into typed messages.
194pub fn parse_messages_value(value: &Value) -> Result<Vec<Message>, String> {
195    let wire_messages: Vec<MessageInputWire> = serde_json::from_value(value.clone())
196        .map_err(|e| format!("messages must be a list of message objects: {e}"))?;
197    if wire_messages.is_empty() {
198        return Err("messages cannot be empty".to_string());
199    }
200
201    wire_messages
202        .into_iter()
203        .enumerate()
204        .map(|(idx, wire)| {
205            if wire.content.is_empty() {
206                return Err(format!("message[{idx}].content cannot be empty"));
207            }
208
209            let mut msg = match wire.role {
210                Role::System => Message::system(wire.content),
211                Role::User => Message::user(wire.content),
212                Role::Assistant => {
213                    let mut m = Message::assistant(wire.content);
214                    if let Some(calls) = wire.tool_calls {
215                        if !calls.is_empty() {
216                            m = m.with_tool_calls(calls);
217                        }
218                    }
219                    m
220                }
221                Role::Tool => {
222                    let call_id = wire.tool_call_id.ok_or_else(|| {
223                        format!("message[{idx}].tool_call_id is required for tool role")
224                    })?;
225                    Message::tool(wire.content, call_id)
226                }
227            };
228
229            if let Some(name) = wire.name {
230                if !name.is_empty() {
231                    msg = msg.with_name(name);
232                }
233            }
234
235            Ok(msg)
236        })
237        .collect()
238}
239
240/// Parses a JSON string containing an array of message objects.
241pub fn parse_messages_json(messages_json: &str) -> Result<Vec<Message>, String> {
242    let value: Value =
243        serde_json::from_str(messages_json).map_err(|e| format!("invalid messages json: {e}"))?;
244    parse_messages_value(&value)
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn test_message_user() {
253        let msg = Message::user("test");
254        assert_eq!(msg.role, Role::User);
255        assert_eq!(msg.content, "test");
256        assert_eq!(msg.name, None);
257        assert_eq!(msg.tool_call_id, None);
258        assert_eq!(msg.tool_calls, None);
259    }
260
261    #[test]
262    fn test_message_assistant() {
263        let msg = Message::assistant("response");
264        assert_eq!(msg.role, Role::Assistant);
265        assert_eq!(msg.content, "response");
266        assert_eq!(msg.tool_calls, None);
267    }
268
269    #[test]
270    fn test_message_system() {
271        let msg = Message::system("instruction");
272        assert_eq!(msg.role, Role::System);
273        assert_eq!(msg.content, "instruction");
274        assert_eq!(msg.tool_calls, None);
275    }
276
277    #[test]
278    fn test_message_tool() {
279        let msg = Message::tool("result", "call_123");
280        assert_eq!(msg.role, Role::Tool);
281        assert_eq!(msg.content, "result");
282        assert_eq!(msg.tool_call_id, Some("call_123".to_string()));
283        assert_eq!(msg.tool_calls, None);
284    }
285
286    #[test]
287    fn test_message_with_name() {
288        let msg = Message::user("test").with_name("Alice");
289        assert_eq!(msg.name, Some("Alice".to_string()));
290    }
291
292    #[test]
293    fn test_role_serialization() {
294        let json = serde_json::to_string(&Role::User).unwrap();
295        assert_eq!(json, "\"user\"");
296
297        let json = serde_json::to_string(&Role::Assistant).unwrap();
298        assert_eq!(json, "\"assistant\"");
299
300        let json = serde_json::to_string(&Role::System).unwrap();
301        assert_eq!(json, "\"system\"");
302
303        let json = serde_json::to_string(&Role::Tool).unwrap();
304        assert_eq!(json, "\"tool\"");
305    }
306
307    #[test]
308    fn test_message_serialization() {
309        let msg = Message::user("Hello");
310        let json = serde_json::to_string(&msg).unwrap();
311        let parsed: Message = serde_json::from_str(&json).unwrap();
312        assert_eq!(msg, parsed);
313    }
314
315    #[test]
316    fn test_message_optional_fields_not_serialized() {
317        let msg = Message::user("test");
318        let json = serde_json::to_value(&msg).unwrap();
319        assert!(json.get("name").is_none());
320        assert!(json.get("tool_call_id").is_none());
321        assert!(json.get("tool_calls").is_none());
322    }
323
324    #[test]
325    fn test_message_with_name_serialized() {
326        let msg = Message::user("test").with_name("Alice");
327        let json = serde_json::to_value(&msg).unwrap();
328        assert_eq!(json.get("name").and_then(|v| v.as_str()), Some("Alice"));
329    }
330}