1use crate::{
2 events::{MemoryRecord, NUM_PAGE_PROT_ENTRIES_PER_ROW_EXEC},
3 CompressedMemory, ExecutionReport, Instruction, Opcode, RiscvAirId, SyscallCode,
4};
5use enum_map::EnumMap;
6use hashbrown::{HashMap, HashSet};
7use std::str::FromStr;
8
9use super::shapes::riscv_air_id_from_opcode_flag;
10
11pub struct ReportGenerator {
15 pub opcode_counts: EnumMap<Opcode, u64>,
16 pub syscall_counts: EnumMap<SyscallCode, u64>,
17 pub deferred_syscall_counts: EnumMap<SyscallCode, u64>,
18 pub system_chips_counts: EnumMap<RiscvAirId, u64>,
19
20 pub(crate) syscall_sent: bool,
21 pub(crate) local_mem_counts: u64,
22 pub(crate) local_page_prot_counts: u64,
24 is_last_read_external: CompressedMemory,
25 is_last_page_prot_access_external: HashMap<u64, bool>,
27
28 trace_cost_lookup: EnumMap<RiscvAirId, u64>,
29
30 page_prot_entry_count: u64,
32 enable_untrusted_programs: bool,
33
34 shard_start_clk: u64,
35 exit_code: u64,
36}
37
38impl ReportGenerator {
39 pub fn new(shard_start_clk: u64, enable_untrusted_programs: bool) -> Self {
40 let costs: HashMap<String, usize> =
41 serde_json::from_str(include_str!("../artifacts/rv64im_costs.json")).unwrap();
42 let costs: EnumMap<RiscvAirId, u64> =
43 costs.into_iter().map(|(k, v)| (RiscvAirId::from_str(&k).unwrap(), v as u64)).collect();
44
45 Self {
46 trace_cost_lookup: costs,
47 opcode_counts: EnumMap::default(),
48 syscall_counts: EnumMap::default(),
49 deferred_syscall_counts: EnumMap::default(),
50 system_chips_counts: EnumMap::default(),
51 syscall_sent: false,
52 local_mem_counts: 0,
53 local_page_prot_counts: 0,
54 is_last_read_external: CompressedMemory::new(),
55 is_last_page_prot_access_external: HashMap::new(),
56 page_prot_entry_count: 0,
57 enable_untrusted_programs,
58 shard_start_clk,
59 exit_code: 0,
60 }
61 }
62
63 #[inline]
65 pub fn reset(&mut self, clk: u64) {
66 *self = Self::new(clk, self.enable_untrusted_programs);
67 }
68
69 pub fn get_costs(&self) -> (u64, u64) {
70 (self.sum_total_complexity(), self.sum_total_trace_area())
71 }
72
73 pub fn generate_report(&self) -> ExecutionReport {
75 let mut total_syscall_counts = EnumMap::default();
80 for (syscall_code, &count) in self.syscall_counts.iter() {
81 if count > 0 {
82 if let Some(air_id) = syscall_code.as_air_id() {
83 total_syscall_counts[syscall_code] += count / air_id.rows_per_event() as u64;
84 }
85 }
86 }
87 for (syscall_code, &count) in self.deferred_syscall_counts.iter() {
88 if count > 0 {
89 if let Some(air_id) = syscall_code.as_air_id() {
90 total_syscall_counts[syscall_code] += count / air_id.rows_per_event() as u64;
91 }
92 }
93 }
94
95 let (complexity, trace_area) = self.get_costs();
96 let gas = (3 * trace_area + complexity) / 10;
99
100 ExecutionReport {
101 opcode_counts: Box::new(self.opcode_counts),
102 syscall_counts: Box::new(total_syscall_counts),
103 cycle_tracker: HashMap::new(),
104 invocation_tracker: HashMap::new(),
105 touched_memory_addresses: 0,
106 gas: Some(gas),
107 exit_code: self.exit_code,
108 }
109 }
110
111 pub fn set_exit_code(&mut self, exit_code: u64) {
113 self.exit_code = exit_code;
114 }
115
116 fn filtered_opcode_counts(&self) -> impl Iterator<Item = (Opcode, u64)> + '_ {
118 self.opcode_counts
119 .iter()
120 .filter(|(_, &count)| count > 0)
121 .map(|(opcode, &count)| (opcode, count))
122 }
123
124 fn sum_total_complexity(&self) -> u64 {
125 self.filtered_opcode_counts()
126 .map(|(opcode, count)| {
127 get_complexity_mapping()
128 [riscv_air_id_from_opcode_flag(opcode, self.enable_untrusted_programs)]
129 * count
130 })
131 .sum::<u64>()
132 + self
133 .syscall_counts
134 .iter()
135 .map(|(syscall_code, count)| {
136 if let Some(syscall_air_id) =
137 syscall_code.as_air_id_flag(self.enable_untrusted_programs)
138 {
139 get_complexity_mapping()[syscall_air_id] * count
140 } else {
141 0
142 }
143 })
144 .sum::<u64>()
145 + self
146 .deferred_syscall_counts
147 .iter()
148 .map(|(syscall_code, count)| {
149 if let Some(syscall_air_id) =
150 syscall_code.as_air_id_flag(self.enable_untrusted_programs)
151 {
152 get_complexity_mapping()[syscall_air_id] * count
153 } else {
154 0
155 }
156 })
157 .sum::<u64>()
158 + self
159 .system_chips_counts
160 .iter()
161 .map(|(riscv_air_id, count)| get_complexity_mapping()[riscv_air_id] * count)
162 .sum::<u64>()
163 }
164
165 fn sum_total_trace_area(&self) -> u64 {
166 self.filtered_opcode_counts()
167 .map(|(opcode, count)| {
168 self.trace_cost_lookup
169 [riscv_air_id_from_opcode_flag(opcode, self.enable_untrusted_programs)]
170 * count
171 })
172 .sum::<u64>()
173 + self
174 .syscall_counts
175 .iter()
176 .map(|(syscall_code, count)| {
177 if let Some(syscall_air_id) =
178 syscall_code.as_air_id_flag(self.enable_untrusted_programs)
179 {
180 self.trace_cost_lookup[syscall_air_id] * count
181 } else {
182 0
183 }
184 })
185 .sum::<u64>()
186 + self
187 .deferred_syscall_counts
188 .iter()
189 .map(|(syscall_code, count)| {
190 if let Some(syscall_air_id) =
191 syscall_code.as_air_id_flag(self.enable_untrusted_programs)
192 {
193 self.trace_cost_lookup[syscall_air_id] * count
194 } else {
195 0
196 }
197 })
198 .sum::<u64>()
199 + self
200 .system_chips_counts
201 .iter()
202 .map(|(riscv_air_id, count)| self.trace_cost_lookup[riscv_air_id] * count)
203 .sum::<u64>()
204 }
205
206 #[inline]
207 pub fn handle_mem_event(&mut self, addr: u64, clk: u64) {
208 let addr = if addr > 31 { addr & !0b111 } else { addr };
210
211 let is_external = self.syscall_sent;
212
213 let is_first_read_this_shard = self.shard_start_clk > clk;
214
215 let is_last_read_external = self.is_last_read_external.insert(addr, is_external);
216
217 self.local_mem_counts +=
218 (is_first_read_this_shard || (is_last_read_external && !is_external)) as u64;
219 }
220
221 #[inline]
222 pub fn local_mem_syscall_rr(&mut self) {
223 self.local_mem_counts += self.syscall_sent as u64;
224 }
225
226 #[inline]
227 pub fn handle_page_prot_event(&mut self, page_idx: u64, clk: u64) {
228 let is_external = self.syscall_sent;
229 let is_first_read_this_shard = self.shard_start_clk > clk;
230 let is_last_page_prot_access_external =
231 self.is_last_page_prot_access_external.insert(page_idx, is_external).unwrap_or(false);
232
233 self.local_page_prot_counts += (is_first_read_this_shard
234 || (is_last_page_prot_access_external && !is_external))
235 as u64;
236 }
237
238 #[inline]
239 pub fn handle_page_prot_check(&mut self) {
240 self.page_prot_entry_count += 1;
241 self.system_chips_counts[RiscvAirId::PageProt] =
242 self.page_prot_entry_count.div_ceil(NUM_PAGE_PROT_ENTRIES_PER_ROW_EXEC as u64);
243 }
244
245 #[inline]
246 pub fn handle_trap_exec_event(&mut self) {
247 self.system_chips_counts[RiscvAirId::TrapExec] += 1;
248 }
249
250 #[inline]
251 pub fn handle_trap_mem_event(&mut self) {
252 self.system_chips_counts[RiscvAirId::TrapMem] += 1;
253 }
254
255 #[inline]
256 pub fn handle_trap_events(&mut self, bump_clk_high: bool) {
257 self.update_system_chip_counts(bump_clk_high, bump_clk_high);
258 }
259
260 #[inline]
261 pub fn handle_untrusted_instruction(&mut self) {
262 self.system_chips_counts[RiscvAirId::InstructionFetch] += 1;
263 }
264
265 #[inline]
266 pub fn handle_retained_syscall(&mut self, syscall_code: SyscallCode) {
267 if let Some(syscall_air_id) = syscall_code.as_air_id() {
268 let rows_per_event = syscall_air_id.rows_per_event() as u64;
269
270 self.syscall_counts[syscall_code] += rows_per_event;
271 }
272 }
273
274 #[inline]
275 pub fn get_syscall_sent(&self) -> bool {
276 self.syscall_sent
277 }
278
279 #[inline]
280 pub fn set_syscall_sent(&mut self, syscall_sent: bool) {
281 self.syscall_sent = syscall_sent;
282 }
283
284 #[inline]
285 pub fn add_global_init_and_finalize_counts(
286 &mut self,
287 final_registers: &[MemoryRecord; 32],
288 mut touched_addresses: HashSet<u64>,
289 hint_init_events_addrs: &HashSet<u64>,
290 memory_image_addrs: &[u64],
291 ) {
292 touched_addresses.extend(memory_image_addrs);
293
294 self.system_chips_counts[RiscvAirId::MemoryGlobalInit] += 32;
296
297 self.system_chips_counts[RiscvAirId::MemoryGlobalFinalize] +=
299 final_registers.iter().enumerate().filter(|(_, e)| e.timestamp != 0).count() as u64;
300
301 self.system_chips_counts[RiscvAirId::MemoryGlobalInit] +=
303 hint_init_events_addrs.len() as u64;
304
305 let memory_init_events = touched_addresses
306 .iter()
307 .filter(|addr| !memory_image_addrs.contains(*addr))
308 .filter(|addr| !hint_init_events_addrs.contains(*addr));
309 self.system_chips_counts[RiscvAirId::MemoryGlobalInit] += memory_init_events.count() as u64;
310
311 touched_addresses.extend(hint_init_events_addrs.clone());
312 self.system_chips_counts[RiscvAirId::MemoryGlobalFinalize] +=
313 touched_addresses.len() as u64;
314 }
315
316 #[inline]
325 #[allow(clippy::fn_params_excessive_bools)]
326 pub fn handle_instruction(
327 &mut self,
328 instruction: &Instruction,
329 bump_clk_high: bool,
330 _is_load_x0: bool,
331 needs_state_bump: bool,
332 ) {
333 self.opcode_counts[instruction.opcode] += 1;
334 self.update_system_chip_counts(bump_clk_high, needs_state_bump);
335 }
336
337 fn update_system_chip_counts(&mut self, bump_clk_high: bool, needs_state_bump: bool) {
339 let touched_addresses: u64 = std::mem::take(&mut self.local_mem_counts);
340 let syscall_sent = std::mem::take(&mut self.syscall_sent);
341
342 let bump_clk_high_num_events = 32 * bump_clk_high as u64;
343 self.system_chips_counts[RiscvAirId::MemoryBump] += bump_clk_high_num_events;
344 self.system_chips_counts[RiscvAirId::MemoryLocal] += touched_addresses;
345 self.system_chips_counts[RiscvAirId::StateBump] += needs_state_bump as u64;
346 self.system_chips_counts[RiscvAirId::Global] += 2 * touched_addresses + syscall_sent as u64;
347 self.system_chips_counts[RiscvAirId::SyscallCore] += syscall_sent as u64;
348 }
349
350 pub fn update_page_chip_counts(&mut self) {
352 let touched_pages: u64 = std::mem::take(&mut self.local_page_prot_counts);
353 self.system_chips_counts[RiscvAirId::Global] += 2 * touched_pages;
354 self.system_chips_counts[RiscvAirId::PageProtLocal] += touched_pages;
355 }
356
357 #[inline]
358 pub fn syscall_sent(&mut self, syscall_code: SyscallCode) {
359 self.syscall_sent = true;
360 if let Some(syscall_air_id) = syscall_code.as_air_id() {
361 let rows_per_event = syscall_air_id.rows_per_event() as u64;
362
363 self.deferred_syscall_counts[syscall_code] += rows_per_event;
364 }
365 }
366}
367
368#[must_use]
371pub fn get_complexity_mapping() -> EnumMap<RiscvAirId, u64> {
372 #[cfg(not(feature = "mprotect"))]
373 let json = include_str!("../artifacts/rv64im_complexity.json");
374 #[cfg(feature = "mprotect")]
375 let json = include_str!("../artifacts/rv64im_complexity_mprotect.json");
376
377 let complexity: HashMap<String, u64> = serde_json::from_str(json).unwrap();
378 complexity.into_iter().map(|(k, v)| (RiscvAirId::from_str(&k).unwrap(), v)).collect()
379}