tfhe_versionable_derive/
lib.rs1mod associated;
9mod dispatch_type;
10mod transparent;
11mod version_type;
12mod versionize_attribute;
13mod versionize_impl;
14
15use dispatch_type::DispatchType;
16use proc_macro::TokenStream;
17use proc_macro2::Span;
18use quote::{quote, ToTokens};
19use syn::parse::Parse;
20use syn::punctuated::Punctuated;
21use syn::token::Plus;
22use syn::{
23 parse_macro_input, parse_quote, DeriveInput, GenericParam, Generics, Ident, Lifetime,
24 LifetimeParam, Path, TraitBound, TraitBoundModifier, Type, TypeParamBound, WhereClause,
25};
26
27macro_rules! crate_full_path {
29 ($trait_name:expr) => {
30 concat!("::tfhe_versionable::", $trait_name)
31 };
32}
33
34pub(crate) const LIFETIME_NAME: &str = "'vers";
35pub(crate) const VERSION_TRAIT_NAME: &str = crate_full_path!("Version");
36pub(crate) const DISPATCH_TRAIT_NAME: &str = crate_full_path!("VersionsDispatch");
37pub(crate) const VERSIONIZE_TRAIT_NAME: &str = crate_full_path!("Versionize");
38pub(crate) const VERSIONIZE_OWNED_TRAIT_NAME: &str = crate_full_path!("VersionizeOwned");
39pub(crate) const VERSIONIZE_SLICE_TRAIT_NAME: &str = crate_full_path!("VersionizeSlice");
40pub(crate) const VERSIONIZE_VEC_TRAIT_NAME: &str = crate_full_path!("VersionizeVec");
41pub(crate) const UNVERSIONIZE_TRAIT_NAME: &str = crate_full_path!("Unversionize");
42pub(crate) const UNVERSIONIZE_VEC_TRAIT_NAME: &str = crate_full_path!("UnversionizeVec");
43pub(crate) const UPGRADE_TRAIT_NAME: &str = crate_full_path!("Upgrade");
44pub(crate) const UNVERSIONIZE_ERROR_NAME: &str = crate_full_path!("UnversionizeError");
45
46pub(crate) const SERIALIZE_TRAIT_NAME: &str = "::serde::Serialize";
47pub(crate) const DESERIALIZE_TRAIT_NAME: &str = "::serde::Deserialize";
48pub(crate) const DESERIALIZE_OWNED_TRAIT_NAME: &str = "::serde::de::DeserializeOwned";
49pub(crate) const TRY_FROM_TRAIT_NAME: &str = "::core::convert::TryFrom";
50pub(crate) const FROM_TRAIT_NAME: &str = "::core::convert::From";
51pub(crate) const TRY_INTO_TRAIT_NAME: &str = "::core::convert::TryInto";
52pub(crate) const INTO_TRAIT_NAME: &str = "::core::convert::Into";
53pub(crate) const ERROR_TRAIT_NAME: &str = "::core::error::Error";
54pub(crate) const SYNC_TRAIT_NAME: &str = "::core::marker::Sync";
55pub(crate) const SEND_TRAIT_NAME: &str = "::core::marker::Send";
56pub(crate) const DEFAULT_TRAIT_NAME: &str = "::core::default::Default";
57pub(crate) const RESULT_TYPE_NAME: &str = "::core::result::Result";
58pub(crate) const VEC_TYPE_NAME: &str = "::std::vec::Vec";
59pub(crate) const STATIC_LIFETIME_NAME: &str = "'static";
60
61use associated::AssociatingTrait;
62use versionize_impl::VersionizeImplementor;
63
64use crate::version_type::VersionType;
65use crate::versionize_attribute::VersionizeAttribute;
66
67macro_rules! syn_unwrap {
70 ($e:expr) => {
71 match $e {
72 Ok(res) => res,
73 Err(err) => return err.to_compile_error().into(),
74 }
75 };
76}
77
78#[proc_macro_derive(Version, attributes(versionize))]
79pub fn derive_version(input: TokenStream) -> TokenStream {
81 let input = parse_macro_input!(input as DeriveInput);
82
83 impl_version_trait(&input).into()
84}
85
86fn impl_version_trait(input: &DeriveInput) -> proc_macro2::TokenStream {
89 let version_trait = syn_unwrap!(AssociatingTrait::<VersionType>::new(
90 input,
91 VERSION_TRAIT_NAME,
92 ));
93
94 let version_types = syn_unwrap!(version_trait.generate_types_declarations());
95
96 let version_impl = syn_unwrap!(version_trait.generate_impl());
97
98 quote! {
99 const _: () = {
100 #version_types
101
102 #[automatically_derived]
103 #version_impl
104 };
105 }
106}
107
108#[proc_macro_derive(VersionsDispatch)]
112pub fn derive_versions_dispatch(input: TokenStream) -> TokenStream {
113 let input = parse_macro_input!(input as DeriveInput);
114
115 let dispatch_trait = syn_unwrap!(AssociatingTrait::<DispatchType>::new(
116 &input,
117 DISPATCH_TRAIT_NAME,
118 ));
119
120 let dispatch_types = syn_unwrap!(dispatch_trait.generate_types_declarations());
121
122 let dispatch_impl = syn_unwrap!(dispatch_trait.generate_impl());
123
124 quote! {
125 const _: () = {
126 #dispatch_types
127
128 #[automatically_derived]
129 #dispatch_impl
130 };
131 }
132 .into()
133}
134
135#[proc_macro_derive(Versionize, attributes(versionize))]
183pub fn derive_versionize(input: TokenStream) -> TokenStream {
184 let input = parse_macro_input!(input as DeriveInput);
185
186 let input_generics = filter_unsized_bounds(&input.generics);
187
188 let attributes = syn_unwrap!(VersionizeAttribute::parse_from_attributes_list(
189 &input.attrs
190 ));
191
192 let implementor = syn_unwrap!(VersionizeImplementor::new(
193 attributes,
194 &input.data,
195 Span::call_site()
196 ));
197
198 let version_trait_impl: Option<proc_macro2::TokenStream> =
201 if implementor.is_directly_versioned() {
202 Some(impl_version_trait(&input))
203 } else {
204 None
205 };
206
207 let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
209 let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
210 let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME);
211 let versionize_vec_trait: Path = parse_const_str(VERSIONIZE_VEC_TRAIT_NAME);
212 let versionize_slice_trait: Path = parse_const_str(VERSIONIZE_SLICE_TRAIT_NAME);
213 let unversionize_vec_trait: Path = parse_const_str(UNVERSIONIZE_VEC_TRAIT_NAME);
214
215 let input_ident = &input.ident;
216 let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site());
217
218 let (_, ty_generics, _) = input_generics.split_for_impl();
220
221 let versioned_type = implementor.versioned_type(&lifetime, &input_generics);
223 let versioned_owned_type = implementor.versioned_owned_type(&input_generics);
224 let versioned_type_where_clause =
225 implementor.versioned_type_where_clause(&lifetime, &input_generics);
226 let versioned_owned_type_where_clause =
227 implementor.versioned_owned_type_where_clause(&input_generics);
228
229 let versionize_trait_where_clause =
232 syn_unwrap!(implementor.versionize_trait_where_clause(&input_generics));
233 let versionize_owned_trait_where_clause =
234 syn_unwrap!(implementor.versionize_owned_trait_where_clause(&input_generics));
235 let unversionize_trait_where_clause =
236 syn_unwrap!(implementor.unversionize_trait_where_clause(&input_generics));
237
238 let trait_impl_generics = input_generics.split_for_impl().0;
239
240 let versionize_body = implementor.versionize_method_body();
241 let versionize_owned_body = implementor.versionize_owned_method_body();
242 let unversionize_arg_name = Ident::new("versioned", Span::call_site());
243 let unversionize_body = implementor.unversionize_method_body(&unversionize_arg_name);
244 let unversionize_error: Path = parse_const_str(UNVERSIONIZE_ERROR_NAME);
245
246 let result_type: Path = parse_const_str(RESULT_TYPE_NAME);
247 let vec_type: Path = parse_const_str(VEC_TYPE_NAME);
248
249 quote! {
250 #version_trait_impl
251
252 #[automatically_derived]
253 impl #trait_impl_generics #versionize_trait for #input_ident #ty_generics
254 #versionize_trait_where_clause
255 {
256 type Versioned<#lifetime> = #versioned_type #versioned_type_where_clause;
257
258 fn versionize(&self) -> Self::Versioned<'_> {
259 #versionize_body
260 }
261 }
262
263 #[automatically_derived]
264 impl #trait_impl_generics #versionize_owned_trait for #input_ident #ty_generics
265 #versionize_owned_trait_where_clause
266 {
267 type VersionedOwned = #versioned_owned_type #versioned_owned_type_where_clause;
268
269 fn versionize_owned(self) -> Self::VersionedOwned {
270 #versionize_owned_body
271 }
272 }
273
274 #[automatically_derived]
275 impl #trait_impl_generics #unversionize_trait for #input_ident #ty_generics
276 #unversionize_trait_where_clause
277 {
278 fn unversionize(#unversionize_arg_name: Self::VersionedOwned) -> #result_type<Self, #unversionize_error> {
279 #unversionize_body
280 }
281 }
282
283 #[automatically_derived]
284 impl #trait_impl_generics #versionize_slice_trait for #input_ident #ty_generics
285 #versionize_trait_where_clause
286 {
287 type VersionedSlice<#lifetime> = #vec_type<<Self as #versionize_trait>::Versioned<#lifetime>> #versioned_type_where_clause;
288
289 fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> {
290 slice.iter().map(|val| #versionize_trait::versionize(val)).collect()
291 }
292 }
293
294 #[automatically_derived]
295 impl #trait_impl_generics #versionize_vec_trait for #input_ident #ty_generics
296 #versionize_owned_trait_where_clause
297 {
298
299 type VersionedVec = #vec_type<<Self as #versionize_owned_trait>::VersionedOwned> #versioned_owned_type_where_clause;
300
301 fn versionize_vec(vec: #vec_type<Self>) -> Self::VersionedVec {
302 vec.into_iter().map(|val| #versionize_owned_trait::versionize_owned(val)).collect()
303 }
304 }
305
306 #[automatically_derived]
307 impl #trait_impl_generics #unversionize_vec_trait for #input_ident #ty_generics
308 #unversionize_trait_where_clause
309 {
310 fn unversionize_vec(versioned: Self::VersionedVec) -> #result_type<#vec_type<Self>, #unversionize_error> {
311 versioned
312 .into_iter()
313 .map(|versioned| <Self as #unversionize_trait>::unversionize(versioned))
314 .collect()
315 }
316 }
317 }
318 .into()
319}
320
321#[proc_macro_derive(NotVersioned)]
324pub fn derive_not_versioned(input: TokenStream) -> TokenStream {
325 let input = parse_macro_input!(input as DeriveInput);
326
327 let mut versionize_generics = input.generics.clone();
329 syn_unwrap!(add_trait_where_clause(
330 &mut versionize_generics,
331 &[parse_quote! { Self }],
332 &[SERIALIZE_TRAIT_NAME]
333 ));
334
335 let mut versionize_owned_generics = input.generics.clone();
337 syn_unwrap!(add_trait_where_clause(
338 &mut versionize_owned_generics,
339 &[parse_quote! { Self }],
340 &[SERIALIZE_TRAIT_NAME, DESERIALIZE_OWNED_TRAIT_NAME]
341 ));
342
343 let (impl_generics, ty_generics, versionize_where_clause) =
344 versionize_generics.split_for_impl();
345 let (_, _, versionize_owned_where_clause) = versionize_owned_generics.split_for_impl();
346
347 let input_ident = &input.ident;
348
349 let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
350 let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
351 let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME);
352 let unversionize_error: Path = parse_const_str(UNVERSIONIZE_ERROR_NAME);
353 let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site());
354
355 let result_type: Path = parse_const_str(RESULT_TYPE_NAME);
356
357 quote! {
358 #[automatically_derived]
359 impl #impl_generics #versionize_trait for #input_ident #ty_generics #versionize_where_clause {
360 type Versioned<#lifetime> = &#lifetime Self where Self: 'vers;
361
362 fn versionize(&self) -> Self::Versioned<'_> {
363 self
364 }
365 }
366
367 #[automatically_derived]
368 impl #impl_generics #versionize_owned_trait for #input_ident #ty_generics #versionize_owned_where_clause {
369 type VersionedOwned = Self;
370
371 fn versionize_owned(self) -> Self::VersionedOwned {
372 self
373 }
374 }
375
376 #[automatically_derived]
377 impl #impl_generics #unversionize_trait for #input_ident #ty_generics #versionize_owned_where_clause {
378 fn unversionize(versioned: Self::VersionedOwned) -> #result_type<Self, #unversionize_error> {
379 Ok(versioned)
380 }
381 }
382
383 #[automatically_derived]
384 impl #impl_generics NotVersioned for #input_ident #ty_generics #versionize_owned_where_clause {}
385
386 }
387 .into()
388}
389
390fn add_where_lifetime_bound_to_generics(generics: &mut Generics, lifetime: &Lifetime) {
392 let mut params = Vec::new();
393 for param in generics.params.iter() {
394 let param_ident = match param {
395 GenericParam::Lifetime(generic_lifetime) => {
396 if generic_lifetime.lifetime.ident == lifetime.ident {
397 continue;
398 }
399 &generic_lifetime.lifetime.ident
400 }
401 GenericParam::Type(generic_type) => &generic_type.ident,
402 GenericParam::Const(_) => continue,
403 };
404 params.push(param_ident.clone());
405 }
406
407 for param in params.iter() {
408 generics
409 .make_where_clause()
410 .predicates
411 .push(parse_quote! { #param: #lifetime });
412 }
413}
414
415fn add_lifetime_param(generics: &mut Generics, lifetime: &Lifetime) {
417 generics
418 .params
419 .push(GenericParam::Lifetime(LifetimeParam::new(lifetime.clone())));
420 for param in generics.type_params_mut() {
421 param
422 .bounds
423 .push(TypeParamBound::Lifetime(lifetime.clone()));
424 }
425}
426
427fn parse_trait_bound(trait_name: &str) -> syn::Result<TraitBound> {
429 let trait_path: Path = syn::parse_str(trait_name)?;
430 Ok(parse_quote!(#trait_path))
431}
432
433fn add_trait_where_clause<'a, S: AsRef<str>, I: IntoIterator<Item = &'a Type>>(
435 generics: &mut Generics,
436 types: I,
437 traits_name: &[S],
438) -> syn::Result<()> {
439 let preds = &mut generics.make_where_clause().predicates;
440
441 if !traits_name.is_empty() {
442 let bounds: Vec<TraitBound> = traits_name
443 .iter()
444 .map(|bound_name| parse_trait_bound(bound_name.as_ref()))
445 .collect::<syn::Result<_>>()?;
446 for ty in types {
447 preds.push(parse_quote! { #ty: #(#bounds)+* });
448 }
449 }
450
451 Ok(())
452}
453
454fn add_lifetime_where_clause<'a, S: AsRef<str>, I: IntoIterator<Item = &'a Type>>(
456 generics: &mut Generics,
457 types: I,
458 lifetimes: &[S],
459) -> syn::Result<()> {
460 let preds = &mut generics.make_where_clause().predicates;
461
462 if !lifetimes.is_empty() {
463 let bounds: Vec<Lifetime> = lifetimes
464 .iter()
465 .map(|lifetime| syn::parse_str(lifetime.as_ref()))
466 .collect::<syn::Result<_>>()?;
467 for ty in types {
468 preds.push(parse_quote! { #ty: #(#bounds)+* });
469 }
470 }
471
472 Ok(())
473}
474
475fn extend_where_clause(base_clause: &mut WhereClause, extension_clause: &WhereClause) {
477 for extend_predicate in &extension_clause.predicates {
478 if base_clause.predicates.iter().all(|base_predicate| {
479 base_predicate.to_token_stream().to_string()
480 != extend_predicate.to_token_stream().to_string()
481 }) {
482 base_clause.predicates.push(extend_predicate.clone());
483 }
484 }
485}
486
487fn punctuated_from_iter_result<T, P: Default, I: IntoIterator<Item = syn::Result<T>>>(
489 iter: I,
490) -> syn::Result<Punctuated<T, P>> {
491 let mut ret = Punctuated::new();
492 for value in iter {
493 ret.push(value?)
494 }
495 Ok(ret)
496}
497
498fn parse_const_str<T: Parse>(s: &'static str) -> T {
500 syn::parse_str(s).expect("Parsing of const string should not fail")
501}
502
503fn filter_unsized_bounds(generics: &Generics) -> Generics {
509 let mut generics = generics.clone();
510
511 for param in generics.type_params_mut() {
512 param.bounds = remove_unsized_bound(¶m.bounds);
513 }
514
515 if let Some(clause) = &mut generics.where_clause {
516 for pred in &mut clause.predicates {
517 match pred {
518 syn::WherePredicate::Lifetime(_) => {}
519 syn::WherePredicate::Type(type_predicate) => {
520 type_predicate.bounds = remove_unsized_bound(&type_predicate.bounds);
521 }
522 _ => {}
523 }
524 }
525 }
526
527 generics
528}
529
530fn remove_unsized_bound(
532 bounds: &Punctuated<TypeParamBound, Plus>,
533) -> Punctuated<TypeParamBound, Plus> {
534 bounds
535 .iter()
536 .filter(|bound| match bound {
537 TypeParamBound::Trait(trait_bound) => {
538 if !matches!(trait_bound.modifier, TraitBoundModifier::None) {
539 if let Some(segment) = trait_bound.path.segments.iter().next_back() {
540 if segment.ident == "Sized" {
541 return false;
542 }
543 }
544 }
545 true
546 }
547 TypeParamBound::Lifetime(_) => true,
548 TypeParamBound::Verbatim(_) => true,
549 _ => true,
550 })
551 .cloned()
552 .collect()
553}