shumai_config_impl/
lib.rs1extern crate proc_macro;
2extern crate syn;
3#[macro_use]
4extern crate quote;
5use proc_macro::TokenStream;
6use syn::{parse_macro_input, GenericArgument};
7
8#[proc_macro_attribute]
9pub fn config(args: TokenStream, input: TokenStream) -> TokenStream {
10 let args = parse_macro_input!(args as syn::AttributeArgs)
11 .first()
12 .expect("Benchmark file must be annotated with #[config(path = \"/path/to/file.toml\")]")
13 .clone();
14 let file_path = get_config_file_path(&args)
15 .expect("Benchmark file must be annotated with #[config(path = \"/path/to/file.toml\")]");
16
17 let ty: syn::Item = syn::parse_macro_input!(input as syn::Item);
18
19 let item_struct = if let syn::Item::Struct(m) = ty {
20 m
21 } else {
22 panic!("config attribute must be applied to a Struct");
23 };
24
25 let name = &item_struct.ident;
26 let matrix_name = gen_matrix_name(name);
27
28 let fields = if let syn::Fields::Named(syn::FieldsNamed { ref named, .. }) = item_struct.fields
29 {
30 named
31 } else {
32 panic!("config attribute must be applied to a Struct with named fields");
33 };
34
35 let config_fields = fields.iter().map(|f| {
36 let name = &f.ident;
37 let ty = &f.ty;
38
39 if name.as_ref().unwrap() == "threads" && is_matrix_field(f) {
40 panic!("threads can't be marked as matrix, it's matrix by definition");
41 }
42 if is_matrix_field(f) {
43 if let Some(t) = get_optional_inner_type(ty) {
45 quote! {#name: std::option::Option<std::vec::Vec<#t>>}
46 } else {
47 quote! {#name: std::vec::Vec<#ty>}
48 }
49 } else {
50 quote! {#name: #ty}
51 }
52 });
53
54 let methods = gen_methods(fields, 0, name);
55 let dummy_struct_name = syn::Ident::new(&format!("{name}DummyStruct"), name.span());
56 let expanded = quote! {
57 #[derive(Debug, shumai::__dep::serde::Deserialize)]
58 pub struct #matrix_name {
59 #(#config_fields, )*
60 }
61
62 impl #matrix_name {
63 pub fn unfold(&self) -> std::vec::Vec<#name> {
64 let mut configs: std::vec::Vec<#name> = std::vec::Vec::new();
65
66 #methods
67
68 configs
69 }
70 }
71
72 #[derive(Debug, Clone, shumai::ShumaiConfig, shumai::__dep::serde::Serialize, shumai::__dep::serde::Deserialize)]
73 #item_struct
74
75 #[derive(shumai::__dep::serde::Deserialize, Debug)]
76 #[allow(non_snake_case)]
77 struct #dummy_struct_name {
78 #name: std::option::Option<std::vec::Vec<#matrix_name>>,
79 }
80
81 impl #name {
82 #[allow(non_snake_case)]
83 pub fn load() -> std::option::Option<std::vec::Vec<#name>> {
84 let contents = std::fs::read_to_string(#file_path).expect(&format!("failed to read the benchmark config file at {}", #file_path));
85 let configs = shumai::__dep::toml::from_str::<#dummy_struct_name>(&contents).expect(&format!("failed to parse the benchmark config file at {}", #file_path));
86
87 let configs = configs.#name?;
88
89 let mut expanded = std::vec::Vec::new();
90 for b in configs.iter() {
91 expanded.extend(b.unfold());
92 }
93
94
95 match std::env::var("SHUMAI_FILTER") {
96 Ok(filter) => {
97 let regex_filter =
98 shumai::__dep::regex::Regex::new(filter.as_ref())
99 .expect(&format!("Filter {} from env `SHUMAI_FILTER` is not a valid regex expression!", filter));
100 let configs: std::vec::Vec<_> = expanded.into_iter().filter(|c| regex_filter.is_match(&c.name)).collect();
101 Some(configs)
102 },
103 Err(_) => {
104 Some(expanded)
105 }
106 }
107 }
108 }
109
110 impl shumai::BenchConfig for #name {
111 fn name(&self) -> &String {
112 &self.name
113 }
114
115 fn thread(&self) -> &[usize] {
116 &self.threads
117 }
118
119 fn bench_sec(&self) -> usize {
120 self.time
121 }
122 }
123 };
124
125 expanded.into()
127}
128
129#[proc_macro_derive(ShumaiConfig, attributes(matrix))]
130pub fn derive_bench_config(_input: TokenStream) -> TokenStream {
131 quote!().into()
132}
133
134fn gen_matrix_name(name: &syn::Ident) -> syn::Ident {
135 let gen_name = format!("{name}Matrix");
136 syn::Ident::new(&gen_name, name.span())
137}
138
139fn gen_methods(
140 fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
141 current: usize,
142 origin_name: &syn::Ident,
143) -> proc_macro2::TokenStream {
144 if current == fields.len() {
145 let name_prefix = origin_name.to_string().to_ascii_lowercase();
146 let mut name_gen = quote! {
147 let mut name_lit = format!("{}-{}", #name_prefix, name.clone());
148 };
149 for f in fields {
150 let f_name = &f.ident;
151 if is_matrix_field(f) {
152 if get_optional_inner_type(&f.ty).is_some() {
153 name_gen = quote! {
154 #name_gen
155 if let Some(t_v) = &self.#f_name {
156 if t_v.len() > 1 {
157 name_lit = format!("{}-{:?}", name_lit, #f_name);
158 }
159 }
160 }
161 } else {
162 name_gen = quote! {
163 #name_gen
164 if self.#f_name.len() > 1 {
165 name_lit = format!("{}-{:?}", name_lit, #f_name);
166 }
167 }
168 }
169 }
170 }
171 let assign_fields = fields.iter().map(|f| {
172 let name = &f.ident;
173 if name.as_ref().unwrap() == "name" {
175 quote! {}
176 } else {
177 quote! {
178 #name: #name.clone(),
179 }
180 }
181 });
182
183 return quote! {
184 #name_gen
185 configs.push(#origin_name {
186 name: name_lit,
187 #(#assign_fields)*
188 });
189 };
190 }
191
192 let inner = gen_methods(fields, current + 1, origin_name);
193
194 let current = &fields[current];
195 let name = ¤t.ident;
196
197 if is_matrix_field(current) {
198 if get_optional_inner_type(¤t.ty).is_some() {
199 quote! {
200 if let Some(#name) = &self.#name {
201 for i in #name.iter() {
202 let #name = Some(i.clone());
203 #inner
204 }
205 }else{
206 let #name = None;
207 #inner
208 }
209 }
210 } else {
211 quote! {
212 for i in self.#name.iter() {
213 let #name = i.clone();
214 #inner
215 }
216 }
217 }
218 } else {
219 quote! {
220 let #name = self.#name.clone();
221 #inner
222 }
223 }
224}
225
226fn is_matrix_field(f: &syn::Field) -> bool {
227 for attr in &f.attrs {
228 if attr.path.segments.len() == 1 && attr.path.segments[0].ident == "matrix" {
229 return true;
230 }
231 }
232 false
233}
234
235fn get_optional_inner_type(ty: &syn::Type) -> Option<&GenericArgument> {
236 if let syn::Type::Path(syn::TypePath {
237 path: syn::Path { segments, .. },
238 ..
239 }) = ty
240 {
241 if segments.len() == 1 {
242 let segment = segments.first().unwrap();
243 if segment.ident == "Option" {
244 let option_inner = &segment.arguments;
245 match option_inner {
246 syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
247 args,
248 ..
249 }) => {
250 let ty = args.first().unwrap();
251 return Some(ty);
252 }
253 _ => panic!("Option must be used with angle bracketed generic arguments"),
254 }
255 }
256 }
257 }
258 None
259}
260
261fn get_config_file_path(meta: &syn::NestedMeta) -> Option<syn::LitStr> {
262 let meta = if let syn::NestedMeta::Meta(m) = meta {
263 m
264 } else {
265 return None;
266 };
267
268 let name_value = if let syn::Meta::NameValue(v) = meta {
269 v
270 } else {
271 return None;
272 };
273
274 if name_value.path.segments[0].ident != "path" {
275 return None;
276 }
277
278 let v = if let syn::Lit::Str(v) = name_value.lit.clone() {
279 v
280 } else {
281 return None;
282 };
283
284 Some(v)
285}