1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
//! Ulib derive macros.
//!
//! This implements the derive macro of `ulib::UniversalCopy`.
//! It is adapted from [the source code of `cust_derive`](https://docs.rs/cust_derive/latest/src/cust_derive/lib.rs.html#20-24).

use proc_macro2::{Ident, Span, TokenStream};
use syn::{
    parse_str, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, Generics,
    TypeParamBound,
};
use quote::quote;

#[proc_macro_derive(UniversalCopy)]
pub fn universal_copy(input: BaseTokenStream) -> BaseTokenStream {
    let ast = syn::parse(input).unwrap();
    let gen = impl_universal_copy(&ast);
    BaseTokenStream::from(gen)
}

use proc_macro::TokenStream as BaseTokenStream;

fn impl_universal_copy(input: &DeriveInput) -> TokenStream {
    let input_type = &input.ident;

    let check_types_code = match input.data {
        Data::Struct(ref data_struct) => type_check_struct(data_struct),
        Data::Enum(ref data_enum) => type_check_enum(data_enum),
        Data::Union(ref data_union) => type_check_union(data_union),
    };

    let type_test_func_name = format!(
        "__ulib_derive_verify_{}_can_implement_universalcopy",
        input_type.to_string().to_lowercase()
    );
    let type_test_func_ident = Ident::new(&type_test_func_name, Span::call_site());

    // If the struct/enum/union is generic, we need to add the DeviceCopy bound to the generics
    // when implementing DeviceCopy.
    let generics = add_bound_to_generics(&input.generics, quote! {
        ::std::marker::Copy
    });
    #[cfg(feature = "cuda")]
    let generics = add_bound_to_generics(&generics, quote! {
        ::std::marker::Copy
    });
    let (impl_generics, type_generics, where_clause) = generics.split_for_impl();

    // Finally, generate the unsafe impl and the type-checking function.
    #[cfg(feature = "cuda")]
    let impl_cuda = quote! {
        unsafe impl #impl_generics ::ulib::cust::memory::DeviceCopy for #input_type #type_generics #where_clause {}
    };
    #[cfg(not(feature = "cuda"))]
    let impl_cuda = quote! {};
    
    #[cfg(feature = "cuda")]
    let trait_bounds = quote! {
        ::std::marker::Copy + ::ulib::cust::memory::DeviceCopy
    };
    #[cfg(not(feature = "cuda"))]
    let trait_bounds = quote! {
        ::std::marker::Copy
    };
    
    let generated_code = quote! {
        impl #impl_generics ::std::marker::Copy for #input_type #type_generics #where_clause {}
        #impl_cuda

        #[doc(hidden)]
        #[allow(all)]
        fn #type_test_func_ident #impl_generics(value: & #input_type #type_generics) #where_clause {
            fn assert_impl<T: #trait_bounds>() {}
            #check_types_code
        }
    };

    generated_code
}

fn add_bound_to_generics(generics: &Generics, import: TokenStream) -> Generics {
    let mut new_generics = generics.clone();
    let bound: TypeParamBound = parse_str(&quote! {#import}.to_string()).unwrap();

    for type_param in &mut new_generics.type_params_mut() {
        type_param.bounds.push(bound.clone())
    }

    new_generics
}

fn type_check_struct(s: &DataStruct) -> TokenStream {
    let checks = match s.fields {
        Fields::Named(ref named_fields) => {
            let fields: Vec<&Field> = named_fields.named.iter().collect();
            check_fields(&fields)
        }
        Fields::Unnamed(ref unnamed_fields) => {
            let fields: Vec<&Field> = unnamed_fields.unnamed.iter().collect();
            check_fields(&fields)
        }
        Fields::Unit => vec![],
    };
    quote!(
        #(#checks)*
    )
}

fn type_check_enum(s: &DataEnum) -> TokenStream {
    let mut checks = vec![];

    for variant in &s.variants {
        match variant.fields {
            Fields::Named(ref named_fields) => {
                let fields: Vec<&Field> = named_fields.named.iter().collect();
                checks.extend(check_fields(&fields));
            }
            Fields::Unnamed(ref unnamed_fields) => {
                let fields: Vec<&Field> = unnamed_fields.unnamed.iter().collect();
                checks.extend(check_fields(&fields));
            }
            Fields::Unit => {}
        }
    }
    quote!(
        #(#checks)*
    )
}

fn type_check_union(s: &DataUnion) -> TokenStream {
    let fields: Vec<&Field> = s.fields.named.iter().collect();
    let checks = check_fields(&fields);
    quote!(
        #(#checks)*
    )
}

fn check_fields(fields: &[&Field]) -> Vec<TokenStream> {
    fields
        .iter()
        .map(|field| {
            let field_type = &field.ty;
            quote! {assert_impl::<#field_type>();}
        })
        .collect()
}