trident_fuzz/trident/
flow_executor.rs

1use std::panic::catch_unwind;
2use std::panic::AssertUnwindSafe;
3use std::sync::atomic::AtomicBool;
4use std::sync::atomic::Ordering;
5use std::sync::Arc;
6use std::thread;
7use std::time::Instant;
8use trident_fuzz_metrics::TridentFuzzingData;
9
10use crate::trident::Trident;
11
12// Thread-local storage for panic location information.
13// When a panic occurs, the panic handler stores the location here so we can retrieve it
14// after catching the panic with catch_unwind.
15thread_local! {
16    static PANIC_LOCATION: std::cell::Cell<Option<String>> = const { std::cell::Cell::new(None) };
17}
18
19/// Configuration constants for the flow executor
20mod config {
21    use std::time::Duration;
22
23    /// How often to update progress bars (in flow calls)
24    pub const PROGRESS_UPDATE_INTERVAL: u64 = 100;
25
26    /// How often to update progress bars (in time)
27    pub const PROGRESS_UPDATE_DURATION: Duration = Duration::from_millis(50);
28
29    /// Default seed size in bytes
30    pub const SEED_SIZE: usize = 32;
31
32    /// Environment variable names
33    pub const ENV_FUZZ_DEBUG: &str = "TRIDENT_FUZZ_DEBUG";
34    pub const ENV_FUZZ_SEED: &str = "TRIDENT_FUZZ_SEED";
35    pub const ENV_FUZZING_METRICS: &str = "FUZZING_METRICS";
36    pub const ENV_WITH_EXIT_CODE: &str = "TRIDENT_WITH_EXIT_CODE";
37}
38
39/// Trait for executing fuzzing flows in the Trident framework
40///
41/// This trait defines the interface for fuzzing executors that can run
42/// multiple iterations of randomized program interactions. Implementors
43/// should provide the core fuzzing logic while this trait handles
44/// parallelization, progress tracking, and metrics collection.
45pub trait FlowExecutor: Send + 'static + Sized {
46    /// Creates a new instance of the flow executor
47    fn new() -> Self;
48
49    /// Executes a specified number of flow calls in a single iteration
50    ///
51    /// # Arguments
52    /// * `flow_calls_per_iteration` - Number of individual flow calls to execute
53    ///
54    /// # Returns
55    /// Result indicating success or a fuzzing error
56    fn execute_flows(
57        &mut self,
58        flow_calls_per_iteration: u64,
59    ) -> Result<(), crate::error::FuzzingError>;
60
61    /// Returns a mutable reference to the underlying Trident instance
62    fn trident_mut(&mut self) -> &mut Trident;
63
64    /// Resets fuzz accounts to their initial state for the next iteration
65    fn reset_fuzz_accounts(&mut self);
66
67    /// Handles LLVM coverage collection (generated by macro)
68    ///
69    /// This method is typically empty or contains LLVM coverage calls
70    /// depending on whether coverage profiling is enabled.
71    ///
72    /// # Arguments
73    /// * `current_iteration` - The current iteration number
74    fn handle_llvm_coverage(&mut self, current_iteration: u64);
75
76    /// Main entry point for fuzzing execution
77    ///
78    /// This method orchestrates the entire fuzzing process, handling both
79    /// single-threaded and parallel execution based on the environment
80    /// and available system resources.
81    ///
82    /// # Arguments
83    /// * `iterations` - Total number of fuzzing iterations to run
84    /// * `flow_calls_per_iteration` - Number of flow calls per iteration
85    fn fuzz(iterations: u64, flow_calls_per_iteration: u64) {
86        // Setup panic handler to capture location information when panics occur
87        Self::setup_panic_handler();
88
89        // Debug mode: run single iteration with provided seed (for reproducing specific failures)
90        if std::env::var(config::ENV_FUZZ_DEBUG).is_ok() {
91            println!("Debug mode detected: Running single iteration with provided seed");
92            Self::fuzz_single_threaded(1, flow_calls_per_iteration);
93            return;
94        }
95
96        // Get or generate master seed for reproducible fuzzing
97        let master_seed = Self::get_or_generate_master_seed();
98
99        // Determine number of threads to use (limited by available parallelism and iteration count)
100        let num_threads = thread::available_parallelism()
101            .map(|n| n.get())
102            .unwrap_or(1)
103            .min(iterations as usize);
104
105        // Use single-threaded mode if we only have one thread or one iteration
106        if num_threads <= 1 || iterations <= 1 {
107            Self::fuzz_single_threaded(iterations, flow_calls_per_iteration);
108            return;
109        }
110
111        // Use parallel mode for better performance with multiple threads
112        Self::fuzz_parallel(
113            iterations,
114            flow_calls_per_iteration,
115            num_threads,
116            master_seed,
117        );
118    }
119
120    /// Sets up a global panic handler that captures panic location information.
121    /// This allows us to retrieve the file, line, and column where a panic occurred
122    /// even after catching it with catch_unwind.
123    fn setup_panic_handler() {
124        std::panic::set_hook(Box::new(|info| {
125            let location = info
126                .location()
127                .map(|loc| format!("{}:{}:{}", loc.file(), loc.line(), loc.column()))
128                .unwrap_or_else(|| "unknown".to_string());
129
130            PANIC_LOCATION.with(|cell| {
131                cell.set(Some(location));
132            });
133        }));
134    }
135
136    /// Extracts the panic message from a panic payload.
137    /// Panics can have either &str or String payloads, so we handle both cases.
138    fn extract_panic_message(panic_err: &Box<dyn std::any::Any + Send>) -> String {
139        panic_err
140            .downcast_ref::<&str>()
141            .map(|s| s.to_string())
142            .or_else(|| panic_err.downcast_ref::<String>().cloned())
143            .unwrap_or_else(|| "unknown panic".to_string())
144    }
145
146    /// Handles a caught panic by logging it and updating the panic tracking flag.
147    /// Returns the formatted panic message for display.
148    fn handle_panic(
149        panic_err: &Box<dyn std::any::Any + Send>,
150        fuzzer: &mut Self,
151        panic_occurred: Option<&Arc<AtomicBool>>,
152    ) -> String {
153        // Mark that a panic occurred (for exit code handling)
154        if let Some(flag) = panic_occurred {
155            flag.store(true, Ordering::Relaxed);
156        }
157
158        // Extract panic details
159        let message = Self::extract_panic_message(panic_err);
160        let location =
161            PANIC_LOCATION.with(|cell| cell.take().unwrap_or_else(|| "unknown".to_string()));
162        let seed = hex::encode(fuzzer.trident_mut().get_current_seed());
163
164        format!(
165            "Assertion failed at {}: {} (seed: {})",
166            location, message, seed
167        )
168    }
169
170    /// Determines the exit code based on panic status and configuration.
171    /// Returns 99 if with_exit_code is enabled and panics occurred, otherwise uses metrics exit code.
172    fn determine_exit_code(
173        with_exit_code: bool,
174        panic_occurred: bool,
175        fuzzing_data: &TridentFuzzingData,
176    ) -> i32 {
177        if with_exit_code {
178            // When exit code mode is enabled, return 99 if any panics occurred
179            if panic_occurred || fuzzing_data.get_exit_code() != 0 {
180                99
181            } else {
182                0
183            }
184        } else {
185            // Otherwise, use the exit code from metrics (which may be 0 or 99)
186            fuzzing_data.get_exit_code()
187        }
188    }
189
190    /// Gets the master seed from environment variable or generates a random one.
191    /// The master seed is used to initialize all fuzzer instances for reproducible runs.
192    fn get_or_generate_master_seed() -> [u8; config::SEED_SIZE] {
193        if let Ok(seed_hex) = std::env::var(config::ENV_FUZZ_SEED) {
194            Self::parse_hex_seed(&seed_hex)
195        } else {
196            Self::generate_random_seed()
197        }
198    }
199
200    /// Parses a hex-encoded seed string into a byte array.
201    /// Validates that the seed is exactly the required size.
202    fn parse_hex_seed(seed_hex: &str) -> [u8; config::SEED_SIZE] {
203        let seed_bytes = hex::decode(seed_hex)
204            .unwrap_or_else(|_| panic!("Invalid hex string in seed: {}", seed_hex));
205
206        if seed_bytes.len() != config::SEED_SIZE {
207            panic!(
208                "Seed must be exactly {} bytes ({} hex characters), got: {}",
209                config::SEED_SIZE,
210                config::SEED_SIZE * 2,
211                seed_bytes.len()
212            );
213        }
214
215        let mut seed = [0u8; config::SEED_SIZE];
216        seed.copy_from_slice(&seed_bytes);
217        seed
218    }
219
220    /// Generates a cryptographically secure random seed.
221    fn generate_random_seed() -> [u8; config::SEED_SIZE] {
222        let mut seed = [0u8; config::SEED_SIZE];
223        if let Err(err) = getrandom::fill(&mut seed) {
224            panic!("Failed to generate random seed: {}", err);
225        }
226        seed
227    }
228
229    /// Outputs fuzzing metrics (JSON, dashboard, etc.) if metrics are enabled.
230    fn output_metrics_if_enabled(fuzzing_data: &TridentFuzzingData) {
231        if std::env::var(config::ENV_FUZZING_METRICS).is_ok() {
232            if let Err(e) = fuzzing_data.generate() {
233                eprintln!("Warning: Failed to generate metrics: {}", e);
234            }
235        }
236    }
237
238    /// Executes fuzzing in a single thread.
239    /// This is used for debug mode, small iteration counts, or when only one thread is available.
240    fn fuzz_single_threaded(iterations: u64, flow_calls_per_iteration: u64) {
241        let mut fuzzer = Self::new();
242        let is_debug_mode = std::env::var(config::ENV_FUZZ_DEBUG).is_ok();
243        let with_exit_code = std::env::var(config::ENV_WITH_EXIT_CODE).is_ok();
244        let mut panic_occurred = false; // Simple bool since we're single-threaded
245
246        // Configure debug seed if in debug mode
247        if is_debug_mode {
248            let debug_seed_hex = std::env::var(config::ENV_FUZZ_DEBUG).unwrap();
249            let debug_seed = Self::parse_hex_seed(&debug_seed_hex);
250            println!("Using debug seed: {}", debug_seed_hex);
251            fuzzer.trident_mut().set_master_seed_for_debug(debug_seed);
252        }
253
254        // Setup progress bar (disabled in debug mode for cleaner output)
255        let pb = if is_debug_mode {
256            None
257        } else {
258            let total_flow_calls = iterations * flow_calls_per_iteration;
259            let pb = indicatif::ProgressBar::new(total_flow_calls);
260            pb.set_style(
261                indicatif::ProgressStyle::with_template(
262                    "{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({percent}%) [{eta_precise}] {msg}"
263                )
264                .unwrap()
265                .progress_chars("#>-"),
266            );
267            pb.set_message(format!(
268                "Fuzzing {} iterations with {} flow calls each...",
269                iterations, flow_calls_per_iteration
270            ));
271            Some(pb)
272        };
273
274        // Main fuzzing loop: execute flows, catch panics, and track progress
275        for i in 0..iterations {
276            // Catch panics from user code (assertions, invariants, etc.)
277            let panic_result = catch_unwind(AssertUnwindSafe(|| {
278                let _ = fuzzer.execute_flows(flow_calls_per_iteration);
279            }));
280
281            // Handle any panics that occurred
282            if let Err(panic_err) = panic_result {
283                panic_occurred = true;
284                let panic_msg = Self::handle_panic(&panic_err, &mut fuzzer, None);
285
286                // Display panic message via progress bar or stderr
287                if let Some(ref pb) = pb {
288                    pb.println(panic_msg);
289                } else {
290                    eprintln!("{}", panic_msg);
291                }
292            }
293
294            // Prepare for next iteration
295            fuzzer.trident_mut().next_iteration();
296            fuzzer.reset_fuzz_accounts();
297
298            // Handle coverage profiling if enabled
299            Self::handle_coverage_if_enabled(&mut fuzzer, i + 1);
300
301            // Update progress bar
302            if let Some(ref pb) = pb {
303                pb.inc(flow_calls_per_iteration);
304                pb.set_message(format!("Iteration {}/{} completed", i + 1, iterations));
305            }
306        }
307
308        // Finalize progress bar
309        if let Some(pb) = pb {
310            pb.finish_with_message("Fuzzing completed!");
311        }
312
313        // Generate metrics if enabled
314        let fuzzing_data = fuzzer.trident_mut().get_fuzzing_data();
315        Self::output_metrics_if_enabled(&fuzzing_data);
316
317        // Exit with appropriate code if exit code mode is enabled
318        if with_exit_code {
319            let exit_code =
320                Self::determine_exit_code(with_exit_code, panic_occurred, &fuzzing_data);
321            std::process::exit(exit_code);
322        }
323    }
324
325    /// Executes fuzzing across multiple threads for better performance.
326    /// Each thread runs a subset of iterations with its own fuzzer instance.
327    fn fuzz_parallel(
328        iterations: u64,
329        flow_calls_per_iteration: u64,
330        num_threads: usize,
331        master_seed: [u8; 32],
332    ) {
333        let iterations_per_thread = iterations / num_threads as u64;
334        let total_flow_calls = iterations * flow_calls_per_iteration;
335        let with_exit_code = std::env::var(config::ENV_WITH_EXIT_CODE).is_ok();
336        let panic_occurred = Arc::new(AtomicBool::new(false)); // Shared across threads
337
338        // Setup shared progress bar
339        let main_pb = indicatif::ProgressBar::new(total_flow_calls);
340        main_pb.set_style(
341            indicatif::ProgressStyle::with_template(
342                "Overall: {spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({percent}%) [{eta_precise}] {msg}"
343            )
344            .unwrap()
345            .progress_chars("#>-"),
346        );
347        main_pb.set_message(format!(
348            "Fuzzing with {} threads - {} iterations with {} flow calls each",
349            num_threads, iterations, flow_calls_per_iteration
350        ));
351
352        // Spawn worker threads
353        let mut handles = Vec::new();
354        for thread_id in 0..num_threads {
355            let thread_iterations = iterations_per_thread;
356            if thread_iterations == 0 {
357                continue; // Skip threads with no work
358            }
359
360            let main_pb_clone = main_pb.clone();
361            let panic_occurred_clone = panic_occurred.clone();
362            let handle = thread::spawn(move || -> TridentFuzzingData {
363                Self::run_thread_workload(
364                    master_seed,
365                    thread_id,
366                    thread_iterations,
367                    flow_calls_per_iteration,
368                    main_pb_clone,
369                    panic_occurred_clone,
370                )
371            });
372
373            handles.push(handle);
374        }
375
376        // Collect results from all threads
377        let mut fuzzing_data = TridentFuzzingData::with_master_seed(master_seed);
378        for handle in handles {
379            match handle.join() {
380                Ok(thread_metrics) => {
381                    fuzzing_data._merge(thread_metrics);
382                }
383                Err(err) => {
384                    // This should rarely happen since we catch panics inside threads
385                    // Only occurs if the thread itself crashes (not user code)
386                    eprintln!("Warning: Thread failed to join (not a fuzz test panic)");
387                    if let Some(s) = err.downcast_ref::<&str>() {
388                        eprintln!("  Message: {}", s);
389                    } else if let Some(s) = err.downcast_ref::<String>() {
390                        eprintln!("  Message: {}", s);
391                    }
392                    // Continue processing other threads
393                }
394            }
395        }
396
397        main_pb.finish_with_message("Parallel fuzzing completed!");
398
399        // Determine and set exit code
400        let exit_code = Self::determine_exit_code(
401            with_exit_code,
402            panic_occurred.load(Ordering::Relaxed),
403            &fuzzing_data,
404        );
405
406        Self::output_metrics_if_enabled(&fuzzing_data);
407        println!("MASTER SEED used: {:?}", &hex::encode(master_seed));
408
409        std::process::exit(exit_code);
410    }
411
412    /// Runs the fuzzing workload for a single thread.
413    /// This is extracted to reduce complexity in fuzz_parallel.
414    fn run_thread_workload(
415        master_seed: [u8; 32],
416        thread_id: usize,
417        thread_iterations: u64,
418        flow_calls_per_iteration: u64,
419        progress_bar: indicatif::ProgressBar,
420        panic_occurred: Arc<AtomicBool>,
421    ) -> TridentFuzzingData {
422        let mut fuzzer = Self::new();
423        fuzzer
424            .trident_mut()
425            .set_master_seed_and_thread_id(master_seed, thread_id);
426
427        // Track progress updates to avoid excessive bar updates
428        let mut last_update = Instant::now();
429        let mut local_counter = 0u64;
430
431        // Execute iterations for this thread
432        for i in 0..thread_iterations {
433            // Catch panics from user code (assertions, invariants, etc.)
434            let panic_result = catch_unwind(AssertUnwindSafe(|| {
435                let _ = fuzzer.execute_flows(flow_calls_per_iteration);
436            }));
437
438            // Handle any panics that occurred
439            if let Err(panic_err) = panic_result {
440                let panic_msg = Self::handle_panic(&panic_err, &mut fuzzer, Some(&panic_occurred));
441                progress_bar.println(panic_msg);
442            }
443
444            // Prepare for next iteration
445            fuzzer.trident_mut().next_iteration();
446            fuzzer.reset_fuzz_accounts();
447
448            // Handle coverage profiling (only thread 0 to avoid duplicate work)
449            if thread_id == 0 {
450                Self::handle_coverage_if_enabled(&mut fuzzer, i + 1);
451            }
452
453            // Batch progress updates for performance
454            local_counter += flow_calls_per_iteration;
455            let should_update = local_counter >= config::PROGRESS_UPDATE_INTERVAL
456                || last_update.elapsed() >= config::PROGRESS_UPDATE_DURATION
457                || i == thread_iterations - 1; // Always update on last iteration
458
459            if should_update {
460                progress_bar.inc(local_counter);
461                local_counter = 0;
462                last_update = Instant::now();
463            }
464        }
465
466        // Ensure any remaining progress is reported
467        if local_counter > 0 {
468            progress_bar.inc(local_counter);
469        }
470
471        fuzzer.trident_mut().get_fuzzing_data()
472    }
473
474    /// Handles LLVM coverage collection if coverage profiling is enabled.
475    /// Coverage is collected periodically based on FUZZER_LOOPCOUNT environment variable.
476    fn handle_coverage_if_enabled(fuzzer: &mut Self, current_iteration: u64) {
477        let loopcount = std::env::var("FUZZER_LOOPCOUNT")
478            .ok()
479            .and_then(|val| val.parse::<u64>().ok())
480            .unwrap_or(0);
481
482        // Collect coverage at specified intervals
483        if loopcount > 0 && current_iteration > 0 && current_iteration % loopcount == 0 {
484            // Call the macro-generated LLVM method to write coverage data
485            fuzzer.handle_llvm_coverage(current_iteration);
486
487            // Notify VS Code extension to update coverage decorations
488            Self::notify_coverage_extension();
489        }
490    }
491
492    /// Notifies the VS Code coverage extension to update coverage decorations.
493    /// This runs in a background thread to avoid blocking fuzzing execution.
494    fn notify_coverage_extension() {
495        let coverage_server_port =
496            std::env::var("COVERAGE_SERVER_PORT").unwrap_or_else(|_| "58432".to_string());
497
498        let url = format!(
499            "http://localhost:{}/update-decorations",
500            coverage_server_port
501        );
502        std::thread::spawn(move || {
503            let client = reqwest::blocking::Client::new();
504            let _ = client
505                .post(&url)
506                .header("Content-Type", "application/json")
507                .body("")
508                .send();
509        });
510    }
511}