trident_syn/codegen/
trident_flow_executor.rs1use proc_macro2::TokenStream;
2use quote::quote;
3use quote::ToTokens;
4
5use crate::types::trident_flow_executor::TridentFlowExecutorImpl;
6
7impl ToTokens for TridentFlowExecutorImpl {
8 fn to_tokens(&self, tokens: &mut TokenStream) {
9 let expanded = self.generate_flow_executor_impl();
10 tokens.extend(expanded);
11 }
12}
13
14impl TridentFlowExecutorImpl {
15 fn generate_flow_executor_impl(&self) -> TokenStream {
17 let type_name = &self.type_name;
18 let impl_items = &self.impl_block;
19 let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
20
21 let generated_impl = self.generate_generated_impl_block();
22
23 quote! {
24 impl #impl_generics #type_name #ty_generics #where_clause {
25 #(#impl_items)*
26 }
27
28 #generated_impl
29 }
30 }
31
32 fn generate_generated_impl_block(&self) -> TokenStream {
34 let type_name = &self.type_name;
35 let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
36
37 let execute_flows_method = self.generate_execute_flows_method();
38 let fuzz_method = self.generate_fuzz_method();
39
40 quote! {
41 impl #impl_generics #type_name #ty_generics #where_clause {
42 #execute_flows_method
43 #fuzz_method
44 }
45 }
46 }
47
48 fn generate_execute_flows_method(&self) -> TokenStream {
50 let init_call = self.generate_init_call();
51 let flow_execution_logic = self.generate_flow_execution_logic();
52 let end_call = self.generate_end_call();
53
54 quote! {
55 pub fn execute_flows(
56 &mut self,
57 flow_calls_per_iteration: u64,
58 ) -> std::result::Result<(), FuzzingError> {
59 #init_call
60 #flow_execution_logic
61 #end_call
62 Ok(())
63 }
64 }
65 }
66
67 fn generate_init_call(&self) -> TokenStream {
69 if let Some(init_method) = &self.init_method {
70 quote! {
71 self.#init_method();
72 }
73 } else {
74 quote! {}
75 }
76 }
77
78 fn generate_end_call(&self) -> TokenStream {
80 if let Some(end_method) = &self.end_method {
81 quote! {
82 self.#end_method();
83 }
84 } else {
85 quote! {}
86 }
87 }
88
89 fn generate_flow_execution_logic(&self) -> TokenStream {
91 let active_methods: Vec<_> = self
93 .flow_methods
94 .iter()
95 .filter(|method| !method.constraints.ignore)
96 .collect();
97
98 if active_methods.is_empty() {
99 quote! {
100 }
102 } else {
103 let flow_selection_logic = self.generate_flow_selection_logic(&active_methods);
104
105 quote! {
106 #flow_selection_logic
107 }
108 }
109 }
110
111 fn generate_flow_selection_logic(
113 &self,
114 active_methods: &[&crate::types::trident_flow_executor::FlowMethod],
115 ) -> TokenStream {
116 let has_weights = active_methods
118 .iter()
119 .any(|method| method.constraints.weight.is_some());
120
121 if has_weights {
122 self.generate_weighted_flow_selection(active_methods)
124 } else {
125 self.generate_uniform_flow_selection(active_methods)
127 }
128 }
129
130 fn generate_uniform_flow_selection(
132 &self,
133 active_methods: &[&crate::types::trident_flow_executor::FlowMethod],
134 ) -> TokenStream {
135 let flow_match_arms = active_methods.iter().enumerate().map(|(index, method)| {
136 let method_ident = &method.ident;
137 quote! {
138 #index => self.#method_ident(),
139 }
140 });
141 let num_flows = active_methods.len();
142
143 quote! {
144 let flows_results = for _ in 0..flow_calls_per_iteration {
146 let flow_index = self.trident.gen_range(0..#num_flows);
147 match flow_index {
148 #(#flow_match_arms)*
149 _ => unreachable!("Invalid flow index"),
150 }
151 };
152 }
153 }
154
155 fn generate_weighted_flow_selection(
157 &self,
158 active_methods: &[&crate::types::trident_flow_executor::FlowMethod],
159 ) -> TokenStream {
160 let weighted_methods: Vec<_> = active_methods
162 .iter()
163 .filter(|method| method.constraints.weight.unwrap_or(0) > 0)
164 .collect();
165
166 if weighted_methods.is_empty() {
167 return quote! {
168 };
170 }
171
172 let total_weight: u32 = weighted_methods
174 .iter()
175 .map(|method| method.constraints.weight.unwrap())
176 .sum();
177
178 let mut cumulative_weight = 0u32;
180 let weight_ranges: Vec<_> = weighted_methods
181 .iter()
182 .map(|method| {
183 let weight = method.constraints.weight.unwrap();
184 let _start = cumulative_weight;
185 cumulative_weight += weight;
186 let end = cumulative_weight;
187 let method_ident = &method.ident;
188
189 quote! {
190 if random_weight < #end {
191 self.#method_ident();
192 continue;
193 }
194 }
195 })
196 .collect();
197
198 quote! {
199 let flows_results = for _ in 0..flow_calls_per_iteration {
201 let random_weight = self.trident.gen_range(0..#total_weight);
202 #(#weight_ranges)*
203 };
204 }
205 }
206
207 fn generate_fuzz_method(&self) -> TokenStream {
209 let thread_management = self.generate_thread_management_logic();
210 let single_threaded_fallback = self.generate_single_threaded_fallback();
211
212 quote! {
213 fn fuzz(iterations: u64, flow_calls_per_iteration: u64) {
214 if std::env::var("TRIDENT_FUZZ_DEBUG").is_ok() {
216 println!("Debug mode detected: Running single iteration with provided seed");
217 let iterations = 1u64;
218 #single_threaded_fallback
219 return;
220 } else {
221 std::panic::set_hook(Box::new(|_info| {
223 }));
225 }
226
227 use std::thread;
228 use std::time::{Duration, Instant};
229
230 let master_seed = if let Ok(seed) = std::env::var("TRIDENT_FUZZ_SEED") {
231 let seed_bytes = hex::decode(&seed).unwrap_or_else(|_| panic!("The seed is not a valid hex string: {}", seed));
232 let mut seed = [0; 32];
233 seed.copy_from_slice(&seed_bytes);
234 seed
235 } else{
236 let mut seed = [0; 32];
237 if let Err(err) = getrandom::fill(&mut seed) {
238 panic!("from_entropy failed: {}", err);
239 }
240 seed
241 };
242
243 let num_threads = thread::available_parallelism()
244 .map(|n| n.get())
245 .unwrap_or(1)
246 .min(iterations as usize);
247
248 if num_threads <= 1 || iterations <= 1 {
249 #single_threaded_fallback
251 return;
252 }
253
254 #thread_management
255 }
256 }
257 }
258
259 fn generate_single_threaded_fallback(&self) -> TokenStream {
261 let type_name = &self.type_name;
262 let progress_bar_setup = self.generate_progress_bar_setup(false);
263 let fuzzing_loop = self.generate_single_threaded_fuzzing_loop();
264 let metrics_output = self.generate_metrics_output();
265
266 quote! {
267 let mut fuzzer = #type_name::new();
268
269 if let Ok(debug_seed_hex) = std::env::var("TRIDENT_FUZZ_DEBUG") {
271 let seed_bytes = hex::decode(&debug_seed_hex)
273 .unwrap_or_else(|_| panic!("Invalid hex string in debug seed: {}", debug_seed_hex));
274
275 if seed_bytes.len() != 32 {
276 panic!("Debug seed must be exactly 32 bytes (64 hex characters), got: {}", seed_bytes.len());
277 }
278
279 let mut seed = [0u8; 32];
280 seed.copy_from_slice(&seed_bytes);
281
282 println!("Using debug seed: {}", debug_seed_hex);
283 fuzzer.trident._set_master_seed_for_debug(seed);
284 }
285 let total_flow_calls = iterations * flow_calls_per_iteration;
286
287 #progress_bar_setup
288 #fuzzing_loop
289 #metrics_output
290 }
291 }
292
293 fn generate_progress_bar_setup(&self, is_parallel: bool) -> TokenStream {
295 let message_prefix = if is_parallel { "Overall: " } else { "" };
296 let message_content = if is_parallel {
297 quote! { format!("Fuzzing with {} threads - {} iterations with {} flow calls each", num_threads, iterations, flow_calls_per_iteration) }
298 } else {
299 quote! { format!("Fuzzing {} iterations with {} flow calls each...", iterations, flow_calls_per_iteration) }
300 };
301
302 quote! {
303 let pb = indicatif::ProgressBar::new(total_flow_calls);
304 pb.set_style(
305 indicatif::ProgressStyle::with_template(
306 concat!(#message_prefix, "{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({percent}%) [{eta_precise}] {msg}")
307 )
308 .unwrap()
309 .progress_chars("#>-"),
310 );
311 pb.set_message(#message_content);
312 }
313 }
314
315 fn generate_single_threaded_fuzzing_loop(&self) -> TokenStream {
317 let generate_write_profile_logic = self.generate_write_profile_logic();
318 let loopcount_retrieval = self.generate_loopcount_retrieval();
319 let generate_coverage_server_port_retrieval =
320 self.generate_coverage_server_port_retrieval();
321
322 quote! {
323 #loopcount_retrieval
324 #generate_coverage_server_port_retrieval
325
326 for i in 0..iterations {
327 let result = fuzzer.execute_flows(flow_calls_per_iteration);
328 fuzzer.trident._next_iteration();
329 let _ = std::mem::take(&mut fuzzer.fuzz_accounts);
332
333 pb.inc(flow_calls_per_iteration);
334 pb.set_message(format!("Iteration {}/{} completed", i + 1, iterations));
335
336 #generate_write_profile_logic
337 }
338
339 pb.finish_with_message("Fuzzing completed!");
340
341 let fuzzing_data = fuzzer.trident._get_fuzzing_data();
342
343 }
344 }
345
346 fn generate_thread_management_logic(&self) -> TokenStream {
348 let parallel_progress_setup = self.generate_parallel_progress_setup();
349 let thread_spawn_logic = self.generate_thread_spawn_logic();
350 let metrics_collection = self.generate_metrics_collection_logic();
351
352 quote! {
353 let iterations_per_thread = iterations / num_threads as u64;
354 let remaining_iterations = iterations % num_threads as u64;
355 let total_flow_calls = iterations * flow_calls_per_iteration;
356
357 let mut handles = Vec::new();
358
359 #parallel_progress_setup
360
361 for thread_id in 0..num_threads {
362
363 let thread_iterations = iterations_per_thread;
364
365 if thread_iterations == 0 {
366 continue;
367 }
368
369 #thread_spawn_logic
370 }
371
372 #metrics_collection
373 }
374 }
375
376 fn generate_parallel_progress_setup(&self) -> TokenStream {
378 quote! {
379 let main_pb = indicatif::ProgressBar::new(total_flow_calls);
381 main_pb.set_style(
382 indicatif::ProgressStyle::with_template(
383 "Overall: {spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({percent}%) [{eta_precise}] {msg}"
384 )
385 .unwrap()
386 .progress_chars("#>-"),
387 );
388 main_pb.set_message(format!("Fuzzing with {} threads - {} iterations with {} flow calls each", num_threads, iterations, flow_calls_per_iteration));
389 }
390 }
391
392 fn generate_thread_spawn_logic(&self) -> TokenStream {
394 let type_name = &self.type_name;
395 let generate_loopcount_retrieval = self.generate_loopcount_retrieval();
396 let generate_coverage_server_port_retrieval =
397 self.generate_coverage_server_port_retrieval();
398 let generate_write_profile_logic = self.generate_multi_threaded_coverage();
399
400 quote! {
401 let main_pb_clone = main_pb.clone();
402 let handle = thread::spawn(move || -> TridentFuzzingData {
403 let mut fuzzer = #type_name::new();
405
406 fuzzer.trident._set_master_seed_and_thread_id(master_seed, thread_id);
407
408 const UPDATE_INTERVAL: u64 = 100;
410 let mut last_update = Instant::now();
411 let update_duration = Duration::from_millis(50);
412
413 let mut local_counter = 0u64;
414
415 #generate_loopcount_retrieval
416 #generate_coverage_server_port_retrieval
417
418 for i in 0..thread_iterations {
419 let _ = fuzzer.execute_flows(flow_calls_per_iteration);
420 fuzzer.trident._next_iteration();
421
422 let _ = std::mem::take(&mut fuzzer.fuzz_accounts);
425
426 local_counter += flow_calls_per_iteration;
427
428 let should_update = local_counter >= UPDATE_INTERVAL ||
430 last_update.elapsed() >= update_duration ||
431 i == thread_iterations - 1; if should_update {
434 main_pb_clone.inc(local_counter);
435 local_counter = 0;
436 last_update = Instant::now();
437 }
438
439 #generate_write_profile_logic
440 }
441
442 if local_counter > 0 {
444 main_pb_clone.inc(local_counter);
445 }
446
447 fuzzer.trident._get_fuzzing_data()
449 });
450
451 handles.push(handle);
452 }
453 }
454
455 fn generate_metrics_collection_logic(&self) -> TokenStream {
457 let metrics_output = self.generate_metrics_output();
458 quote! {
459 let mut fuzzing_data = TridentFuzzingData::with_master_seed(master_seed);
461
462 for handle in handles {
463 match handle.join() {
464 Ok(thread_metrics) => {
465 if std::env::var("FUZZING_METRICS").is_ok() {
466 fuzzing_data._merge(thread_metrics);
467 }
468 }
469 Err(err) => {
470 if let Some(s) = err.downcast_ref::<&str>() {
471 eprintln!("Thread panicked with message: {}", s);
472 } else if let Some(s) = err.downcast_ref::<String>() {
473 eprintln!("Thread panicked with message: {}", s);
474 } else {
475 eprintln!("Thread panicked with unknown error type");
476 }
477 panic!("Error joining thread: {:?}", err);
478 }
479 }
480 }
481
482 main_pb.finish_with_message("Parallel fuzzing completed!");
483 #metrics_output
484 }
485 }
486
487 fn generate_metrics_output(&self) -> TokenStream {
489 quote! {
490 if std::env::var("FUZZING_METRICS").is_ok() {
491 fuzzing_data.generate().unwrap();
492 }
493 }
494 }
495
496 fn generate_loopcount_retrieval(&self) -> TokenStream {
497 quote! {
498 let loopcount = match std::env::var("FUZZER_LOOPCOUNT") {
499 Ok(val) => val.parse().unwrap_or(0),
500 Err(_) => 0,
501 };
502 }
503 }
504
505 fn generate_coverage_server_port_retrieval(&self) -> TokenStream {
506 quote! {
507 let coverage_server_port = std::env::var("COVERAGE_SERVER_PORT").unwrap_or("58432".to_string());
508 }
509 }
510
511 fn retrieve_collect_coverage_flag(&self) -> String {
512 std::env::var("COLLECT_COVERAGE").unwrap_or("0".to_string())
513 }
514
515 #[allow(unused_doc_comments)]
516 fn generate_write_profile_logic(&self) -> TokenStream {
517 let generate_notify_extension_logic = self.generate_notify_extension_logic();
518
519 match self.retrieve_collect_coverage_flag().as_str() {
542 "1" => quote! {
543 if loopcount > 0 &&
544 i > 0 &&
545 i % loopcount == 0 {
546
547 unsafe {
548 let filename = format!("target/fuzz-cov-run-{}.profraw", i);
549 let filename_cstr = std::ffi::CString::new(filename).unwrap();
550 __llvm_profile_set_filename(filename_cstr.as_ptr());
551
552 let _ = __llvm_profile_write_file();
553 __llvm_profile_reset_counters();
554
555 #generate_notify_extension_logic
556
557 let final_filename = std::ffi::CString::new("target/fuzz-cov-run-final.profraw").unwrap();
558 __llvm_profile_set_filename(final_filename.as_ptr());
559 }
560 }
561 },
562 _ => quote! {},
563 }
564 }
565
566 fn generate_multi_threaded_coverage(&self) -> TokenStream {
567 let generate_write_profile_logic = self.generate_write_profile_logic();
568
569 quote! {
570 if thread_id == 0 {
571 #generate_write_profile_logic
572 }
573 }
574 }
575
576 fn generate_notify_extension_logic(&self) -> TokenStream {
582 quote! {
583 let url = format!(
584 "http://localhost:{}/update-decorations",
585 coverage_server_port
586 );
587
588 std::thread::spawn(move || {
591 let client = reqwest::blocking::Client::new();
592 let _ = client
593 .post(&url)
594 .header("Content-Type", "application/json")
595 .body("")
596 .send();
597 });
598 }
599 }
600}