1use proc_macro::TokenStream;
2use quote::{quote};
3use syn::{parse_macro_input, Data, DeriveInput, Fields, Type};
4
5#[proc_macro_derive(Sampleable)]
6pub fn sampleable_derive(input: TokenStream) -> TokenStream {
7 let input = parse_macro_input!(input as DeriveInput);
9
10 let name = input.ident.clone();
12
13 match input.data {
15 Data::Struct(data_struct) => {
16 expand_struct(name, data_struct)
18 },
19 Data::Enum(data_enum) => {
20 expand_enum(name, data_enum)
22 },
23 _ => {
24 unimplemented!("Sampleable can only be derived for structs and enums");
25 }
26 }
27}
28
29fn expand_struct(name: syn::Ident, data_struct: syn::DataStruct) -> TokenStream {
30 let fields = match data_struct.fields {
32 Fields::Named(fields_named) => fields_named.named,
33 _ => unimplemented!("Sampleable can only be derived for structs with named fields"),
34 };
35
36 let field_samples = fields.iter().map(|field| {
38 let field_name = field.ident.as_ref().unwrap();
39 let field_name_str = field_name.to_string();
40 let field_type = &field.ty;
41
42 let sample_code = generate_sample_code(field_type, &field_name_str, "e!(config));
43
44 quote! {
45 #field_name: #sample_code
46 }
47 });
48
49 let expanded = quote! {
51 impl #name {
52 pub fn sample_with_config(config: &serde_json::Map<String, serde_json::Value>) -> Result<Self, String> {
53 use rand::Rng;
54 use rand::seq::SliceRandom;
55
56 Ok(Self {
57 #(#field_samples),*
58 })
59 }
60 }
61 };
62
63 TokenStream::from(expanded)
65}
66
67fn expand_enum(name: syn::Ident, data_enum: syn::DataEnum) -> TokenStream {
68 let variants = data_enum.variants;
70
71 let variant_names = variants.iter().map(|v| v.ident.clone());
73
74 let variant_sample_cases = variants.iter().map(|variant| {
75 let variant_name = &variant.ident;
76 let variant_name_str = variant_name.to_string();
77
78 match &variant.fields {
79 Fields::Unit => {
80 quote! {
82 #variant_name_str => {
83 #name::#variant_name
84 }
85 }
86 },
87 Fields::Named(fields_named) => {
88 let field_samples = fields_named.named.iter().map(|field| {
90 let field_name = &field.ident;
91 let field_name_str = field_name.as_ref().unwrap().to_string();
92 let field_type = &field.ty;
93
94 let sample_code = generate_sample_code(field_type, &field_name_str, "e!(variant_data));
95
96 quote! {
97 #field_name: #sample_code
98 }
99 });
100
101 quote! {
102 #variant_name_str => {
103 if let Some(serde_json::Value::Object(variant_data)) = variant_config.get(#variant_name_str) {
104 #name::#variant_name {
105 #(#field_samples),*
106 }
107 } else {
108 return Err(format!("Configuration for variant '{}' is missing or invalid", #variant_name_str));
109 }
110 }
111 }
112 },
113 Fields::Unnamed(fields_unnamed) => {
114 let field_samples = fields_unnamed.unnamed.iter().enumerate().map(|(i, field)| {
116 let field_name_str = format!("field{}", i);
117 let field_type = &field.ty;
118
119 let sample_code = generate_sample_code(field_type, &field_name_str, "e!(variant_data));
120
121 quote! {
122 #sample_code
123 }
124 });
125
126 quote! {
127 #variant_name_str => {
128 if let Some(serde_json::Value::Object(variant_data)) = variant_config.get(#variant_name_str) {
129 #name::#variant_name(
130 #(#field_samples),*
131 )
132 } else {
133 return Err(format!("Configuration for variant '{}' is missing or invalid", #variant_name_str));
134 }
135 }
136 }
137 },
138 }
139 });
140
141 let expanded = quote! {
143 impl #name {
144 pub fn sample_with_config(config: &serde_json::Map<String, serde_json::Value>) -> Result<Self, String> {
145 use rand::Rng;
146 use rand::seq::SliceRandom;
147
148 let variants: Vec<String> = if let Some(serde_json::Value::Array(variant_array)) = config.get("variants") {
150 variant_array.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect()
151 } else {
152 {
153 let mut vec = Vec::new();
154 #(
155 vec.push(String::from(stringify!(#variant_names)));
156 )*
157 vec
158 }
159 };
160
161 if variants.is_empty() {
162 return Err("No variants specified for enum sampling".to_string());
163 }
164
165 let selected_variant = variants.choose(&mut rand::thread_rng()).unwrap();
166
167 let variant_config = if let Some(serde_json::Value::Object(map)) = config.get("variant_data") {
169 map
170 } else {
171 &serde_json::Map::new()
172 };
173
174 let result = match selected_variant.as_str() {
175 #(#variant_sample_cases),*,
176 _ => return Err(format!("Variant '{}' is not recognized", selected_variant)),
177 };
178
179 Ok(result)
180 }
181 }
182 };
183
184 TokenStream::from(expanded)
185}
186
187fn generate_sample_code(field_type: &Type, field_name_str: &str, config_var: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
189 if is_option(field_type) {
190 let inner_type = get_inner_type(field_type);
191 let inner_sample_code = generate_sample_code(&inner_type, field_name_str, config_var);
192
193 quote! {
194 {
195 if let Some(config_value) = #config_var.get(#field_name_str) {
196 if config_value.is_null() {
197 None
198 } else {
199 Some(#inner_sample_code)
200 }
201 } else {
202 None
203 }
204 }
205 }
206 } else if is_vec(field_type) {
207 let inner_type = get_inner_type(field_type);
208 let inner_sample_code = generate_sample_code_for_vec_elements(&inner_type, field_name_str, config_var);
209
210 quote! {
211 {
212 #inner_sample_code
213 }
214 }
215 } else if is_box(field_type) {
216 let inner_type = get_inner_type(field_type);
217 let inner_sample_code = generate_sample_code(&inner_type, field_name_str, config_var);
218
219 quote! {
220 Box::new(#inner_sample_code)
221 }
222 } else if is_primitive(field_type) {
223 generate_primitive_sample_code(field_type, field_name_str, config_var)
224 } else {
225 quote! {
227 {
228 if let Some(serde_json::Value::Object(map)) = #config_var.get(#field_name_str) {
229 <#field_type>::sample_with_config(map)?
230 } else {
231 return Err(format!("Configuration for '{}' must be an object", #field_name_str));
232 }
233 }
234 }
235 }
236}
237
238fn generate_sample_code_for_vec_elements(element_type: &Type, field_name_str: &str, config_var: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
239 if is_primitive(element_type) {
240 let element_type_str = match element_type {
242 Type::Path(type_path) => {
243 type_path.path.segments.last().unwrap().ident.to_string()
244 },
245 _ => "".to_string(),
246 };
247 let parse_value = match element_type_str.as_str() {
248 "String" => quote! {
249 v.as_str().map(|s| s.to_string())
250 },
251 "i32" | "i64" | "u32" | "u64" | "usize" | "isize" => quote! {
252 v.as_i64().map(|n| n as #element_type)
253 },
254 "f32" | "f64" => quote! {
255 v.as_f64().map(|n| n as #element_type)
256 },
257 "bool" => quote! {
258 v.as_bool()
259 },
260 _ => quote! {
261 None
262 },
263 };
264
265 quote! {
266 {
267 if let Some(config_value) = #config_var.get(#field_name_str) {
268 if let serde_json::Value::Array(values_array) = config_value {
269 let values: Vec<#element_type> = values_array.iter()
270 .filter_map(|v| #parse_value)
271 .collect();
272 if values.is_empty() {
273 return Err(format!("Values array for field '{}' is empty or contains invalid types", #field_name_str));
274 }
275 let mut rng = rand::thread_rng();
276 let sample_size = rng.gen_range(1..=values.len());
277 let samples = values.choose_multiple(&mut rng, sample_size)
278 .cloned()
279 .collect::<Vec<#element_type>>();
280 samples
281 } else {
282 return Err(format!("Configuration for '{}' must be an array", #field_name_str));
283 }
284 } else {
285 Vec::<#element_type>::new()
286 }
287 }
288 }
289 } else {
290 quote! {
292 {
293 if let Some(config_value) = #config_var.get(#field_name_str) {
294 if let serde_json::Value::Array(array) = config_value {
295 let mut vec = Vec::new();
296 for item in array {
297 if let serde_json::Value::Object(item_config) = item {
298 vec.push(<#element_type>::sample_with_config(&item_config)?);
299 } else {
300 return Err(format!("Each item in '{}' must be an object", #field_name_str));
301 }
302 }
303 vec
304 } else {
305 return Err(format!("Configuration for '{}' must be an array", #field_name_str));
306 }
307 } else {
308 Vec::<#element_type>::new()
309 }
310 }
311 }
312 }
313}
314
315fn is_option(ty: &Type) -> bool {
318 match ty {
319 Type::Path(type_path) => type_path.path.segments.last().unwrap().ident == "Option",
320 _ => false,
321 }
322}
323
324fn is_vec(ty: &Type) -> bool {
325 match ty {
326 Type::Path(type_path) => type_path.path.segments.last().unwrap().ident == "Vec",
327 _ => false,
328 }
329}
330
331fn get_inner_type(ty: &Type) -> Type {
332 match ty {
333 Type::Path(type_path) => {
334 if let syn::PathArguments::AngleBracketed(args) = &type_path.path.segments.last().unwrap().arguments {
335 if let Some(syn::GenericArgument::Type(inner_type)) = args.args.first() {
336 inner_type.clone()
337 } else {
338 panic!("Expected a type argument");
339 }
340 } else {
341 panic!("Expected angle bracketed arguments");
342 }
343 }
344 _ => panic!("Expected a type path"),
345 }
346}
347
348fn is_primitive(ty: &Type) -> bool {
349 match ty {
350 Type::Path(type_path) => {
351 let ident = &type_path.path.segments.last().unwrap().ident;
352 ["f64", "f32", "i32", "i64", "u32", "u64", "usize", "isize", "String", "bool"].contains(&ident.to_string().as_str())
353 }
354 _ => false,
355 }
356}
357
358fn is_box(ty: &Type) -> bool {
359 match ty {
360 Type::Path(type_path) => type_path.path.segments.last().unwrap().ident == "Box",
361 _ => false,
362 }
363}
364
365fn generate_primitive_sample_code(field_type: &Type, field_name_str: &str, config_var: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
366 let type_ident = match field_type {
367 Type::Path(type_path) => &type_path.path.segments.last().unwrap().ident,
368 _ => panic!("Expected a type path"),
369 };
370 let type_ident_str = type_ident.to_string();
371
372 if ["f64", "f32"].contains(&type_ident_str.as_str()) {
373 quote! {
375 {
376 if let Some(config_value) = #config_var.get(#field_name_str) {
377 if let Some(range_array) = config_value.as_array() {
378 if range_array.len() == 2 {
379 if let (Some(start), Some(end)) = (range_array[0].as_f64(), range_array[1].as_f64()) {
380 rand::thread_rng().gen_range(start..end)
381 } else {
382 return Err(format!("Invalid range values for field '{}'", #field_name_str));
383 }
384 } else {
385 return Err(format!("Range array for field '{}' must have exactly two elements", #field_name_str));
386 }
387 } else {
388 return Err(format!("Configuration for field '{}' must be an array", #field_name_str));
389 }
390 } else {
391 return Err(format!("Configuration for '{}' is missing", #field_name_str));
392 }
393 }
394 }
395 } else if ["i32", "i64", "u32", "u64", "usize", "isize"].contains(&type_ident_str.as_str()) {
396 quote! {
398 {
399 if let Some(config_value) = #config_var.get(#field_name_str) {
400 if let Some(range_array) = config_value.as_array() {
401 if range_array.len() == 2 {
402 if let (Some(start), Some(end)) = (range_array[0].as_i64(), range_array[1].as_i64()) {
403 rand::thread_rng().gen_range(start..end) as #field_type
404 } else {
405 return Err(format!("Invalid range values for field '{}'", #field_name_str));
406 }
407 } else {
408 return Err(format!("Range array for field '{}' must have exactly two elements", #field_name_str));
409 }
410 } else {
411 return Err(format!("Configuration for field '{}' must be an array", #field_name_str));
412 }
413 } else {
414 return Err(format!("Configuration for '{}' is missing", #field_name_str));
415 }
416 }
417 }
418 } else if type_ident_str == "String" {
419 quote! {
421 {
422 if let Some(config_value) = #config_var.get(#field_name_str) {
423 if let Some(values_array) = config_value.as_array() {
424 let values: Vec<String> = values_array.iter()
425 .filter_map(|v| v.as_str().map(|s| s.to_string()))
426 .collect();
427 if !values.is_empty() {
428 values.choose(&mut rand::thread_rng()).unwrap().clone()
429 } else {
430 return Err(format!("Values array for field '{}' is empty", #field_name_str));
431 }
432 } else if let Some(value_str) = config_value.as_str() {
433 value_str.to_string()
434 } else {
435 return Err(format!("Configuration for '{}' must be an array or string", #field_name_str));
436 }
437 } else {
438 return Err(format!("Configuration for '{}' is missing", #field_name_str));
439 }
440 }
441 }
442 } else if type_ident_str == "bool" {
443 quote! {
445 {
446 if let Some(config_value) = #config_var.get(#field_name_str) {
447 if let Some(value_bool) = config_value.as_bool() {
448 value_bool
449 } else {
450 return Err(format!("Configuration for '{}' must be a boolean", #field_name_str));
451 }
452 } else {
453 return Err(format!("Configuration for '{}' is missing", #field_name_str));
454 }
455 }
456 }
457 } else {
458 quote! {
460 return Err(format!("Unsupported type for field '{}'", #field_name_str));
461 }
462 }
463}