version_migrate_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput, Meta};
4
5/// Derives the `Versioned` trait for a struct.
6///
7/// # Attributes
8///
9/// - `#[versioned(version = "x.y.z")]`: Specifies the semantic version (required).
10///   The version string must be a valid semantic version.
11/// - `#[versioned(version_key = "...")]`: Customizes the version field key (optional, default: "version").
12/// - `#[versioned(data_key = "...")]`: Customizes the data field key (optional, default: "data").
13/// - `#[versioned(auto_tag = true)]`: Auto-generates Serialize/Deserialize with version field (optional, default: false).
14///   When enabled, the version field is automatically inserted during serialization and validated during deserialization.
15///
16/// # Examples
17///
18/// Basic usage:
19/// ```ignore
20/// use version_migrate::Versioned;
21///
22/// #[derive(Versioned)]
23/// #[versioned(version = "1.0.0")]
24/// pub struct Task_V1_0_0 {
25///     pub id: String,
26///     pub title: String,
27/// }
28/// ```
29///
30/// Custom keys:
31/// ```ignore
32/// #[derive(Versioned)]
33/// #[versioned(
34///     version = "1.0.0",
35///     version_key = "schema_version",
36///     data_key = "payload"
37/// )]
38/// pub struct Task { ... }
39/// // When used with Migrator:
40/// // Serializes to: {"schema_version":"1.0.0","payload":{...}}
41/// ```
42///
43/// Auto-tag for direct serialization:
44/// ```ignore
45/// #[derive(Versioned)]
46/// #[versioned(version = "1.0.0", auto_tag = true)]
47/// pub struct Task {
48///     pub id: String,
49///     pub title: String,
50/// }
51///
52/// // Use serde directly without Migrator
53/// let task = Task { id: "1".into(), title: "Test".into() };
54/// let json = serde_json::to_string(&task)?;
55/// // → {"version":"1.0.0","id":"1","title":"Test"}
56/// ```
57#[proc_macro_derive(Versioned, attributes(versioned))]
58pub fn derive_versioned(input: TokenStream) -> TokenStream {
59    let input = parse_macro_input!(input as DeriveInput);
60
61    // Extract attributes
62    let attrs = extract_attributes(&input);
63
64    let name = &input.ident;
65    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
66
67    let version = &attrs.version;
68    let version_key = &attrs.version_key;
69    let data_key = &attrs.data_key;
70
71    let versioned_impl = quote! {
72        impl #impl_generics version_migrate::Versioned for #name #ty_generics #where_clause {
73            const VERSION: &'static str = #version;
74            const VERSION_KEY: &'static str = #version_key;
75            const DATA_KEY: &'static str = #data_key;
76        }
77    };
78
79    let expanded = if attrs.auto_tag {
80        // Generate custom Serialize and Deserialize implementations
81        let serialize_impl = generate_serialize_impl(&input, &attrs);
82        let deserialize_impl = generate_deserialize_impl(&input, &attrs);
83
84        quote! {
85            #versioned_impl
86            #serialize_impl
87            #deserialize_impl
88        }
89    } else {
90        versioned_impl
91    };
92
93    TokenStream::from(expanded)
94}
95
96struct VersionedAttributes {
97    version: String,
98    version_key: String,
99    data_key: String,
100    auto_tag: bool,
101}
102
103fn extract_attributes(input: &DeriveInput) -> VersionedAttributes {
104    let mut version = None;
105    let mut version_key = String::from("version");
106    let mut data_key = String::from("data");
107    let mut auto_tag = false;
108
109    for attr in &input.attrs {
110        if attr.path().is_ident("versioned") {
111            if let Meta::List(meta_list) = &attr.meta {
112                let tokens = meta_list.tokens.to_string();
113                parse_versioned_attrs(
114                    &tokens,
115                    &mut version,
116                    &mut version_key,
117                    &mut data_key,
118                    &mut auto_tag,
119                );
120            }
121        }
122    }
123
124    let version = version.unwrap_or_else(|| {
125        panic!("Missing #[versioned(version = \"x.y.z\")] attribute");
126    });
127
128    // Validate semver at compile time
129    if let Err(e) = semver::Version::parse(&version) {
130        panic!("Invalid semantic version '{}': {}", version, e);
131    }
132
133    VersionedAttributes {
134        version,
135        version_key,
136        data_key,
137        auto_tag,
138    }
139}
140
141fn parse_versioned_attrs(
142    tokens: &str,
143    version: &mut Option<String>,
144    version_key: &mut String,
145    data_key: &mut String,
146    auto_tag: &mut bool,
147) {
148    // Parse comma-separated key = "value" pairs
149    for part in tokens.split(',') {
150        let part = part.trim();
151
152        if let Some(val) = parse_attr_value(part, "version") {
153            *version = Some(val);
154        } else if let Some(val) = parse_attr_value(part, "version_key") {
155            *version_key = val;
156        } else if let Some(val) = parse_attr_value(part, "data_key") {
157            *data_key = val;
158        } else if let Some(val) = parse_attr_bool_value(part, "auto_tag") {
159            *auto_tag = val;
160        }
161    }
162}
163
164fn parse_attr_value(token: &str, key: &str) -> Option<String> {
165    let token = token.trim();
166    if let Some(rest) = token.strip_prefix(key) {
167        let rest = rest.trim();
168        if let Some(rest) = rest.strip_prefix('=') {
169            let rest = rest.trim();
170            if rest.starts_with('"') && rest.ends_with('"') {
171                return Some(rest[1..rest.len() - 1].to_string());
172            }
173        }
174    }
175    None
176}
177
178fn parse_attr_bool_value(token: &str, key: &str) -> Option<bool> {
179    let token = token.trim();
180    if let Some(rest) = token.strip_prefix(key) {
181        let rest = rest.trim();
182        if let Some(rest) = rest.strip_prefix('=') {
183            let rest = rest.trim();
184            return match rest {
185                "true" => Some(true),
186                "false" => Some(false),
187                _ => None,
188            };
189        }
190    }
191    None
192}
193
194fn generate_serialize_impl(
195    input: &DeriveInput,
196    attrs: &VersionedAttributes,
197) -> proc_macro2::TokenStream {
198    let name = &input.ident;
199    let version = &attrs.version;
200    let version_key = &attrs.version_key;
201
202    // Extract field information
203    let fields = match &input.data {
204        syn::Data::Struct(data_struct) => match &data_struct.fields {
205            syn::Fields::Named(fields) => &fields.named,
206            _ => panic!("auto_tag only supports structs with named fields"),
207        },
208        _ => panic!("auto_tag only supports structs"),
209    };
210
211    let field_count = fields.len() + 1; // +1 for version field
212    let field_serializations = fields.iter().map(|field| {
213        let field_name = field.ident.as_ref().unwrap();
214        let field_name_str = field_name.to_string();
215        quote! {
216            state.serialize_field(#field_name_str, &self.#field_name)?;
217        }
218    });
219
220    quote! {
221        impl serde::Serialize for #name {
222            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
223            where
224                S: serde::Serializer,
225            {
226                use serde::ser::SerializeStruct;
227                let mut state = serializer.serialize_struct(stringify!(#name), #field_count)?;
228                state.serialize_field(#version_key, #version)?;
229                #(#field_serializations)*
230                state.end()
231            }
232        }
233    }
234}
235
236fn generate_deserialize_impl(
237    input: &DeriveInput,
238    attrs: &VersionedAttributes,
239) -> proc_macro2::TokenStream {
240    let name = &input.ident;
241    let version = &attrs.version;
242    let version_key = &attrs.version_key;
243
244    // Extract field information
245    let fields = match &input.data {
246        syn::Data::Struct(data_struct) => match &data_struct.fields {
247            syn::Fields::Named(fields) => &fields.named,
248            _ => panic!("auto_tag only supports structs with named fields"),
249        },
250        _ => panic!("auto_tag only supports structs"),
251    };
252
253    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
254    let field_name_strs: Vec<_> = field_names.iter().map(|f| f.to_string()).collect();
255
256    let all_field_names = {
257        let mut names = vec![version_key.clone()];
258        names.extend(field_name_strs.iter().cloned());
259        names
260    };
261
262    let field_enum_variants = field_names.iter().map(|name| {
263        let variant = quote::format_ident!("{}", name.to_string().to_uppercase());
264        quote! { #variant }
265    });
266
267    let field_match_arms =
268        field_names
269            .iter()
270            .zip(field_name_strs.iter())
271            .map(|(name, name_str)| {
272                let variant = quote::format_ident!("{}", name.to_string().to_uppercase());
273                quote! {
274                    #name_str => Ok(Field::#variant)
275                }
276            });
277
278    let field_visit_arms = field_names.iter().map(|name| {
279        let variant = quote::format_ident!("{}", name.to_string().to_uppercase());
280        quote! {
281            Field::#variant => {
282                if #name.is_some() {
283                    return Err(serde::de::Error::duplicate_field(stringify!(#name)));
284                }
285                #name = Some(map.next_value()?);
286            }
287        }
288    });
289
290    let field_unwrap = field_names.iter().map(|name| {
291        quote! {
292            let #name = #name.ok_or_else(|| serde::de::Error::missing_field(stringify!(#name)))?;
293        }
294    });
295
296    quote! {
297        impl<'de> serde::Deserialize<'de> for #name {
298            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
299            where
300                D: serde::Deserializer<'de>,
301            {
302                #[allow(non_camel_case_types)]
303                enum Field {
304                    Version,
305                    #(#field_enum_variants,)*
306                }
307
308                impl<'de> serde::Deserialize<'de> for Field {
309                    fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
310                    where
311                        D: serde::Deserializer<'de>,
312                    {
313                        struct FieldVisitor;
314
315                        impl<'de> serde::de::Visitor<'de> for FieldVisitor {
316                            type Value = Field;
317
318                            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
319                                formatter.write_str(&format!("field identifier: {}", &[#(#all_field_names),*].join(", ")))
320                            }
321
322                            fn visit_str<E>(self, value: &str) -> Result<Field, E>
323                            where
324                                E: serde::de::Error,
325                            {
326                                match value {
327                                    #version_key => Ok(Field::Version),
328                                    #(#field_match_arms,)*
329                                    _ => Err(serde::de::Error::unknown_field(value, &[#(#all_field_names),*])),
330                                }
331                            }
332                        }
333
334                        deserializer.deserialize_identifier(FieldVisitor)
335                    }
336                }
337
338                struct StructVisitor;
339
340                impl<'de> serde::de::Visitor<'de> for StructVisitor {
341                    type Value = #name;
342
343                    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
344                        formatter.write_str(&format!("struct {}", stringify!(#name)))
345                    }
346
347                    fn visit_map<V>(self, mut map: V) -> Result<#name, V::Error>
348                    where
349                        V: serde::de::MapAccess<'de>,
350                    {
351                        let mut version: Option<String> = None;
352                        #(let mut #field_names = None;)*
353
354                        while let Some(key) = map.next_key()? {
355                            match key {
356                                Field::Version => {
357                                    if version.is_some() {
358                                        return Err(serde::de::Error::duplicate_field(#version_key));
359                                    }
360                                    let v: String = map.next_value()?;
361                                    if v != #version {
362                                        return Err(serde::de::Error::custom(format!(
363                                            "version mismatch: expected {}, found {}",
364                                            #version, v
365                                        )));
366                                    }
367                                    version = Some(v);
368                                }
369                                #(#field_visit_arms)*
370                            }
371                        }
372
373                        let _version = version.ok_or_else(|| serde::de::Error::missing_field(#version_key))?;
374                        #(#field_unwrap)*
375
376                        Ok(#name {
377                            #(#field_names,)*
378                        })
379                    }
380                }
381
382                deserializer.deserialize_struct(
383                    stringify!(#name),
384                    &[#(#all_field_names),*],
385                    StructVisitor,
386                )
387            }
388        }
389    }
390}