1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
extern crate proc_macro;
extern crate syn;
#[macro_use]
extern crate quote;
extern crate synstructure;

use proc_macro::TokenStream;

#[proc_macro_derive(Shippai)]
pub fn shippai(input: TokenStream) -> TokenStream {
    let ast = syn::parse(input).unwrap();
    let gen = impl_shippai(&ast);
    gen.into()
}

fn impl_shippai(ast: &syn::DeriveInput) -> quote::Tokens {
    let error_name = &ast.ident;
    let fn_name = syn::Ident::from(format!("shippai_cast_error_{}", error_name));

    let variant_helpers = impl_enums(error_name, ast);

    quote! {
        #[no_mangle]
        pub unsafe extern "C" fn #fn_name(t: *const ShippaiError) -> *const #error_name {
            (*t).error
                .downcast_ref::<#error_name>()
                .map(|x| x as *const #error_name)
                .unwrap_or_else(::std::ptr::null)
        }

        #( #variant_helpers )*
    }
}

fn impl_enums(error_name: &syn::Ident, ast: &syn::DeriveInput) -> Option<quote::Tokens> {
    match ast.data {
        syn::Data::Enum(_) => (),
        _ => return None
    }

    let mut exported_discriminants = vec![];

    let s = synstructure::Structure::new(ast);
    let match_arms = s.each_variant(|v| {
        let name = &v.ast().ident;
        exported_discriminants.push(
            syn::Ident::from(format!("SHIPPAI_VARIANT_{}_{}", error_name, name))
        );
        let index = syn::Index::from(exported_discriminants.len());
        quote!(return #index)
    });

    let indices = (1 .. (exported_discriminants.len() + 1))
        .map(syn::Index::from);
    let fn_name = syn::Ident::from(format!("shippai_get_variant_{}", error_name));
    Some(quote! {
        #(
            #[no_mangle]
            pub static #exported_discriminants: u8 = #indices;
        )*

        #[no_mangle]
        pub unsafe extern "C" fn #fn_name(f: *const #error_name) -> u8 {
            match *f {
                #match_arms
            }
        }
    })
}