Skip to main content

tiny_loop/types/
message.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4/// System message body
5#[derive(Serialize, Deserialize, Clone, Debug)]
6pub struct SystemMessage {
7    /// Message content
8    pub content: String,
9}
10
11/// User message body
12#[derive(Serialize, Deserialize, Clone, Debug)]
13pub struct UserMessage {
14    /// Message content
15    pub content: String,
16}
17
18/// Assistant message body
19#[derive(Serialize, Deserialize, Clone, Debug)]
20pub struct AssistantMessage {
21    /// Message content
22    pub content: String,
23    /// Tool calls requested by the assistant
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub tool_calls: Option<Vec<ToolCall>>,
26}
27
28/// Tool message body
29#[derive(Serialize, Deserialize, Clone, Debug)]
30pub struct ToolMessage {
31    /// Tool execution result content
32    pub content: String,
33    /// ID of the tool call this responds to
34    pub tool_call_id: String,
35}
36
37/// Custom message body
38#[derive(Serialize, Deserialize, Clone, Debug)]
39pub struct CustomMessage {
40    /// Custom role name
41    pub role: String,
42    /// Custom message body
43    #[serde(flatten)]
44    pub body: Value,
45}
46
47/// LLM message with role-specific fields
48#[derive(Serialize, Deserialize, Clone, Debug)]
49#[serde(tag = "role", rename_all = "lowercase")]
50pub enum Message {
51    /// System message with instructions
52    System(SystemMessage),
53    /// User message with input
54    User(UserMessage),
55    /// Assistant message with response and optional tool calls
56    Assistant(AssistantMessage),
57    /// Tool execution result
58    Tool(ToolMessage),
59    /// Custom role with arbitrary fields
60    #[serde(untagged)]
61    Custom(CustomMessage),
62}
63
64impl From<SystemMessage> for Message {
65    fn from(msg: SystemMessage) -> Self {
66        Message::System(msg)
67    }
68}
69
70impl From<UserMessage> for Message {
71    fn from(msg: UserMessage) -> Self {
72        Message::User(msg)
73    }
74}
75
76impl From<AssistantMessage> for Message {
77    fn from(msg: AssistantMessage) -> Self {
78        Message::Assistant(msg)
79    }
80}
81
82impl From<ToolMessage> for Message {
83    fn from(msg: ToolMessage) -> Self {
84        Message::Tool(msg)
85    }
86}
87
88impl From<CustomMessage> for Message {
89    fn from(msg: CustomMessage) -> Self {
90        Message::Custom(msg)
91    }
92}
93
94/// Tool call from LLM
95#[derive(Serialize, Deserialize, Clone, Debug)]
96pub struct ToolCall {
97    /// Unique identifier for this tool call
98    pub id: String,
99    /// Type of the call (typically "function")
100    #[serde(rename = "type")]
101    pub call_type: String,
102    /// Function call details
103    pub function: FunctionCall,
104}
105
106/// Function call details
107#[derive(Serialize, Deserialize, Clone, Debug)]
108pub struct FunctionCall {
109    /// Function name to call
110    pub name: String,
111    /// JSON-encoded function arguments
112    pub arguments: String,
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn test_system_roundtrip() {
121        let msg = Message::System(SystemMessage {
122            content: "test".into(),
123        });
124        let json = serde_json::to_string(&msg).unwrap();
125        let parsed: Message = serde_json::from_str(&json).unwrap();
126        assert!(matches!(parsed, Message::System(SystemMessage { content }) if content == "test"));
127    }
128
129    #[test]
130    fn test_user_roundtrip() {
131        let msg = Message::User(UserMessage {
132            content: "test".into(),
133        });
134        let json = serde_json::to_string(&msg).unwrap();
135        let parsed: Message = serde_json::from_str(&json).unwrap();
136        assert!(matches!(parsed, Message::User(UserMessage { content }) if content == "test"));
137    }
138
139    #[test]
140    fn test_assistant_no_tools_roundtrip() {
141        let msg = Message::Assistant(AssistantMessage {
142            content: "test".into(),
143            tool_calls: None,
144        });
145        let json = serde_json::to_string(&msg).unwrap();
146        assert!(!json.contains("tool_calls"));
147        let parsed: Message = serde_json::from_str(&json).unwrap();
148        assert!(
149            matches!(parsed, Message::Assistant(AssistantMessage { content, tool_calls: None }) if content == "test")
150        );
151    }
152
153    #[test]
154    fn test_assistant_with_tools_roundtrip() {
155        let msg = Message::Assistant(AssistantMessage {
156            content: "test".into(),
157            tool_calls: Some(vec![ToolCall {
158                id: "call_1".into(),
159                call_type: "function".into(),
160                function: FunctionCall {
161                    name: "fn".into(),
162                    arguments: "{}".into(),
163                },
164            }]),
165        });
166        let json = serde_json::to_string(&msg).unwrap();
167        let parsed: Message = serde_json::from_str(&json).unwrap();
168        assert!(
169            matches!(parsed, Message::Assistant(AssistantMessage { tool_calls: Some(calls), .. }) if calls.len() == 1)
170        );
171    }
172
173    #[test]
174    fn test_tool_roundtrip() {
175        let msg = Message::Tool(ToolMessage {
176            content: "result".into(),
177            tool_call_id: "call_123".into(),
178        });
179        let json = serde_json::to_string(&msg).unwrap();
180        let parsed: Message = serde_json::from_str(&json).unwrap();
181        assert!(
182            matches!(parsed, Message::Tool(ToolMessage { content, tool_call_id })
183            if content == "result" && tool_call_id == "call_123")
184        );
185    }
186
187    #[test]
188    fn test_custom_roundtrip() {
189        let msg = Message::Custom(CustomMessage {
190            role: "custom".into(),
191            body: serde_json::json!({"content": "test", "extra": "field"}),
192        });
193        let json = serde_json::to_string(&msg).unwrap();
194        let parsed: Message = serde_json::from_str(&json).unwrap();
195        assert!(matches!(parsed, Message::Custom(CustomMessage { role, .. }) if role == "custom"));
196    }
197
198    #[test]
199    fn test_tool_call_roundtrip() {
200        let tc = ToolCall {
201            id: "call_1".into(),
202            call_type: "function".into(),
203            function: FunctionCall {
204                name: "test".into(),
205                arguments: r#"{"key":"value"}"#.into(),
206            },
207        };
208        let json = serde_json::to_string(&tc).unwrap();
209        let parsed: ToolCall = serde_json::from_str(&json).unwrap();
210        assert_eq!(parsed.id, "call_1");
211        assert_eq!(parsed.function.name, "test");
212    }
213}