1use schemars::{JsonSchema, schema_for};
2use serde::{Deserialize, Serialize};
3use serde_json::{Map, Value};
4
5#[derive(Serialize, Deserialize, Clone, Debug)]
7#[serde(tag = "role", rename_all = "lowercase")]
8pub enum Message {
9 System {
11 content: String,
13 },
14 User {
16 content: String,
18 },
19 Assistant {
21 content: String,
23 #[serde(skip_serializing_if = "Option::is_none")]
25 tool_calls: Option<Vec<ToolCall>>,
26 },
27 Tool {
29 content: String,
31 tool_call_id: String,
33 },
34 #[serde(untagged)]
36 Custom {
37 role: String,
39 #[serde(flatten)]
41 body: Value,
42 },
43}
44
45#[derive(Serialize, Deserialize, Clone, Debug)]
47pub struct ToolCall {
48 pub id: String,
50 #[serde(rename = "type")]
52 pub call_type: String,
53 pub function: FunctionCall,
55}
56
57#[derive(Serialize, Deserialize, Clone, Debug)]
59pub struct FunctionCall {
60 pub name: String,
62 pub arguments: String,
64}
65
66#[derive(Serialize, Clone)]
68pub struct ToolDefinition {
69 #[serde(rename = "type")]
71 pub tool_type: String,
72 pub function: ToolFunction,
74}
75
76#[derive(Serialize, Clone)]
78pub struct ToolFunction {
79 pub name: String,
81 pub description: String,
83 pub parameters: Parameters,
85}
86
87#[derive(Serialize, Clone, Debug)]
89pub struct Parameters(Map<String, Value>);
90
91impl Parameters {
92 pub fn from_object(mut obj: Map<String, Value>) -> Self {
94 obj.remove("$schema");
96 obj.remove("title");
97
98 Self(obj)
99 }
100
101 pub fn from_schema(schema: schemars::Schema) -> Self {
103 let obj = schema.to_value().as_object().unwrap().clone();
104 Self::from_object(obj)
105 }
106
107 pub fn from_type<T: JsonSchema>() -> Self {
109 Self::from_schema(schema_for!(T))
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 #[test]
118 fn test_system_roundtrip() {
119 let msg = Message::System {
120 content: "test".into(),
121 };
122 let json = serde_json::to_string(&msg).unwrap();
123 let parsed: Message = serde_json::from_str(&json).unwrap();
124 assert!(matches!(parsed, Message::System { content } if content == "test"));
125 }
126
127 #[test]
128 fn test_user_roundtrip() {
129 let msg = Message::User {
130 content: "test".into(),
131 };
132 let json = serde_json::to_string(&msg).unwrap();
133 let parsed: Message = serde_json::from_str(&json).unwrap();
134 assert!(matches!(parsed, Message::User { content } if content == "test"));
135 }
136
137 #[test]
138 fn test_assistant_no_tools_roundtrip() {
139 let msg = Message::Assistant {
140 content: "test".into(),
141 tool_calls: None,
142 };
143 let json = serde_json::to_string(&msg).unwrap();
144 assert!(!json.contains("tool_calls"));
145 let parsed: Message = serde_json::from_str(&json).unwrap();
146 assert!(
147 matches!(parsed, Message::Assistant { content, tool_calls: None } if content == "test")
148 );
149 }
150
151 #[test]
152 fn test_assistant_with_tools_roundtrip() {
153 let msg = Message::Assistant {
154 content: "test".into(),
155 tool_calls: Some(vec![ToolCall {
156 id: "call_1".into(),
157 call_type: "function".into(),
158 function: FunctionCall {
159 name: "fn".into(),
160 arguments: "{}".into(),
161 },
162 }]),
163 };
164 let json = serde_json::to_string(&msg).unwrap();
165 let parsed: Message = serde_json::from_str(&json).unwrap();
166 assert!(
167 matches!(parsed, Message::Assistant { tool_calls: Some(calls), .. } if calls.len() == 1)
168 );
169 }
170
171 #[test]
172 fn test_tool_roundtrip() {
173 let msg = Message::Tool {
174 content: "result".into(),
175 tool_call_id: "call_123".into(),
176 };
177 let json = serde_json::to_string(&msg).unwrap();
178 let parsed: Message = serde_json::from_str(&json).unwrap();
179 assert!(matches!(parsed, Message::Tool { content, tool_call_id }
180 if content == "result" && tool_call_id == "call_123"));
181 }
182
183 #[test]
184 fn test_custom_roundtrip() {
185 let msg = Message::Custom {
186 role: "custom".into(),
187 body: serde_json::json!({"content": "test", "extra": "field"}),
188 };
189 let json = serde_json::to_string(&msg).unwrap();
190 let parsed: Message = serde_json::from_str(&json).unwrap();
191 assert!(matches!(parsed, Message::Custom { role, .. } if role == "custom"));
192 }
193
194 #[test]
195 fn test_tool_call_roundtrip() {
196 let tc = ToolCall {
197 id: "call_1".into(),
198 call_type: "function".into(),
199 function: FunctionCall {
200 name: "test".into(),
201 arguments: r#"{"key":"value"}"#.into(),
202 },
203 };
204 let json = serde_json::to_string(&tc).unwrap();
205 let parsed: ToolCall = serde_json::from_str(&json).unwrap();
206 assert_eq!(parsed.id, "call_1");
207 assert_eq!(parsed.function.name, "test");
208 }
209
210 #[test]
211 fn test_tool_definition_serialization() {
212 let td = ToolDefinition {
213 tool_type: "function".into(),
214 function: ToolFunction {
215 name: "test".into(),
216 description: "desc".into(),
217 parameters: Parameters::from_schema(schema_for!(String)),
218 },
219 };
220 let json = serde_json::to_string(&td).unwrap();
221 assert!(json.contains(r#""type":"function"#));
222 assert!(json.contains(r#""name":"test"#));
223 }
224}