1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput, Meta};
4
5#[proc_macro_derive(Versioned, attributes(versioned))]
58pub fn derive_versioned(input: TokenStream) -> TokenStream {
59 let input = parse_macro_input!(input as DeriveInput);
60
61 let attrs = extract_attributes(&input);
63
64 let name = &input.ident;
65 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
66
67 let version = &attrs.version;
68 let version_key = &attrs.version_key;
69 let data_key = &attrs.data_key;
70
71 let versioned_impl = quote! {
72 impl #impl_generics version_migrate::Versioned for #name #ty_generics #where_clause {
73 const VERSION: &'static str = #version;
74 const VERSION_KEY: &'static str = #version_key;
75 const DATA_KEY: &'static str = #data_key;
76 }
77 };
78
79 let expanded = if attrs.auto_tag {
80 let serialize_impl = generate_serialize_impl(&input, &attrs);
82 let deserialize_impl = generate_deserialize_impl(&input, &attrs);
83
84 quote! {
85 #versioned_impl
86 #serialize_impl
87 #deserialize_impl
88 }
89 } else {
90 versioned_impl
91 };
92
93 TokenStream::from(expanded)
94}
95
96struct VersionedAttributes {
97 version: String,
98 version_key: String,
99 data_key: String,
100 auto_tag: bool,
101}
102
103fn extract_attributes(input: &DeriveInput) -> VersionedAttributes {
104 let mut version = None;
105 let mut version_key = String::from("version");
106 let mut data_key = String::from("data");
107 let mut auto_tag = false;
108
109 for attr in &input.attrs {
110 if attr.path().is_ident("versioned") {
111 if let Meta::List(meta_list) = &attr.meta {
112 let tokens = meta_list.tokens.to_string();
113 parse_versioned_attrs(
114 &tokens,
115 &mut version,
116 &mut version_key,
117 &mut data_key,
118 &mut auto_tag,
119 );
120 }
121 }
122 }
123
124 let version = version.unwrap_or_else(|| {
125 panic!("Missing #[versioned(version = \"x.y.z\")] attribute");
126 });
127
128 if let Err(e) = semver::Version::parse(&version) {
130 panic!("Invalid semantic version '{}': {}", version, e);
131 }
132
133 VersionedAttributes {
134 version,
135 version_key,
136 data_key,
137 auto_tag,
138 }
139}
140
141fn parse_versioned_attrs(
142 tokens: &str,
143 version: &mut Option<String>,
144 version_key: &mut String,
145 data_key: &mut String,
146 auto_tag: &mut bool,
147) {
148 for part in tokens.split(',') {
150 let part = part.trim();
151
152 if let Some(val) = parse_attr_value(part, "version") {
153 *version = Some(val);
154 } else if let Some(val) = parse_attr_value(part, "version_key") {
155 *version_key = val;
156 } else if let Some(val) = parse_attr_value(part, "data_key") {
157 *data_key = val;
158 } else if let Some(val) = parse_attr_bool_value(part, "auto_tag") {
159 *auto_tag = val;
160 }
161 }
162}
163
164fn parse_attr_value(token: &str, key: &str) -> Option<String> {
165 let token = token.trim();
166 if let Some(rest) = token.strip_prefix(key) {
167 let rest = rest.trim();
168 if let Some(rest) = rest.strip_prefix('=') {
169 let rest = rest.trim();
170 if rest.starts_with('"') && rest.ends_with('"') {
171 return Some(rest[1..rest.len() - 1].to_string());
172 }
173 }
174 }
175 None
176}
177
178fn parse_attr_bool_value(token: &str, key: &str) -> Option<bool> {
179 let token = token.trim();
180 if let Some(rest) = token.strip_prefix(key) {
181 let rest = rest.trim();
182 if let Some(rest) = rest.strip_prefix('=') {
183 let rest = rest.trim();
184 return match rest {
185 "true" => Some(true),
186 "false" => Some(false),
187 _ => None,
188 };
189 }
190 }
191 None
192}
193
194fn generate_serialize_impl(
195 input: &DeriveInput,
196 attrs: &VersionedAttributes,
197) -> proc_macro2::TokenStream {
198 let name = &input.ident;
199 let version = &attrs.version;
200 let version_key = &attrs.version_key;
201
202 let fields = match &input.data {
204 syn::Data::Struct(data_struct) => match &data_struct.fields {
205 syn::Fields::Named(fields) => &fields.named,
206 _ => panic!("auto_tag only supports structs with named fields"),
207 },
208 _ => panic!("auto_tag only supports structs"),
209 };
210
211 let field_count = fields.len() + 1; let field_serializations = fields.iter().map(|field| {
213 let field_name = field.ident.as_ref().unwrap();
214 let field_name_str = field_name.to_string();
215 quote! {
216 state.serialize_field(#field_name_str, &self.#field_name)?;
217 }
218 });
219
220 quote! {
221 impl serde::Serialize for #name {
222 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
223 where
224 S: serde::Serializer,
225 {
226 use serde::ser::SerializeStruct;
227 let mut state = serializer.serialize_struct(stringify!(#name), #field_count)?;
228 state.serialize_field(#version_key, #version)?;
229 #(#field_serializations)*
230 state.end()
231 }
232 }
233 }
234}
235
236fn generate_deserialize_impl(
237 input: &DeriveInput,
238 attrs: &VersionedAttributes,
239) -> proc_macro2::TokenStream {
240 let name = &input.ident;
241 let version = &attrs.version;
242 let version_key = &attrs.version_key;
243
244 let fields = match &input.data {
246 syn::Data::Struct(data_struct) => match &data_struct.fields {
247 syn::Fields::Named(fields) => &fields.named,
248 _ => panic!("auto_tag only supports structs with named fields"),
249 },
250 _ => panic!("auto_tag only supports structs"),
251 };
252
253 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
254 let field_name_strs: Vec<_> = field_names.iter().map(|f| f.to_string()).collect();
255
256 let all_field_names = {
257 let mut names = vec![version_key.clone()];
258 names.extend(field_name_strs.iter().cloned());
259 names
260 };
261
262 let field_enum_variants = field_names.iter().map(|name| {
263 let variant = quote::format_ident!("{}", name.to_string().to_uppercase());
264 quote! { #variant }
265 });
266
267 let field_match_arms =
268 field_names
269 .iter()
270 .zip(field_name_strs.iter())
271 .map(|(name, name_str)| {
272 let variant = quote::format_ident!("{}", name.to_string().to_uppercase());
273 quote! {
274 #name_str => Ok(Field::#variant)
275 }
276 });
277
278 let field_visit_arms = field_names.iter().map(|name| {
279 let variant = quote::format_ident!("{}", name.to_string().to_uppercase());
280 quote! {
281 Field::#variant => {
282 if #name.is_some() {
283 return Err(serde::de::Error::duplicate_field(stringify!(#name)));
284 }
285 #name = Some(map.next_value()?);
286 }
287 }
288 });
289
290 let field_unwrap = field_names.iter().map(|name| {
291 quote! {
292 let #name = #name.ok_or_else(|| serde::de::Error::missing_field(stringify!(#name)))?;
293 }
294 });
295
296 quote! {
297 impl<'de> serde::Deserialize<'de> for #name {
298 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
299 where
300 D: serde::Deserializer<'de>,
301 {
302 #[allow(non_camel_case_types)]
303 enum Field {
304 Version,
305 #(#field_enum_variants,)*
306 }
307
308 impl<'de> serde::Deserialize<'de> for Field {
309 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
310 where
311 D: serde::Deserializer<'de>,
312 {
313 struct FieldVisitor;
314
315 impl<'de> serde::de::Visitor<'de> for FieldVisitor {
316 type Value = Field;
317
318 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
319 formatter.write_str(&format!("field identifier: {}", &[#(#all_field_names),*].join(", ")))
320 }
321
322 fn visit_str<E>(self, value: &str) -> Result<Field, E>
323 where
324 E: serde::de::Error,
325 {
326 match value {
327 #version_key => Ok(Field::Version),
328 #(#field_match_arms,)*
329 _ => Err(serde::de::Error::unknown_field(value, &[#(#all_field_names),*])),
330 }
331 }
332 }
333
334 deserializer.deserialize_identifier(FieldVisitor)
335 }
336 }
337
338 struct StructVisitor;
339
340 impl<'de> serde::de::Visitor<'de> for StructVisitor {
341 type Value = #name;
342
343 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
344 formatter.write_str(&format!("struct {}", stringify!(#name)))
345 }
346
347 fn visit_map<V>(self, mut map: V) -> Result<#name, V::Error>
348 where
349 V: serde::de::MapAccess<'de>,
350 {
351 let mut version: Option<String> = None;
352 #(let mut #field_names = None;)*
353
354 while let Some(key) = map.next_key()? {
355 match key {
356 Field::Version => {
357 if version.is_some() {
358 return Err(serde::de::Error::duplicate_field(#version_key));
359 }
360 let v: String = map.next_value()?;
361 if v != #version {
362 return Err(serde::de::Error::custom(format!(
363 "version mismatch: expected {}, found {}",
364 #version, v
365 )));
366 }
367 version = Some(v);
368 }
369 #(#field_visit_arms)*
370 }
371 }
372
373 let _version = version.ok_or_else(|| serde::de::Error::missing_field(#version_key))?;
374 #(#field_unwrap)*
375
376 Ok(#name {
377 #(#field_names,)*
378 })
379 }
380 }
381
382 deserializer.deserialize_struct(
383 stringify!(#name),
384 &[#(#all_field_names),*],
385 StructVisitor,
386 )
387 }
388 }
389 }
390}