Skip to main content

sp1_core_executor/vm/
gas.rs

1use crate::{
2    events::{MemoryRecord, NUM_PAGE_PROT_ENTRIES_PER_ROW_EXEC},
3    CompressedMemory, ExecutionReport, Instruction, Opcode, RiscvAirId, SyscallCode,
4};
5use enum_map::EnumMap;
6use hashbrown::{HashMap, HashSet};
7use std::str::FromStr;
8
9use super::shapes::riscv_air_id_from_opcode_flag;
10
11// Trusted gas estimation calculator
12// For a given executor, calculate the total complexity and trace area
13// Based off of ShapeChecker
14pub struct ReportGenerator {
15    pub opcode_counts: EnumMap<Opcode, u64>,
16    pub syscall_counts: EnumMap<SyscallCode, u64>,
17    pub deferred_syscall_counts: EnumMap<SyscallCode, u64>,
18    pub system_chips_counts: EnumMap<RiscvAirId, u64>,
19
20    pub(crate) syscall_sent: bool,
21    pub(crate) local_mem_counts: u64,
22    /// The number of local page prot accesses during this cycle.
23    pub(crate) local_page_prot_counts: u64,
24    is_last_read_external: CompressedMemory,
25    /// Whether the last page prot access was external, ie: it was read from a deferred precompile.
26    is_last_page_prot_access_external: HashMap<u64, bool>,
27
28    trace_cost_lookup: EnumMap<RiscvAirId, u64>,
29
30    /// Running count of page prot entries.
31    page_prot_entry_count: u64,
32    enable_untrusted_programs: bool,
33
34    shard_start_clk: u64,
35    exit_code: u64,
36}
37
38impl ReportGenerator {
39    pub fn new(shard_start_clk: u64, enable_untrusted_programs: bool) -> Self {
40        let costs: HashMap<String, usize> =
41            serde_json::from_str(include_str!("../artifacts/rv64im_costs.json")).unwrap();
42        let costs: EnumMap<RiscvAirId, u64> =
43            costs.into_iter().map(|(k, v)| (RiscvAirId::from_str(&k).unwrap(), v as u64)).collect();
44
45        Self {
46            trace_cost_lookup: costs,
47            opcode_counts: EnumMap::default(),
48            syscall_counts: EnumMap::default(),
49            deferred_syscall_counts: EnumMap::default(),
50            system_chips_counts: EnumMap::default(),
51            syscall_sent: false,
52            local_mem_counts: 0,
53            local_page_prot_counts: 0,
54            is_last_read_external: CompressedMemory::new(),
55            is_last_page_prot_access_external: HashMap::new(),
56            page_prot_entry_count: 0,
57            enable_untrusted_programs,
58            shard_start_clk,
59            exit_code: 0,
60        }
61    }
62
63    /// Set the start clock of the shard.
64    #[inline]
65    pub fn reset(&mut self, clk: u64) {
66        *self = Self::new(clk, self.enable_untrusted_programs);
67    }
68
69    pub fn get_costs(&self) -> (u64, u64) {
70        (self.sum_total_complexity(), self.sum_total_trace_area())
71    }
72
73    /// Generate an `ExecutionReport` from the current state of the `ReportGenerator`
74    pub fn generate_report(&self) -> ExecutionReport {
75        // Combine syscall_counts and deferred_syscall_counts, converting from row counts
76        // back to invocation counts for the report. Internally these fields store
77        // rows_per_event * invocations for gas calculation; the report should show
78        // actual invocation counts.
79        let mut total_syscall_counts = EnumMap::default();
80        for (syscall_code, &count) in self.syscall_counts.iter() {
81            if count > 0 {
82                if let Some(air_id) = syscall_code.as_air_id() {
83                    total_syscall_counts[syscall_code] += count / air_id.rows_per_event() as u64;
84                }
85            }
86        }
87        for (syscall_code, &count) in self.deferred_syscall_counts.iter() {
88            if count > 0 {
89                if let Some(air_id) = syscall_code.as_air_id() {
90                    total_syscall_counts[syscall_code] += count / air_id.rows_per_event() as u64;
91                }
92            }
93        }
94
95        let (complexity, trace_area) = self.get_costs();
96        // Use integer arithmetic to avoid f64 precision warnings
97        // 0.3 * trace_area + 0.1 * complexity ≈ (3 * trace_area + complexity) / 10
98        let gas = (3 * trace_area + complexity) / 10;
99
100        ExecutionReport {
101            opcode_counts: Box::new(self.opcode_counts),
102            syscall_counts: Box::new(total_syscall_counts),
103            cycle_tracker: HashMap::new(),
104            invocation_tracker: HashMap::new(),
105            touched_memory_addresses: 0,
106            gas: Some(gas),
107            exit_code: self.exit_code,
108        }
109    }
110
111    // Set the exit code which will be returned when the `ExecutionReport` is generated.
112    pub fn set_exit_code(&mut self, exit_code: u64) {
113        self.exit_code = exit_code;
114    }
115
116    /// Helper method to filter out opcode counts with zero values
117    fn filtered_opcode_counts(&self) -> impl Iterator<Item = (Opcode, u64)> + '_ {
118        self.opcode_counts
119            .iter()
120            .filter(|(_, &count)| count > 0)
121            .map(|(opcode, &count)| (opcode, count))
122    }
123
124    fn sum_total_complexity(&self) -> u64 {
125        self.filtered_opcode_counts()
126            .map(|(opcode, count)| {
127                get_complexity_mapping()
128                    [riscv_air_id_from_opcode_flag(opcode, self.enable_untrusted_programs)]
129                    * count
130            })
131            .sum::<u64>()
132            + self
133                .syscall_counts
134                .iter()
135                .map(|(syscall_code, count)| {
136                    if let Some(syscall_air_id) =
137                        syscall_code.as_air_id_flag(self.enable_untrusted_programs)
138                    {
139                        get_complexity_mapping()[syscall_air_id] * count
140                    } else {
141                        0
142                    }
143                })
144                .sum::<u64>()
145            + self
146                .deferred_syscall_counts
147                .iter()
148                .map(|(syscall_code, count)| {
149                    if let Some(syscall_air_id) =
150                        syscall_code.as_air_id_flag(self.enable_untrusted_programs)
151                    {
152                        get_complexity_mapping()[syscall_air_id] * count
153                    } else {
154                        0
155                    }
156                })
157                .sum::<u64>()
158            + self
159                .system_chips_counts
160                .iter()
161                .map(|(riscv_air_id, count)| get_complexity_mapping()[riscv_air_id] * count)
162                .sum::<u64>()
163    }
164
165    fn sum_total_trace_area(&self) -> u64 {
166        self.filtered_opcode_counts()
167            .map(|(opcode, count)| {
168                self.trace_cost_lookup
169                    [riscv_air_id_from_opcode_flag(opcode, self.enable_untrusted_programs)]
170                    * count
171            })
172            .sum::<u64>()
173            + self
174                .syscall_counts
175                .iter()
176                .map(|(syscall_code, count)| {
177                    if let Some(syscall_air_id) =
178                        syscall_code.as_air_id_flag(self.enable_untrusted_programs)
179                    {
180                        self.trace_cost_lookup[syscall_air_id] * count
181                    } else {
182                        0
183                    }
184                })
185                .sum::<u64>()
186            + self
187                .deferred_syscall_counts
188                .iter()
189                .map(|(syscall_code, count)| {
190                    if let Some(syscall_air_id) =
191                        syscall_code.as_air_id_flag(self.enable_untrusted_programs)
192                    {
193                        self.trace_cost_lookup[syscall_air_id] * count
194                    } else {
195                        0
196                    }
197                })
198                .sum::<u64>()
199            + self
200                .system_chips_counts
201                .iter()
202                .map(|(riscv_air_id, count)| self.trace_cost_lookup[riscv_air_id] * count)
203                .sum::<u64>()
204    }
205
206    #[inline]
207    pub fn handle_mem_event(&mut self, addr: u64, clk: u64) {
208        // Round down to the nearest 8-byte aligned address.
209        let addr = if addr > 31 { addr & !0b111 } else { addr };
210
211        let is_external = self.syscall_sent;
212
213        let is_first_read_this_shard = self.shard_start_clk > clk;
214
215        let is_last_read_external = self.is_last_read_external.insert(addr, is_external);
216
217        self.local_mem_counts +=
218            (is_first_read_this_shard || (is_last_read_external && !is_external)) as u64;
219    }
220
221    #[inline]
222    pub fn local_mem_syscall_rr(&mut self) {
223        self.local_mem_counts += self.syscall_sent as u64;
224    }
225
226    #[inline]
227    pub fn handle_page_prot_event(&mut self, page_idx: u64, clk: u64) {
228        let is_external = self.syscall_sent;
229        let is_first_read_this_shard = self.shard_start_clk > clk;
230        let is_last_page_prot_access_external =
231            self.is_last_page_prot_access_external.insert(page_idx, is_external).unwrap_or(false);
232
233        self.local_page_prot_counts += (is_first_read_this_shard
234            || (is_last_page_prot_access_external && !is_external))
235            as u64;
236    }
237
238    #[inline]
239    pub fn handle_page_prot_check(&mut self) {
240        self.page_prot_entry_count += 1;
241        self.system_chips_counts[RiscvAirId::PageProt] =
242            self.page_prot_entry_count.div_ceil(NUM_PAGE_PROT_ENTRIES_PER_ROW_EXEC as u64);
243    }
244
245    #[inline]
246    pub fn handle_trap_exec_event(&mut self) {
247        self.system_chips_counts[RiscvAirId::TrapExec] += 1;
248    }
249
250    #[inline]
251    pub fn handle_trap_mem_event(&mut self) {
252        self.system_chips_counts[RiscvAirId::TrapMem] += 1;
253    }
254
255    #[inline]
256    pub fn handle_trap_events(&mut self, bump_clk_high: bool) {
257        self.update_system_chip_counts(bump_clk_high, bump_clk_high);
258    }
259
260    #[inline]
261    pub fn handle_untrusted_instruction(&mut self) {
262        self.system_chips_counts[RiscvAirId::InstructionFetch] += 1;
263    }
264
265    #[inline]
266    pub fn handle_retained_syscall(&mut self, syscall_code: SyscallCode) {
267        if let Some(syscall_air_id) = syscall_code.as_air_id() {
268            let rows_per_event = syscall_air_id.rows_per_event() as u64;
269
270            self.syscall_counts[syscall_code] += rows_per_event;
271        }
272    }
273
274    #[inline]
275    pub fn get_syscall_sent(&self) -> bool {
276        self.syscall_sent
277    }
278
279    #[inline]
280    pub fn set_syscall_sent(&mut self, syscall_sent: bool) {
281        self.syscall_sent = syscall_sent;
282    }
283
284    #[inline]
285    pub fn add_global_init_and_finalize_counts(
286        &mut self,
287        final_registers: &[MemoryRecord; 32],
288        mut touched_addresses: HashSet<u64>,
289        hint_init_events_addrs: &HashSet<u64>,
290        memory_image_addrs: &[u64],
291    ) {
292        touched_addresses.extend(memory_image_addrs);
293
294        // Add init for registers
295        self.system_chips_counts[RiscvAirId::MemoryGlobalInit] += 32;
296
297        // Add finalize for registers
298        self.system_chips_counts[RiscvAirId::MemoryGlobalFinalize] +=
299            final_registers.iter().enumerate().filter(|(_, e)| e.timestamp != 0).count() as u64;
300
301        // Add memory init events
302        self.system_chips_counts[RiscvAirId::MemoryGlobalInit] +=
303            hint_init_events_addrs.len() as u64;
304
305        let memory_init_events = touched_addresses
306            .iter()
307            .filter(|addr| !memory_image_addrs.contains(*addr))
308            .filter(|addr| !hint_init_events_addrs.contains(*addr));
309        self.system_chips_counts[RiscvAirId::MemoryGlobalInit] += memory_init_events.count() as u64;
310
311        touched_addresses.extend(hint_init_events_addrs.clone());
312        self.system_chips_counts[RiscvAirId::MemoryGlobalFinalize] +=
313            touched_addresses.len() as u64;
314    }
315
316    /// Increment the trace area for the given instruction.
317    ///
318    /// # Arguments
319    ///
320    /// * `instruction`: The instruction that is being handled.
321    /// * `bump_clk_high`: Whether the clk's top 24 bits incremented during this cycle.
322    /// * `is_load_x0`: Whether the instruction is a load of x0, if so the riscv air id is `LoadX0`.
323    /// * `needs_state_bump`: Whether this cycle induced a state bump.
324    #[inline]
325    #[allow(clippy::fn_params_excessive_bools)]
326    pub fn handle_instruction(
327        &mut self,
328        instruction: &Instruction,
329        bump_clk_high: bool,
330        _is_load_x0: bool,
331        needs_state_bump: bool,
332    ) {
333        self.opcode_counts[instruction.opcode] += 1;
334        self.update_system_chip_counts(bump_clk_high, needs_state_bump);
335    }
336
337    /// Update system chip counts based on the current cycle's state.
338    fn update_system_chip_counts(&mut self, bump_clk_high: bool, needs_state_bump: bool) {
339        let touched_addresses: u64 = std::mem::take(&mut self.local_mem_counts);
340        let syscall_sent = std::mem::take(&mut self.syscall_sent);
341
342        let bump_clk_high_num_events = 32 * bump_clk_high as u64;
343        self.system_chips_counts[RiscvAirId::MemoryBump] += bump_clk_high_num_events;
344        self.system_chips_counts[RiscvAirId::MemoryLocal] += touched_addresses;
345        self.system_chips_counts[RiscvAirId::StateBump] += needs_state_bump as u64;
346        self.system_chips_counts[RiscvAirId::Global] += 2 * touched_addresses + syscall_sent as u64;
347        self.system_chips_counts[RiscvAirId::SyscallCore] += syscall_sent as u64;
348    }
349
350    /// Update system chip counts based on the current cycle's page count state.
351    pub fn update_page_chip_counts(&mut self) {
352        let touched_pages: u64 = std::mem::take(&mut self.local_page_prot_counts);
353        self.system_chips_counts[RiscvAirId::Global] += 2 * touched_pages;
354        self.system_chips_counts[RiscvAirId::PageProtLocal] += touched_pages;
355    }
356
357    #[inline]
358    pub fn syscall_sent(&mut self, syscall_code: SyscallCode) {
359        self.syscall_sent = true;
360        if let Some(syscall_air_id) = syscall_code.as_air_id() {
361            let rows_per_event = syscall_air_id.rows_per_event() as u64;
362
363            self.deferred_syscall_counts[syscall_code] += rows_per_event;
364        }
365    }
366}
367
368/// Returns a mapping of `RiscvAirId` to their associated complexity codes.
369/// This provides the complexity cost for each AIR component in the system.
370#[must_use]
371pub fn get_complexity_mapping() -> EnumMap<RiscvAirId, u64> {
372    #[cfg(not(feature = "mprotect"))]
373    let json = include_str!("../artifacts/rv64im_complexity.json");
374    #[cfg(feature = "mprotect")]
375    let json = include_str!("../artifacts/rv64im_complexity_mprotect.json");
376
377    let complexity: HashMap<String, u64> = serde_json::from_str(json).unwrap();
378    complexity.into_iter().map(|(k, v)| (RiscvAirId::from_str(&k).unwrap(), v)).collect()
379}