tiny_loop/types/
message.rs1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4#[derive(Serialize, Deserialize, Clone, Debug)]
6pub struct SystemMessage {
7 pub content: String,
9}
10
11#[derive(Serialize, Deserialize, Clone, Debug)]
13pub struct UserMessage {
14 pub content: String,
16}
17
18#[derive(Serialize, Deserialize, Clone, Debug)]
20pub struct AssistantMessage {
21 pub content: String,
23 #[serde(skip_serializing_if = "Option::is_none")]
25 pub tool_calls: Option<Vec<ToolCall>>,
26}
27
28#[derive(Serialize, Deserialize, Clone, Debug)]
30pub struct ToolMessage {
31 pub content: String,
33 pub tool_call_id: String,
35}
36
37#[derive(Serialize, Deserialize, Clone, Debug)]
39pub struct CustomMessage {
40 pub role: String,
42 #[serde(flatten)]
44 pub body: Value,
45}
46
47#[derive(Serialize, Deserialize, Clone, Debug)]
49#[serde(tag = "role", rename_all = "lowercase")]
50pub enum Message {
51 System(SystemMessage),
53 User(UserMessage),
55 Assistant(AssistantMessage),
57 Tool(ToolMessage),
59 #[serde(untagged)]
61 Custom(CustomMessage),
62}
63
64impl From<SystemMessage> for Message {
65 fn from(msg: SystemMessage) -> Self {
66 Message::System(msg)
67 }
68}
69
70impl From<UserMessage> for Message {
71 fn from(msg: UserMessage) -> Self {
72 Message::User(msg)
73 }
74}
75
76impl From<AssistantMessage> for Message {
77 fn from(msg: AssistantMessage) -> Self {
78 Message::Assistant(msg)
79 }
80}
81
82impl From<ToolMessage> for Message {
83 fn from(msg: ToolMessage) -> Self {
84 Message::Tool(msg)
85 }
86}
87
88impl From<CustomMessage> for Message {
89 fn from(msg: CustomMessage) -> Self {
90 Message::Custom(msg)
91 }
92}
93
94#[derive(Serialize, Deserialize, Clone, Debug)]
96pub struct ToolCall {
97 pub id: String,
99 #[serde(rename = "type")]
101 pub call_type: String,
102 pub function: FunctionCall,
104}
105
106#[derive(Serialize, Deserialize, Clone, Debug)]
108pub struct FunctionCall {
109 pub name: String,
111 pub arguments: String,
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn test_system_roundtrip() {
121 let msg = Message::System(SystemMessage {
122 content: "test".into(),
123 });
124 let json = serde_json::to_string(&msg).unwrap();
125 let parsed: Message = serde_json::from_str(&json).unwrap();
126 assert!(matches!(parsed, Message::System(SystemMessage { content }) if content == "test"));
127 }
128
129 #[test]
130 fn test_user_roundtrip() {
131 let msg = Message::User(UserMessage {
132 content: "test".into(),
133 });
134 let json = serde_json::to_string(&msg).unwrap();
135 let parsed: Message = serde_json::from_str(&json).unwrap();
136 assert!(matches!(parsed, Message::User(UserMessage { content }) if content == "test"));
137 }
138
139 #[test]
140 fn test_assistant_no_tools_roundtrip() {
141 let msg = Message::Assistant(AssistantMessage {
142 content: "test".into(),
143 tool_calls: None,
144 });
145 let json = serde_json::to_string(&msg).unwrap();
146 assert!(!json.contains("tool_calls"));
147 let parsed: Message = serde_json::from_str(&json).unwrap();
148 assert!(
149 matches!(parsed, Message::Assistant(AssistantMessage { content, tool_calls: None }) if content == "test")
150 );
151 }
152
153 #[test]
154 fn test_assistant_with_tools_roundtrip() {
155 let msg = Message::Assistant(AssistantMessage {
156 content: "test".into(),
157 tool_calls: Some(vec![ToolCall {
158 id: "call_1".into(),
159 call_type: "function".into(),
160 function: FunctionCall {
161 name: "fn".into(),
162 arguments: "{}".into(),
163 },
164 }]),
165 });
166 let json = serde_json::to_string(&msg).unwrap();
167 let parsed: Message = serde_json::from_str(&json).unwrap();
168 assert!(
169 matches!(parsed, Message::Assistant(AssistantMessage { tool_calls: Some(calls), .. }) if calls.len() == 1)
170 );
171 }
172
173 #[test]
174 fn test_tool_roundtrip() {
175 let msg = Message::Tool(ToolMessage {
176 content: "result".into(),
177 tool_call_id: "call_123".into(),
178 });
179 let json = serde_json::to_string(&msg).unwrap();
180 let parsed: Message = serde_json::from_str(&json).unwrap();
181 assert!(
182 matches!(parsed, Message::Tool(ToolMessage { content, tool_call_id })
183 if content == "result" && tool_call_id == "call_123")
184 );
185 }
186
187 #[test]
188 fn test_custom_roundtrip() {
189 let msg = Message::Custom(CustomMessage {
190 role: "custom".into(),
191 body: serde_json::json!({"content": "test", "extra": "field"}),
192 });
193 let json = serde_json::to_string(&msg).unwrap();
194 let parsed: Message = serde_json::from_str(&json).unwrap();
195 assert!(matches!(parsed, Message::Custom(CustomMessage { role, .. }) if role == "custom"));
196 }
197
198 #[test]
199 fn test_tool_call_roundtrip() {
200 let tc = ToolCall {
201 id: "call_1".into(),
202 call_type: "function".into(),
203 function: FunctionCall {
204 name: "test".into(),
205 arguments: r#"{"key":"value"}"#.into(),
206 },
207 };
208 let json = serde_json::to_string(&tc).unwrap();
209 let parsed: ToolCall = serde_json::from_str(&json).unwrap();
210 assert_eq!(parsed.id, "call_1");
211 assert_eq!(parsed.function.name, "test");
212 }
213}