Skip to main content

tiny_loop/types/
message.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::time::{Duration, SystemTime};
4
5/// System message body
6#[derive(Serialize, Deserialize, Clone, Debug)]
7pub struct SystemMessage {
8    /// Message content
9    pub content: String,
10}
11
12/// User message body
13#[derive(Serialize, Deserialize, Clone, Debug)]
14pub struct UserMessage {
15    /// Message content
16    pub content: String,
17}
18
19/// Assistant message body
20#[derive(Serialize, Deserialize, Clone, Debug)]
21pub struct AssistantMessage {
22    /// Message content
23    pub content: String,
24    /// Tool calls requested by the assistant
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub tool_calls: Option<Vec<ToolCall>>,
27}
28
29/// Tool message body
30#[derive(Serialize, Deserialize, Clone, Debug)]
31pub struct ToolMessage {
32    /// Tool execution result content
33    pub content: String,
34    /// ID of the tool call this responds to
35    pub tool_call_id: String,
36}
37
38/// Custom message body
39#[derive(Serialize, Deserialize, Clone, Debug)]
40pub struct CustomMessage {
41    /// Custom role name
42    pub role: String,
43    /// Custom message body
44    #[serde(flatten)]
45    pub body: Value,
46}
47
48/// LLM message with role-specific fields
49#[derive(Serialize, Deserialize, Clone, Debug)]
50#[serde(tag = "role", rename_all = "lowercase")]
51pub enum Message {
52    /// System message with instructions
53    System(SystemMessage),
54    /// User message with input
55    User(UserMessage),
56    /// Assistant message with response and optional tool calls
57    Assistant(AssistantMessage),
58    /// Tool execution result
59    Tool(ToolMessage),
60    /// Custom role with arbitrary fields
61    #[serde(untagged)]
62    Custom(CustomMessage),
63}
64
65impl From<SystemMessage> for Message {
66    fn from(msg: SystemMessage) -> Self {
67        Message::System(msg)
68    }
69}
70
71impl From<UserMessage> for Message {
72    fn from(msg: UserMessage) -> Self {
73        Message::User(msg)
74    }
75}
76
77impl From<AssistantMessage> for Message {
78    fn from(msg: AssistantMessage) -> Self {
79        Message::Assistant(msg)
80    }
81}
82
83impl From<ToolMessage> for Message {
84    fn from(msg: ToolMessage) -> Self {
85        Message::Tool(msg)
86    }
87}
88
89impl From<CustomMessage> for Message {
90    fn from(msg: CustomMessage) -> Self {
91        Message::Custom(msg)
92    }
93}
94
95/// Tool call from LLM
96#[derive(Serialize, Deserialize, Clone, Debug)]
97pub struct ToolCall {
98    /// Unique identifier for this tool call
99    pub id: String,
100    /// Type of the call (typically "function")
101    #[serde(rename = "type")]
102    pub call_type: String,
103    /// Function call details
104    pub function: FunctionCall,
105}
106
107/// Function call details
108#[derive(Serialize, Deserialize, Clone, Debug)]
109pub struct FunctionCall {
110    /// Function name to call
111    pub name: String,
112    /// JSON-encoded function arguments
113    pub arguments: String,
114}
115
116/// Message with timing metadata
117#[derive(Serialize, Deserialize, Clone, Debug)]
118pub struct TimedMessage {
119    pub message: Message,
120    /// When the message was completed
121    pub timestamp: SystemTime,
122    /// Time taken to generate this message
123    pub elapsed: Duration,
124}
125
126/// Tool execution result with timing metadata
127#[derive(Serialize, Deserialize, Clone, Debug)]
128pub struct ToolResult {
129    pub tool_message: ToolMessage,
130    /// When the tool execution completed
131    pub timestamp: SystemTime,
132    /// Time taken to execute the tool
133    pub elapsed: Duration,
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    #[test]
141    fn test_system_roundtrip() {
142        let msg = Message::System(SystemMessage {
143            content: "test".into(),
144        });
145        let json = serde_json::to_string(&msg).unwrap();
146        let parsed: Message = serde_json::from_str(&json).unwrap();
147        assert!(matches!(parsed, Message::System(SystemMessage { content }) if content == "test"));
148    }
149
150    #[test]
151    fn test_user_roundtrip() {
152        let msg = Message::User(UserMessage {
153            content: "test".into(),
154        });
155        let json = serde_json::to_string(&msg).unwrap();
156        let parsed: Message = serde_json::from_str(&json).unwrap();
157        assert!(matches!(parsed, Message::User(UserMessage { content }) if content == "test"));
158    }
159
160    #[test]
161    fn test_assistant_no_tools_roundtrip() {
162        let msg = Message::Assistant(AssistantMessage {
163            content: "test".into(),
164            tool_calls: None,
165        });
166        let json = serde_json::to_string(&msg).unwrap();
167        assert!(!json.contains("tool_calls"));
168        let parsed: Message = serde_json::from_str(&json).unwrap();
169        assert!(
170            matches!(parsed, Message::Assistant(AssistantMessage { content, tool_calls: None }) if content == "test")
171        );
172    }
173
174    #[test]
175    fn test_assistant_with_tools_roundtrip() {
176        let msg = Message::Assistant(AssistantMessage {
177            content: "test".into(),
178            tool_calls: Some(vec![ToolCall {
179                id: "call_1".into(),
180                call_type: "function".into(),
181                function: FunctionCall {
182                    name: "fn".into(),
183                    arguments: "{}".into(),
184                },
185            }]),
186        });
187        let json = serde_json::to_string(&msg).unwrap();
188        let parsed: Message = serde_json::from_str(&json).unwrap();
189        assert!(
190            matches!(parsed, Message::Assistant(AssistantMessage { tool_calls: Some(calls), .. }) if calls.len() == 1)
191        );
192    }
193
194    #[test]
195    fn test_tool_roundtrip() {
196        let msg = Message::Tool(ToolMessage {
197            content: "result".into(),
198            tool_call_id: "call_123".into(),
199        });
200        let json = serde_json::to_string(&msg).unwrap();
201        let parsed: Message = serde_json::from_str(&json).unwrap();
202        assert!(
203            matches!(parsed, Message::Tool(ToolMessage { content, tool_call_id })
204            if content == "result" && tool_call_id == "call_123")
205        );
206    }
207
208    #[test]
209    fn test_custom_roundtrip() {
210        let msg = Message::Custom(CustomMessage {
211            role: "custom".into(),
212            body: serde_json::json!({"content": "test", "extra": "field"}),
213        });
214        let json = serde_json::to_string(&msg).unwrap();
215        let parsed: Message = serde_json::from_str(&json).unwrap();
216        assert!(matches!(parsed, Message::Custom(CustomMessage { role, .. }) if role == "custom"));
217    }
218
219    #[test]
220    fn test_tool_call_roundtrip() {
221        let tc = ToolCall {
222            id: "call_1".into(),
223            call_type: "function".into(),
224            function: FunctionCall {
225                name: "test".into(),
226                arguments: r#"{"key":"value"}"#.into(),
227            },
228        };
229        let json = serde_json::to_string(&tc).unwrap();
230        let parsed: ToolCall = serde_json::from_str(&json).unwrap();
231        assert_eq!(parsed.id, "call_1");
232        assert_eq!(parsed.function.name, "test");
233    }
234}