1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput, Meta, Type};
4
5#[proc_macro_derive(Versioned, attributes(versioned))]
76pub fn derive_versioned(input: TokenStream) -> TokenStream {
77 let input = parse_macro_input!(input as DeriveInput);
78
79 let attrs = extract_attributes(&input);
81
82 let name = &input.ident;
83 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
84
85 let version = &attrs.version;
86 let version_key = &attrs.version_key;
87 let data_key = &attrs.data_key;
88
89 let versioned_impl = quote! {
90 impl #impl_generics version_migrate::Versioned for #name #ty_generics #where_clause {
91 const VERSION: &'static str = #version;
92 const VERSION_KEY: &'static str = #version_key;
93 const DATA_KEY: &'static str = #data_key;
94 }
95 };
96
97 let mut impls = vec![versioned_impl];
98
99 if attrs.auto_tag {
100 let serialize_impl = generate_serialize_impl(&input, &attrs);
102 let deserialize_impl = generate_deserialize_impl(&input, &attrs);
103 impls.push(serialize_impl);
104 impls.push(deserialize_impl);
105 }
106
107 if attrs.queryable {
108 let queryable_impl = generate_queryable_impl(&input, &attrs);
110 impls.push(queryable_impl);
111 }
112
113 let expanded = quote! {
114 #(#impls)*
115 };
116
117 TokenStream::from(expanded)
118}
119
120struct VersionedAttributes {
121 version: String,
122 version_key: String,
123 data_key: String,
124 auto_tag: bool,
125 queryable: bool,
126 queryable_key: Option<String>,
127}
128
129fn extract_attributes(input: &DeriveInput) -> VersionedAttributes {
130 let mut version = None;
131 let mut version_key = String::from("version");
132 let mut data_key = String::from("data");
133 let mut auto_tag = false;
134 let mut queryable = false;
135 let mut queryable_key = None;
136
137 for attr in &input.attrs {
138 if attr.path().is_ident("versioned") {
139 if let Meta::List(meta_list) = &attr.meta {
140 let tokens = meta_list.tokens.to_string();
141 parse_versioned_attrs(
142 &tokens,
143 &mut version,
144 &mut version_key,
145 &mut data_key,
146 &mut auto_tag,
147 &mut queryable,
148 &mut queryable_key,
149 );
150 }
151 }
152 }
153
154 let version = version.unwrap_or_else(|| {
155 panic!("Missing #[versioned(version = \"x.y.z\")] attribute");
156 });
157
158 if let Err(e) = semver::Version::parse(&version) {
160 panic!("Invalid semantic version '{}': {}", version, e);
161 }
162
163 VersionedAttributes {
164 version,
165 version_key,
166 data_key,
167 auto_tag,
168 queryable,
169 queryable_key,
170 }
171}
172
173fn parse_versioned_attrs(
174 tokens: &str,
175 version: &mut Option<String>,
176 version_key: &mut String,
177 data_key: &mut String,
178 auto_tag: &mut bool,
179 queryable: &mut bool,
180 queryable_key: &mut Option<String>,
181) {
182 for part in tokens.split(',') {
184 let part = part.trim();
185
186 if let Some(val) = parse_attr_value(part, "version") {
187 *version = Some(val);
188 } else if let Some(val) = parse_attr_value(part, "version_key") {
189 *version_key = val;
190 } else if let Some(val) = parse_attr_value(part, "data_key") {
191 *data_key = val;
192 } else if let Some(val) = parse_attr_bool_value(part, "auto_tag") {
193 *auto_tag = val;
194 } else if let Some(val) = parse_attr_bool_value(part, "queryable") {
195 *queryable = val;
196 } else if let Some(val) = parse_attr_value(part, "queryable_key") {
197 *queryable_key = Some(val);
198 }
199 }
200}
201
202fn parse_attr_value(token: &str, key: &str) -> Option<String> {
203 let token = token.trim();
204 if let Some(rest) = token.strip_prefix(key) {
205 let rest = rest.trim();
206 if let Some(rest) = rest.strip_prefix('=') {
207 let rest = rest.trim();
208 if rest.starts_with('"') && rest.ends_with('"') {
209 return Some(rest[1..rest.len() - 1].to_string());
210 }
211 }
212 }
213 None
214}
215
216fn parse_attr_bool_value(token: &str, key: &str) -> Option<bool> {
217 let token = token.trim();
218 if let Some(rest) = token.strip_prefix(key) {
219 let rest = rest.trim();
220 if let Some(rest) = rest.strip_prefix('=') {
221 let rest = rest.trim();
222 return match rest {
223 "true" => Some(true),
224 "false" => Some(false),
225 _ => None,
226 };
227 }
228 }
229 None
230}
231
232fn generate_queryable_impl(
233 input: &DeriveInput,
234 attrs: &VersionedAttributes,
235) -> proc_macro2::TokenStream {
236 let name = &input.ident;
237
238 let entity_name = if let Some(ref key) = attrs.queryable_key {
240 key.clone()
241 } else {
242 name.to_string().to_lowercase()
244 };
245
246 quote! {
247 impl version_migrate::Queryable for #name {
248 const ENTITY_NAME: &'static str = #entity_name;
249 }
250 }
251}
252
253#[proc_macro_derive(Queryable, attributes(queryable))]
291pub fn derive_queryable(input: TokenStream) -> TokenStream {
292 let input = parse_macro_input!(input as DeriveInput);
293
294 let name = &input.ident;
295 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
296 let mut entity_name: Option<String> = None;
297
298 for attr in &input.attrs {
300 if attr.path().is_ident("queryable") {
301 if let Meta::List(meta_list) = &attr.meta {
302 let tokens = meta_list.tokens.to_string();
303 entity_name = parse_entity_attr(&tokens);
304 }
305 }
306 }
307
308 let entity_name = entity_name.unwrap_or_else(|| {
309 panic!("Missing #[queryable(entity = \"name\")] attribute");
310 });
311
312 let expanded = quote! {
313 impl #impl_generics version_migrate::Queryable for #name #ty_generics #where_clause {
314 const ENTITY_NAME: &'static str = #entity_name;
315 }
316 };
317
318 TokenStream::from(expanded)
319}
320
321fn parse_entity_attr(tokens: &str) -> Option<String> {
322 for part in tokens.split(',') {
323 let part = part.trim();
324 if let Some(val) = parse_attr_value(part, "entity") {
325 return Some(val);
326 }
327 }
328 None
329}
330
331fn generate_serialize_impl(
332 input: &DeriveInput,
333 attrs: &VersionedAttributes,
334) -> proc_macro2::TokenStream {
335 let name = &input.ident;
336 let version = &attrs.version;
337 let version_key = &attrs.version_key;
338
339 let fields = match &input.data {
341 syn::Data::Struct(data_struct) => match &data_struct.fields {
342 syn::Fields::Named(fields) => &fields.named,
343 _ => panic!("auto_tag only supports structs with named fields"),
344 },
345 _ => panic!("auto_tag only supports structs"),
346 };
347
348 let field_count = fields.len() + 1; let field_serializations = fields.iter().map(|field| {
350 let field_name = field.ident.as_ref().unwrap();
351 let field_name_str = field_name.to_string();
352 quote! {
353 state.serialize_field(#field_name_str, &self.#field_name)?;
354 }
355 });
356
357 quote! {
358 impl serde::Serialize for #name {
359 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
360 where
361 S: serde::Serializer,
362 {
363 use serde::ser::SerializeStruct;
364 let mut state = serializer.serialize_struct(stringify!(#name), #field_count)?;
365 state.serialize_field(#version_key, #version)?;
366 #(#field_serializations)*
367 state.end()
368 }
369 }
370 }
371}
372
373fn generate_deserialize_impl(
374 input: &DeriveInput,
375 attrs: &VersionedAttributes,
376) -> proc_macro2::TokenStream {
377 let name = &input.ident;
378 let version = &attrs.version;
379 let version_key = &attrs.version_key;
380
381 let fields = match &input.data {
383 syn::Data::Struct(data_struct) => match &data_struct.fields {
384 syn::Fields::Named(fields) => &fields.named,
385 _ => panic!("auto_tag only supports structs with named fields"),
386 },
387 _ => panic!("auto_tag only supports structs"),
388 };
389
390 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
391 let field_name_strs: Vec<_> = field_names.iter().map(|f| f.to_string()).collect();
392
393 let all_field_names = {
394 let mut names = vec![version_key.clone()];
395 names.extend(field_name_strs.iter().cloned());
396 names
397 };
398
399 let field_enum_variants = field_names.iter().map(|name| {
400 let variant = quote::format_ident!("{}", name.to_string().to_uppercase());
401 quote! { #variant }
402 });
403
404 let field_match_arms =
405 field_names
406 .iter()
407 .zip(field_name_strs.iter())
408 .map(|(name, name_str)| {
409 let variant = quote::format_ident!("{}", name.to_string().to_uppercase());
410 quote! {
411 #name_str => Ok(Field::#variant)
412 }
413 });
414
415 let field_visit_arms = field_names.iter().map(|name| {
416 let variant = quote::format_ident!("{}", name.to_string().to_uppercase());
417 quote! {
418 Field::#variant => {
419 if #name.is_some() {
420 return Err(serde::de::Error::duplicate_field(stringify!(#name)));
421 }
422 #name = Some(map.next_value()?);
423 }
424 }
425 });
426
427 let field_unwrap = field_names.iter().map(|name| {
428 quote! {
429 let #name = #name.ok_or_else(|| serde::de::Error::missing_field(stringify!(#name)))?;
430 }
431 });
432
433 quote! {
434 impl<'de> serde::Deserialize<'de> for #name {
435 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
436 where
437 D: serde::Deserializer<'de>,
438 {
439 #[allow(non_camel_case_types)]
440 enum Field {
441 Version,
442 #(#field_enum_variants,)*
443 }
444
445 impl<'de> serde::Deserialize<'de> for Field {
446 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
447 where
448 D: serde::Deserializer<'de>,
449 {
450 struct FieldVisitor;
451
452 impl<'de> serde::de::Visitor<'de> for FieldVisitor {
453 type Value = Field;
454
455 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
456 formatter.write_str(&format!("field identifier: {}", &[#(#all_field_names),*].join(", ")))
457 }
458
459 fn visit_str<E>(self, value: &str) -> Result<Field, E>
460 where
461 E: serde::de::Error,
462 {
463 match value {
464 #version_key => Ok(Field::Version),
465 #(#field_match_arms,)*
466 _ => Err(serde::de::Error::unknown_field(value, &[#(#all_field_names),*])),
467 }
468 }
469 }
470
471 deserializer.deserialize_identifier(FieldVisitor)
472 }
473 }
474
475 struct StructVisitor;
476
477 impl<'de> serde::de::Visitor<'de> for StructVisitor {
478 type Value = #name;
479
480 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
481 formatter.write_str(&format!("struct {}", stringify!(#name)))
482 }
483
484 fn visit_map<V>(self, mut map: V) -> Result<#name, V::Error>
485 where
486 V: serde::de::MapAccess<'de>,
487 {
488 let mut version: Option<String> = None;
489 #(let mut #field_names = None;)*
490
491 while let Some(key) = map.next_key()? {
492 match key {
493 Field::Version => {
494 if version.is_some() {
495 return Err(serde::de::Error::duplicate_field(#version_key));
496 }
497 let v: String = map.next_value()?;
498 if v != #version {
499 return Err(serde::de::Error::custom(format!(
500 "version mismatch: expected {}, found {}",
501 #version, v
502 )));
503 }
504 version = Some(v);
505 }
506 #(#field_visit_arms)*
507 }
508 }
509
510 let _version = version.ok_or_else(|| serde::de::Error::missing_field(#version_key))?;
511 #(#field_unwrap)*
512
513 Ok(#name {
514 #(#field_names,)*
515 })
516 }
517 }
518
519 deserializer.deserialize_struct(
520 stringify!(#name),
521 &[#(#all_field_names),*],
522 StructVisitor,
523 )
524 }
525 }
526 }
527}
528
529#[proc_macro_derive(VersionMigrate, attributes(version_migrate))]
605pub fn derive_version_migrate(input: TokenStream) -> TokenStream {
606 let input = parse_macro_input!(input as DeriveInput);
607
608 let name = &input.ident;
609 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
610
611 let mut entity_name: Option<String> = None;
613 let mut latest_type: Option<Type> = None;
614 let mut save = false; for attr in &input.attrs {
617 if attr.path().is_ident("version_migrate") {
618 if let Meta::List(meta_list) = &attr.meta {
619 let tokens = meta_list.tokens.to_string();
620 parse_version_migrate_attrs(&tokens, &mut entity_name, &mut latest_type, &mut save);
621 }
622 }
623 }
624
625 let entity_name = entity_name.unwrap_or_else(|| {
626 panic!("Missing #[version_migrate(entity = \"name\", ...)] attribute");
627 });
628
629 let latest_type = latest_type.unwrap_or_else(|| {
630 panic!("Missing #[version_migrate(..., latest = Type)] attribute");
631 });
632
633 let expanded = quote! {
634 impl #impl_generics version_migrate::LatestVersioned for #name #ty_generics #where_clause {
635 type Latest = #latest_type;
636 const ENTITY_NAME: &'static str = #entity_name;
637 const SAVE: bool = #save;
638 }
639 };
640
641 TokenStream::from(expanded)
642}
643
644fn parse_version_migrate_attrs(
645 tokens: &str,
646 entity_name: &mut Option<String>,
647 latest_type: &mut Option<Type>,
648 save: &mut bool,
649) {
650 let parts: Vec<&str> = tokens.split(',').collect();
652
653 for part in parts {
654 let part = part.trim();
655
656 if let Some(val) = parse_attr_value(part, "entity") {
657 *entity_name = Some(val);
658 } else if let Some(rest) = part.strip_prefix("latest") {
659 let rest = rest.trim();
660 if let Some(rest) = rest.strip_prefix('=') {
661 let type_str = rest.trim();
662 if let Ok(ty) = syn::parse_str::<Type>(type_str) {
664 *latest_type = Some(ty);
665 }
666 }
667 } else if let Some(val) = parse_attr_bool_value(part, "save") {
668 *save = val;
669 }
670 }
671}