1use schemars::{JsonSchema, generate::SchemaSettings};
2use serde::{Deserialize, Serialize};
3use serde_json::{Map, Value};
4
5#[derive(Serialize, Deserialize, Clone, Debug)]
7#[serde(tag = "role", rename_all = "lowercase")]
8pub enum Message {
9 System {
11 content: String,
13 },
14 User {
16 content: String,
18 },
19 Assistant {
21 content: String,
23 #[serde(skip_serializing_if = "Option::is_none")]
25 tool_calls: Option<Vec<ToolCall>>,
26 },
27 Tool {
29 content: String,
31 tool_call_id: String,
33 },
34 #[serde(untagged)]
36 Custom {
37 role: String,
39 #[serde(flatten)]
41 body: Value,
42 },
43}
44
45#[derive(Serialize, Deserialize, Clone, Debug)]
47pub struct ToolCall {
48 pub id: String,
50 #[serde(rename = "type")]
52 pub call_type: String,
53 pub function: FunctionCall,
55}
56
57#[derive(Serialize, Deserialize, Clone, Debug)]
59pub struct FunctionCall {
60 pub name: String,
62 pub arguments: String,
64}
65
66#[derive(Serialize, Clone, Debug)]
68pub struct ToolDefinition {
69 #[serde(rename = "type")]
71 pub tool_type: String,
72 pub function: ToolFunction,
74}
75
76#[derive(Serialize, Clone, Debug)]
78pub struct ToolFunction {
79 pub name: String,
81 pub description: String,
83 pub parameters: Parameters,
85}
86
87#[derive(Serialize, Clone, Debug)]
89pub struct Parameters(Map<String, Value>);
90
91impl Parameters {
92 pub fn from_object(mut obj: Map<String, Value>) -> Self {
94 obj.remove("$schema");
96 obj.remove("title");
97
98 Self(obj)
99 }
100
101 pub fn from_schema(schema: schemars::Schema) -> Self {
103 let obj = schema.to_value().as_object().unwrap().clone();
104 Self::from_object(obj)
105 }
106
107 pub fn from_type<T: JsonSchema>() -> Self {
109 let settings = SchemaSettings::default().with(|s| {
110 s.inline_subschemas = true;
111 });
112 let generator = settings.into_generator();
113 let schema = generator.into_root_schema_for::<T>();
114 Self::from_schema(schema)
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn test_system_roundtrip() {
124 let msg = Message::System {
125 content: "test".into(),
126 };
127 let json = serde_json::to_string(&msg).unwrap();
128 let parsed: Message = serde_json::from_str(&json).unwrap();
129 assert!(matches!(parsed, Message::System { content } if content == "test"));
130 }
131
132 #[test]
133 fn test_user_roundtrip() {
134 let msg = Message::User {
135 content: "test".into(),
136 };
137 let json = serde_json::to_string(&msg).unwrap();
138 let parsed: Message = serde_json::from_str(&json).unwrap();
139 assert!(matches!(parsed, Message::User { content } if content == "test"));
140 }
141
142 #[test]
143 fn test_assistant_no_tools_roundtrip() {
144 let msg = Message::Assistant {
145 content: "test".into(),
146 tool_calls: None,
147 };
148 let json = serde_json::to_string(&msg).unwrap();
149 assert!(!json.contains("tool_calls"));
150 let parsed: Message = serde_json::from_str(&json).unwrap();
151 assert!(
152 matches!(parsed, Message::Assistant { content, tool_calls: None } if content == "test")
153 );
154 }
155
156 #[test]
157 fn test_assistant_with_tools_roundtrip() {
158 let msg = Message::Assistant {
159 content: "test".into(),
160 tool_calls: Some(vec![ToolCall {
161 id: "call_1".into(),
162 call_type: "function".into(),
163 function: FunctionCall {
164 name: "fn".into(),
165 arguments: "{}".into(),
166 },
167 }]),
168 };
169 let json = serde_json::to_string(&msg).unwrap();
170 let parsed: Message = serde_json::from_str(&json).unwrap();
171 assert!(
172 matches!(parsed, Message::Assistant { tool_calls: Some(calls), .. } if calls.len() == 1)
173 );
174 }
175
176 #[test]
177 fn test_tool_roundtrip() {
178 let msg = Message::Tool {
179 content: "result".into(),
180 tool_call_id: "call_123".into(),
181 };
182 let json = serde_json::to_string(&msg).unwrap();
183 let parsed: Message = serde_json::from_str(&json).unwrap();
184 assert!(matches!(parsed, Message::Tool { content, tool_call_id }
185 if content == "result" && tool_call_id == "call_123"));
186 }
187
188 #[test]
189 fn test_custom_roundtrip() {
190 let msg = Message::Custom {
191 role: "custom".into(),
192 body: serde_json::json!({"content": "test", "extra": "field"}),
193 };
194 let json = serde_json::to_string(&msg).unwrap();
195 let parsed: Message = serde_json::from_str(&json).unwrap();
196 assert!(matches!(parsed, Message::Custom { role, .. } if role == "custom"));
197 }
198
199 #[test]
200 fn test_tool_call_roundtrip() {
201 let tc = ToolCall {
202 id: "call_1".into(),
203 call_type: "function".into(),
204 function: FunctionCall {
205 name: "test".into(),
206 arguments: r#"{"key":"value"}"#.into(),
207 },
208 };
209 let json = serde_json::to_string(&tc).unwrap();
210 let parsed: ToolCall = serde_json::from_str(&json).unwrap();
211 assert_eq!(parsed.id, "call_1");
212 assert_eq!(parsed.function.name, "test");
213 }
214
215 #[test]
216 fn test_tool_definition_serialization() {
217 let td = ToolDefinition {
218 tool_type: "function".into(),
219 function: ToolFunction {
220 name: "test".into(),
221 description: "desc".into(),
222 parameters: Parameters::from_type::<String>(),
223 },
224 };
225 let json = serde_json::to_string(&td).unwrap();
226 assert!(json.contains(r#""type":"function"#));
227 assert!(json.contains(r#""name":"test"#));
228 }
229}