tfhe_versionable_derive/
lib.rs

1//! Set of derive macro to automatically implement the `Versionize` and `Unversionize` traits.
2//! The macro defined in this crate are:
3//! - `Versionize`: should be derived on the main type that is used in your code
4//! - `Version`: should be derived on a previous version of this type
5//! - `VersionsDispatch`: should be derived ont the enum that holds all the versions of the type
6//! - `NotVersioned`: can be used to implement `Versionize` for a type that is not really versioned
7
8mod associated;
9mod dispatch_type;
10mod transparent;
11mod version_type;
12mod versionize_attribute;
13mod versionize_impl;
14
15use dispatch_type::DispatchType;
16use proc_macro::TokenStream;
17use proc_macro2::Span;
18use quote::{quote, ToTokens};
19use syn::parse::Parse;
20use syn::punctuated::Punctuated;
21use syn::token::Plus;
22use syn::{
23    parse_macro_input, parse_quote, DeriveInput, GenericParam, Generics, Ident, Lifetime,
24    LifetimeParam, Path, TraitBound, TraitBoundModifier, Type, TypeParamBound, WhereClause,
25};
26
27/// Adds the full path of the current crate name to avoid name clashes in generated code.
28macro_rules! crate_full_path {
29    ($trait_name:expr) => {
30        concat!("::tfhe_versionable::", $trait_name)
31    };
32}
33
34pub(crate) const LIFETIME_NAME: &str = "'vers";
35pub(crate) const VERSION_TRAIT_NAME: &str = crate_full_path!("Version");
36pub(crate) const DISPATCH_TRAIT_NAME: &str = crate_full_path!("VersionsDispatch");
37pub(crate) const VERSIONIZE_TRAIT_NAME: &str = crate_full_path!("Versionize");
38pub(crate) const VERSIONIZE_OWNED_TRAIT_NAME: &str = crate_full_path!("VersionizeOwned");
39pub(crate) const VERSIONIZE_SLICE_TRAIT_NAME: &str = crate_full_path!("VersionizeSlice");
40pub(crate) const VERSIONIZE_VEC_TRAIT_NAME: &str = crate_full_path!("VersionizeVec");
41pub(crate) const UNVERSIONIZE_TRAIT_NAME: &str = crate_full_path!("Unversionize");
42pub(crate) const UNVERSIONIZE_VEC_TRAIT_NAME: &str = crate_full_path!("UnversionizeVec");
43pub(crate) const UPGRADE_TRAIT_NAME: &str = crate_full_path!("Upgrade");
44pub(crate) const UNVERSIONIZE_ERROR_NAME: &str = crate_full_path!("UnversionizeError");
45
46pub(crate) const SERIALIZE_TRAIT_NAME: &str = "::serde::Serialize";
47pub(crate) const DESERIALIZE_TRAIT_NAME: &str = "::serde::Deserialize";
48pub(crate) const DESERIALIZE_OWNED_TRAIT_NAME: &str = "::serde::de::DeserializeOwned";
49pub(crate) const TRY_FROM_TRAIT_NAME: &str = "::core::convert::TryFrom";
50pub(crate) const FROM_TRAIT_NAME: &str = "::core::convert::From";
51pub(crate) const TRY_INTO_TRAIT_NAME: &str = "::core::convert::TryInto";
52pub(crate) const INTO_TRAIT_NAME: &str = "::core::convert::Into";
53pub(crate) const ERROR_TRAIT_NAME: &str = "::core::error::Error";
54pub(crate) const SYNC_TRAIT_NAME: &str = "::core::marker::Sync";
55pub(crate) const SEND_TRAIT_NAME: &str = "::core::marker::Send";
56pub(crate) const DEFAULT_TRAIT_NAME: &str = "::core::default::Default";
57pub(crate) const RESULT_TYPE_NAME: &str = "::core::result::Result";
58pub(crate) const VEC_TYPE_NAME: &str = "::std::vec::Vec";
59pub(crate) const STATIC_LIFETIME_NAME: &str = "'static";
60
61use associated::AssociatingTrait;
62use versionize_impl::VersionizeImplementor;
63
64use crate::version_type::VersionType;
65use crate::versionize_attribute::VersionizeAttribute;
66
67/// unwrap a `syn::Result` by extracting the Ok value or returning from the outer function with
68/// a compile error
69macro_rules! syn_unwrap {
70    ($e:expr) => {
71        match $e {
72            Ok(res) => res,
73            Err(err) => return err.to_compile_error().into(),
74        }
75    };
76}
77
78#[proc_macro_derive(Version, attributes(versionize))]
79/// Implement the `Version` trait for the target type.
80pub fn derive_version(input: TokenStream) -> TokenStream {
81    let input = parse_macro_input!(input as DeriveInput);
82
83    impl_version_trait(&input).into()
84}
85
86/// Actual implementation of the version trait. This will create the ref and owned
87/// associated types and use them to implement the trait.
88fn impl_version_trait(input: &DeriveInput) -> proc_macro2::TokenStream {
89    let version_trait = syn_unwrap!(AssociatingTrait::<VersionType>::new(
90        input,
91        VERSION_TRAIT_NAME,
92    ));
93
94    let version_types = syn_unwrap!(version_trait.generate_types_declarations());
95
96    let version_impl = syn_unwrap!(version_trait.generate_impl());
97
98    quote! {
99        const _: () = {
100            #version_types
101
102            #[automatically_derived]
103            #version_impl
104        };
105    }
106}
107
108/// Implement the `VersionsDispatch` trait for the target type. The type where this macro is
109/// applied should be an enum where each variant is a version of the type that we want to
110/// versionize.
111#[proc_macro_derive(VersionsDispatch)]
112pub fn derive_versions_dispatch(input: TokenStream) -> TokenStream {
113    let input = parse_macro_input!(input as DeriveInput);
114
115    let dispatch_trait = syn_unwrap!(AssociatingTrait::<DispatchType>::new(
116        &input,
117        DISPATCH_TRAIT_NAME,
118    ));
119
120    let dispatch_types = syn_unwrap!(dispatch_trait.generate_types_declarations());
121
122    let dispatch_impl = syn_unwrap!(dispatch_trait.generate_impl());
123
124    quote! {
125        const _: () = {
126            #dispatch_types
127
128            #[automatically_derived]
129            #dispatch_impl
130        };
131    }
132    .into()
133}
134
135/// This derives the `Versionize` and `Unversionize` trait for the target type.
136///
137/// This macro has a mandatory attribute parameter, which is the name of the versioned enum for this
138/// type. This enum can be anywhere in the code but should be in scope.
139///
140/// Example:
141/// ```ignore
142/// // The structure that should be versioned, as defined in your code
143/// #[derive(Versionize)]
144/// // We have to link to the enum type that will holds all the versions of this
145/// // type. This can also be written `#[versionize(dispatch = MyStructVersions)]`.
146/// #[versionize(MyStructVersions)]
147/// struct MyStruct<T> {
148///     attr: T,
149///     builtin: u32,
150/// }
151///
152/// // To avoid polluting your code, the old versions can be defined in another module/file, along with
153/// // the dispatch enum
154/// #[derive(Version)] // Used to mark an old version of the type
155/// struct MyStructV0 {
156///     builtin: u32,
157/// }
158///
159/// // The Upgrade trait tells how to go from the first version to the last. During unversioning, the
160/// // upgrade method will be called on the deserialized value enough times to go to the last variant.
161/// impl<T: Default> Upgrade<MyStruct<T>> for MyStructV0 {
162///     type Error = Infallible;
163///
164///     fn upgrade(self) -> Result<MyStruct<T>, Self::Error> {
165///         Ok(MyStruct {
166///             attr: T::default(),
167///             builtin: self.builtin,
168///         })
169///     }
170/// }
171///
172/// // This is the dispatch enum, that holds one variant for each version of your type.
173/// #[derive(VersionsDispatch)]
174/// // This enum is not directly used but serves as a template to generate a new enum that will be
175/// // serialized. This allows recursive versioning.
176/// #[allow(unused)]
177/// enum MyStructVersions<T> {
178///     V0(MyStructV0),
179///     V1(MyStruct<T>),
180/// }
181/// ```
182#[proc_macro_derive(Versionize, attributes(versionize))]
183pub fn derive_versionize(input: TokenStream) -> TokenStream {
184    let input = parse_macro_input!(input as DeriveInput);
185
186    let input_generics = filter_unsized_bounds(&input.generics);
187
188    let attributes = syn_unwrap!(VersionizeAttribute::parse_from_attributes_list(
189        &input.attrs
190    ));
191
192    let implementor = syn_unwrap!(VersionizeImplementor::new(
193        attributes,
194        &input.data,
195        Span::call_site()
196    ));
197
198    // If we apply a type conversion before the call to versionize, the type that implements
199    // the `Version` trait is the target type and not Self
200    let version_trait_impl: Option<proc_macro2::TokenStream> =
201        if implementor.is_directly_versioned() {
202            Some(impl_version_trait(&input))
203        } else {
204            None
205        };
206
207    // Parse the name of the traits that we will implement
208    let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
209    let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
210    let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME);
211    let versionize_vec_trait: Path = parse_const_str(VERSIONIZE_VEC_TRAIT_NAME);
212    let versionize_slice_trait: Path = parse_const_str(VERSIONIZE_SLICE_TRAIT_NAME);
213    let unversionize_vec_trait: Path = parse_const_str(UNVERSIONIZE_VEC_TRAIT_NAME);
214
215    let input_ident = &input.ident;
216    let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site());
217
218    // split generics so they can be used inside the generated code
219    let (_, ty_generics, _) = input_generics.split_for_impl();
220
221    // Generates the associated types required by the traits
222    let versioned_type = implementor.versioned_type(&lifetime, &input_generics);
223    let versioned_owned_type = implementor.versioned_owned_type(&input_generics);
224    let versioned_type_where_clause =
225        implementor.versioned_type_where_clause(&lifetime, &input_generics);
226    let versioned_owned_type_where_clause =
227        implementor.versioned_owned_type_where_clause(&input_generics);
228
229    // If the original type has some generics, we need to add bounds on them for
230    // the traits impl
231    let versionize_trait_where_clause =
232        syn_unwrap!(implementor.versionize_trait_where_clause(&input_generics));
233    let versionize_owned_trait_where_clause =
234        syn_unwrap!(implementor.versionize_owned_trait_where_clause(&input_generics));
235    let unversionize_trait_where_clause =
236        syn_unwrap!(implementor.unversionize_trait_where_clause(&input_generics));
237
238    let trait_impl_generics = input_generics.split_for_impl().0;
239
240    let versionize_body = implementor.versionize_method_body();
241    let versionize_owned_body = implementor.versionize_owned_method_body();
242    let unversionize_arg_name = Ident::new("versioned", Span::call_site());
243    let unversionize_body = implementor.unversionize_method_body(&unversionize_arg_name);
244    let unversionize_error: Path = parse_const_str(UNVERSIONIZE_ERROR_NAME);
245
246    let result_type: Path = parse_const_str(RESULT_TYPE_NAME);
247    let vec_type: Path = parse_const_str(VEC_TYPE_NAME);
248
249    quote! {
250        #version_trait_impl
251
252        #[automatically_derived]
253        impl #trait_impl_generics #versionize_trait for #input_ident #ty_generics
254        #versionize_trait_where_clause
255        {
256            type Versioned<#lifetime> = #versioned_type #versioned_type_where_clause;
257
258            fn versionize(&self) -> Self::Versioned<'_> {
259                #versionize_body
260            }
261        }
262
263        #[automatically_derived]
264        impl #trait_impl_generics #versionize_owned_trait for #input_ident #ty_generics
265        #versionize_owned_trait_where_clause
266        {
267            type VersionedOwned = #versioned_owned_type #versioned_owned_type_where_clause;
268
269            fn versionize_owned(self) -> Self::VersionedOwned {
270                #versionize_owned_body
271            }
272        }
273
274        #[automatically_derived]
275        impl #trait_impl_generics #unversionize_trait for #input_ident #ty_generics
276        #unversionize_trait_where_clause
277        {
278            fn unversionize(#unversionize_arg_name: Self::VersionedOwned) -> #result_type<Self, #unversionize_error>  {
279                #unversionize_body
280            }
281        }
282
283        #[automatically_derived]
284        impl #trait_impl_generics #versionize_slice_trait for #input_ident #ty_generics
285        #versionize_trait_where_clause
286        {
287            type VersionedSlice<#lifetime> = #vec_type<<Self as #versionize_trait>::Versioned<#lifetime>> #versioned_type_where_clause;
288
289            fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> {
290                slice.iter().map(|val| #versionize_trait::versionize(val)).collect()
291            }
292        }
293
294        #[automatically_derived]
295        impl #trait_impl_generics #versionize_vec_trait for #input_ident #ty_generics
296        #versionize_owned_trait_where_clause
297        {
298
299            type VersionedVec = #vec_type<<Self as #versionize_owned_trait>::VersionedOwned> #versioned_owned_type_where_clause;
300
301            fn versionize_vec(vec: #vec_type<Self>) -> Self::VersionedVec {
302                vec.into_iter().map(|val| #versionize_owned_trait::versionize_owned(val)).collect()
303            }
304        }
305
306        #[automatically_derived]
307        impl #trait_impl_generics #unversionize_vec_trait for #input_ident #ty_generics
308        #unversionize_trait_where_clause
309        {
310            fn unversionize_vec(versioned: Self::VersionedVec) -> #result_type<#vec_type<Self>, #unversionize_error> {
311                versioned
312                .into_iter()
313                .map(|versioned| <Self as #unversionize_trait>::unversionize(versioned))
314                .collect()
315            }
316        }
317    }
318    .into()
319}
320
321/// This derives the `Versionize` and `Unversionize` trait for a type that should not
322/// be versioned. The `versionize` method will simply return self
323#[proc_macro_derive(NotVersioned)]
324pub fn derive_not_versioned(input: TokenStream) -> TokenStream {
325    let input = parse_macro_input!(input as DeriveInput);
326
327    // Versionize needs T to impl Serialize
328    let mut versionize_generics = input.generics.clone();
329    syn_unwrap!(add_trait_where_clause(
330        &mut versionize_generics,
331        &[parse_quote! { Self }],
332        &[SERIALIZE_TRAIT_NAME]
333    ));
334
335    // VersionizeOwned needs T to impl Serialize and DeserializeOwned
336    let mut versionize_owned_generics = input.generics.clone();
337    syn_unwrap!(add_trait_where_clause(
338        &mut versionize_owned_generics,
339        &[parse_quote! { Self }],
340        &[SERIALIZE_TRAIT_NAME, DESERIALIZE_OWNED_TRAIT_NAME]
341    ));
342
343    let (impl_generics, ty_generics, versionize_where_clause) =
344        versionize_generics.split_for_impl();
345    let (_, _, versionize_owned_where_clause) = versionize_owned_generics.split_for_impl();
346
347    let input_ident = &input.ident;
348
349    let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
350    let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
351    let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME);
352    let unversionize_error: Path = parse_const_str(UNVERSIONIZE_ERROR_NAME);
353    let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site());
354
355    let result_type: Path = parse_const_str(RESULT_TYPE_NAME);
356
357    quote! {
358        #[automatically_derived]
359        impl #impl_generics #versionize_trait for #input_ident #ty_generics #versionize_where_clause {
360            type Versioned<#lifetime> = &#lifetime Self where Self: 'vers;
361
362            fn versionize(&self) -> Self::Versioned<'_> {
363                self
364            }
365        }
366
367        #[automatically_derived]
368        impl #impl_generics #versionize_owned_trait for #input_ident #ty_generics #versionize_owned_where_clause {
369            type VersionedOwned = Self;
370
371            fn versionize_owned(self) -> Self::VersionedOwned {
372                self
373            }
374        }
375
376        #[automatically_derived]
377        impl #impl_generics #unversionize_trait for #input_ident #ty_generics #versionize_owned_where_clause {
378            fn unversionize(versioned: Self::VersionedOwned) -> #result_type<Self, #unversionize_error> {
379                Ok(versioned)
380            }
381        }
382
383        #[automatically_derived]
384        impl #impl_generics NotVersioned for #input_ident #ty_generics #versionize_owned_where_clause {}
385
386    }
387    .into()
388}
389
390/// Adds a where clause with a lifetime bound on all the generic types and lifetimes in `generics`
391fn add_where_lifetime_bound_to_generics(generics: &mut Generics, lifetime: &Lifetime) {
392    let mut params = Vec::new();
393    for param in generics.params.iter() {
394        let param_ident = match param {
395            GenericParam::Lifetime(generic_lifetime) => {
396                if generic_lifetime.lifetime.ident == lifetime.ident {
397                    continue;
398                }
399                &generic_lifetime.lifetime.ident
400            }
401            GenericParam::Type(generic_type) => &generic_type.ident,
402            GenericParam::Const(_) => continue,
403        };
404        params.push(param_ident.clone());
405    }
406
407    for param in params.iter() {
408        generics
409            .make_where_clause()
410            .predicates
411            .push(parse_quote! { #param: #lifetime  });
412    }
413}
414
415/// Adds a new lifetime param with a bound for all the generic types in `generics`
416fn add_lifetime_param(generics: &mut Generics, lifetime: &Lifetime) {
417    generics
418        .params
419        .push(GenericParam::Lifetime(LifetimeParam::new(lifetime.clone())));
420    for param in generics.type_params_mut() {
421        param
422            .bounds
423            .push(TypeParamBound::Lifetime(lifetime.clone()));
424    }
425}
426
427/// Parse the input str trait bound
428fn parse_trait_bound(trait_name: &str) -> syn::Result<TraitBound> {
429    let trait_path: Path = syn::parse_str(trait_name)?;
430    Ok(parse_quote!(#trait_path))
431}
432
433/// Adds a "where clause" bound for all the input types with all the input traits
434fn add_trait_where_clause<'a, S: AsRef<str>, I: IntoIterator<Item = &'a Type>>(
435    generics: &mut Generics,
436    types: I,
437    traits_name: &[S],
438) -> syn::Result<()> {
439    let preds = &mut generics.make_where_clause().predicates;
440
441    if !traits_name.is_empty() {
442        let bounds: Vec<TraitBound> = traits_name
443            .iter()
444            .map(|bound_name| parse_trait_bound(bound_name.as_ref()))
445            .collect::<syn::Result<_>>()?;
446        for ty in types {
447            preds.push(parse_quote! { #ty: #(#bounds)+*  });
448        }
449    }
450
451    Ok(())
452}
453
454/// Adds a "where clause" bound for all the input types with all the input lifetimes
455fn add_lifetime_where_clause<'a, S: AsRef<str>, I: IntoIterator<Item = &'a Type>>(
456    generics: &mut Generics,
457    types: I,
458    lifetimes: &[S],
459) -> syn::Result<()> {
460    let preds = &mut generics.make_where_clause().predicates;
461
462    if !lifetimes.is_empty() {
463        let bounds: Vec<Lifetime> = lifetimes
464            .iter()
465            .map(|lifetime| syn::parse_str(lifetime.as_ref()))
466            .collect::<syn::Result<_>>()?;
467        for ty in types {
468            preds.push(parse_quote! { #ty: #(#bounds)+*  });
469        }
470    }
471
472    Ok(())
473}
474
475/// Extends a where clause with predicates from another one, filtering duplicates
476fn extend_where_clause(base_clause: &mut WhereClause, extension_clause: &WhereClause) {
477    for extend_predicate in &extension_clause.predicates {
478        if base_clause.predicates.iter().all(|base_predicate| {
479            base_predicate.to_token_stream().to_string()
480                != extend_predicate.to_token_stream().to_string()
481        }) {
482            base_clause.predicates.push(extend_predicate.clone());
483        }
484    }
485}
486
487/// Creates a Result [`syn::punctuated::Punctuated`] from an iterator of Results
488fn punctuated_from_iter_result<T, P: Default, I: IntoIterator<Item = syn::Result<T>>>(
489    iter: I,
490) -> syn::Result<Punctuated<T, P>> {
491    let mut ret = Punctuated::new();
492    for value in iter {
493        ret.push(value?)
494    }
495    Ok(ret)
496}
497
498/// Like [`syn::parse_str`] for inputs that are known at compile time to be valid
499fn parse_const_str<T: Parse>(s: &'static str) -> T {
500    syn::parse_str(s).expect("Parsing of const string should not fail")
501}
502
503/// Remove the '?Sized' bounds from the generics
504///
505/// The VersionDispatch trait requires that the versioned type is Sized so we have to remove this
506/// bound. It means that for a type `MyStruct<T: ?Sized>`, we will only be able to call
507/// `.versionize()` when T is Sized.
508fn filter_unsized_bounds(generics: &Generics) -> Generics {
509    let mut generics = generics.clone();
510
511    for param in generics.type_params_mut() {
512        param.bounds = remove_unsized_bound(&param.bounds);
513    }
514
515    if let Some(clause) = &mut generics.where_clause {
516        for pred in &mut clause.predicates {
517            match pred {
518                syn::WherePredicate::Lifetime(_) => {}
519                syn::WherePredicate::Type(type_predicate) => {
520                    type_predicate.bounds = remove_unsized_bound(&type_predicate.bounds);
521                }
522                _ => {}
523            }
524        }
525    }
526
527    generics
528}
529
530/// Filter the ?Sized bound in a list of bounds
531fn remove_unsized_bound(
532    bounds: &Punctuated<TypeParamBound, Plus>,
533) -> Punctuated<TypeParamBound, Plus> {
534    bounds
535        .iter()
536        .filter(|bound| match bound {
537            TypeParamBound::Trait(trait_bound) => {
538                if !matches!(trait_bound.modifier, TraitBoundModifier::None) {
539                    if let Some(segment) = trait_bound.path.segments.iter().next_back() {
540                        if segment.ident == "Sized" {
541                            return false;
542                        }
543                    }
544                }
545                true
546            }
547            TypeParamBound::Lifetime(_) => true,
548            TypeParamBound::Verbatim(_) => true,
549            _ => true,
550        })
551        .cloned()
552        .collect()
553}