tfhe_versionable_derive/
lib.rs

1//! Set of derive macro to automatically implement the `Versionize` and `Unversionize` traits.
2//! The macro defined in this crate are:
3//! - `Versionize`: should be derived on the main type that is used in your code
4//! - `Version`: should be derived on a previous version of this type
5//! - `VersionsDispatch`: should be derived ont the enum that holds all the versions of the type
6//! - `NotVersioned`: can be used to implement `Versionize` for a type that is not really versioned
7
8mod associated;
9mod dispatch_type;
10mod transparent;
11mod version_type;
12mod versionize_attribute;
13mod versionize_impl;
14
15use dispatch_type::DispatchType;
16use proc_macro::TokenStream;
17use proc_macro2::Span;
18use quote::{quote, ToTokens};
19use syn::parse::Parse;
20use syn::punctuated::Punctuated;
21use syn::token::Plus;
22use syn::{
23    parse_macro_input, parse_quote, DeriveInput, GenericParam, Generics, Ident, Lifetime,
24    LifetimeParam, Path, TraitBound, TraitBoundModifier, Type, TypeParamBound, WhereClause,
25};
26
27/// Adds the full path of the current crate name to avoid name clashes in generated code.
28macro_rules! crate_full_path {
29    ($trait_name:expr) => {
30        concat!("::tfhe_versionable::", $trait_name)
31    };
32}
33
34pub(crate) const LIFETIME_NAME: &str = "'vers";
35pub(crate) const VERSION_TRAIT_NAME: &str = crate_full_path!("Version");
36pub(crate) const DISPATCH_TRAIT_NAME: &str = crate_full_path!("VersionsDispatch");
37pub(crate) const VERSIONIZE_TRAIT_NAME: &str = crate_full_path!("Versionize");
38pub(crate) const VERSIONIZE_OWNED_TRAIT_NAME: &str = crate_full_path!("VersionizeOwned");
39pub(crate) const VERSIONIZE_SLICE_TRAIT_NAME: &str = crate_full_path!("VersionizeSlice");
40pub(crate) const VERSIONIZE_VEC_TRAIT_NAME: &str = crate_full_path!("VersionizeVec");
41pub(crate) const UNVERSIONIZE_TRAIT_NAME: &str = crate_full_path!("Unversionize");
42pub(crate) const UNVERSIONIZE_VEC_TRAIT_NAME: &str = crate_full_path!("UnversionizeVec");
43pub(crate) const UPGRADE_TRAIT_NAME: &str = crate_full_path!("Upgrade");
44pub(crate) const UNVERSIONIZE_ERROR_NAME: &str = crate_full_path!("UnversionizeError");
45
46pub(crate) const SERIALIZE_TRAIT_NAME: &str = "::serde::Serialize";
47pub(crate) const DESERIALIZE_TRAIT_NAME: &str = "::serde::Deserialize";
48pub(crate) const DESERIALIZE_OWNED_TRAIT_NAME: &str = "::serde::de::DeserializeOwned";
49pub(crate) const FROM_TRAIT_NAME: &str = "::core::convert::From";
50pub(crate) const TRY_INTO_TRAIT_NAME: &str = "::core::convert::TryInto";
51pub(crate) const INTO_TRAIT_NAME: &str = "::core::convert::Into";
52pub(crate) const ERROR_TRAIT_NAME: &str = "::core::error::Error";
53pub(crate) const SYNC_TRAIT_NAME: &str = "::core::marker::Sync";
54pub(crate) const SEND_TRAIT_NAME: &str = "::core::marker::Send";
55pub(crate) const DEFAULT_TRAIT_NAME: &str = "::core::default::Default";
56pub(crate) const STATIC_LIFETIME_NAME: &str = "'static";
57
58use associated::AssociatingTrait;
59use versionize_impl::VersionizeImplementor;
60
61use crate::version_type::VersionType;
62use crate::versionize_attribute::VersionizeAttribute;
63
64/// unwrap a `syn::Result` by extracting the Ok value or returning from the outer function with
65/// a compile error
66macro_rules! syn_unwrap {
67    ($e:expr) => {
68        match $e {
69            Ok(res) => res,
70            Err(err) => return err.to_compile_error().into(),
71        }
72    };
73}
74
75#[proc_macro_derive(Version, attributes(versionize))]
76/// Implement the `Version` trait for the target type.
77pub fn derive_version(input: TokenStream) -> TokenStream {
78    let input = parse_macro_input!(input as DeriveInput);
79
80    impl_version_trait(&input).into()
81}
82
83/// Actual implementation of the version trait. This will create the ref and owned
84/// associated types and use them to implement the trait.
85fn impl_version_trait(input: &DeriveInput) -> proc_macro2::TokenStream {
86    let version_trait = syn_unwrap!(AssociatingTrait::<VersionType>::new(
87        input,
88        VERSION_TRAIT_NAME,
89    ));
90
91    let version_types = syn_unwrap!(version_trait.generate_types_declarations());
92
93    let version_impl = syn_unwrap!(version_trait.generate_impl());
94
95    quote! {
96        const _: () = {
97            #version_types
98
99            #[automatically_derived]
100            #version_impl
101        };
102    }
103}
104
105/// Implement the `VersionsDispatch` trait for the target type. The type where this macro is
106/// applied should be an enum where each variant is a version of the type that we want to
107/// versionize.
108#[proc_macro_derive(VersionsDispatch)]
109pub fn derive_versions_dispatch(input: TokenStream) -> TokenStream {
110    let input = parse_macro_input!(input as DeriveInput);
111
112    let dispatch_trait = syn_unwrap!(AssociatingTrait::<DispatchType>::new(
113        &input,
114        DISPATCH_TRAIT_NAME,
115    ));
116
117    let dispatch_types = syn_unwrap!(dispatch_trait.generate_types_declarations());
118
119    let dispatch_impl = syn_unwrap!(dispatch_trait.generate_impl());
120
121    quote! {
122        const _: () = {
123            #dispatch_types
124
125            #[automatically_derived]
126            #dispatch_impl
127        };
128    }
129    .into()
130}
131
132/// This derives the `Versionize` and `Unversionize` trait for the target type.
133///
134/// This macro has a mandatory attribute parameter, which is the name of the versioned enum for this
135/// type. This enum can be anywhere in the code but should be in scope.
136///
137/// Example:
138/// ```ignore
139/// // The structure that should be versioned, as defined in your code
140/// #[derive(Versionize)]
141/// // We have to link to the enum type that will holds all the versions of this
142/// // type. This can also be written `#[versionize(dispatch = MyStructVersions)]`.
143/// #[versionize(MyStructVersions)]
144/// struct MyStruct<T> {
145///     attr: T,
146///     builtin: u32,
147/// }
148///
149/// // To avoid polluting your code, the old versions can be defined in another module/file, along with
150/// // the dispatch enum
151/// #[derive(Version)] // Used to mark an old version of the type
152/// struct MyStructV0 {
153///     builtin: u32,
154/// }
155///
156/// // The Upgrade trait tells how to go from the first version to the last. During unversioning, the
157/// // upgrade method will be called on the deserialized value enough times to go to the last variant.
158/// impl<T: Default> Upgrade<MyStruct<T>> for MyStructV0 {
159///     type Error = Infallible;
160///
161///     fn upgrade(self) -> Result<MyStruct<T>, Self::Error> {
162///         Ok(MyStruct {
163///             attr: T::default(),
164///             builtin: self.builtin,
165///         })
166///     }
167/// }
168///
169/// // This is the dispatch enum, that holds one variant for each version of your type.
170/// #[derive(VersionsDispatch)]
171/// // This enum is not directly used but serves as a template to generate a new enum that will be
172/// // serialized. This allows recursive versioning.
173/// #[allow(unused)]
174/// enum MyStructVersions<T> {
175///     V0(MyStructV0),
176///     V1(MyStruct<T>),
177/// }
178/// ```
179#[proc_macro_derive(Versionize, attributes(versionize))]
180pub fn derive_versionize(input: TokenStream) -> TokenStream {
181    let input = parse_macro_input!(input as DeriveInput);
182
183    let input_generics = filter_unsized_bounds(&input.generics);
184
185    let attributes = syn_unwrap!(VersionizeAttribute::parse_from_attributes_list(
186        &input.attrs
187    ));
188
189    let implementor = syn_unwrap!(VersionizeImplementor::new(
190        attributes,
191        &input.data,
192        Span::call_site()
193    ));
194
195    // If we apply a type conversion before the call to versionize, the type that implements
196    // the `Version` trait is the target type and not Self
197    let version_trait_impl: Option<proc_macro2::TokenStream> =
198        if implementor.is_directly_versioned() {
199            Some(impl_version_trait(&input))
200        } else {
201            None
202        };
203
204    // Parse the name of the traits that we will implement
205    let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
206    let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
207    let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME);
208    let versionize_vec_trait: Path = parse_const_str(VERSIONIZE_VEC_TRAIT_NAME);
209    let versionize_slice_trait: Path = parse_const_str(VERSIONIZE_SLICE_TRAIT_NAME);
210    let unversionize_vec_trait: Path = parse_const_str(UNVERSIONIZE_VEC_TRAIT_NAME);
211
212    let input_ident = &input.ident;
213    let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site());
214
215    // split generics so they can be used inside the generated code
216    let (_, ty_generics, _) = input_generics.split_for_impl();
217
218    // Generates the associated types required by the traits
219    let versioned_type = implementor.versioned_type(&lifetime, &input_generics);
220    let versioned_owned_type = implementor.versioned_owned_type(&input_generics);
221    let versioned_type_where_clause =
222        implementor.versioned_type_where_clause(&lifetime, &input_generics);
223    let versioned_owned_type_where_clause =
224        implementor.versioned_owned_type_where_clause(&input_generics);
225
226    // If the original type has some generics, we need to add bounds on them for
227    // the traits impl
228    let versionize_trait_where_clause =
229        syn_unwrap!(implementor.versionize_trait_where_clause(&input_generics));
230    let versionize_owned_trait_where_clause =
231        syn_unwrap!(implementor.versionize_owned_trait_where_clause(&input_generics));
232    let unversionize_trait_where_clause =
233        syn_unwrap!(implementor.unversionize_trait_where_clause(&input_generics));
234
235    let trait_impl_generics = input_generics.split_for_impl().0;
236
237    let versionize_body = implementor.versionize_method_body();
238    let versionize_owned_body = implementor.versionize_owned_method_body();
239    let unversionize_arg_name = Ident::new("versioned", Span::call_site());
240    let unversionize_body = implementor.unversionize_method_body(&unversionize_arg_name);
241    let unversionize_error: Path = parse_const_str(UNVERSIONIZE_ERROR_NAME);
242
243    quote! {
244        #version_trait_impl
245
246        #[automatically_derived]
247        impl #trait_impl_generics #versionize_trait for #input_ident #ty_generics
248        #versionize_trait_where_clause
249        {
250            type Versioned<#lifetime> = #versioned_type #versioned_type_where_clause;
251
252            fn versionize(&self) -> Self::Versioned<'_> {
253                #versionize_body
254            }
255        }
256
257        #[automatically_derived]
258        impl #trait_impl_generics #versionize_owned_trait for #input_ident #ty_generics
259        #versionize_owned_trait_where_clause
260        {
261            type VersionedOwned = #versioned_owned_type #versioned_owned_type_where_clause;
262
263            fn versionize_owned(self) -> Self::VersionedOwned {
264                #versionize_owned_body
265            }
266        }
267
268        #[automatically_derived]
269        impl #trait_impl_generics #unversionize_trait for #input_ident #ty_generics
270        #unversionize_trait_where_clause
271        {
272            fn unversionize(#unversionize_arg_name: Self::VersionedOwned) -> Result<Self, #unversionize_error>  {
273                #unversionize_body
274            }
275        }
276
277        #[automatically_derived]
278        impl #trait_impl_generics #versionize_slice_trait for #input_ident #ty_generics
279        #versionize_trait_where_clause
280        {
281            type VersionedSlice<#lifetime> = Vec<<Self as #versionize_trait>::Versioned<#lifetime>> #versioned_type_where_clause;
282
283            fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> {
284                slice.iter().map(|val| #versionize_trait::versionize(val)).collect()
285            }
286        }
287
288        #[automatically_derived]
289        impl #trait_impl_generics #versionize_vec_trait for #input_ident #ty_generics
290        #versionize_owned_trait_where_clause
291        {
292
293            type VersionedVec = Vec<<Self as #versionize_owned_trait>::VersionedOwned> #versioned_owned_type_where_clause;
294
295            fn versionize_vec(vec: Vec<Self>) -> Self::VersionedVec {
296                vec.into_iter().map(|val| #versionize_owned_trait::versionize_owned(val)).collect()
297            }
298        }
299
300        #[automatically_derived]
301        impl #trait_impl_generics #unversionize_vec_trait for #input_ident #ty_generics
302        #unversionize_trait_where_clause
303        {
304            fn unversionize_vec(versioned: Self::VersionedVec) -> Result<Vec<Self>, #unversionize_error> {
305                versioned
306                .into_iter()
307                .map(|versioned| <Self as #unversionize_trait>::unversionize(versioned))
308                .collect()
309            }
310        }
311    }
312    .into()
313}
314
315/// This derives the `Versionize` and `Unversionize` trait for a type that should not
316/// be versioned. The `versionize` method will simply return self
317#[proc_macro_derive(NotVersioned)]
318pub fn derive_not_versioned(input: TokenStream) -> TokenStream {
319    let input = parse_macro_input!(input as DeriveInput);
320
321    // Versionize needs T to impl Serialize
322    let mut versionize_generics = input.generics.clone();
323    syn_unwrap!(add_trait_where_clause(
324        &mut versionize_generics,
325        &[parse_quote! { Self }],
326        &[SERIALIZE_TRAIT_NAME]
327    ));
328
329    // VersionizeOwned needs T to impl Serialize and DeserializeOwned
330    let mut versionize_owned_generics = input.generics.clone();
331    syn_unwrap!(add_trait_where_clause(
332        &mut versionize_owned_generics,
333        &[parse_quote! { Self }],
334        &[SERIALIZE_TRAIT_NAME, DESERIALIZE_OWNED_TRAIT_NAME]
335    ));
336
337    let (impl_generics, ty_generics, versionize_where_clause) =
338        versionize_generics.split_for_impl();
339    let (_, _, versionize_owned_where_clause) = versionize_owned_generics.split_for_impl();
340
341    let input_ident = &input.ident;
342
343    let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
344    let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
345    let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME);
346    let unversionize_error: Path = parse_const_str(UNVERSIONIZE_ERROR_NAME);
347    let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site());
348
349    quote! {
350        #[automatically_derived]
351        impl #impl_generics #versionize_trait for #input_ident #ty_generics #versionize_where_clause {
352            type Versioned<#lifetime> = &#lifetime Self where Self: 'vers;
353
354            fn versionize(&self) -> Self::Versioned<'_> {
355                self
356            }
357        }
358
359        #[automatically_derived]
360        impl #impl_generics #versionize_owned_trait for #input_ident #ty_generics #versionize_owned_where_clause {
361            type VersionedOwned = Self;
362
363            fn versionize_owned(self) -> Self::VersionedOwned {
364                self
365            }
366        }
367
368        #[automatically_derived]
369        impl #impl_generics #unversionize_trait for #input_ident #ty_generics #versionize_owned_where_clause {
370            fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, #unversionize_error> {
371                Ok(versioned)
372            }
373        }
374
375        #[automatically_derived]
376        impl #impl_generics NotVersioned for #input_ident #ty_generics #versionize_owned_where_clause {}
377
378    }
379    .into()
380}
381
382/// Adds a where clause with a lifetime bound on all the generic types and lifetimes in `generics`
383fn add_where_lifetime_bound_to_generics(generics: &mut Generics, lifetime: &Lifetime) {
384    let mut params = Vec::new();
385    for param in generics.params.iter() {
386        let param_ident = match param {
387            GenericParam::Lifetime(generic_lifetime) => {
388                if generic_lifetime.lifetime.ident == lifetime.ident {
389                    continue;
390                }
391                &generic_lifetime.lifetime.ident
392            }
393            GenericParam::Type(generic_type) => &generic_type.ident,
394            GenericParam::Const(_) => continue,
395        };
396        params.push(param_ident.clone());
397    }
398
399    for param in params.iter() {
400        generics
401            .make_where_clause()
402            .predicates
403            .push(parse_quote! { #param: #lifetime  });
404    }
405}
406
407/// Adds a new lifetime param with a bound for all the generic types in `generics`
408fn add_lifetime_param(generics: &mut Generics, lifetime: &Lifetime) {
409    generics
410        .params
411        .push(GenericParam::Lifetime(LifetimeParam::new(lifetime.clone())));
412    for param in generics.type_params_mut() {
413        param
414            .bounds
415            .push(TypeParamBound::Lifetime(lifetime.clone()));
416    }
417}
418
419/// Parse the input str trait bound
420fn parse_trait_bound(trait_name: &str) -> syn::Result<TraitBound> {
421    let trait_path: Path = syn::parse_str(trait_name)?;
422    Ok(parse_quote!(#trait_path))
423}
424
425/// Adds a "where clause" bound for all the input types with all the input traits
426fn add_trait_where_clause<'a, S: AsRef<str>, I: IntoIterator<Item = &'a Type>>(
427    generics: &mut Generics,
428    types: I,
429    traits_name: &[S],
430) -> syn::Result<()> {
431    let preds = &mut generics.make_where_clause().predicates;
432
433    if !traits_name.is_empty() {
434        let bounds: Vec<TraitBound> = traits_name
435            .iter()
436            .map(|bound_name| parse_trait_bound(bound_name.as_ref()))
437            .collect::<syn::Result<_>>()?;
438        for ty in types {
439            preds.push(parse_quote! { #ty: #(#bounds)+*  });
440        }
441    }
442
443    Ok(())
444}
445
446/// Adds a "where clause" bound for all the input types with all the input lifetimes
447fn add_lifetime_where_clause<'a, S: AsRef<str>, I: IntoIterator<Item = &'a Type>>(
448    generics: &mut Generics,
449    types: I,
450    lifetimes: &[S],
451) -> syn::Result<()> {
452    let preds = &mut generics.make_where_clause().predicates;
453
454    if !lifetimes.is_empty() {
455        let bounds: Vec<Lifetime> = lifetimes
456            .iter()
457            .map(|lifetime| syn::parse_str(lifetime.as_ref()))
458            .collect::<syn::Result<_>>()?;
459        for ty in types {
460            preds.push(parse_quote! { #ty: #(#bounds)+*  });
461        }
462    }
463
464    Ok(())
465}
466
467/// Extends a where clause with predicates from another one, filtering duplicates
468fn extend_where_clause(base_clause: &mut WhereClause, extension_clause: &WhereClause) {
469    for extend_predicate in &extension_clause.predicates {
470        if base_clause.predicates.iter().all(|base_predicate| {
471            base_predicate.to_token_stream().to_string()
472                != extend_predicate.to_token_stream().to_string()
473        }) {
474            base_clause.predicates.push(extend_predicate.clone());
475        }
476    }
477}
478
479/// Creates a Result [`syn::punctuated::Punctuated`] from an iterator of Results
480fn punctuated_from_iter_result<T, P: Default, I: IntoIterator<Item = syn::Result<T>>>(
481    iter: I,
482) -> syn::Result<Punctuated<T, P>> {
483    let mut ret = Punctuated::new();
484    for value in iter {
485        ret.push(value?)
486    }
487    Ok(ret)
488}
489
490/// Like [`syn::parse_str`] for inputs that are known at compile time to be valid
491fn parse_const_str<T: Parse>(s: &'static str) -> T {
492    syn::parse_str(s).expect("Parsing of const string should not fail")
493}
494
495/// Remove the '?Sized' bounds from the generics
496///
497/// The VersionDispatch trait requires that the versioned type is Sized so we have to remove this
498/// bound. It means that for a type `MyStruct<T: ?Sized>`, we will only be able to call
499/// `.versionize()` when T is Sized.
500fn filter_unsized_bounds(generics: &Generics) -> Generics {
501    let mut generics = generics.clone();
502
503    for param in generics.type_params_mut() {
504        param.bounds = remove_unsized_bound(&param.bounds);
505    }
506
507    if let Some(clause) = &mut generics.where_clause {
508        for pred in &mut clause.predicates {
509            match pred {
510                syn::WherePredicate::Lifetime(_) => {}
511                syn::WherePredicate::Type(type_predicate) => {
512                    type_predicate.bounds = remove_unsized_bound(&type_predicate.bounds);
513                }
514                _ => {}
515            }
516        }
517    }
518
519    generics
520}
521
522/// Filter the ?Sized bound in a list of bounds
523fn remove_unsized_bound(
524    bounds: &Punctuated<TypeParamBound, Plus>,
525) -> Punctuated<TypeParamBound, Plus> {
526    bounds
527        .iter()
528        .filter(|bound| match bound {
529            TypeParamBound::Trait(trait_bound) => {
530                if !matches!(trait_bound.modifier, TraitBoundModifier::None) {
531                    if let Some(segment) = trait_bound.path.segments.iter().next_back() {
532                        if segment.ident == "Sized" {
533                            return false;
534                        }
535                    }
536                }
537                true
538            }
539            TypeParamBound::Lifetime(_) => true,
540            TypeParamBound::Verbatim(_) => true,
541            _ => true,
542        })
543        .cloned()
544        .collect()
545}