trident_syn/codegen/
trident_flow_executor.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use quote::ToTokens;
4
5use crate::types::trident_flow_executor::TridentFlowExecutorImpl;
6
7impl ToTokens for TridentFlowExecutorImpl {
8    fn to_tokens(&self, tokens: &mut TokenStream) {
9        let expanded = self.generate_flow_executor_impl();
10        tokens.extend(expanded);
11    }
12}
13
14impl TridentFlowExecutorImpl {
15    /// Generate the complete flow executor implementation
16    fn generate_flow_executor_impl(&self) -> TokenStream {
17        let type_name = &self.type_name;
18        let impl_items = &self.impl_block;
19        let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
20
21        let generated_impl = self.generate_generated_impl_block();
22
23        quote! {
24            impl #impl_generics #type_name #ty_generics #where_clause {
25                #(#impl_items)*
26            }
27
28            #generated_impl
29        }
30    }
31
32    /// Generate the main implementation block with flow execution methods
33    fn generate_generated_impl_block(&self) -> TokenStream {
34        let type_name = &self.type_name;
35        let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
36
37        let execute_flows_method = self.generate_execute_flows_method();
38        let fuzz_method = self.generate_fuzz_method();
39
40        quote! {
41            impl #impl_generics #type_name #ty_generics #where_clause {
42                #execute_flows_method
43                #fuzz_method
44            }
45        }
46    }
47
48    /// Generate the main execute_flows method
49    fn generate_execute_flows_method(&self) -> TokenStream {
50        let init_call = self.generate_init_call();
51        let flow_execution_logic = self.generate_flow_execution_logic();
52        let end_call = self.generate_end_call();
53
54        quote! {
55            pub fn execute_flows(
56                &mut self,
57                flow_calls_per_iteration: u64,
58            ) -> std::result::Result<(), FuzzingError> {
59                #init_call
60                #flow_execution_logic
61                #end_call
62                Ok(())
63            }
64        }
65    }
66
67    /// Generate the initialization call if an init method exists
68    fn generate_init_call(&self) -> TokenStream {
69        if let Some(init_method) = &self.init_method {
70            quote! {
71                self.#init_method();
72            }
73        } else {
74            quote! {}
75        }
76    }
77
78    /// Generate the end call if an end method exists
79    fn generate_end_call(&self) -> TokenStream {
80        if let Some(end_method) = &self.end_method {
81            quote! {
82                self.#end_method();
83            }
84        } else {
85            quote! {}
86        }
87    }
88
89    /// Generate the flow execution logic
90    fn generate_flow_execution_logic(&self) -> TokenStream {
91        // Filter out ignored flows
92        let active_methods: Vec<_> = self
93            .flow_methods
94            .iter()
95            .filter(|method| !method.constraints.ignore)
96            .collect();
97
98        if active_methods.is_empty() {
99            quote! {
100                // No flow methods defined or all are ignored, nothing to execute
101            }
102        } else {
103            let flow_selection_logic = self.generate_flow_selection_logic(&active_methods);
104
105            quote! {
106                #flow_selection_logic
107            }
108        }
109    }
110
111    /// Generate the random flow selection logic
112    fn generate_flow_selection_logic(
113        &self,
114        active_methods: &[&crate::types::trident_flow_executor::FlowMethod],
115    ) -> TokenStream {
116        // Check if any flow has weights
117        let has_weights = active_methods
118            .iter()
119            .any(|method| method.constraints.weight.is_some());
120
121        if has_weights {
122            // Generate weighted selection logic
123            self.generate_weighted_flow_selection(active_methods)
124        } else {
125            // Generate uniform random selection logic (original behavior)
126            self.generate_uniform_flow_selection(active_methods)
127        }
128    }
129
130    /// Generate uniform random flow selection (original behavior)
131    fn generate_uniform_flow_selection(
132        &self,
133        active_methods: &[&crate::types::trident_flow_executor::FlowMethod],
134    ) -> TokenStream {
135        let flow_match_arms = active_methods.iter().enumerate().map(|(index, method)| {
136            let method_ident = &method.ident;
137            quote! {
138                #index => self.#method_ident(),
139            }
140        });
141        let num_flows = active_methods.len();
142
143        quote! {
144            // Randomly select and execute flows for the specified number of calls
145            let flows_results = for _ in 0..flow_calls_per_iteration {
146                let flow_index = self.trident.gen_range(0..#num_flows);
147                match flow_index {
148                    #(#flow_match_arms)*
149                    _ => unreachable!("Invalid flow index"),
150                }
151            };
152        }
153    }
154
155    /// Generate weighted flow selection logic
156    fn generate_weighted_flow_selection(
157        &self,
158        active_methods: &[&crate::types::trident_flow_executor::FlowMethod],
159    ) -> TokenStream {
160        // Filter out flows with weight 0 (they should be skipped)
161        let weighted_methods: Vec<_> = active_methods
162            .iter()
163            .filter(|method| method.constraints.weight.unwrap_or(0) > 0)
164            .collect();
165
166        if weighted_methods.is_empty() {
167            return quote! {
168                // All flows have weight 0, nothing to execute
169            };
170        }
171
172        // Calculate total weight
173        let total_weight: u32 = weighted_methods
174            .iter()
175            .map(|method| method.constraints.weight.unwrap())
176            .sum();
177
178        // Generate weight ranges and method calls
179        let mut cumulative_weight = 0u32;
180        let weight_ranges: Vec<_> = weighted_methods
181            .iter()
182            .map(|method| {
183                let weight = method.constraints.weight.unwrap();
184                let _start = cumulative_weight;
185                cumulative_weight += weight;
186                let end = cumulative_weight;
187                let method_ident = &method.ident;
188
189                quote! {
190                    if random_weight < #end {
191                        self.#method_ident();
192                        continue;
193                    }
194                }
195            })
196            .collect();
197
198        quote! {
199            // Weighted flow selection based on specified weights
200            let flows_results = for _ in 0..flow_calls_per_iteration {
201                let random_weight = self.trident.gen_range(0..#total_weight);
202                #(#weight_ranges)*
203            };
204        }
205    }
206
207    /// Generate the unified fuzz method that runs in parallel by default
208    fn generate_fuzz_method(&self) -> TokenStream {
209        let thread_management = self.generate_thread_management_logic();
210        let single_threaded_fallback = self.generate_single_threaded_fallback();
211
212        quote! {
213            fn fuzz(iterations: u64, flow_calls_per_iteration: u64) {
214                // Check for debug mode first - if present, run single iteration immediately
215                if std::env::var("TRIDENT_FUZZ_DEBUG").is_ok() {
216                    println!("Debug mode detected: Running single iteration with provided seed");
217                    let iterations = 1u64;
218                    #single_threaded_fallback
219                    return;
220                } else {
221                    // TODO: this is a hack to suppress the panic message
222                    std::panic::set_hook(Box::new(|_info| {
223                        // Do nothing — suppress the panic message
224                    }));
225                }
226
227                use std::thread;
228                use std::time::{Duration, Instant};
229
230                let master_seed = if let Ok(seed) = std::env::var("TRIDENT_FUZZ_SEED") {
231                    let seed_bytes = hex::decode(&seed).unwrap_or_else(|_| panic!("The seed is not a valid hex string: {}", seed));
232                    let mut seed = [0; 32];
233                    seed.copy_from_slice(&seed_bytes);
234                    seed
235                } else{
236                    let mut seed = [0; 32];
237                    if let Err(err) = getrandom::fill(&mut seed) {
238                        panic!("from_entropy failed: {}", err);
239                    }
240                    seed
241                };
242
243                let num_threads = thread::available_parallelism()
244                    .map(|n| n.get())
245                    .unwrap_or(1)
246                    .min(iterations as usize);
247
248                if num_threads <= 1 || iterations <= 1 {
249                    // Single-threaded fallback
250                    #single_threaded_fallback
251                    return;
252                }
253
254                #thread_management
255            }
256        }
257    }
258
259    /// Generate single-threaded fallback logic
260    fn generate_single_threaded_fallback(&self) -> TokenStream {
261        let type_name = &self.type_name;
262        let progress_bar_setup = self.generate_progress_bar_setup(false);
263        let fuzzing_loop = self.generate_single_threaded_fuzzing_loop();
264        let metrics_output = self.generate_metrics_output();
265
266        quote! {
267            let mut fuzzer = #type_name::new();
268
269            // Set debug seed if in debug mode
270            if let Ok(debug_seed_hex) = std::env::var("TRIDENT_FUZZ_DEBUG") {
271                // Parse hex string to [u8; 32] using hex crate
272                let seed_bytes = hex::decode(&debug_seed_hex)
273                    .unwrap_or_else(|_| panic!("Invalid hex string in debug seed: {}", debug_seed_hex));
274
275                if seed_bytes.len() != 32 {
276                    panic!("Debug seed must be exactly 32 bytes (64 hex characters), got: {}", seed_bytes.len());
277                }
278
279                let mut seed = [0u8; 32];
280                seed.copy_from_slice(&seed_bytes);
281
282                println!("Using debug seed: {}", debug_seed_hex);
283                fuzzer.trident._set_master_seed_for_debug(seed);
284            }
285            let total_flow_calls = iterations * flow_calls_per_iteration;
286
287            #progress_bar_setup
288            #fuzzing_loop
289            #metrics_output
290        }
291    }
292
293    /// Generate progress bar setup code
294    fn generate_progress_bar_setup(&self, is_parallel: bool) -> TokenStream {
295        let message_prefix = if is_parallel { "Overall: " } else { "" };
296        let message_content = if is_parallel {
297            quote! { format!("Fuzzing with {} threads - {} iterations with {} flow calls each", num_threads, iterations, flow_calls_per_iteration) }
298        } else {
299            quote! { format!("Fuzzing {} iterations with {} flow calls each...", iterations, flow_calls_per_iteration) }
300        };
301
302        quote! {
303            let pb = indicatif::ProgressBar::new(total_flow_calls);
304            pb.set_style(
305                indicatif::ProgressStyle::with_template(
306                    concat!(#message_prefix, "{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({percent}%) [{eta_precise}] {msg}")
307                )
308                .unwrap()
309                .progress_chars("#>-"),
310            );
311            pb.set_message(#message_content);
312        }
313    }
314
315    /// Generate the single-threaded fuzzing loop
316    fn generate_single_threaded_fuzzing_loop(&self) -> TokenStream {
317        let generate_write_profile_logic = self.generate_write_profile_logic();
318        let loopcount_retrieval = self.generate_loopcount_retrieval();
319        let generate_coverage_server_port_retrieval =
320            self.generate_coverage_server_port_retrieval();
321
322        quote! {
323            #loopcount_retrieval
324            #generate_coverage_server_port_retrieval
325
326            for i in 0..iterations {
327                let result = fuzzer.execute_flows(flow_calls_per_iteration);
328                fuzzer.trident._next_iteration();
329                // this will ensure the fuzz accounts will reset without
330                // specifiyng the type of the fuzz accounts
331                let _ = std::mem::take(&mut fuzzer.fuzz_accounts);
332
333                pb.inc(flow_calls_per_iteration);
334                pb.set_message(format!("Iteration {}/{} completed", i + 1, iterations));
335
336                #generate_write_profile_logic
337            }
338
339            pb.finish_with_message("Fuzzing completed!");
340
341            let fuzzing_data = fuzzer.trident._get_fuzzing_data();
342
343        }
344    }
345
346    /// Generate thread management logic for parallel execution
347    fn generate_thread_management_logic(&self) -> TokenStream {
348        let parallel_progress_setup = self.generate_parallel_progress_setup();
349        let thread_spawn_logic = self.generate_thread_spawn_logic();
350        let metrics_collection = self.generate_metrics_collection_logic();
351
352        quote! {
353            let iterations_per_thread = iterations / num_threads as u64;
354            let remaining_iterations = iterations % num_threads as u64;
355            let total_flow_calls = iterations * flow_calls_per_iteration;
356
357            let mut handles = Vec::new();
358
359            #parallel_progress_setup
360
361            for thread_id in 0..num_threads {
362
363                let thread_iterations = iterations_per_thread;
364
365                if thread_iterations == 0 {
366                    continue;
367                }
368
369                #thread_spawn_logic
370            }
371
372            #metrics_collection
373        }
374    }
375
376    /// Generate parallel progress bar setup
377    fn generate_parallel_progress_setup(&self) -> TokenStream {
378        quote! {
379            // Create a separate progress bar for overall status
380            let main_pb = indicatif::ProgressBar::new(total_flow_calls);
381            main_pb.set_style(
382                indicatif::ProgressStyle::with_template(
383                    "Overall: {spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({percent}%) [{eta_precise}] {msg}"
384                )
385                .unwrap()
386                .progress_chars("#>-"),
387            );
388            main_pb.set_message(format!("Fuzzing with {} threads - {} iterations with {} flow calls each", num_threads, iterations, flow_calls_per_iteration));
389        }
390    }
391
392    /// Generate thread spawning logic
393    fn generate_thread_spawn_logic(&self) -> TokenStream {
394        let type_name = &self.type_name;
395        let generate_loopcount_retrieval = self.generate_loopcount_retrieval();
396        let generate_coverage_server_port_retrieval =
397            self.generate_coverage_server_port_retrieval();
398        let generate_write_profile_logic = self.generate_multi_threaded_coverage();
399
400        quote! {
401            let main_pb_clone = main_pb.clone();
402            let handle = thread::spawn(move || -> TridentFuzzingData {
403                // Each thread creates its own client and fuzzer instance
404                let mut fuzzer = #type_name::new();
405
406                fuzzer.trident._set_master_seed_and_thread_id(master_seed, thread_id);
407
408                // Update progress every 100 flow calls or every 50ms, whichever comes first
409                const UPDATE_INTERVAL: u64 = 100;
410                let mut last_update = Instant::now();
411                let update_duration = Duration::from_millis(50);
412
413                let mut local_counter = 0u64;
414
415                #generate_loopcount_retrieval
416                #generate_coverage_server_port_retrieval
417
418                for i in 0..thread_iterations {
419                    let _ = fuzzer.execute_flows(flow_calls_per_iteration);
420                    fuzzer.trident._next_iteration();
421
422                    // this will ensure the fuzz accounts will reset without
423                    // specifiyng the type of the fuzz accounts
424                    let _ = std::mem::take(&mut fuzzer.fuzz_accounts);
425
426                    local_counter += flow_calls_per_iteration;
427
428                    // Update progress bars with granularity control
429                    let should_update = local_counter >= UPDATE_INTERVAL ||
430                                      last_update.elapsed() >= update_duration ||
431                                      i == thread_iterations - 1; // Always update on last iteration
432
433                    if should_update {
434                        main_pb_clone.inc(local_counter);
435                        local_counter = 0;
436                        last_update = Instant::now();
437                    }
438
439                    #generate_write_profile_logic
440                }
441
442                // Ensure final update
443                if local_counter > 0 {
444                    main_pb_clone.inc(local_counter);
445                }
446
447                // Return the metrics from this thread
448                fuzzer.trident._get_fuzzing_data()
449            });
450
451            handles.push(handle);
452        }
453    }
454
455    /// Generate metrics collection and output logic
456    fn generate_metrics_collection_logic(&self) -> TokenStream {
457        let metrics_output = self.generate_metrics_output();
458        quote! {
459            // Collect metrics from all threads
460            let mut fuzzing_data = TridentFuzzingData::with_master_seed(master_seed);
461
462            for handle in handles {
463                match handle.join() {
464                    Ok(thread_metrics) => {
465                        if std::env::var("FUZZING_METRICS").is_ok() {
466                            fuzzing_data._merge(thread_metrics);
467                        }
468                    }
469                    Err(err) => {
470                        if let Some(s) = err.downcast_ref::<&str>() {
471                            eprintln!("Thread panicked with message: {}", s);
472                        } else if let Some(s) = err.downcast_ref::<String>() {
473                            eprintln!("Thread panicked with message: {}", s);
474                        } else {
475                            eprintln!("Thread panicked with unknown error type");
476                        }
477                        panic!("Error joining thread: {:?}", err);
478                    }
479                }
480            }
481
482            main_pb.finish_with_message("Parallel fuzzing completed!");
483            #metrics_output
484        }
485    }
486
487    /// Generate metrics output logic
488    fn generate_metrics_output(&self) -> TokenStream {
489        quote! {
490            if std::env::var("FUZZING_METRICS").is_ok() {
491                fuzzing_data.generate().unwrap();
492            }
493        }
494    }
495
496    fn generate_loopcount_retrieval(&self) -> TokenStream {
497        quote! {
498            let loopcount = match std::env::var("FUZZER_LOOPCOUNT") {
499                Ok(val) => val.parse().unwrap_or(0),
500                Err(_) => 0,
501            };
502        }
503    }
504
505    fn generate_coverage_server_port_retrieval(&self) -> TokenStream {
506        quote! {
507            let coverage_server_port = std::env::var("COVERAGE_SERVER_PORT").unwrap_or("58432".to_string());
508        }
509    }
510
511    fn retrieve_collect_coverage_flag(&self) -> String {
512        std::env::var("COLLECT_COVERAGE").unwrap_or("0".to_string())
513    }
514
515    #[allow(unused_doc_comments)]
516    fn generate_write_profile_logic(&self) -> TokenStream {
517        let generate_notify_extension_logic = self.generate_notify_extension_logic();
518
519        /// This part is a bit tricky and requires a thorough explanation:
520        ///
521        /// LLVM automatically creates a profraw file when the process ends, but since
522        /// we run fuzz tests in a single process with multiple threads, we only get
523        /// one file with combined data from all threads. To enable real-time coverage
524        /// display, we manually create profraw files at intervals.
525        ///
526        /// set_filename: sets the filename for the profraw file
527        /// write_file: creates a profraw file with collected data
528        /// reset_counters: resets the counters to 0
529        ///
530        /// Only thread 0 writes files to avoid duplicates. We use unique filenames
531        /// for each iteration and reset counters after writing. Since the final
532        /// profraw file is created automatically at process end, we preemptively
533        /// set the filename to avoid overwriting the last intermediate file.
534        ///
535        /// Coverage won't be 100% accurate because while the first thread creates
536        /// the profraw file, the other threads are still running and generating data,
537        /// which we reset after writing.
538        ///
539        /// We only generate this code if COLLECT_COVERAGE is set to 1 because
540        /// if -C instrument-coverage is not enabled, llvm methods will not be available
541        match self.retrieve_collect_coverage_flag().as_str() {
542            "1" => quote! {
543                if loopcount > 0 &&
544                    i > 0 &&
545                    i % loopcount == 0 {
546
547                    unsafe {
548                        let filename = format!("target/fuzz-cov-run-{}.profraw", i);
549                        let filename_cstr = std::ffi::CString::new(filename).unwrap();
550                        __llvm_profile_set_filename(filename_cstr.as_ptr());
551
552                        let _ = __llvm_profile_write_file();
553                        __llvm_profile_reset_counters();
554
555                        #generate_notify_extension_logic
556
557                        let final_filename = std::ffi::CString::new("target/fuzz-cov-run-final.profraw").unwrap();
558                        __llvm_profile_set_filename(final_filename.as_ptr());
559                    }
560                }
561            },
562            _ => quote! {},
563        }
564    }
565
566    fn generate_multi_threaded_coverage(&self) -> TokenStream {
567        let generate_write_profile_logic = self.generate_write_profile_logic();
568
569        quote! {
570            if thread_id == 0 {
571                #generate_write_profile_logic
572            }
573        }
574    }
575
576    /// Notifies the extension to update coverage
577    /// decorations if dynamic coverage is enabled
578    ///
579    /// Not very nice many things hardcoded here
580    /// TODO: improve architecture to avoid this
581    fn generate_notify_extension_logic(&self) -> TokenStream {
582        quote! {
583            let url = format!(
584                "http://localhost:{}/update-decorations",
585                coverage_server_port
586            );
587
588            // Right now requests are rapidly fired, regardless of whether extension is ready
589            // TODO: Only fire if extension responded
590            std::thread::spawn(move || {
591                let client = reqwest::blocking::Client::new();
592                let _ = client
593                    .post(&url)
594                    .header("Content-Type", "application/json")
595                    .body("")
596                    .send();
597            });
598        }
599    }
600}