1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::time::{Duration, SystemTime};
4
5#[derive(Serialize, Deserialize, Clone, Debug)]
7pub struct SystemMessage {
8 pub content: String,
10}
11
12#[derive(Serialize, Deserialize, Clone, Debug)]
14pub struct UserMessage {
15 pub content: String,
17}
18
19#[derive(Serialize, Deserialize, Clone, Debug)]
21pub struct AssistantMessage {
22 pub content: String,
24 #[serde(skip_serializing_if = "Option::is_none")]
26 pub tool_calls: Option<Vec<ToolCall>>,
27}
28
29#[derive(Serialize, Deserialize, Clone, Debug)]
31pub struct ToolMessage {
32 pub content: String,
34 pub tool_call_id: String,
36}
37
38#[derive(Serialize, Deserialize, Clone, Debug)]
40pub struct CustomMessage {
41 pub role: String,
43 #[serde(flatten)]
45 pub body: Value,
46}
47
48#[derive(Serialize, Deserialize, Clone, Debug)]
50#[serde(tag = "role", rename_all = "lowercase")]
51pub enum Message {
52 System(SystemMessage),
54 User(UserMessage),
56 Assistant(AssistantMessage),
58 Tool(ToolMessage),
60 #[serde(untagged)]
62 Custom(CustomMessage),
63}
64
65impl From<SystemMessage> for Message {
66 fn from(msg: SystemMessage) -> Self {
67 Message::System(msg)
68 }
69}
70
71impl From<UserMessage> for Message {
72 fn from(msg: UserMessage) -> Self {
73 Message::User(msg)
74 }
75}
76
77impl From<AssistantMessage> for Message {
78 fn from(msg: AssistantMessage) -> Self {
79 Message::Assistant(msg)
80 }
81}
82
83impl From<ToolMessage> for Message {
84 fn from(msg: ToolMessage) -> Self {
85 Message::Tool(msg)
86 }
87}
88
89impl From<CustomMessage> for Message {
90 fn from(msg: CustomMessage) -> Self {
91 Message::Custom(msg)
92 }
93}
94
95#[derive(Serialize, Deserialize, Clone, Debug)]
97pub struct ToolCall {
98 pub id: String,
100 #[serde(rename = "type")]
102 pub call_type: String,
103 pub function: FunctionCall,
105}
106
107#[derive(Serialize, Deserialize, Clone, Debug)]
109pub struct FunctionCall {
110 pub name: String,
112 pub arguments: String,
114}
115
116#[derive(Serialize, Deserialize, Clone, Debug)]
118pub struct TimedMessage {
119 pub message: Message,
120 pub timestamp: SystemTime,
122 pub elapsed: Duration,
124}
125
126#[derive(Serialize, Deserialize, Clone, Debug)]
128pub struct ToolResult {
129 pub tool_message: ToolMessage,
130 pub timestamp: SystemTime,
132 pub elapsed: Duration,
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[test]
141 fn test_system_roundtrip() {
142 let msg = Message::System(SystemMessage {
143 content: "test".into(),
144 });
145 let json = serde_json::to_string(&msg).unwrap();
146 let parsed: Message = serde_json::from_str(&json).unwrap();
147 assert!(matches!(parsed, Message::System(SystemMessage { content }) if content == "test"));
148 }
149
150 #[test]
151 fn test_user_roundtrip() {
152 let msg = Message::User(UserMessage {
153 content: "test".into(),
154 });
155 let json = serde_json::to_string(&msg).unwrap();
156 let parsed: Message = serde_json::from_str(&json).unwrap();
157 assert!(matches!(parsed, Message::User(UserMessage { content }) if content == "test"));
158 }
159
160 #[test]
161 fn test_assistant_no_tools_roundtrip() {
162 let msg = Message::Assistant(AssistantMessage {
163 content: "test".into(),
164 tool_calls: None,
165 });
166 let json = serde_json::to_string(&msg).unwrap();
167 assert!(!json.contains("tool_calls"));
168 let parsed: Message = serde_json::from_str(&json).unwrap();
169 assert!(
170 matches!(parsed, Message::Assistant(AssistantMessage { content, tool_calls: None }) if content == "test")
171 );
172 }
173
174 #[test]
175 fn test_assistant_with_tools_roundtrip() {
176 let msg = Message::Assistant(AssistantMessage {
177 content: "test".into(),
178 tool_calls: Some(vec![ToolCall {
179 id: "call_1".into(),
180 call_type: "function".into(),
181 function: FunctionCall {
182 name: "fn".into(),
183 arguments: "{}".into(),
184 },
185 }]),
186 });
187 let json = serde_json::to_string(&msg).unwrap();
188 let parsed: Message = serde_json::from_str(&json).unwrap();
189 assert!(
190 matches!(parsed, Message::Assistant(AssistantMessage { tool_calls: Some(calls), .. }) if calls.len() == 1)
191 );
192 }
193
194 #[test]
195 fn test_tool_roundtrip() {
196 let msg = Message::Tool(ToolMessage {
197 content: "result".into(),
198 tool_call_id: "call_123".into(),
199 });
200 let json = serde_json::to_string(&msg).unwrap();
201 let parsed: Message = serde_json::from_str(&json).unwrap();
202 assert!(
203 matches!(parsed, Message::Tool(ToolMessage { content, tool_call_id })
204 if content == "result" && tool_call_id == "call_123")
205 );
206 }
207
208 #[test]
209 fn test_custom_roundtrip() {
210 let msg = Message::Custom(CustomMessage {
211 role: "custom".into(),
212 body: serde_json::json!({"content": "test", "extra": "field"}),
213 });
214 let json = serde_json::to_string(&msg).unwrap();
215 let parsed: Message = serde_json::from_str(&json).unwrap();
216 assert!(matches!(parsed, Message::Custom(CustomMessage { role, .. }) if role == "custom"));
217 }
218
219 #[test]
220 fn test_tool_call_roundtrip() {
221 let tc = ToolCall {
222 id: "call_1".into(),
223 call_type: "function".into(),
224 function: FunctionCall {
225 name: "test".into(),
226 arguments: r#"{"key":"value"}"#.into(),
227 },
228 };
229 let json = serde_json::to_string(&tc).unwrap();
230 let parsed: ToolCall = serde_json::from_str(&json).unwrap();
231 assert_eq!(parsed.id, "call_1");
232 assert_eq!(parsed.function.name, "test");
233 }
234}