1use field::TransField;
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{self};
6
7mod field;
8mod version;
9use version::parse_version_args;
10
11#[proc_macro_derive(SAD, attributes(said, version))]
12pub fn compute_digest_derive(input: TokenStream) -> TokenStream {
13 let ast = syn::parse(input).unwrap();
16
17 impl_compute_digest(&ast)
19}
20
21fn impl_compute_digest(ast: &syn::DeriveInput) -> TokenStream {
22 let name = &ast.ident;
23 let fname = format!("{}TMP", name);
24 let varname = syn::Ident::new(&fname, name.span());
25
26 let generics = &ast.generics;
27 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
28
29 let version = ast
31 .attrs
32 .iter()
33 .find(|attr| attr.path().is_ident("version"))
34 .map(parse_version_args);
35
36 let fields = match &ast.data {
37 syn::Data::Struct(s) => s.fields.clone(),
38 _ => panic!("Not a struct"),
39 }
40 .into_iter()
41 .map(TransField::from_ast);
42
43 let body = fields.clone().map(|field| {
46 if !field.said {
47 let original = field.original;
48 quote! {#original}
49 } else {
50 let name = &field.name;
51 let attrs = field.attributes;
52 quote! {
53 #(#attrs)*
54 #name: String
55 }
56 }
57 });
58
59 let concrete = fields.clone().map(|field| {
63 let name = &field.name;
64 if field.said {
65 quote! {#name: "#".repeat(dig_length).to_string()}
66 } else {
67 quote! {#name: value.#name.clone()}
68 }
69 });
70
71 let out = fields.map(|field| {
73 let name = &field.name;
74 if field.said {
75 quote! {self.#name = digest.clone();}
76 } else {
77 quote! {}
78 }
79 });
80
81 let version_field = if version.is_some() {
83 quote! {
84 #[serde(rename = "v")]
85 version: SerializationInfo,
86 }
87 } else {
88 quote! {}
89 };
90
91 let encode = if let Some((prot, major, minor)) = version.as_ref() {
93 quote! {
94 #[derive(Serialize)]
95 struct Version<D> {
96 v: SerializationInfo,
97 #[serde(flatten)]
98 d: D
99 }
100
101 use said::version::Encode;
102 impl #impl_generics Encode for #name #ty_generics #where_clause {
103 fn encode(&self, code: &HashFunctionCode, format: &SerializationFormats) -> Result<Vec<u8>, said::version::error::Error> {
104 let size = self.derivation_data(code, format).len();
105 let v = SerializationInfo::new(#prot.to_string(), #major, #minor, format.clone(), size);
106 let versioned = Version {v, d: self.clone()};
107 Ok(format.encode(&versioned).unwrap())
108 }
109 }
110
111
112 }
113 } else {
114 quote!()
115 };
116
117 let tmp_struct = if let Some((prot, major, minor)) = version {
118 quote! {
119 let mut tmp_self = Self {
120 version: SerializationInfo::new_empty(#prot.to_string(), #major, #minor, SerializationFormats::JSON),
121 #(#concrete,)*
122 };
123 let enc = tmp_self.version.serialize(&tmp_self).unwrap();
124 tmp_self.version.size = enc.len();
125 tmp_self
126 }
127 } else {
128 quote! {Self {
129 #(#concrete,)*
130 }}
131 };
132
133 let gen = quote! {
134 #[derive(Serialize)]
136 struct #varname #ty_generics #where_clause {
137 #version_field
138 #(#body,)*
139 }
140
141 #encode
142
143 impl #impl_generics From<(&#name #ty_generics, usize)> for #varname #ty_generics #where_clause {
144 fn from(value: (&#name #ty_generics, usize)) -> Self {
145 let dig_length = value.1;
146
147 let value = value.0;
148 #tmp_struct
149 }
150 }
151
152 impl #impl_generics SAD for #name #ty_generics #where_clause {
153 fn compute_digest(&mut self, code: &HashFunctionCode, format: &SerializationFormats ) {
154 use said::derivation::{HashFunctionCode, HashFunction};
155 let serialized = self.derivation_data(code, format);
156 let digest = Some(HashFunction::from(code.clone()).derive(&serialized));
157 #(#out;)*
158 }
159
160 fn derivation_data(&self, code: &HashFunctionCode, serialization_format: &SerializationFormats) -> Vec<u8> {
161 use said::derivation::HashFunctionCode;
162 use said::sad::DerivationCode;
163 let tmp: #varname #ty_generics = (self, code.full_size()).into();
164 serialization_format.encode(&tmp).unwrap()
165 }
166 };
167 };
168 gen.into()
169}