tfhe_versionable_derive/
lib.rs1mod associated;
9mod dispatch_type;
10mod transparent;
11mod version_type;
12mod versionize_attribute;
13mod versionize_impl;
14
15use dispatch_type::DispatchType;
16use proc_macro::TokenStream;
17use proc_macro2::Span;
18use quote::{quote, ToTokens};
19use syn::parse::Parse;
20use syn::punctuated::Punctuated;
21use syn::token::Plus;
22use syn::{
23 parse_macro_input, parse_quote, DeriveInput, GenericParam, Generics, Ident, Lifetime,
24 LifetimeParam, Path, TraitBound, TraitBoundModifier, Type, TypeParamBound, WhereClause,
25};
26
27macro_rules! crate_full_path {
29 ($trait_name:expr) => {
30 concat!("::tfhe_versionable::", $trait_name)
31 };
32}
33
34pub(crate) const LIFETIME_NAME: &str = "'vers";
35pub(crate) const VERSION_TRAIT_NAME: &str = crate_full_path!("Version");
36pub(crate) const DISPATCH_TRAIT_NAME: &str = crate_full_path!("VersionsDispatch");
37pub(crate) const VERSIONIZE_TRAIT_NAME: &str = crate_full_path!("Versionize");
38pub(crate) const VERSIONIZE_OWNED_TRAIT_NAME: &str = crate_full_path!("VersionizeOwned");
39pub(crate) const VERSIONIZE_SLICE_TRAIT_NAME: &str = crate_full_path!("VersionizeSlice");
40pub(crate) const VERSIONIZE_VEC_TRAIT_NAME: &str = crate_full_path!("VersionizeVec");
41pub(crate) const UNVERSIONIZE_TRAIT_NAME: &str = crate_full_path!("Unversionize");
42pub(crate) const UNVERSIONIZE_VEC_TRAIT_NAME: &str = crate_full_path!("UnversionizeVec");
43pub(crate) const UPGRADE_TRAIT_NAME: &str = crate_full_path!("Upgrade");
44pub(crate) const UNVERSIONIZE_ERROR_NAME: &str = crate_full_path!("UnversionizeError");
45
46pub(crate) const SERIALIZE_TRAIT_NAME: &str = "::serde::Serialize";
47pub(crate) const DESERIALIZE_TRAIT_NAME: &str = "::serde::Deserialize";
48pub(crate) const DESERIALIZE_OWNED_TRAIT_NAME: &str = "::serde::de::DeserializeOwned";
49pub(crate) const FROM_TRAIT_NAME: &str = "::core::convert::From";
50pub(crate) const TRY_INTO_TRAIT_NAME: &str = "::core::convert::TryInto";
51pub(crate) const INTO_TRAIT_NAME: &str = "::core::convert::Into";
52pub(crate) const ERROR_TRAIT_NAME: &str = "::core::error::Error";
53pub(crate) const SYNC_TRAIT_NAME: &str = "::core::marker::Sync";
54pub(crate) const SEND_TRAIT_NAME: &str = "::core::marker::Send";
55pub(crate) const DEFAULT_TRAIT_NAME: &str = "::core::default::Default";
56pub(crate) const STATIC_LIFETIME_NAME: &str = "'static";
57
58use associated::AssociatingTrait;
59use versionize_impl::VersionizeImplementor;
60
61use crate::version_type::VersionType;
62use crate::versionize_attribute::VersionizeAttribute;
63
64macro_rules! syn_unwrap {
67 ($e:expr) => {
68 match $e {
69 Ok(res) => res,
70 Err(err) => return err.to_compile_error().into(),
71 }
72 };
73}
74
75#[proc_macro_derive(Version, attributes(versionize))]
76pub fn derive_version(input: TokenStream) -> TokenStream {
78 let input = parse_macro_input!(input as DeriveInput);
79
80 impl_version_trait(&input).into()
81}
82
83fn impl_version_trait(input: &DeriveInput) -> proc_macro2::TokenStream {
86 let version_trait = syn_unwrap!(AssociatingTrait::<VersionType>::new(
87 input,
88 VERSION_TRAIT_NAME,
89 ));
90
91 let version_types = syn_unwrap!(version_trait.generate_types_declarations());
92
93 let version_impl = syn_unwrap!(version_trait.generate_impl());
94
95 quote! {
96 const _: () = {
97 #version_types
98
99 #[automatically_derived]
100 #version_impl
101 };
102 }
103}
104
105#[proc_macro_derive(VersionsDispatch)]
109pub fn derive_versions_dispatch(input: TokenStream) -> TokenStream {
110 let input = parse_macro_input!(input as DeriveInput);
111
112 let dispatch_trait = syn_unwrap!(AssociatingTrait::<DispatchType>::new(
113 &input,
114 DISPATCH_TRAIT_NAME,
115 ));
116
117 let dispatch_types = syn_unwrap!(dispatch_trait.generate_types_declarations());
118
119 let dispatch_impl = syn_unwrap!(dispatch_trait.generate_impl());
120
121 quote! {
122 const _: () = {
123 #dispatch_types
124
125 #[automatically_derived]
126 #dispatch_impl
127 };
128 }
129 .into()
130}
131
132#[proc_macro_derive(Versionize, attributes(versionize))]
180pub fn derive_versionize(input: TokenStream) -> TokenStream {
181 let input = parse_macro_input!(input as DeriveInput);
182
183 let input_generics = filter_unsized_bounds(&input.generics);
184
185 let attributes = syn_unwrap!(VersionizeAttribute::parse_from_attributes_list(
186 &input.attrs
187 ));
188
189 let implementor = syn_unwrap!(VersionizeImplementor::new(
190 attributes,
191 &input.data,
192 Span::call_site()
193 ));
194
195 let version_trait_impl: Option<proc_macro2::TokenStream> =
198 if implementor.is_directly_versioned() {
199 Some(impl_version_trait(&input))
200 } else {
201 None
202 };
203
204 let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
206 let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
207 let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME);
208 let versionize_vec_trait: Path = parse_const_str(VERSIONIZE_VEC_TRAIT_NAME);
209 let versionize_slice_trait: Path = parse_const_str(VERSIONIZE_SLICE_TRAIT_NAME);
210 let unversionize_vec_trait: Path = parse_const_str(UNVERSIONIZE_VEC_TRAIT_NAME);
211
212 let input_ident = &input.ident;
213 let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site());
214
215 let (_, ty_generics, _) = input_generics.split_for_impl();
217
218 let versioned_type = implementor.versioned_type(&lifetime, &input_generics);
220 let versioned_owned_type = implementor.versioned_owned_type(&input_generics);
221 let versioned_type_where_clause =
222 implementor.versioned_type_where_clause(&lifetime, &input_generics);
223 let versioned_owned_type_where_clause =
224 implementor.versioned_owned_type_where_clause(&input_generics);
225
226 let versionize_trait_where_clause =
229 syn_unwrap!(implementor.versionize_trait_where_clause(&input_generics));
230 let versionize_owned_trait_where_clause =
231 syn_unwrap!(implementor.versionize_owned_trait_where_clause(&input_generics));
232 let unversionize_trait_where_clause =
233 syn_unwrap!(implementor.unversionize_trait_where_clause(&input_generics));
234
235 let trait_impl_generics = input_generics.split_for_impl().0;
236
237 let versionize_body = implementor.versionize_method_body();
238 let versionize_owned_body = implementor.versionize_owned_method_body();
239 let unversionize_arg_name = Ident::new("versioned", Span::call_site());
240 let unversionize_body = implementor.unversionize_method_body(&unversionize_arg_name);
241 let unversionize_error: Path = parse_const_str(UNVERSIONIZE_ERROR_NAME);
242
243 quote! {
244 #version_trait_impl
245
246 #[automatically_derived]
247 impl #trait_impl_generics #versionize_trait for #input_ident #ty_generics
248 #versionize_trait_where_clause
249 {
250 type Versioned<#lifetime> = #versioned_type #versioned_type_where_clause;
251
252 fn versionize(&self) -> Self::Versioned<'_> {
253 #versionize_body
254 }
255 }
256
257 #[automatically_derived]
258 impl #trait_impl_generics #versionize_owned_trait for #input_ident #ty_generics
259 #versionize_owned_trait_where_clause
260 {
261 type VersionedOwned = #versioned_owned_type #versioned_owned_type_where_clause;
262
263 fn versionize_owned(self) -> Self::VersionedOwned {
264 #versionize_owned_body
265 }
266 }
267
268 #[automatically_derived]
269 impl #trait_impl_generics #unversionize_trait for #input_ident #ty_generics
270 #unversionize_trait_where_clause
271 {
272 fn unversionize(#unversionize_arg_name: Self::VersionedOwned) -> Result<Self, #unversionize_error> {
273 #unversionize_body
274 }
275 }
276
277 #[automatically_derived]
278 impl #trait_impl_generics #versionize_slice_trait for #input_ident #ty_generics
279 #versionize_trait_where_clause
280 {
281 type VersionedSlice<#lifetime> = Vec<<Self as #versionize_trait>::Versioned<#lifetime>> #versioned_type_where_clause;
282
283 fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> {
284 slice.iter().map(|val| #versionize_trait::versionize(val)).collect()
285 }
286 }
287
288 #[automatically_derived]
289 impl #trait_impl_generics #versionize_vec_trait for #input_ident #ty_generics
290 #versionize_owned_trait_where_clause
291 {
292
293 type VersionedVec = Vec<<Self as #versionize_owned_trait>::VersionedOwned> #versioned_owned_type_where_clause;
294
295 fn versionize_vec(vec: Vec<Self>) -> Self::VersionedVec {
296 vec.into_iter().map(|val| #versionize_owned_trait::versionize_owned(val)).collect()
297 }
298 }
299
300 #[automatically_derived]
301 impl #trait_impl_generics #unversionize_vec_trait for #input_ident #ty_generics
302 #unversionize_trait_where_clause
303 {
304 fn unversionize_vec(versioned: Self::VersionedVec) -> Result<Vec<Self>, #unversionize_error> {
305 versioned
306 .into_iter()
307 .map(|versioned| <Self as #unversionize_trait>::unversionize(versioned))
308 .collect()
309 }
310 }
311 }
312 .into()
313}
314
315#[proc_macro_derive(NotVersioned)]
318pub fn derive_not_versioned(input: TokenStream) -> TokenStream {
319 let input = parse_macro_input!(input as DeriveInput);
320
321 let mut versionize_generics = input.generics.clone();
323 syn_unwrap!(add_trait_where_clause(
324 &mut versionize_generics,
325 &[parse_quote! { Self }],
326 &[SERIALIZE_TRAIT_NAME]
327 ));
328
329 let mut versionize_owned_generics = input.generics.clone();
331 syn_unwrap!(add_trait_where_clause(
332 &mut versionize_owned_generics,
333 &[parse_quote! { Self }],
334 &[SERIALIZE_TRAIT_NAME, DESERIALIZE_OWNED_TRAIT_NAME]
335 ));
336
337 let (impl_generics, ty_generics, versionize_where_clause) =
338 versionize_generics.split_for_impl();
339 let (_, _, versionize_owned_where_clause) = versionize_owned_generics.split_for_impl();
340
341 let input_ident = &input.ident;
342
343 let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
344 let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
345 let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME);
346 let unversionize_error: Path = parse_const_str(UNVERSIONIZE_ERROR_NAME);
347 let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site());
348
349 quote! {
350 #[automatically_derived]
351 impl #impl_generics #versionize_trait for #input_ident #ty_generics #versionize_where_clause {
352 type Versioned<#lifetime> = &#lifetime Self where Self: 'vers;
353
354 fn versionize(&self) -> Self::Versioned<'_> {
355 self
356 }
357 }
358
359 #[automatically_derived]
360 impl #impl_generics #versionize_owned_trait for #input_ident #ty_generics #versionize_owned_where_clause {
361 type VersionedOwned = Self;
362
363 fn versionize_owned(self) -> Self::VersionedOwned {
364 self
365 }
366 }
367
368 #[automatically_derived]
369 impl #impl_generics #unversionize_trait for #input_ident #ty_generics #versionize_owned_where_clause {
370 fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, #unversionize_error> {
371 Ok(versioned)
372 }
373 }
374
375 #[automatically_derived]
376 impl #impl_generics NotVersioned for #input_ident #ty_generics #versionize_owned_where_clause {}
377
378 }
379 .into()
380}
381
382fn add_where_lifetime_bound_to_generics(generics: &mut Generics, lifetime: &Lifetime) {
384 let mut params = Vec::new();
385 for param in generics.params.iter() {
386 let param_ident = match param {
387 GenericParam::Lifetime(generic_lifetime) => {
388 if generic_lifetime.lifetime.ident == lifetime.ident {
389 continue;
390 }
391 &generic_lifetime.lifetime.ident
392 }
393 GenericParam::Type(generic_type) => &generic_type.ident,
394 GenericParam::Const(_) => continue,
395 };
396 params.push(param_ident.clone());
397 }
398
399 for param in params.iter() {
400 generics
401 .make_where_clause()
402 .predicates
403 .push(parse_quote! { #param: #lifetime });
404 }
405}
406
407fn add_lifetime_param(generics: &mut Generics, lifetime: &Lifetime) {
409 generics
410 .params
411 .push(GenericParam::Lifetime(LifetimeParam::new(lifetime.clone())));
412 for param in generics.type_params_mut() {
413 param
414 .bounds
415 .push(TypeParamBound::Lifetime(lifetime.clone()));
416 }
417}
418
419fn parse_trait_bound(trait_name: &str) -> syn::Result<TraitBound> {
421 let trait_path: Path = syn::parse_str(trait_name)?;
422 Ok(parse_quote!(#trait_path))
423}
424
425fn add_trait_where_clause<'a, S: AsRef<str>, I: IntoIterator<Item = &'a Type>>(
427 generics: &mut Generics,
428 types: I,
429 traits_name: &[S],
430) -> syn::Result<()> {
431 let preds = &mut generics.make_where_clause().predicates;
432
433 if !traits_name.is_empty() {
434 let bounds: Vec<TraitBound> = traits_name
435 .iter()
436 .map(|bound_name| parse_trait_bound(bound_name.as_ref()))
437 .collect::<syn::Result<_>>()?;
438 for ty in types {
439 preds.push(parse_quote! { #ty: #(#bounds)+* });
440 }
441 }
442
443 Ok(())
444}
445
446fn add_lifetime_where_clause<'a, S: AsRef<str>, I: IntoIterator<Item = &'a Type>>(
448 generics: &mut Generics,
449 types: I,
450 lifetimes: &[S],
451) -> syn::Result<()> {
452 let preds = &mut generics.make_where_clause().predicates;
453
454 if !lifetimes.is_empty() {
455 let bounds: Vec<Lifetime> = lifetimes
456 .iter()
457 .map(|lifetime| syn::parse_str(lifetime.as_ref()))
458 .collect::<syn::Result<_>>()?;
459 for ty in types {
460 preds.push(parse_quote! { #ty: #(#bounds)+* });
461 }
462 }
463
464 Ok(())
465}
466
467fn extend_where_clause(base_clause: &mut WhereClause, extension_clause: &WhereClause) {
469 for extend_predicate in &extension_clause.predicates {
470 if base_clause.predicates.iter().all(|base_predicate| {
471 base_predicate.to_token_stream().to_string()
472 != extend_predicate.to_token_stream().to_string()
473 }) {
474 base_clause.predicates.push(extend_predicate.clone());
475 }
476 }
477}
478
479fn punctuated_from_iter_result<T, P: Default, I: IntoIterator<Item = syn::Result<T>>>(
481 iter: I,
482) -> syn::Result<Punctuated<T, P>> {
483 let mut ret = Punctuated::new();
484 for value in iter {
485 ret.push(value?)
486 }
487 Ok(ret)
488}
489
490fn parse_const_str<T: Parse>(s: &'static str) -> T {
492 syn::parse_str(s).expect("Parsing of const string should not fail")
493}
494
495fn filter_unsized_bounds(generics: &Generics) -> Generics {
501 let mut generics = generics.clone();
502
503 for param in generics.type_params_mut() {
504 param.bounds = remove_unsized_bound(¶m.bounds);
505 }
506
507 if let Some(clause) = &mut generics.where_clause {
508 for pred in &mut clause.predicates {
509 match pred {
510 syn::WherePredicate::Lifetime(_) => {}
511 syn::WherePredicate::Type(type_predicate) => {
512 type_predicate.bounds = remove_unsized_bound(&type_predicate.bounds);
513 }
514 _ => {}
515 }
516 }
517 }
518
519 generics
520}
521
522fn remove_unsized_bound(
524 bounds: &Punctuated<TypeParamBound, Plus>,
525) -> Punctuated<TypeParamBound, Plus> {
526 bounds
527 .iter()
528 .filter(|bound| match bound {
529 TypeParamBound::Trait(trait_bound) => {
530 if !matches!(trait_bound.modifier, TraitBoundModifier::None) {
531 if let Some(segment) = trait_bound.path.segments.iter().next_back() {
532 if segment.ident == "Sized" {
533 return false;
534 }
535 }
536 }
537 true
538 }
539 TypeParamBound::Lifetime(_) => true,
540 TypeParamBound::Verbatim(_) => true,
541 _ => true,
542 })
543 .cloned()
544 .collect()
545}