shumai_config_impl/
lib.rs

1extern crate proc_macro;
2extern crate syn;
3#[macro_use]
4extern crate quote;
5use proc_macro::TokenStream;
6use syn::{parse_macro_input, GenericArgument};
7
8#[proc_macro_attribute]
9pub fn config(args: TokenStream, input: TokenStream) -> TokenStream {
10    let args = parse_macro_input!(args as syn::AttributeArgs)
11        .first()
12        .expect("Benchmark file must be annotated with #[config(path = \"/path/to/file.toml\")]")
13        .clone();
14    let file_path = get_config_file_path(&args)
15        .expect("Benchmark file must be annotated with #[config(path = \"/path/to/file.toml\")]");
16
17    let ty: syn::Item = syn::parse_macro_input!(input as syn::Item);
18
19    let item_struct = if let syn::Item::Struct(m) = ty {
20        m
21    } else {
22        panic!("config attribute must be applied to a Struct");
23    };
24
25    let name = &item_struct.ident;
26    let matrix_name = gen_matrix_name(name);
27
28    let fields = if let syn::Fields::Named(syn::FieldsNamed { ref named, .. }) = item_struct.fields
29    {
30        named
31    } else {
32        panic!("config attribute must be applied to a Struct with named fields");
33    };
34
35    let config_fields = fields.iter().map(|f| {
36        let name = &f.ident;
37        let ty = &f.ty;
38
39        if name.as_ref().unwrap() == "threads" && is_matrix_field(f) {
40            panic!("threads can't be marked as matrix, it's matrix by definition");
41        }
42        if is_matrix_field(f) {
43            // // If the type is Option, return Option<Vec<ty>>; otherwise return Vec<ty>
44            if let Some(t) = get_optional_inner_type(ty) {
45                quote! {#name: std::option::Option<std::vec::Vec<#t>>}
46            } else {
47                quote! {#name: std::vec::Vec<#ty>}
48            }
49        } else {
50            quote! {#name: #ty}
51        }
52    });
53
54    let methods = gen_methods(fields, 0, name);
55    let dummy_struct_name = syn::Ident::new(&format!("{name}DummyStruct"), name.span());
56    let expanded = quote! {
57        #[derive(Debug, shumai::__dep::serde::Deserialize)]
58        pub struct #matrix_name {
59            #(#config_fields, )*
60        }
61
62        impl #matrix_name {
63            pub fn unfold(&self) -> std::vec::Vec<#name> {
64                let mut configs: std::vec::Vec<#name> = std::vec::Vec::new();
65
66                #methods
67
68                configs
69            }
70        }
71
72        #[derive(Debug, Clone, shumai::ShumaiConfig, shumai::__dep::serde::Serialize, shumai::__dep::serde::Deserialize)]
73        #item_struct
74
75        #[derive(shumai::__dep::serde::Deserialize, Debug)]
76        #[allow(non_snake_case)]
77        struct #dummy_struct_name {
78            #name: std::option::Option<std::vec::Vec<#matrix_name>>,
79        }
80
81        impl #name {
82            #[allow(non_snake_case)]
83            pub fn load() -> std::option::Option<std::vec::Vec<#name>> {
84                let contents = std::fs::read_to_string(#file_path).expect(&format!("failed to read the benchmark config file at {}", #file_path));
85                let configs = shumai::__dep::toml::from_str::<#dummy_struct_name>(&contents).expect(&format!("failed to parse the benchmark config file at {}", #file_path));
86
87                let configs = configs.#name?;
88
89                let mut expanded = std::vec::Vec::new();
90                for b in configs.iter() {
91                    expanded.extend(b.unfold());
92                }
93
94
95                match std::env::var("SHUMAI_FILTER") {
96                    Ok(filter) => {
97                        let regex_filter =
98                            shumai::__dep::regex::Regex::new(filter.as_ref())
99                            .expect(&format!("Filter {} from env `SHUMAI_FILTER` is not a valid regex expression!", filter));
100                        let configs: std::vec::Vec<_> = expanded.into_iter().filter(|c| regex_filter.is_match(&c.name)).collect();
101                        Some(configs)
102                    },
103                    Err(_) => {
104                        Some(expanded)
105                    }
106                }
107            }
108        }
109
110        impl shumai::BenchConfig for #name {
111            fn name(&self) -> &String {
112                &self.name
113            }
114
115            fn thread(&self) -> &[usize] {
116                &self.threads
117            }
118
119            fn bench_sec(&self) -> usize {
120                self.time
121            }
122        }
123    };
124
125    // eprintln!("{}", expanded);
126    expanded.into()
127}
128
129#[proc_macro_derive(ShumaiConfig, attributes(matrix))]
130pub fn derive_bench_config(_input: TokenStream) -> TokenStream {
131    quote!().into()
132}
133
134fn gen_matrix_name(name: &syn::Ident) -> syn::Ident {
135    let gen_name = format!("{name}Matrix");
136    syn::Ident::new(&gen_name, name.span())
137}
138
139fn gen_methods(
140    fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
141    current: usize,
142    origin_name: &syn::Ident,
143) -> proc_macro2::TokenStream {
144    if current == fields.len() {
145        let name_prefix = origin_name.to_string().to_ascii_lowercase();
146        let mut name_gen = quote! {
147            let mut name_lit = format!("{}-{}", #name_prefix, name.clone());
148        };
149        for f in fields {
150            let f_name = &f.ident;
151            if is_matrix_field(f) {
152                if get_optional_inner_type(&f.ty).is_some() {
153                    name_gen = quote! {
154                        #name_gen
155                        if let Some(t_v) = &self.#f_name {
156                            if t_v.len() > 1 {
157                                name_lit = format!("{}-{:?}", name_lit, #f_name);
158                            }
159                        }
160                    }
161                } else {
162                    name_gen = quote! {
163                        #name_gen
164                        if self.#f_name.len() > 1 {
165                            name_lit = format!("{}-{:?}", name_lit, #f_name);
166                        }
167                    }
168                }
169            }
170        }
171        let assign_fields = fields.iter().map(|f| {
172            let name = &f.ident;
173            // We skip the `name` field because it's handled separately
174            if name.as_ref().unwrap() == "name" {
175                quote! {}
176            } else {
177                quote! {
178                    #name: #name.clone(),
179                }
180            }
181        });
182
183        return quote! {
184            #name_gen
185            configs.push(#origin_name {
186                name: name_lit,
187                #(#assign_fields)*
188            });
189        };
190    }
191
192    let inner = gen_methods(fields, current + 1, origin_name);
193
194    let current = &fields[current];
195    let name = &current.ident;
196
197    if is_matrix_field(current) {
198        if get_optional_inner_type(&current.ty).is_some() {
199            quote! {
200                if let Some(#name) = &self.#name {
201                    for i in #name.iter() {
202                        let #name = Some(i.clone());
203                        #inner
204                    }
205                }else{
206                    let #name = None;
207                    #inner
208                }
209            }
210        } else {
211            quote! {
212                for i in self.#name.iter() {
213                    let #name = i.clone();
214                    #inner
215                }
216            }
217        }
218    } else {
219        quote! {
220            let #name = self.#name.clone();
221            #inner
222        }
223    }
224}
225
226fn is_matrix_field(f: &syn::Field) -> bool {
227    for attr in &f.attrs {
228        if attr.path.segments.len() == 1 && attr.path.segments[0].ident == "matrix" {
229            return true;
230        }
231    }
232    false
233}
234
235fn get_optional_inner_type(ty: &syn::Type) -> Option<&GenericArgument> {
236    if let syn::Type::Path(syn::TypePath {
237        path: syn::Path { segments, .. },
238        ..
239    }) = ty
240    {
241        if segments.len() == 1 {
242            let segment = segments.first().unwrap();
243            if segment.ident == "Option" {
244                let option_inner = &segment.arguments;
245                match option_inner {
246                    syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
247                        args,
248                        ..
249                    }) => {
250                        let ty = args.first().unwrap();
251                        return Some(ty);
252                    }
253                    _ => panic!("Option must be used with angle bracketed generic arguments"),
254                }
255            }
256        }
257    }
258    None
259}
260
261fn get_config_file_path(meta: &syn::NestedMeta) -> Option<syn::LitStr> {
262    let meta = if let syn::NestedMeta::Meta(m) = meta {
263        m
264    } else {
265        return None;
266    };
267
268    let name_value = if let syn::Meta::NameValue(v) = meta {
269        v
270    } else {
271        return None;
272    };
273
274    if name_value.path.segments[0].ident != "path" {
275        return None;
276    }
277
278    let v = if let syn::Lit::Str(v) = name_value.lit.clone() {
279        v
280    } else {
281        return None;
282    };
283
284    Some(v)
285}