trident_syn/codegen/
trident_flow_executor.rs1use proc_macro2::TokenStream;
2use quote::quote;
3use quote::ToTokens;
4
5use crate::types::trident_flow_executor::TridentFlowExecutorImpl;
6
7impl ToTokens for TridentFlowExecutorImpl {
8 fn to_tokens(&self, tokens: &mut TokenStream) {
9 let expanded = self.generate_flow_executor_impl();
10 tokens.extend(expanded);
11 }
12}
13
14impl TridentFlowExecutorImpl {
15 fn generate_flow_executor_impl(&self) -> TokenStream {
17 let type_name = &self.type_name;
18 let impl_items = &self.impl_block;
19 let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
20
21 let flow_executor_impl = self.generate_flow_executor_trait_impl();
22
23 quote! {
24 impl #impl_generics #type_name #ty_generics #where_clause {
25 #(#impl_items)*
26 }
27
28 #flow_executor_impl
29 }
30 }
31
32 fn generate_flow_executor_trait_impl(&self) -> TokenStream {
34 let type_name = &self.type_name;
35 let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
36
37 let execute_flows_method = self.generate_execute_flows_method();
38 let coverage_method = self.generate_coverage_method();
39
40 quote! {
41 impl #impl_generics FlowExecutor for #type_name #ty_generics #where_clause {
42 fn new() -> Self {
43 Self::new()
44 }
45
46 fn execute_flows(&mut self, flow_calls_per_iteration: u64) -> std::result::Result<(), FuzzingError> {
47 #execute_flows_method
48 Ok(())
49 }
50
51 fn trident_mut(&mut self) -> &mut Trident {
52 &mut self.trident
53 }
54
55 fn reset_fuzz_accounts(&mut self) {
56 let _ = std::mem::take(&mut self.fuzz_accounts);
59 }
60
61 fn handle_llvm_coverage(&mut self, current_iteration: u64) {
62 #coverage_method
63 }
64 }
65 }
66 }
67
68 fn generate_execute_flows_method(&self) -> TokenStream {
70 let init_call = self.generate_init_call();
71 let flow_execution_logic = self.generate_flow_execution_logic();
72 let end_call = self.generate_end_call();
73
74 quote! {
75 #init_call
76 #flow_execution_logic
77 #end_call
78
79 }
80 }
81
82 fn generate_init_call(&self) -> TokenStream {
84 if let Some(init_method) = &self.init_method {
85 quote! {
86 self.#init_method();
87 }
88 } else {
89 quote! {}
90 }
91 }
92
93 fn generate_end_call(&self) -> TokenStream {
95 if let Some(end_method) = &self.end_method {
96 quote! {
97 self.#end_method();
98 }
99 } else {
100 quote! {}
101 }
102 }
103
104 fn generate_flow_execution_logic(&self) -> TokenStream {
106 let active_methods: Vec<_> = self
108 .flow_methods
109 .iter()
110 .filter(|method| !method.constraints.ignore)
111 .collect();
112
113 if active_methods.is_empty() {
114 quote! {
115 }
117 } else {
118 let flow_selection_logic = self.generate_flow_selection_logic(&active_methods);
119
120 quote! {
121 #flow_selection_logic
122 }
123 }
124 }
125
126 fn generate_flow_selection_logic(
128 &self,
129 active_methods: &[&crate::types::trident_flow_executor::FlowMethod],
130 ) -> TokenStream {
131 let has_weights = active_methods
133 .iter()
134 .any(|method| method.constraints.weight.is_some());
135
136 if has_weights {
137 self.generate_weighted_flow_selection(active_methods)
139 } else {
140 self.generate_uniform_flow_selection(active_methods)
142 }
143 }
144
145 fn generate_uniform_flow_selection(
147 &self,
148 active_methods: &[&crate::types::trident_flow_executor::FlowMethod],
149 ) -> TokenStream {
150 let flow_match_arms = active_methods.iter().enumerate().map(|(index, method)| {
151 let method_ident = &method.ident;
152 quote! {
153 #index => self.#method_ident(),
154 }
155 });
156 let num_flows = active_methods.len();
157
158 quote! {
159 for _ in 0..flow_calls_per_iteration {
161 let flow_index = self.trident.random_from_range(0..#num_flows);
162 match flow_index {
163 #(#flow_match_arms)*
164 _ => unreachable!("Invalid flow index"),
165 }
166 }
167 }
168 }
169
170 fn generate_weighted_flow_selection(
172 &self,
173 active_methods: &[&crate::types::trident_flow_executor::FlowMethod],
174 ) -> TokenStream {
175 let weighted_methods: Vec<_> = active_methods
177 .iter()
178 .filter(|method| method.constraints.weight.unwrap_or(0) > 0)
179 .collect();
180
181 if weighted_methods.is_empty() {
182 return quote! {
183 };
185 }
186
187 let total_weight: u32 = weighted_methods
189 .iter()
190 .map(|method| method.constraints.weight.unwrap())
191 .sum();
192
193 let mut cumulative_weight = 0u32;
195 let weight_ranges: Vec<_> = weighted_methods
196 .iter()
197 .map(|method| {
198 let weight = method.constraints.weight.unwrap();
199 let _start = cumulative_weight;
200 cumulative_weight += weight;
201 let end = cumulative_weight;
202 let method_ident = &method.ident;
203
204 quote! {
205 if random_weight < #end {
206 self.#method_ident();
207 continue;
208 }
209 }
210 })
211 .collect();
212
213 quote! {
214 for _ in 0..flow_calls_per_iteration {
216 let random_weight = self.trident.random_from_range(0..#total_weight);
217 #(#weight_ranges)*
218 }
219 }
220 }
221
222 fn generate_coverage_method(&self) -> TokenStream {
223 let rustflags = std::env::var("RUSTFLAGS").unwrap_or_default();
226 let coverage_enabled = rustflags.contains("-C instrument-coverage");
227
228 if coverage_enabled {
229 quote! {
230 unsafe {
232 let filename = format!("target/fuzz-cov-run-{}.profraw", current_iteration);
233 if let Ok(filename_cstr) = std::ffi::CString::new(filename) {
234 trident_fuzz::fuzzing::__llvm_profile_set_filename(filename_cstr.as_ptr());
235 let _ = trident_fuzz::fuzzing::__llvm_profile_write_file();
236 trident_fuzz::fuzzing::__llvm_profile_reset_counters();
237
238 if let Ok(final_filename_cstr) = std::ffi::CString::new("target/fuzz-cov-run-final.profraw") {
240 trident_fuzz::fuzzing::__llvm_profile_set_filename(final_filename_cstr.as_ptr());
241 }
242 }
243 }
244 }
245 } else {
246 quote! {
247 }
249 }
250 }
251}