version_migrate_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput, Meta, Type};
4
5/// Derives the `Versioned` trait for a struct.
6///
7/// # Attributes
8///
9/// - `#[versioned(version = "x.y.z")]`: Specifies the semantic version (required).
10///   The version string must be a valid semantic version.
11/// - `#[versioned(version_key = "...")]`: Customizes the version field key (optional, default: "version").
12/// - `#[versioned(data_key = "...")]`: Customizes the data field key (optional, default: "data").
13/// - `#[versioned(auto_tag = true)]`: Auto-generates Serialize/Deserialize with version field (optional, default: false).
14///   When enabled, the version field is automatically inserted during serialization and validated during deserialization.
15/// - `#[versioned(queryable = true)]`: Auto-generates Queryable trait implementation (optional, default: false).
16///   Enables use with ConfigMigrator for ORM-like queries.
17/// - `#[versioned(queryable_key = "...")]`: Customizes the entity name for Queryable (optional).
18///   If not specified, uses the lowercased type name. Only used when `queryable = true`.
19///
20/// # Examples
21///
22/// Basic usage:
23/// ```ignore
24/// use version_migrate::Versioned;
25///
26/// #[derive(Versioned)]
27/// #[versioned(version = "1.0.0")]
28/// pub struct Task_V1_0_0 {
29///     pub id: String,
30///     pub title: String,
31/// }
32/// ```
33///
34/// Custom keys:
35/// ```ignore
36/// #[derive(Versioned)]
37/// #[versioned(
38///     version = "1.0.0",
39///     version_key = "schema_version",
40///     data_key = "payload"
41/// )]
42/// pub struct Task { ... }
43/// // When used with Migrator:
44/// // Serializes to: {"schema_version":"1.0.0","payload":{...}}
45/// ```
46///
47/// Auto-tag for direct serialization:
48/// ```ignore
49/// #[derive(Versioned)]
50/// #[versioned(version = "1.0.0", auto_tag = true)]
51/// pub struct Task {
52///     pub id: String,
53///     pub title: String,
54/// }
55///
56/// // Use serde directly without Migrator
57/// let task = Task { id: "1".into(), title: "Test".into() };
58/// let json = serde_json::to_string(&task)?;
59/// // → {"version":"1.0.0","id":"1","title":"Test"}
60/// ```
61///
62/// Queryable for ConfigMigrator:
63/// ```ignore
64/// #[derive(Serialize, Deserialize, Versioned)]
65/// #[versioned(version = "2.0.0", queryable = true, queryable_key = "task")]
66/// pub struct TaskEntity {
67///     pub id: String,
68///     pub title: String,
69///     pub description: Option<String>,
70/// }
71///
72/// // Now TaskEntity implements Queryable automatically
73/// let tasks: Vec<TaskEntity> = config_migrator.query("tasks")?;
74/// ```
75#[proc_macro_derive(Versioned, attributes(versioned))]
76pub fn derive_versioned(input: TokenStream) -> TokenStream {
77    let input = parse_macro_input!(input as DeriveInput);
78
79    // Extract attributes
80    let attrs = extract_attributes(&input);
81
82    let name = &input.ident;
83    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
84
85    let version = &attrs.version;
86    let version_key = &attrs.version_key;
87    let data_key = &attrs.data_key;
88
89    let versioned_impl = quote! {
90        impl #impl_generics version_migrate::Versioned for #name #ty_generics #where_clause {
91            const VERSION: &'static str = #version;
92            const VERSION_KEY: &'static str = #version_key;
93            const DATA_KEY: &'static str = #data_key;
94        }
95    };
96
97    let mut impls = vec![versioned_impl];
98
99    if attrs.auto_tag {
100        // Generate custom Serialize and Deserialize implementations
101        let serialize_impl = generate_serialize_impl(&input, &attrs);
102        let deserialize_impl = generate_deserialize_impl(&input, &attrs);
103        impls.push(serialize_impl);
104        impls.push(deserialize_impl);
105    }
106
107    if attrs.queryable {
108        // Generate Queryable trait implementation
109        let queryable_impl = generate_queryable_impl(&input, &attrs);
110        impls.push(queryable_impl);
111    }
112
113    let expanded = quote! {
114        #(#impls)*
115    };
116
117    TokenStream::from(expanded)
118}
119
120struct VersionedAttributes {
121    version: String,
122    version_key: String,
123    data_key: String,
124    auto_tag: bool,
125    queryable: bool,
126    queryable_key: Option<String>,
127}
128
129fn extract_attributes(input: &DeriveInput) -> VersionedAttributes {
130    let mut version = None;
131    let mut version_key = String::from("version");
132    let mut data_key = String::from("data");
133    let mut auto_tag = false;
134    let mut queryable = false;
135    let mut queryable_key = None;
136
137    for attr in &input.attrs {
138        if attr.path().is_ident("versioned") {
139            if let Meta::List(meta_list) = &attr.meta {
140                let tokens = meta_list.tokens.to_string();
141                parse_versioned_attrs(
142                    &tokens,
143                    &mut version,
144                    &mut version_key,
145                    &mut data_key,
146                    &mut auto_tag,
147                    &mut queryable,
148                    &mut queryable_key,
149                );
150            }
151        }
152    }
153
154    let version = version.unwrap_or_else(|| {
155        panic!("Missing #[versioned(version = \"x.y.z\")] attribute");
156    });
157
158    // Validate semver at compile time
159    if let Err(e) = semver::Version::parse(&version) {
160        panic!("Invalid semantic version '{}': {}", version, e);
161    }
162
163    VersionedAttributes {
164        version,
165        version_key,
166        data_key,
167        auto_tag,
168        queryable,
169        queryable_key,
170    }
171}
172
173fn parse_versioned_attrs(
174    tokens: &str,
175    version: &mut Option<String>,
176    version_key: &mut String,
177    data_key: &mut String,
178    auto_tag: &mut bool,
179    queryable: &mut bool,
180    queryable_key: &mut Option<String>,
181) {
182    // Parse comma-separated key = "value" pairs
183    for part in tokens.split(',') {
184        let part = part.trim();
185
186        if let Some(val) = parse_attr_value(part, "version") {
187            *version = Some(val);
188        } else if let Some(val) = parse_attr_value(part, "version_key") {
189            *version_key = val;
190        } else if let Some(val) = parse_attr_value(part, "data_key") {
191            *data_key = val;
192        } else if let Some(val) = parse_attr_bool_value(part, "auto_tag") {
193            *auto_tag = val;
194        } else if let Some(val) = parse_attr_bool_value(part, "queryable") {
195            *queryable = val;
196        } else if let Some(val) = parse_attr_value(part, "queryable_key") {
197            *queryable_key = Some(val);
198        }
199    }
200}
201
202fn parse_attr_value(token: &str, key: &str) -> Option<String> {
203    let token = token.trim();
204    if let Some(rest) = token.strip_prefix(key) {
205        let rest = rest.trim();
206        if let Some(rest) = rest.strip_prefix('=') {
207            let rest = rest.trim();
208            if rest.starts_with('"') && rest.ends_with('"') {
209                return Some(rest[1..rest.len() - 1].to_string());
210            }
211        }
212    }
213    None
214}
215
216fn parse_attr_bool_value(token: &str, key: &str) -> Option<bool> {
217    let token = token.trim();
218    if let Some(rest) = token.strip_prefix(key) {
219        let rest = rest.trim();
220        if let Some(rest) = rest.strip_prefix('=') {
221            let rest = rest.trim();
222            return match rest {
223                "true" => Some(true),
224                "false" => Some(false),
225                _ => None,
226            };
227        }
228    }
229    None
230}
231
232fn generate_queryable_impl(
233    input: &DeriveInput,
234    attrs: &VersionedAttributes,
235) -> proc_macro2::TokenStream {
236    let name = &input.ident;
237
238    // Determine the entity name
239    let entity_name = if let Some(ref key) = attrs.queryable_key {
240        key.clone()
241    } else {
242        // Default: use the type name in lowercase
243        name.to_string().to_lowercase()
244    };
245
246    quote! {
247        impl version_migrate::Queryable for #name {
248            const ENTITY_NAME: &'static str = #entity_name;
249        }
250    }
251}
252
253/// Derives the `Queryable` trait for a struct.
254///
255/// This is a standalone macro for domain entities that need to be queryable
256/// via `ConfigMigrator` but don't have version information themselves.
257///
258/// # Attributes
259///
260/// - `#[queryable(entity = "name")]`: Specifies the entity name (required).
261///   This must match the entity name used when registering migration paths.
262///
263/// # Examples
264///
265/// Basic usage:
266/// ```ignore
267/// use version_migrate::Queryable;
268///
269/// #[derive(Queryable)]
270/// #[queryable(entity = "task")]
271/// pub struct TaskEntity {
272///     pub id: String,
273///     pub title: String,
274/// }
275///
276/// // Now can be used with ConfigMigrator
277/// let tasks: Vec<TaskEntity> = config.query("tasks")?;
278/// ```
279///
280/// The entity name must match the Migrator registration:
281/// ```ignore
282/// let path = Migrator::define("task")  // ← This name
283///     .from::<TaskV1>()
284///     .into::<TaskEntity>();
285///
286/// #[derive(Queryable)]
287/// #[queryable(entity = "task")]  // ← Must match
288/// struct TaskEntity { ... }
289/// ```
290#[proc_macro_derive(Queryable, attributes(queryable))]
291pub fn derive_queryable(input: TokenStream) -> TokenStream {
292    let input = parse_macro_input!(input as DeriveInput);
293
294    let name = &input.ident;
295    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
296    let mut entity_name: Option<String> = None;
297
298    // Extract entity attribute
299    for attr in &input.attrs {
300        if attr.path().is_ident("queryable") {
301            if let Meta::List(meta_list) = &attr.meta {
302                let tokens = meta_list.tokens.to_string();
303                entity_name = parse_entity_attr(&tokens);
304            }
305        }
306    }
307
308    let entity_name = entity_name.unwrap_or_else(|| {
309        panic!("Missing #[queryable(entity = \"name\")] attribute");
310    });
311
312    let expanded = quote! {
313        impl #impl_generics version_migrate::Queryable for #name #ty_generics #where_clause {
314            const ENTITY_NAME: &'static str = #entity_name;
315        }
316    };
317
318    TokenStream::from(expanded)
319}
320
321fn parse_entity_attr(tokens: &str) -> Option<String> {
322    for part in tokens.split(',') {
323        let part = part.trim();
324        if let Some(val) = parse_attr_value(part, "entity") {
325            return Some(val);
326        }
327    }
328    None
329}
330
331fn generate_serialize_impl(
332    input: &DeriveInput,
333    attrs: &VersionedAttributes,
334) -> proc_macro2::TokenStream {
335    let name = &input.ident;
336    let version = &attrs.version;
337    let version_key = &attrs.version_key;
338
339    // Extract field information
340    let fields = match &input.data {
341        syn::Data::Struct(data_struct) => match &data_struct.fields {
342            syn::Fields::Named(fields) => &fields.named,
343            _ => panic!("auto_tag only supports structs with named fields"),
344        },
345        _ => panic!("auto_tag only supports structs"),
346    };
347
348    let field_count = fields.len() + 1; // +1 for version field
349    let field_serializations = fields.iter().map(|field| {
350        let field_name = field.ident.as_ref().unwrap();
351        let field_name_str = field_name.to_string();
352        quote! {
353            state.serialize_field(#field_name_str, &self.#field_name)?;
354        }
355    });
356
357    quote! {
358        impl serde::Serialize for #name {
359            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
360            where
361                S: serde::Serializer,
362            {
363                use serde::ser::SerializeStruct;
364                let mut state = serializer.serialize_struct(stringify!(#name), #field_count)?;
365                state.serialize_field(#version_key, #version)?;
366                #(#field_serializations)*
367                state.end()
368            }
369        }
370    }
371}
372
373fn generate_deserialize_impl(
374    input: &DeriveInput,
375    attrs: &VersionedAttributes,
376) -> proc_macro2::TokenStream {
377    let name = &input.ident;
378    let version = &attrs.version;
379    let version_key = &attrs.version_key;
380
381    // Extract field information
382    let fields = match &input.data {
383        syn::Data::Struct(data_struct) => match &data_struct.fields {
384            syn::Fields::Named(fields) => &fields.named,
385            _ => panic!("auto_tag only supports structs with named fields"),
386        },
387        _ => panic!("auto_tag only supports structs"),
388    };
389
390    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
391    let field_name_strs: Vec<_> = field_names.iter().map(|f| f.to_string()).collect();
392
393    let all_field_names = {
394        let mut names = vec![version_key.clone()];
395        names.extend(field_name_strs.iter().cloned());
396        names
397    };
398
399    let field_enum_variants = field_names.iter().map(|name| {
400        let variant = quote::format_ident!("{}", name.to_string().to_uppercase());
401        quote! { #variant }
402    });
403
404    let field_match_arms =
405        field_names
406            .iter()
407            .zip(field_name_strs.iter())
408            .map(|(name, name_str)| {
409                let variant = quote::format_ident!("{}", name.to_string().to_uppercase());
410                quote! {
411                    #name_str => Ok(Field::#variant)
412                }
413            });
414
415    let field_visit_arms = field_names.iter().map(|name| {
416        let variant = quote::format_ident!("{}", name.to_string().to_uppercase());
417        quote! {
418            Field::#variant => {
419                if #name.is_some() {
420                    return Err(serde::de::Error::duplicate_field(stringify!(#name)));
421                }
422                #name = Some(map.next_value()?);
423            }
424        }
425    });
426
427    let field_unwrap = field_names.iter().map(|name| {
428        quote! {
429            let #name = #name.ok_or_else(|| serde::de::Error::missing_field(stringify!(#name)))?;
430        }
431    });
432
433    quote! {
434        impl<'de> serde::Deserialize<'de> for #name {
435            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
436            where
437                D: serde::Deserializer<'de>,
438            {
439                #[allow(non_camel_case_types)]
440                enum Field {
441                    Version,
442                    #(#field_enum_variants,)*
443                }
444
445                impl<'de> serde::Deserialize<'de> for Field {
446                    fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
447                    where
448                        D: serde::Deserializer<'de>,
449                    {
450                        struct FieldVisitor;
451
452                        impl<'de> serde::de::Visitor<'de> for FieldVisitor {
453                            type Value = Field;
454
455                            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
456                                formatter.write_str(&format!("field identifier: {}", &[#(#all_field_names),*].join(", ")))
457                            }
458
459                            fn visit_str<E>(self, value: &str) -> Result<Field, E>
460                            where
461                                E: serde::de::Error,
462                            {
463                                match value {
464                                    #version_key => Ok(Field::Version),
465                                    #(#field_match_arms,)*
466                                    _ => Err(serde::de::Error::unknown_field(value, &[#(#all_field_names),*])),
467                                }
468                            }
469                        }
470
471                        deserializer.deserialize_identifier(FieldVisitor)
472                    }
473                }
474
475                struct StructVisitor;
476
477                impl<'de> serde::de::Visitor<'de> for StructVisitor {
478                    type Value = #name;
479
480                    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
481                        formatter.write_str(&format!("struct {}", stringify!(#name)))
482                    }
483
484                    fn visit_map<V>(self, mut map: V) -> Result<#name, V::Error>
485                    where
486                        V: serde::de::MapAccess<'de>,
487                    {
488                        let mut version: Option<String> = None;
489                        #(let mut #field_names = None;)*
490
491                        while let Some(key) = map.next_key()? {
492                            match key {
493                                Field::Version => {
494                                    if version.is_some() {
495                                        return Err(serde::de::Error::duplicate_field(#version_key));
496                                    }
497                                    let v: String = map.next_value()?;
498                                    if v != #version {
499                                        return Err(serde::de::Error::custom(format!(
500                                            "version mismatch: expected {}, found {}",
501                                            #version, v
502                                        )));
503                                    }
504                                    version = Some(v);
505                                }
506                                #(#field_visit_arms)*
507                            }
508                        }
509
510                        let _version = version.ok_or_else(|| serde::de::Error::missing_field(#version_key))?;
511                        #(#field_unwrap)*
512
513                        Ok(#name {
514                            #(#field_names,)*
515                        })
516                    }
517                }
518
519                deserializer.deserialize_struct(
520                    stringify!(#name),
521                    &[#(#all_field_names),*],
522                    StructVisitor,
523                )
524            }
525        }
526    }
527}
528
529/// Derives the `LatestVersioned` trait for a domain entity.
530///
531/// This macro associates a domain entity with its latest versioned representation,
532/// enabling automatic conversion and saving using the latest version.
533///
534/// # Attributes
535///
536/// - `#[version_migrate(entity = "name", latest = Type)]`: Specifies the entity name
537///   and the latest versioned type (both required).
538/// - `#[version_migrate(..., save = true|false)]`: Controls whether to enable save functionality (default: false)
539///   When `save = false` (default), uses `into()` for read-only access.
540///   When `save = true`, uses `into_with_save()` to enable domain entity saving.
541///
542/// # Requirements
543///
544/// You must manually implement `FromDomain<YourEntity>` on the latest versioned type
545/// to define how to convert from the domain entity to the versioned format.
546/// When `save = true`, the `FromDomain` trait is required for the save functionality.
547///
548/// # Examples
549///
550/// Basic usage (read-only, default):
551/// ```ignore
552/// use version_migrate::{VersionMigrate, FromDomain, Versioned};
553/// use serde::{Serialize, Deserialize};
554///
555/// // Latest versioned type
556/// #[derive(Serialize, Deserialize, Versioned)]
557/// #[versioned(version = "1.1.0")]
558/// struct TaskV1_1_0 {
559///     id: String,
560///     title: String,
561///     description: Option<String>,
562/// }
563///
564/// // Domain entity (read-only, default)
565/// #[derive(Serialize, Deserialize, VersionMigrate)]
566/// #[version_migrate(entity = "task", latest = TaskV1_1_0)]
567/// struct TaskEntity {
568///     id: String,
569///     title: String,
570///     description: Option<String>,
571/// }
572/// ```
573///
574/// With save support:
575/// ```ignore
576/// // Domain entity with save support
577/// #[derive(Serialize, Deserialize, VersionMigrate)]
578/// #[version_migrate(entity = "task", latest = TaskV1_1_0, save = true)]
579/// struct TaskEntity {
580///     id: String,
581///     title: String,
582///     description: Option<String>,
583/// }
584///
585/// // Implement FromDomain to define the conversion
586/// impl FromDomain<TaskEntity> for TaskV1_1_0 {
587///     fn from_domain(entity: TaskEntity) -> Self {
588///         TaskV1_1_0 {
589///             id: entity.id,
590///             title: entity.title,
591///             description: entity.description,
592///         }
593///     }
594/// }
595///
596/// // Now you can save entities directly
597/// let entity = TaskEntity {
598///     id: "1".into(),
599///     title: "My Task".into(),
600///     description: Some("Description".into()),
601/// };
602/// let json = migrator.save_entity(entity)?; // Automatically uses TaskV1_1_0
603/// ```
604#[proc_macro_derive(VersionMigrate, attributes(version_migrate))]
605pub fn derive_version_migrate(input: TokenStream) -> TokenStream {
606    let input = parse_macro_input!(input as DeriveInput);
607
608    let name = &input.ident;
609    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
610
611    // Extract attributes
612    let mut entity_name: Option<String> = None;
613    let mut latest_type: Option<Type> = None;
614    let mut save = false; // Default to false (read-only)
615
616    for attr in &input.attrs {
617        if attr.path().is_ident("version_migrate") {
618            if let Meta::List(meta_list) = &attr.meta {
619                let tokens = meta_list.tokens.to_string();
620                parse_version_migrate_attrs(&tokens, &mut entity_name, &mut latest_type, &mut save);
621            }
622        }
623    }
624
625    let entity_name = entity_name.unwrap_or_else(|| {
626        panic!("Missing #[version_migrate(entity = \"name\", ...)] attribute");
627    });
628
629    let latest_type = latest_type.unwrap_or_else(|| {
630        panic!("Missing #[version_migrate(..., latest = Type)] attribute");
631    });
632
633    let expanded = quote! {
634        impl #impl_generics version_migrate::LatestVersioned for #name #ty_generics #where_clause {
635            type Latest = #latest_type;
636            const ENTITY_NAME: &'static str = #entity_name;
637            const SAVE: bool = #save;
638        }
639    };
640
641    TokenStream::from(expanded)
642}
643
644fn parse_version_migrate_attrs(
645    tokens: &str,
646    entity_name: &mut Option<String>,
647    latest_type: &mut Option<Type>,
648    save: &mut bool,
649) {
650    // Split by commas but preserve type paths
651    let parts: Vec<&str> = tokens.split(',').collect();
652
653    for part in parts {
654        let part = part.trim();
655
656        if let Some(val) = parse_attr_value(part, "entity") {
657            *entity_name = Some(val);
658        } else if let Some(rest) = part.strip_prefix("latest") {
659            let rest = rest.trim();
660            if let Some(rest) = rest.strip_prefix('=') {
661                let type_str = rest.trim();
662                // Parse the type using syn
663                if let Ok(ty) = syn::parse_str::<Type>(type_str) {
664                    *latest_type = Some(ty);
665                }
666            }
667        } else if let Some(val) = parse_attr_bool_value(part, "save") {
668            *save = val;
669        }
670    }
671}