ulib_derive/
lib.rs

1//! Ulib derive macros.
2//!
3//! This implements the derive macro of `ulib::UniversalCopy`.
4//! It is adapted from [the source code of `cust_derive`](https://docs.rs/cust_derive/latest/src/cust_derive/lib.rs.html#20-24).
5
6use proc_macro2::{Ident, Span, TokenStream};
7use syn::{
8    parse_str, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, Generics,
9    TypeParamBound,
10};
11use quote::quote;
12
13#[proc_macro_derive(UniversalCopy)]
14pub fn universal_copy(input: BaseTokenStream) -> BaseTokenStream {
15    let ast = syn::parse(input).unwrap();
16    let gen = impl_universal_copy(&ast);
17    BaseTokenStream::from(gen)
18}
19
20use proc_macro::TokenStream as BaseTokenStream;
21
22fn impl_universal_copy(input: &DeriveInput) -> TokenStream {
23    let input_type = &input.ident;
24
25    let check_types_code = match input.data {
26        Data::Struct(ref data_struct) => type_check_struct(data_struct),
27        Data::Enum(ref data_enum) => type_check_enum(data_enum),
28        Data::Union(ref data_union) => type_check_union(data_union),
29    };
30
31    let type_test_func_name = format!(
32        "__ulib_derive_verify_{}_can_implement_universalcopy",
33        input_type.to_string().to_lowercase()
34    );
35    let type_test_func_ident = Ident::new(&type_test_func_name, Span::call_site());
36
37    // If the struct/enum/union is generic, we need to add the DeviceCopy bound to the generics
38    // when implementing DeviceCopy.
39    let generics = add_bound_to_generics(&input.generics, quote! {
40        ::std::marker::Copy
41    });
42    #[cfg(feature = "cuda")]
43    let generics = add_bound_to_generics(&generics, quote! {
44        ::std::marker::Copy
45    });
46    let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
47
48    // Finally, generate the unsafe impl and the type-checking function.
49    #[cfg(feature = "cuda")]
50    let impl_cuda = quote! {
51        unsafe impl #impl_generics ::ulib::cust::memory::DeviceCopy for #input_type #type_generics #where_clause {}
52    };
53    #[cfg(not(feature = "cuda"))]
54    let impl_cuda = quote! {};
55    
56    #[cfg(feature = "cuda")]
57    let trait_bounds = quote! {
58        ::std::marker::Copy + ::ulib::cust::memory::DeviceCopy
59    };
60    #[cfg(not(feature = "cuda"))]
61    let trait_bounds = quote! {
62        ::std::marker::Copy
63    };
64    
65    let generated_code = quote! {
66        impl #impl_generics ::std::marker::Copy for #input_type #type_generics #where_clause {}
67        #impl_cuda
68
69        #[doc(hidden)]
70        #[allow(all)]
71        fn #type_test_func_ident #impl_generics(value: & #input_type #type_generics) #where_clause {
72            fn assert_impl<T: #trait_bounds>() {}
73            #check_types_code
74        }
75    };
76
77    generated_code
78}
79
80fn add_bound_to_generics(generics: &Generics, import: TokenStream) -> Generics {
81    let mut new_generics = generics.clone();
82    let bound: TypeParamBound = parse_str(&quote! {#import}.to_string()).unwrap();
83
84    for type_param in &mut new_generics.type_params_mut() {
85        type_param.bounds.push(bound.clone())
86    }
87
88    new_generics
89}
90
91fn type_check_struct(s: &DataStruct) -> TokenStream {
92    let checks = match s.fields {
93        Fields::Named(ref named_fields) => {
94            let fields: Vec<&Field> = named_fields.named.iter().collect();
95            check_fields(&fields)
96        }
97        Fields::Unnamed(ref unnamed_fields) => {
98            let fields: Vec<&Field> = unnamed_fields.unnamed.iter().collect();
99            check_fields(&fields)
100        }
101        Fields::Unit => vec![],
102    };
103    quote!(
104        #(#checks)*
105    )
106}
107
108fn type_check_enum(s: &DataEnum) -> TokenStream {
109    let mut checks = vec![];
110
111    for variant in &s.variants {
112        match variant.fields {
113            Fields::Named(ref named_fields) => {
114                let fields: Vec<&Field> = named_fields.named.iter().collect();
115                checks.extend(check_fields(&fields));
116            }
117            Fields::Unnamed(ref unnamed_fields) => {
118                let fields: Vec<&Field> = unnamed_fields.unnamed.iter().collect();
119                checks.extend(check_fields(&fields));
120            }
121            Fields::Unit => {}
122        }
123    }
124    quote!(
125        #(#checks)*
126    )
127}
128
129fn type_check_union(s: &DataUnion) -> TokenStream {
130    let fields: Vec<&Field> = s.fields.named.iter().collect();
131    let checks = check_fields(&fields);
132    quote!(
133        #(#checks)*
134    )
135}
136
137fn check_fields(fields: &[&Field]) -> Vec<TokenStream> {
138    fields
139        .iter()
140        .map(|field| {
141            let field_type = &field.ty;
142            quote! {assert_impl::<#field_type>();}
143        })
144        .collect()
145}