1use proc_macro2::{Ident, Span, TokenStream};
7use syn::{
8 parse_str, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, Generics,
9 TypeParamBound,
10};
11use quote::quote;
12
13#[proc_macro_derive(UniversalCopy)]
14pub fn universal_copy(input: BaseTokenStream) -> BaseTokenStream {
15 let ast = syn::parse(input).unwrap();
16 let gen = impl_universal_copy(&ast);
17 BaseTokenStream::from(gen)
18}
19
20use proc_macro::TokenStream as BaseTokenStream;
21
22fn impl_universal_copy(input: &DeriveInput) -> TokenStream {
23 let input_type = &input.ident;
24
25 let check_types_code = match input.data {
26 Data::Struct(ref data_struct) => type_check_struct(data_struct),
27 Data::Enum(ref data_enum) => type_check_enum(data_enum),
28 Data::Union(ref data_union) => type_check_union(data_union),
29 };
30
31 let type_test_func_name = format!(
32 "__ulib_derive_verify_{}_can_implement_universalcopy",
33 input_type.to_string().to_lowercase()
34 );
35 let type_test_func_ident = Ident::new(&type_test_func_name, Span::call_site());
36
37 let generics = add_bound_to_generics(&input.generics, quote! {
40 ::std::marker::Copy
41 });
42 #[cfg(feature = "cuda")]
43 let generics = add_bound_to_generics(&generics, quote! {
44 ::std::marker::Copy
45 });
46 let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
47
48 #[cfg(feature = "cuda")]
50 let impl_cuda = quote! {
51 unsafe impl #impl_generics ::ulib::cust::memory::DeviceCopy for #input_type #type_generics #where_clause {}
52 };
53 #[cfg(not(feature = "cuda"))]
54 let impl_cuda = quote! {};
55
56 #[cfg(feature = "cuda")]
57 let trait_bounds = quote! {
58 ::std::marker::Copy + ::ulib::cust::memory::DeviceCopy
59 };
60 #[cfg(not(feature = "cuda"))]
61 let trait_bounds = quote! {
62 ::std::marker::Copy
63 };
64
65 let generated_code = quote! {
66 impl #impl_generics ::std::marker::Copy for #input_type #type_generics #where_clause {}
67 #impl_cuda
68
69 #[doc(hidden)]
70 #[allow(all)]
71 fn #type_test_func_ident #impl_generics(value: & #input_type #type_generics) #where_clause {
72 fn assert_impl<T: #trait_bounds>() {}
73 #check_types_code
74 }
75 };
76
77 generated_code
78}
79
80fn add_bound_to_generics(generics: &Generics, import: TokenStream) -> Generics {
81 let mut new_generics = generics.clone();
82 let bound: TypeParamBound = parse_str("e! {#import}.to_string()).unwrap();
83
84 for type_param in &mut new_generics.type_params_mut() {
85 type_param.bounds.push(bound.clone())
86 }
87
88 new_generics
89}
90
91fn type_check_struct(s: &DataStruct) -> TokenStream {
92 let checks = match s.fields {
93 Fields::Named(ref named_fields) => {
94 let fields: Vec<&Field> = named_fields.named.iter().collect();
95 check_fields(&fields)
96 }
97 Fields::Unnamed(ref unnamed_fields) => {
98 let fields: Vec<&Field> = unnamed_fields.unnamed.iter().collect();
99 check_fields(&fields)
100 }
101 Fields::Unit => vec![],
102 };
103 quote!(
104 #(#checks)*
105 )
106}
107
108fn type_check_enum(s: &DataEnum) -> TokenStream {
109 let mut checks = vec![];
110
111 for variant in &s.variants {
112 match variant.fields {
113 Fields::Named(ref named_fields) => {
114 let fields: Vec<&Field> = named_fields.named.iter().collect();
115 checks.extend(check_fields(&fields));
116 }
117 Fields::Unnamed(ref unnamed_fields) => {
118 let fields: Vec<&Field> = unnamed_fields.unnamed.iter().collect();
119 checks.extend(check_fields(&fields));
120 }
121 Fields::Unit => {}
122 }
123 }
124 quote!(
125 #(#checks)*
126 )
127}
128
129fn type_check_union(s: &DataUnion) -> TokenStream {
130 let fields: Vec<&Field> = s.fields.named.iter().collect();
131 let checks = check_fields(&fields);
132 quote!(
133 #(#checks)*
134 )
135}
136
137fn check_fields(fields: &[&Field]) -> Vec<TokenStream> {
138 fields
139 .iter()
140 .map(|field| {
141 let field_type = &field.ty;
142 quote! {assert_impl::<#field_type>();}
143 })
144 .collect()
145}