1use schemars::{JsonSchema, generate::SchemaSettings};
2use serde::{Deserialize, Serialize};
3use serde_json::{Map, Value};
4
5#[derive(Serialize, Deserialize, Clone, Debug)]
7#[serde(tag = "role", rename_all = "lowercase")]
8pub enum Message {
9 System {
11 content: String,
13 },
14 User {
16 content: String,
18 },
19 Assistant {
21 content: String,
23 #[serde(skip_serializing_if = "Option::is_none")]
25 tool_calls: Option<Vec<ToolCall>>,
26 },
27 Tool {
29 content: String,
31 tool_call_id: String,
33 },
34 #[serde(untagged)]
36 Custom {
37 role: String,
39 #[serde(flatten)]
41 body: Value,
42 },
43}
44
45#[derive(Serialize, Deserialize, Clone, Debug)]
47pub struct ToolCall {
48 pub id: String,
50 #[serde(rename = "type")]
52 pub call_type: String,
53 pub function: FunctionCall,
55}
56
57#[derive(Serialize, Deserialize, Clone, Debug)]
59pub struct FunctionCall {
60 pub name: String,
62 pub arguments: String,
64}
65
66#[derive(Serialize, Clone, Debug)]
68pub struct ToolDefinition {
69 #[serde(rename = "type")]
71 pub tool_type: String,
72 pub function: ToolFunction,
74}
75
76#[derive(Serialize, Clone, Debug)]
78pub struct ToolFunction {
79 pub name: String,
81 pub description: String,
83 pub parameters: Parameters,
85}
86
87#[derive(Serialize, Clone, Debug)]
89pub struct Parameters(Map<String, Value>);
90
91impl Parameters {
92 pub fn from_object(mut obj: Map<String, Value>) -> Self {
94 obj.remove("$schema");
96 obj.remove("title");
97 obj.remove("description");
98
99 Self(obj)
100 }
101
102 pub fn from_schema(schema: schemars::Schema) -> Self {
104 let obj = schema.to_value().as_object().unwrap().clone();
105 Self::from_object(obj)
106 }
107
108 pub fn from_type<T: JsonSchema>() -> Self {
110 let settings = SchemaSettings::default().with(|s| {
111 s.inline_subschemas = true;
112 });
113 let generator = settings.into_generator();
114 let schema = generator.into_root_schema_for::<T>();
115 Self::from_schema(schema)
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 #[test]
124 fn test_system_roundtrip() {
125 let msg = Message::System {
126 content: "test".into(),
127 };
128 let json = serde_json::to_string(&msg).unwrap();
129 let parsed: Message = serde_json::from_str(&json).unwrap();
130 assert!(matches!(parsed, Message::System { content } if content == "test"));
131 }
132
133 #[test]
134 fn test_user_roundtrip() {
135 let msg = Message::User {
136 content: "test".into(),
137 };
138 let json = serde_json::to_string(&msg).unwrap();
139 let parsed: Message = serde_json::from_str(&json).unwrap();
140 assert!(matches!(parsed, Message::User { content } if content == "test"));
141 }
142
143 #[test]
144 fn test_assistant_no_tools_roundtrip() {
145 let msg = Message::Assistant {
146 content: "test".into(),
147 tool_calls: None,
148 };
149 let json = serde_json::to_string(&msg).unwrap();
150 assert!(!json.contains("tool_calls"));
151 let parsed: Message = serde_json::from_str(&json).unwrap();
152 assert!(
153 matches!(parsed, Message::Assistant { content, tool_calls: None } if content == "test")
154 );
155 }
156
157 #[test]
158 fn test_assistant_with_tools_roundtrip() {
159 let msg = Message::Assistant {
160 content: "test".into(),
161 tool_calls: Some(vec![ToolCall {
162 id: "call_1".into(),
163 call_type: "function".into(),
164 function: FunctionCall {
165 name: "fn".into(),
166 arguments: "{}".into(),
167 },
168 }]),
169 };
170 let json = serde_json::to_string(&msg).unwrap();
171 let parsed: Message = serde_json::from_str(&json).unwrap();
172 assert!(
173 matches!(parsed, Message::Assistant { tool_calls: Some(calls), .. } if calls.len() == 1)
174 );
175 }
176
177 #[test]
178 fn test_tool_roundtrip() {
179 let msg = Message::Tool {
180 content: "result".into(),
181 tool_call_id: "call_123".into(),
182 };
183 let json = serde_json::to_string(&msg).unwrap();
184 let parsed: Message = serde_json::from_str(&json).unwrap();
185 assert!(matches!(parsed, Message::Tool { content, tool_call_id }
186 if content == "result" && tool_call_id == "call_123"));
187 }
188
189 #[test]
190 fn test_custom_roundtrip() {
191 let msg = Message::Custom {
192 role: "custom".into(),
193 body: serde_json::json!({"content": "test", "extra": "field"}),
194 };
195 let json = serde_json::to_string(&msg).unwrap();
196 let parsed: Message = serde_json::from_str(&json).unwrap();
197 assert!(matches!(parsed, Message::Custom { role, .. } if role == "custom"));
198 }
199
200 #[test]
201 fn test_tool_call_roundtrip() {
202 let tc = ToolCall {
203 id: "call_1".into(),
204 call_type: "function".into(),
205 function: FunctionCall {
206 name: "test".into(),
207 arguments: r#"{"key":"value"}"#.into(),
208 },
209 };
210 let json = serde_json::to_string(&tc).unwrap();
211 let parsed: ToolCall = serde_json::from_str(&json).unwrap();
212 assert_eq!(parsed.id, "call_1");
213 assert_eq!(parsed.function.name, "test");
214 }
215
216 #[test]
217 fn test_tool_definition_serialization() {
218 let td = ToolDefinition {
219 tool_type: "function".into(),
220 function: ToolFunction {
221 name: "test".into(),
222 description: "desc".into(),
223 parameters: Parameters::from_type::<String>(),
224 },
225 };
226 let json = serde_json::to_string(&td).unwrap();
227 assert!(json.contains(r#""type":"function"#));
228 assert!(json.contains(r#""name":"test"#));
229 }
230}