trident_syn/codegen/
trident_flow_executor.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use quote::ToTokens;
4
5use crate::types::trident_flow_executor::TridentFlowExecutorImpl;
6
7impl ToTokens for TridentFlowExecutorImpl {
8    fn to_tokens(&self, tokens: &mut TokenStream) {
9        let expanded = self.generate_flow_executor_impl();
10        tokens.extend(expanded);
11    }
12}
13
14impl TridentFlowExecutorImpl {
15    /// Generate the complete flow executor implementation
16    fn generate_flow_executor_impl(&self) -> TokenStream {
17        let type_name = &self.type_name;
18        let impl_items = &self.impl_block;
19        let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
20
21        let flow_executor_impl = self.generate_flow_executor_trait_impl();
22
23        quote! {
24            impl #impl_generics #type_name #ty_generics #where_clause {
25                #(#impl_items)*
26            }
27
28            #flow_executor_impl
29        }
30    }
31
32    /// Generate the FlowExecutor trait implementation
33    fn generate_flow_executor_trait_impl(&self) -> TokenStream {
34        let type_name = &self.type_name;
35        let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
36
37        let execute_flows_method = self.generate_execute_flows_method();
38        let coverage_method = self.generate_coverage_method();
39
40        quote! {
41            impl #impl_generics FlowExecutor for #type_name #ty_generics #where_clause {
42                fn new() -> Self {
43                    Self::new()
44                }
45
46                fn execute_flows(&mut self, flow_calls_per_iteration: u64) -> std::result::Result<(), FuzzingError> {
47                #execute_flows_method
48                    Ok(())
49                }
50
51                fn trident_mut(&mut self) -> &mut Trident {
52                    &mut self.trident
53                }
54
55                fn reset_fuzz_accounts(&mut self) {
56                    // this will ensure the fuzz accounts will reset without
57                    // specifying the type of the fuzz accounts
58                    let _ = std::mem::take(&mut self.fuzz_accounts);
59                }
60
61                fn handle_llvm_coverage(&mut self, current_iteration: u64) {
62                    #coverage_method
63                }
64            }
65        }
66    }
67
68    /// Generate the execute_flows method implementation
69    fn generate_execute_flows_method(&self) -> TokenStream {
70        let init_call = self.generate_init_call();
71        let flow_execution_logic = self.generate_flow_execution_logic();
72        let end_call = self.generate_end_call();
73
74        quote! {
75                #init_call
76                #flow_execution_logic
77                #end_call
78
79        }
80    }
81
82    /// Generate the initialization call if an init method exists
83    fn generate_init_call(&self) -> TokenStream {
84        if let Some(init_method) = &self.init_method {
85            quote! {
86                self.#init_method();
87            }
88        } else {
89            quote! {}
90        }
91    }
92
93    /// Generate the end call if an end method exists
94    fn generate_end_call(&self) -> TokenStream {
95        if let Some(end_method) = &self.end_method {
96            quote! {
97                self.#end_method();
98            }
99        } else {
100            quote! {}
101        }
102    }
103
104    /// Generate the flow execution logic
105    fn generate_flow_execution_logic(&self) -> TokenStream {
106        // Filter out ignored flows
107        let active_methods: Vec<_> = self
108            .flow_methods
109            .iter()
110            .filter(|method| !method.constraints.ignore)
111            .collect();
112
113        if active_methods.is_empty() {
114            quote! {
115                // No flow methods defined or all are ignored, nothing to execute
116            }
117        } else {
118            let flow_selection_logic = self.generate_flow_selection_logic(&active_methods);
119
120            quote! {
121                #flow_selection_logic
122            }
123        }
124    }
125
126    /// Generate the random flow selection logic
127    fn generate_flow_selection_logic(
128        &self,
129        active_methods: &[&crate::types::trident_flow_executor::FlowMethod],
130    ) -> TokenStream {
131        // Check if any flow has weights
132        let has_weights = active_methods
133            .iter()
134            .any(|method| method.constraints.weight.is_some());
135
136        if has_weights {
137            // Generate weighted selection logic
138            self.generate_weighted_flow_selection(active_methods)
139        } else {
140            // Generate uniform random selection logic (original behavior)
141            self.generate_uniform_flow_selection(active_methods)
142        }
143    }
144
145    /// Generate uniform random flow selection (original behavior)
146    fn generate_uniform_flow_selection(
147        &self,
148        active_methods: &[&crate::types::trident_flow_executor::FlowMethod],
149    ) -> TokenStream {
150        let flow_match_arms = active_methods.iter().enumerate().map(|(index, method)| {
151            let method_ident = &method.ident;
152            quote! {
153                #index => self.#method_ident(),
154            }
155        });
156        let num_flows = active_methods.len();
157
158        quote! {
159            // Randomly select and execute flows for the specified number of calls
160            for _ in 0..flow_calls_per_iteration {
161                let flow_index = self.trident.random_from_range(0..#num_flows);
162                match flow_index {
163                    #(#flow_match_arms)*
164                    _ => unreachable!("Invalid flow index"),
165                }
166            }
167        }
168    }
169
170    /// Generate weighted flow selection logic
171    fn generate_weighted_flow_selection(
172        &self,
173        active_methods: &[&crate::types::trident_flow_executor::FlowMethod],
174    ) -> TokenStream {
175        // Filter out flows with weight 0 (they should be skipped)
176        let weighted_methods: Vec<_> = active_methods
177            .iter()
178            .filter(|method| method.constraints.weight.unwrap_or(0) > 0)
179            .collect();
180
181        if weighted_methods.is_empty() {
182            return quote! {
183                // All flows have weight 0, nothing to execute
184            };
185        }
186
187        // Calculate total weight
188        let total_weight: u32 = weighted_methods
189            .iter()
190            .map(|method| method.constraints.weight.unwrap())
191            .sum();
192
193        // Generate weight ranges and method calls
194        let mut cumulative_weight = 0u32;
195        let weight_ranges: Vec<_> = weighted_methods
196            .iter()
197            .map(|method| {
198                let weight = method.constraints.weight.unwrap();
199                let _start = cumulative_weight;
200                cumulative_weight += weight;
201                let end = cumulative_weight;
202                let method_ident = &method.ident;
203
204                quote! {
205                    if random_weight < #end {
206                        self.#method_ident();
207                        continue;
208                    }
209                }
210            })
211            .collect();
212
213        quote! {
214            // Weighted flow selection based on specified weights
215            for _ in 0..flow_calls_per_iteration {
216                let random_weight = self.trident.random_from_range(0..#total_weight);
217                #(#weight_ranges)*
218            }
219        }
220    }
221
222    fn generate_coverage_method(&self) -> TokenStream {
223        // Check if coverage is enabled by looking for RUSTFLAGS containing -C instrument-coverage
224        // This is set by the Trident CLI when running with coverage via run_with_coverage()
225        let rustflags = std::env::var("RUSTFLAGS").unwrap_or_default();
226        let coverage_enabled = rustflags.contains("-C instrument-coverage");
227
228        if coverage_enabled {
229            quote! {
230                // LLVM coverage profiling calls - only generated when coverage is enabled
231                unsafe {
232                    let filename = format!("target/fuzz-cov-run-{}.profraw", current_iteration);
233                    if let Ok(filename_cstr) = std::ffi::CString::new(filename) {
234                        trident_fuzz::fuzzing::__llvm_profile_set_filename(filename_cstr.as_ptr());
235                        let _ = trident_fuzz::fuzzing::__llvm_profile_write_file();
236                        trident_fuzz::fuzzing::__llvm_profile_reset_counters();
237
238                        // Set final filename to avoid overwriting intermediate files
239                        if let Ok(final_filename_cstr) = std::ffi::CString::new("target/fuzz-cov-run-final.profraw") {
240                            trident_fuzz::fuzzing::__llvm_profile_set_filename(final_filename_cstr.as_ptr());
241                        }
242                    }
243                }
244            }
245        } else {
246            quote! {
247                // Coverage profiling disabled - prevents linking errors
248            }
249        }
250    }
251}