1use proc_macro2::{Ident, Punct, Spacing, TokenStream};
2use quote::{quote, ToTokens};
3use syn::{parse_macro_input, DeriveInput, Field};
4
5mod r#enum;
6mod r#struct;
7
8use r#enum::gen_enum;
9use r#struct::gen_struct;
10
11#[derive(Clone)]
12enum FieldConfig {
13 Default,
14 Skip,
15 With(syn::Path),
16}
17
18#[derive(Clone, Copy)]
19pub(crate) enum PassMode {
20 AsIs,
21 InsertRef,
22 Packed,
23}
24
25fn gen_call_with_arg(
26 func_name: &TokenStream,
27 arg: &TokenStream,
28 pass_mode: PassMode,
29) -> TokenStream {
30 match pass_mode {
31 PassMode::AsIs => quote!(#func_name(#arg)),
32 PassMode::InsertRef => quote!(#func_name(&#arg)),
33 PassMode::Packed => {
34 quote!(({
35 let __typesize_internal_temp = #arg;
36 #func_name(&__typesize_internal_temp)
37 }))
38 }
39 }
40}
41
42fn join_tokens(
43 exprs: impl ExactSizeIterator<Item = impl ToTokens>,
44 sep: impl ToTokens,
45) -> TokenStream {
46 let expr_count = exprs.len();
47 let mut out_tokens = TokenStream::new();
48 for (i, expr) in exprs.enumerate() {
49 expr.to_tokens(&mut out_tokens);
50 if expr_count != i + 1 {
51 sep.to_tokens(&mut out_tokens);
52 }
53 }
54
55 out_tokens
56}
57
58fn try_join_tokens(
59 exprs: impl ExactSizeIterator<Item = syn::Result<impl ToTokens>>,
60 sep: impl ToTokens,
61) -> syn::Result<TokenStream> {
62 let expr_count = exprs.len();
63 let mut out_tokens = TokenStream::new();
64 for (i, expr) in exprs.enumerate() {
65 expr?.to_tokens(&mut out_tokens);
66 if expr_count != i + 1 {
67 sep.to_tokens(&mut out_tokens);
68 }
69 }
70
71 Ok(out_tokens)
72}
73
74fn gen_named_exprs<'a>(
75 named_fields: syn::punctuated::Iter<'a, Field>,
76 transform_named: impl Fn(&'a Ident) -> TokenStream + 'a,
77 common_body: impl Fn(TokenStream, TokenStream, FieldConfig) -> TokenStream + 'a,
78) -> Option<impl ExactSizeIterator<Item = syn::Result<TokenStream>> + 'a> {
79 if named_fields.len() == 0 {
80 return None;
81 }
82
83 Some(named_fields.map(move |field| {
84 let ident = field.ident.as_ref().unwrap();
85 let field_config = get_field_config(&field.attrs)?;
86 Ok(common_body(
87 transform_named(ident),
88 quote!(#ident),
89 field_config,
90 ))
91 }))
92}
93
94fn gen_unnamed_exprs<'a>(
95 unnamed_fields: syn::punctuated::Iter<'a, Field>,
96 transform_unnamed: impl Fn(usize) -> TokenStream + 'a,
97 common_body: impl Fn(TokenStream, TokenStream, FieldConfig) -> TokenStream + 'a,
98) -> Option<impl ExactSizeIterator<Item = syn::Result<TokenStream>> + 'a> {
99 if unnamed_fields.len() == 0 {
100 return None;
101 };
102
103 let enumerated_iter = unnamed_fields.enumerate();
104 Some(enumerated_iter.map(move |(i, field)| {
105 let field_config = get_field_config(&field.attrs)?;
106 Ok(common_body(transform_unnamed(i), quote!(#i), field_config))
107 }))
108}
109
110fn for_each_field<'a>(
111 fields: &'a syn::Fields,
112 join_with: Punct,
113 transform_named: impl Fn(&'a Ident) -> TokenStream + 'a,
114 transform_unnamed: impl Fn(usize) -> TokenStream + 'a,
115 common_body: impl Fn(TokenStream, TokenStream, FieldConfig) -> TokenStream + 'a,
116) -> Option<syn::Result<TokenStream>> {
117 match fields {
118 syn::Fields::Named(fields) => Some(try_join_tokens(
119 gen_named_exprs(fields.named.iter(), transform_named, common_body)?,
120 join_with,
121 )),
122 syn::Fields::Unnamed(fields) => Some(try_join_tokens(
123 gen_unnamed_exprs(fields.unnamed.iter(), transform_unnamed, common_body)?,
124 join_with,
125 )),
126 syn::Fields::Unit => None,
127 }
128}
129
130fn extra_details_visit_fields<'a>(
131 fields: &'a syn::Fields,
132 transform_named: impl Fn(&'a Ident) -> TokenStream + 'a,
133 transform_unnamed: impl Fn(usize) -> TokenStream + 'a,
134 pass_mode: PassMode,
135) -> syn::Result<TokenStream> {
136 for_each_field(
137 fields,
138 Punct::new('+', Spacing::Alone),
139 transform_named,
140 transform_unnamed,
141 move |ident, _name, config| match config {
142 FieldConfig::Skip => quote!(0),
143 FieldConfig::Default => {
144 gen_call_with_arg("e!(::typesize::TypeSize::extra_size), &ident, pass_mode)
145 }
146 FieldConfig::With(fn_path) => {
147 gen_call_with_arg(&fn_path.into_token_stream(), &ident, pass_mode)
148 }
149 },
150 )
151 .unwrap_or_else(|| Ok(quote!(0_usize)))
152}
153
154fn check_repr_packed(attrs: &[syn::Attribute]) -> bool {
155 fn is_valid_repr_for_packed(ident: &syn::Ident) -> bool {
156 ident == "C" || ident == "Rust"
159 }
160
161 struct CheckIsPacked(bool);
162 impl syn::parse::Parse for CheckIsPacked {
163 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
164 let first_token = input.parse::<syn::Ident>()?;
165 if !input.peek(syn::Token![,]) {
166 let is_packed = first_token == "packed";
167 return Ok(Self(is_packed));
168 }
169
170 input.parse::<syn::Token![,]>()?;
171
172 let second_token = input.parse::<syn::Ident>()?;
173 if is_valid_repr_for_packed(&first_token) && second_token == "packed" {
174 return Ok(Self(true));
175 }
176
177 if first_token == "packed" && is_valid_repr_for_packed(&second_token) {
178 return Ok(Self(true));
179 }
180
181 Ok(Self(false))
182 }
183 }
184
185 attrs.iter().any(|attr| {
186 let syn::Meta::List(meta) = &attr.meta else {
187 return false;
188 };
189
190 let Some(ident) = meta.path.get_ident() else {
191 return false;
192 };
193
194 if ident != "repr" {
195 return false;
196 }
197
198 syn::parse2::<CheckIsPacked>(meta.tokens.clone()).unwrap().0
199 })
200}
201
202fn get_field_config(attrs: &[syn::Attribute]) -> syn::Result<FieldConfig> {
203 mod kw {
207 syn::custom_keyword!(skip);
208 syn::custom_keyword!(with);
209 }
210
211 enum Input {
212 Skip {
213 _skip: kw::skip,
214 },
215 With {
216 _with: kw::with,
217 _eq: syn::Token![=],
218 path: syn::Path,
219 },
220 }
221
222 impl syn::parse::Parse for Input {
223 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
224 let lookahead = input.lookahead1();
225 if lookahead.peek(kw::skip) {
226 Ok(Self::Skip {
227 _skip: input.parse()?,
228 })
229 } else if lookahead.peek(kw::with) {
230 Ok(Self::With {
231 _with: input.parse()?,
232 _eq: input.parse()?,
233 path: input.parse()?,
234 })
235 } else {
236 Err(lookahead.error())
237 }
238 }
239 }
240
241 for attr in attrs {
242 let syn::Meta::List(meta) = &attr.meta else {
243 continue;
244 };
245
246 let Some(path) = meta.path.get_ident() else {
247 continue;
248 };
249
250 if path != "typesize" {
251 continue;
252 }
253
254 let input = syn::parse::<Input>(meta.tokens.clone().into())?;
255 return Ok(match input {
256 Input::Skip { .. } => FieldConfig::Skip,
257 Input::With { path, .. } => FieldConfig::With(path),
258 });
259 }
260
261 Ok(FieldConfig::Default)
262}
263
264struct GenerationRet {
265 extra_size: TokenStream,
266 #[cfg(feature = "details")]
267 details: Option<TokenStream>,
268}
269
270#[proc_macro_derive(TypeSize, attributes(typesize))]
291pub fn typesize_derive(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
292 let DeriveInput {
293 attrs,
294 vis: _,
295 ident,
296 generics,
297 data,
298 } = parse_macro_input!(tokens as DeriveInput);
299
300 let is_packed = check_repr_packed(&attrs);
301
302 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
303 let bodies = match data {
304 syn::Data::Struct(data) => gen_struct(&data.fields, is_packed),
305 syn::Data::Enum(data) => gen_enum(data.variants.into_iter(), is_packed),
306 syn::Data::Union(data) => Err(syn::Error::new(
307 data.union_token.span,
308 "Unions are unsupported for typesize derive.",
309 )),
310 };
311
312 let bodies = match bodies {
313 Ok(bodies) => bodies,
314 Err(err) => {
315 return err.into_compile_error().into();
316 }
317 };
318
319 let extra_size = bodies.extra_size;
320 #[cfg_attr(not(feature = "details"), allow(unused_mut))]
321 let mut impl_body = quote!(
322 fn extra_size(&self) -> usize {
323 #extra_size
324 }
325 );
326
327 #[cfg(feature = "details")]
328 if let Some(details) = bodies.details {
329 impl_body = quote!(
330 #impl_body
331
332 fn get_size_details(&self) -> Vec<::typesize::Field> {
333 #details
334 }
335 );
336 }
337
338 let output = quote! {
339 #[automatically_derived]
340 impl #impl_generics ::typesize::TypeSize for #ident #ty_generics #where_clause {
341 #impl_body
342 }
343 };
344
345 output.into()
346}