1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput, Meta};
4
5#[proc_macro_derive(Versioned, attributes(versioned))]
76pub fn derive_versioned(input: TokenStream) -> TokenStream {
77 let input = parse_macro_input!(input as DeriveInput);
78
79 let attrs = extract_attributes(&input);
81
82 let name = &input.ident;
83 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
84
85 let version = &attrs.version;
86 let version_key = &attrs.version_key;
87 let data_key = &attrs.data_key;
88
89 let versioned_impl = quote! {
90 impl #impl_generics version_migrate::Versioned for #name #ty_generics #where_clause {
91 const VERSION: &'static str = #version;
92 const VERSION_KEY: &'static str = #version_key;
93 const DATA_KEY: &'static str = #data_key;
94 }
95 };
96
97 let mut impls = vec![versioned_impl];
98
99 if attrs.auto_tag {
100 let serialize_impl = generate_serialize_impl(&input, &attrs);
102 let deserialize_impl = generate_deserialize_impl(&input, &attrs);
103 impls.push(serialize_impl);
104 impls.push(deserialize_impl);
105 }
106
107 if attrs.queryable {
108 let queryable_impl = generate_queryable_impl(&input, &attrs);
110 impls.push(queryable_impl);
111 }
112
113 let expanded = quote! {
114 #(#impls)*
115 };
116
117 TokenStream::from(expanded)
118}
119
120struct VersionedAttributes {
121 version: String,
122 version_key: String,
123 data_key: String,
124 auto_tag: bool,
125 queryable: bool,
126 queryable_key: Option<String>,
127}
128
129fn extract_attributes(input: &DeriveInput) -> VersionedAttributes {
130 let mut version = None;
131 let mut version_key = String::from("version");
132 let mut data_key = String::from("data");
133 let mut auto_tag = false;
134 let mut queryable = false;
135 let mut queryable_key = None;
136
137 for attr in &input.attrs {
138 if attr.path().is_ident("versioned") {
139 if let Meta::List(meta_list) = &attr.meta {
140 let tokens = meta_list.tokens.to_string();
141 parse_versioned_attrs(
142 &tokens,
143 &mut version,
144 &mut version_key,
145 &mut data_key,
146 &mut auto_tag,
147 &mut queryable,
148 &mut queryable_key,
149 );
150 }
151 }
152 }
153
154 let version = version.unwrap_or_else(|| {
155 panic!("Missing #[versioned(version = \"x.y.z\")] attribute");
156 });
157
158 if let Err(e) = semver::Version::parse(&version) {
160 panic!("Invalid semantic version '{}': {}", version, e);
161 }
162
163 VersionedAttributes {
164 version,
165 version_key,
166 data_key,
167 auto_tag,
168 queryable,
169 queryable_key,
170 }
171}
172
173fn parse_versioned_attrs(
174 tokens: &str,
175 version: &mut Option<String>,
176 version_key: &mut String,
177 data_key: &mut String,
178 auto_tag: &mut bool,
179 queryable: &mut bool,
180 queryable_key: &mut Option<String>,
181) {
182 for part in tokens.split(',') {
184 let part = part.trim();
185
186 if let Some(val) = parse_attr_value(part, "version") {
187 *version = Some(val);
188 } else if let Some(val) = parse_attr_value(part, "version_key") {
189 *version_key = val;
190 } else if let Some(val) = parse_attr_value(part, "data_key") {
191 *data_key = val;
192 } else if let Some(val) = parse_attr_bool_value(part, "auto_tag") {
193 *auto_tag = val;
194 } else if let Some(val) = parse_attr_bool_value(part, "queryable") {
195 *queryable = val;
196 } else if let Some(val) = parse_attr_value(part, "queryable_key") {
197 *queryable_key = Some(val);
198 }
199 }
200}
201
202fn parse_attr_value(token: &str, key: &str) -> Option<String> {
203 let token = token.trim();
204 if let Some(rest) = token.strip_prefix(key) {
205 let rest = rest.trim();
206 if let Some(rest) = rest.strip_prefix('=') {
207 let rest = rest.trim();
208 if rest.starts_with('"') && rest.ends_with('"') {
209 return Some(rest[1..rest.len() - 1].to_string());
210 }
211 }
212 }
213 None
214}
215
216fn parse_attr_bool_value(token: &str, key: &str) -> Option<bool> {
217 let token = token.trim();
218 if let Some(rest) = token.strip_prefix(key) {
219 let rest = rest.trim();
220 if let Some(rest) = rest.strip_prefix('=') {
221 let rest = rest.trim();
222 return match rest {
223 "true" => Some(true),
224 "false" => Some(false),
225 _ => None,
226 };
227 }
228 }
229 None
230}
231
232fn generate_queryable_impl(
233 input: &DeriveInput,
234 attrs: &VersionedAttributes,
235) -> proc_macro2::TokenStream {
236 let name = &input.ident;
237
238 let entity_name = if let Some(ref key) = attrs.queryable_key {
240 key.clone()
241 } else {
242 name.to_string().to_lowercase()
244 };
245
246 quote! {
247 impl version_migrate::Queryable for #name {
248 const ENTITY_NAME: &'static str = #entity_name;
249 }
250 }
251}
252
253fn generate_serialize_impl(
254 input: &DeriveInput,
255 attrs: &VersionedAttributes,
256) -> proc_macro2::TokenStream {
257 let name = &input.ident;
258 let version = &attrs.version;
259 let version_key = &attrs.version_key;
260
261 let fields = match &input.data {
263 syn::Data::Struct(data_struct) => match &data_struct.fields {
264 syn::Fields::Named(fields) => &fields.named,
265 _ => panic!("auto_tag only supports structs with named fields"),
266 },
267 _ => panic!("auto_tag only supports structs"),
268 };
269
270 let field_count = fields.len() + 1; let field_serializations = fields.iter().map(|field| {
272 let field_name = field.ident.as_ref().unwrap();
273 let field_name_str = field_name.to_string();
274 quote! {
275 state.serialize_field(#field_name_str, &self.#field_name)?;
276 }
277 });
278
279 quote! {
280 impl serde::Serialize for #name {
281 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
282 where
283 S: serde::Serializer,
284 {
285 use serde::ser::SerializeStruct;
286 let mut state = serializer.serialize_struct(stringify!(#name), #field_count)?;
287 state.serialize_field(#version_key, #version)?;
288 #(#field_serializations)*
289 state.end()
290 }
291 }
292 }
293}
294
295fn generate_deserialize_impl(
296 input: &DeriveInput,
297 attrs: &VersionedAttributes,
298) -> proc_macro2::TokenStream {
299 let name = &input.ident;
300 let version = &attrs.version;
301 let version_key = &attrs.version_key;
302
303 let fields = match &input.data {
305 syn::Data::Struct(data_struct) => match &data_struct.fields {
306 syn::Fields::Named(fields) => &fields.named,
307 _ => panic!("auto_tag only supports structs with named fields"),
308 },
309 _ => panic!("auto_tag only supports structs"),
310 };
311
312 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
313 let field_name_strs: Vec<_> = field_names.iter().map(|f| f.to_string()).collect();
314
315 let all_field_names = {
316 let mut names = vec![version_key.clone()];
317 names.extend(field_name_strs.iter().cloned());
318 names
319 };
320
321 let field_enum_variants = field_names.iter().map(|name| {
322 let variant = quote::format_ident!("{}", name.to_string().to_uppercase());
323 quote! { #variant }
324 });
325
326 let field_match_arms =
327 field_names
328 .iter()
329 .zip(field_name_strs.iter())
330 .map(|(name, name_str)| {
331 let variant = quote::format_ident!("{}", name.to_string().to_uppercase());
332 quote! {
333 #name_str => Ok(Field::#variant)
334 }
335 });
336
337 let field_visit_arms = field_names.iter().map(|name| {
338 let variant = quote::format_ident!("{}", name.to_string().to_uppercase());
339 quote! {
340 Field::#variant => {
341 if #name.is_some() {
342 return Err(serde::de::Error::duplicate_field(stringify!(#name)));
343 }
344 #name = Some(map.next_value()?);
345 }
346 }
347 });
348
349 let field_unwrap = field_names.iter().map(|name| {
350 quote! {
351 let #name = #name.ok_or_else(|| serde::de::Error::missing_field(stringify!(#name)))?;
352 }
353 });
354
355 quote! {
356 impl<'de> serde::Deserialize<'de> for #name {
357 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
358 where
359 D: serde::Deserializer<'de>,
360 {
361 #[allow(non_camel_case_types)]
362 enum Field {
363 Version,
364 #(#field_enum_variants,)*
365 }
366
367 impl<'de> serde::Deserialize<'de> for Field {
368 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
369 where
370 D: serde::Deserializer<'de>,
371 {
372 struct FieldVisitor;
373
374 impl<'de> serde::de::Visitor<'de> for FieldVisitor {
375 type Value = Field;
376
377 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
378 formatter.write_str(&format!("field identifier: {}", &[#(#all_field_names),*].join(", ")))
379 }
380
381 fn visit_str<E>(self, value: &str) -> Result<Field, E>
382 where
383 E: serde::de::Error,
384 {
385 match value {
386 #version_key => Ok(Field::Version),
387 #(#field_match_arms,)*
388 _ => Err(serde::de::Error::unknown_field(value, &[#(#all_field_names),*])),
389 }
390 }
391 }
392
393 deserializer.deserialize_identifier(FieldVisitor)
394 }
395 }
396
397 struct StructVisitor;
398
399 impl<'de> serde::de::Visitor<'de> for StructVisitor {
400 type Value = #name;
401
402 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
403 formatter.write_str(&format!("struct {}", stringify!(#name)))
404 }
405
406 fn visit_map<V>(self, mut map: V) -> Result<#name, V::Error>
407 where
408 V: serde::de::MapAccess<'de>,
409 {
410 let mut version: Option<String> = None;
411 #(let mut #field_names = None;)*
412
413 while let Some(key) = map.next_key()? {
414 match key {
415 Field::Version => {
416 if version.is_some() {
417 return Err(serde::de::Error::duplicate_field(#version_key));
418 }
419 let v: String = map.next_value()?;
420 if v != #version {
421 return Err(serde::de::Error::custom(format!(
422 "version mismatch: expected {}, found {}",
423 #version, v
424 )));
425 }
426 version = Some(v);
427 }
428 #(#field_visit_arms)*
429 }
430 }
431
432 let _version = version.ok_or_else(|| serde::de::Error::missing_field(#version_key))?;
433 #(#field_unwrap)*
434
435 Ok(#name {
436 #(#field_names,)*
437 })
438 }
439 }
440
441 deserializer.deserialize_struct(
442 stringify!(#name),
443 &[#(#all_field_names),*],
444 StructVisitor,
445 )
446 }
447 }
448 }
449}