sad_macros/
lib.rs

1// use derivation::parse_said_args;
2use field::TransField;
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{self};
6
7mod field;
8mod version;
9use version::parse_version_args;
10
11#[proc_macro_derive(SAD, attributes(said, version))]
12pub fn compute_digest_derive(input: TokenStream) -> TokenStream {
13    // Construct a representation of Rust code as a syntax tree
14    // that we can manipulate
15    let ast = syn::parse(input).unwrap();
16
17    // Build the trait implementation
18    impl_compute_digest(&ast)
19}
20
21fn impl_compute_digest(ast: &syn::DeriveInput) -> TokenStream {
22    let name = &ast.ident;
23    let fname = format!("{}TMP", name);
24    let varname = syn::Ident::new(&fname, name.span());
25
26    let generics = &ast.generics;
27    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
28
29    // Check if versioned attribute is added.
30    let version = ast
31        .attrs
32        .iter()
33        .find(|attr| attr.path().is_ident("version"))
34        .map(parse_version_args);
35
36    let fields = match &ast.data {
37        syn::Data::Struct(s) => s.fields.clone(),
38        _ => panic!("Not a struct"),
39    }
40    .into_iter()
41    .map(TransField::from_ast);
42
43    // Generate body of newly created struct fields.
44    // Replace field type with String if it is tagged as said.
45    let body = fields.clone().map(|field| {
46        if !field.said {
47            let original = field.original;
48            quote! {#original}
49        } else {
50            let name = &field.name;
51            let attrs = field.attributes;
52            quote! {
53                #(#attrs)*
54                #name: String
55            }
56        }
57    });
58
59    // Set fields tagged as said to computed digest string, depending on
60    // digest set in `dig_length` variable. Needed for generation of From
61    // implementation.
62    let concrete = fields.clone().map(|field| {
63        let name = &field.name;
64        if field.said {
65            quote! {#name: "#".repeat(dig_length).to_string()}
66        } else {
67            quote! {#name: value.#name.clone()}
68        }
69    });
70
71    // Set fields tagged as said to computed SAID set in `digest` variable.
72    let out = fields.map(|field| {
73        let name = &field.name;
74        if field.said {
75            quote! {self.#name = digest.clone();}
76        } else {
77            quote! {}
78        }
79    });
80
81    // Adding version field logic.
82    let version_field = if version.is_some() {
83        quote! {
84        #[serde(rename = "v")]
85        version: SerializationInfo,
86        }
87    } else {
88        quote! {}
89    };
90
91    // If version was set, implement Encode trait
92    let encode = if let Some((prot, major, minor)) = version.as_ref() {
93        quote! {
94                #[derive(Serialize)]
95            struct Version<D> {
96                v: SerializationInfo,
97                #[serde(flatten)]
98                d: D
99            }
100
101                use said::version::Encode;
102                impl #impl_generics Encode for #name #ty_generics #where_clause {
103                    fn encode(&self, code: &HashFunctionCode, format: &SerializationFormats) -> Result<Vec<u8>, said::version::error::Error> {
104                        let size = self.derivation_data(code, format).len();
105                        let v = SerializationInfo::new(#prot.to_string(), #major, #minor, format.clone(), size);
106                        let versioned = Version {v, d: self.clone()};
107                        Ok(format.encode(&versioned).unwrap())
108                    }
109                }
110
111
112        }
113    } else {
114        quote!()
115    };
116
117    let tmp_struct = if let Some((prot, major, minor)) = version {
118        quote! {
119           let mut tmp_self = Self {
120                version: SerializationInfo::new_empty(#prot.to_string(), #major, #minor, SerializationFormats::JSON),
121                #(#concrete,)*
122                };
123            let enc = tmp_self.version.serialize(&tmp_self).unwrap();
124            tmp_self.version.size = enc.len();
125            tmp_self
126        }
127    } else {
128        quote! {Self {
129            #(#concrete,)*
130        }}
131    };
132
133    let gen = quote! {
134    // Create temporary, serializable struct
135    #[derive(Serialize)]
136    struct #varname #ty_generics #where_clause {
137            #version_field
138            #(#body,)*
139    }
140
141    #encode
142
143    impl #impl_generics From<(&#name #ty_generics, usize)> for #varname #ty_generics #where_clause {
144        fn from(value: (&#name #ty_generics, usize)) -> Self {
145            let dig_length = value.1;
146
147            let value = value.0;
148            #tmp_struct
149        }
150    }
151
152    impl #impl_generics SAD for #name #ty_generics #where_clause {
153        fn compute_digest(&mut self, code: &HashFunctionCode, format: &SerializationFormats ) {
154            use said::derivation::{HashFunctionCode, HashFunction};
155            let serialized = self.derivation_data(code, format);
156            let digest = Some(HashFunction::from(code.clone()).derive(&serialized));
157            #(#out;)*
158        }
159
160        fn derivation_data(&self, code: &HashFunctionCode, serialization_format: &SerializationFormats) -> Vec<u8> {
161            use said::derivation::HashFunctionCode;
162            use said::sad::DerivationCode;
163            let tmp: #varname #ty_generics = (self, code.full_size()).into();
164            serialization_format.encode(&tmp).unwrap()
165        }
166    };
167    };
168    gen.into()
169}