sway_ir_macros/
lib.rs

1use {
2    itertools::Itertools,
3    proc_macro::TokenStream,
4    quote::{format_ident, quote},
5    syn::{
6        parse_macro_input, Attribute, Data, DeriveInput, Fields, FieldsNamed, FieldsUnnamed, Ident,
7        Variant,
8    },
9};
10
11#[proc_macro_derive(DebugWithContext, attributes(in_context))]
12pub fn derive_debug_with_context(input: TokenStream) -> TokenStream {
13    let DeriveInput {
14        ident,
15        generics,
16        data,
17        ..
18    } = parse_macro_input!(input);
19    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
20    let type_name = ident.to_string();
21    let body = match data {
22        Data::Struct(data_struct) => match &data_struct.fields {
23            Fields::Named(fields_named) => {
24                let (field_names, fmt_fields) = fmt_fields_named(&type_name, fields_named);
25                quote! {
26                    let #ident { #(#field_names,)* } = self;
27                    #fmt_fields
28                }
29            }
30            Fields::Unnamed(fields_unnamed) => {
31                let (field_names, fmt_fields) = fmt_fields_unnamed(&type_name, fields_unnamed);
32                quote! {
33                    let #ident(#(#field_names,)*) = self;
34                    #fmt_fields
35                }
36            }
37            Fields::Unit => {
38                quote! {
39                    formatter.write_str(#type_name)
40                }
41            }
42        },
43        Data::Enum(data_enum) => {
44            let branches = {
45                data_enum.variants.iter().map(|variant| {
46                    let Variant {
47                        ident: variant_ident,
48                        fields,
49                        ..
50                    } = variant;
51                    let type_variant_name = format!("{type_name}::{variant_ident}");
52                    match fields {
53                        Fields::Named(fields_named) => {
54                            let (field_names, fmt_fields) =
55                                fmt_fields_named(&type_variant_name, fields_named);
56                            quote! {
57                                #ident::#variant_ident { #(#field_names,)* } => {
58                                    #fmt_fields
59                                },
60                            }
61                        }
62                        Fields::Unnamed(fields_unnamed) => {
63                            let (field_names, fmt_fields) =
64                                fmt_fields_unnamed(&type_variant_name, fields_unnamed);
65                            quote! {
66                                #ident::#variant_ident(#(#field_names,)*) => {
67                                    #fmt_fields
68                                },
69                            }
70                        }
71                        Fields::Unit => {
72                            quote! {
73                                #ident::#variant_ident => {
74                                    formatter.write_str(#type_variant_name)
75                                },
76                            }
77                        }
78                    }
79                })
80            };
81            quote! {
82                match self {
83                    #(#branches)*
84                }
85            }
86        }
87        Data::Union(_) => {
88            panic!("#[derive(DebugWithContext)] cannot be used on unions");
89        }
90    };
91    let output = quote! {
92        impl #impl_generics DebugWithContext for #ident #ty_generics
93        #where_clause
94        {
95            fn fmt_with_context<'a, 'c>(
96                &'a self,
97                formatter: &mut std::fmt::Formatter,
98                context: &'c Context,
99            ) -> std::fmt::Result {
100                #body
101            }
102        }
103    };
104    output.into()
105}
106
107fn fmt_fields_named<'i>(
108    name: &str,
109    fields_named: &'i FieldsNamed,
110) -> (Vec<&'i Ident>, proc_macro2::TokenStream) {
111    let field_names = {
112        fields_named
113            .named
114            .iter()
115            .map(|field| field.ident.as_ref().unwrap())
116            .collect::<Vec<_>>()
117    };
118    let fmt_fields = {
119        fields_named
120            .named
121            .iter()
122            .zip(field_names.iter())
123            .map(|(field, name)| {
124                let name_str = name.to_string();
125                let expr = pass_through_context(name, &field.attrs);
126                quote! {
127                    debug_struct = debug_struct.field(#name_str, &#expr);
128                }
129            })
130    };
131    let token_tree = quote! {
132        let mut debug_struct = &mut formatter.debug_struct(#name);
133        #(#fmt_fields)*
134        debug_struct.finish()
135    };
136    (field_names, token_tree)
137}
138
139fn fmt_fields_unnamed(
140    name: &str,
141    fields_unnamed: &FieldsUnnamed,
142) -> (Vec<Ident>, proc_macro2::TokenStream) {
143    let field_names = {
144        (0..fields_unnamed.unnamed.len())
145            .map(|i| format_ident!("field_{}", i))
146            .collect::<Vec<_>>()
147    };
148    let fmt_fields = {
149        fields_unnamed
150            .unnamed
151            .iter()
152            .zip(field_names.iter())
153            .map(|(field, name)| {
154                let expr = pass_through_context(name, &field.attrs);
155                quote! {
156                    debug_tuple = debug_tuple.field(&#expr);
157                }
158            })
159    };
160    let token_tree = quote! {
161        let mut debug_tuple = &mut formatter.debug_tuple(#name);
162        #(#fmt_fields)*
163        debug_tuple.finish()
164    };
165    (field_names, token_tree)
166}
167
168fn pass_through_context(field_name: &Ident, attrs: &[Attribute]) -> proc_macro2::TokenStream {
169    let context_field_opt = {
170        attrs
171            .iter()
172            .filter_map(|attr| {
173                let attr_name = attr.path().get_ident()?;
174                if attr_name != "in_context" {
175                    return None;
176                }
177                let context_field = {
178                    try_parse_context_field_from_attr(attr)
179                        .expect("malformed #[in_context(..)] attribute")
180                };
181                Some(context_field)
182            })
183            .dedup()
184            .at_most_one()
185            .expect("multiple #[in_context(..)] attributes on field")
186    };
187    match context_field_opt {
188        None => {
189            quote! {
190                #field_name.with_context(context)
191            }
192        }
193        Some(context_field) => {
194            quote! {
195                context.#context_field[*#field_name].with_context(context)
196            }
197        }
198    }
199}
200
201fn try_parse_context_field_from_attr(attr: &Attribute) -> Option<Ident> {
202    let mut context_fields = Vec::new();
203
204    let _ = attr.parse_nested_meta(|nested_meta| {
205        context_fields.push(nested_meta.path.get_ident().unwrap().clone());
206        Ok(())
207    });
208
209    if context_fields.len() != 1 {
210        None
211    } else {
212        context_fields.pop()
213    }
214}