tskit_derive/
lib.rs

1use proc_macro::TokenStream;
2
3fn impl_serde_json_roundtrip(name: &syn::Ident) -> TokenStream {
4    let gen = quote::quote!(
5        impl ::tskit::metadata::MetadataRoundtrip for #name {
6            fn encode(&self) -> Result<Vec<u8>, ::tskit::metadata::MetadataError> {
7                match ::serde_json::to_string(self) {
8                    Ok(x) => Ok(x.as_bytes().to_vec()),
9                    Err(e) => {
10                        Err(::tskit::metadata::MetadataError::RoundtripError { value: Box::new(e) })
11                    }
12                }
13            }
14
15            fn decode(md: &[u8]) -> Result<Self, ::tskit::metadata::MetadataError> {
16                let value: Result<Self, ::serde_json::Error> = ::serde_json::from_slice(md);
17                match value {
18                    Ok(v) => Ok(v),
19                    Err(e) => {
20                        Err(::tskit::metadata::MetadataError::RoundtripError { value: Box::new(e) })
21                    }
22                }
23            }
24        }
25    );
26    gen.into()
27}
28
29fn impl_serde_bincode_roundtrip(name: &syn::Ident) -> TokenStream {
30    let gen = quote::quote!(
31        impl ::tskit::metadata::MetadataRoundtrip for #name {
32            fn encode(&self) -> Result<Vec<u8>, ::tskit::metadata::MetadataError> {
33                match ::bincode::serialize(&self) {
34                    Ok(x) => Ok(x),
35                    Err(e) => {
36                        Err(::tskit::metadata::MetadataError::RoundtripError { value: Box::new(e) })
37                    }
38                }
39            }
40            fn decode(md: &[u8]) -> Result<Self, ::tskit::metadata::MetadataError> {
41                match ::bincode::deserialize(md) {
42                    Ok(x) => Ok(x),
43                    Err(e) => {
44                        Err(::tskit::metadata::MetadataError::RoundtripError { value: Box::new(e) })
45                    }
46                }
47            }
48        }
49    );
50    gen.into()
51}
52
53fn impl_metadata_roundtrip_macro(ast: &syn::DeriveInput) -> Result<TokenStream, syn::Error> {
54    let name = &ast.ident;
55    let attrs = &ast.attrs;
56
57    for attr in attrs.iter() {
58        if attr.path.is_ident("serializer") {
59            let lit: syn::LitStr = attr.parse_args().unwrap();
60            let serializer = lit.value();
61
62            if &serializer == "serde_json" {
63                return Ok(impl_serde_json_roundtrip(name));
64            } else if &serializer == "bincode" {
65                return Ok(impl_serde_bincode_roundtrip(name));
66            } else {
67                proc_macro_error::abort!(serializer, "is not a supported protocol.");
68            }
69        } else {
70            proc_macro_error::abort!(attr.path, "is not a supported attribute.");
71        }
72    }
73
74    proc_macro_error::abort_call_site!("missing [serializer(...)] attribute")
75}
76
77macro_rules! make_derive_metadata_tag {
78    ($function: ident, $metadatatag: ident) => {
79        #[proc_macro_error::proc_macro_error]
80        #[proc_macro_derive($metadatatag, attributes(serializer))]
81        /// Register a type as metadata.
82        pub fn $function(input: TokenStream) -> TokenStream {
83            let ast: syn::DeriveInput = match syn::parse(input) {
84                Ok(ast) => ast,
85                Err(err) => proc_macro_error::abort!(err),
86            };
87            let mut roundtrip = impl_metadata_roundtrip_macro(&ast).unwrap();
88            let name = &ast.ident;
89            let gen: proc_macro::TokenStream = quote::quote!(
90                impl ::tskit::metadata::$metadatatag for #name {}
91            )
92            .into();
93            roundtrip.extend(gen);
94            roundtrip
95        }
96    };
97}
98
99make_derive_metadata_tag!(individual_metadata_derive, IndividualMetadata);
100make_derive_metadata_tag!(mutation_metadata_derive, MutationMetadata);
101make_derive_metadata_tag!(site_metadata_derive, SiteMetadata);
102make_derive_metadata_tag!(population_metadata_derive, PopulationMetadata);
103make_derive_metadata_tag!(node_metadata_derive, NodeMetadata);
104make_derive_metadata_tag!(edge_metadata_derive, EdgeMetadata);
105make_derive_metadata_tag!(migration_metadata_derive, MigrationMetadata);