Skip to main content

sp1_core_executor/vm/
gas.rs

1use crate::{
2    events::MemoryRecord, vm::shapes::riscv_air_id_from_opcode, CompressedMemory, ExecutionReport,
3    Instruction, Opcode, RiscvAirId, SyscallCode,
4};
5use enum_map::EnumMap;
6use hashbrown::{HashMap, HashSet};
7use std::str::FromStr;
8
9// Trusted gas estimation calculator
10// For a given executor, calculate the total complexity and trace area
11// Based off of ShapeChecker
12pub struct ReportGenerator {
13    pub opcode_counts: EnumMap<Opcode, u64>,
14    pub syscall_counts: EnumMap<SyscallCode, u64>,
15    pub deferred_syscall_counts: EnumMap<SyscallCode, u64>,
16    pub system_chips_counts: EnumMap<RiscvAirId, u64>,
17
18    pub(crate) syscall_sent: bool,
19    pub(crate) local_mem_counts: u64,
20    is_last_read_external: CompressedMemory,
21
22    trace_cost_lookup: EnumMap<RiscvAirId, u64>,
23
24    shard_start_clk: u64,
25    exit_code: u64,
26}
27
28impl ReportGenerator {
29    pub fn new(shard_start_clk: u64) -> Self {
30        let costs: HashMap<String, usize> =
31            serde_json::from_str(include_str!("../artifacts/rv64im_costs.json")).unwrap();
32        let costs: EnumMap<RiscvAirId, u64> =
33            costs.into_iter().map(|(k, v)| (RiscvAirId::from_str(&k).unwrap(), v as u64)).collect();
34
35        Self {
36            trace_cost_lookup: costs,
37            opcode_counts: EnumMap::default(),
38            syscall_counts: EnumMap::default(),
39            deferred_syscall_counts: EnumMap::default(),
40            system_chips_counts: EnumMap::default(),
41            syscall_sent: false,
42            local_mem_counts: 0,
43            is_last_read_external: CompressedMemory::new(),
44            shard_start_clk,
45            exit_code: 0,
46        }
47    }
48
49    /// Set the start clock of the shard.
50    #[inline]
51    pub fn reset(&mut self, clk: u64) {
52        *self = Self::new(clk);
53    }
54
55    pub fn get_costs(&self) -> (u64, u64) {
56        (self.sum_total_complexity(), self.sum_total_trace_area())
57    }
58
59    /// Generate an `ExecutionReport` from the current state of the `ReportGenerator`
60    pub fn generate_report(&self) -> ExecutionReport {
61        // Combine syscall_counts and deferred_syscall_counts, converting from row counts
62        // back to invocation counts for the report. Internally these fields store
63        // rows_per_event * invocations for gas calculation; the report should show
64        // actual invocation counts.
65        let mut total_syscall_counts = EnumMap::default();
66        for (syscall_code, &count) in self.syscall_counts.iter() {
67            if count > 0 {
68                if let Some(air_id) = syscall_code.as_air_id() {
69                    total_syscall_counts[syscall_code] += count / air_id.rows_per_event() as u64;
70                }
71            }
72        }
73        for (syscall_code, &count) in self.deferred_syscall_counts.iter() {
74            if count > 0 {
75                if let Some(air_id) = syscall_code.as_air_id() {
76                    total_syscall_counts[syscall_code] += count / air_id.rows_per_event() as u64;
77                }
78            }
79        }
80
81        let (complexity, trace_area) = self.get_costs();
82        // Use integer arithmetic to avoid f64 precision warnings
83        // 0.3 * trace_area + 0.1 * complexity ≈ (3 * trace_area + complexity) / 10
84        let gas = (3 * trace_area + complexity) / 10;
85
86        ExecutionReport {
87            opcode_counts: Box::new(self.opcode_counts),
88            syscall_counts: Box::new(total_syscall_counts),
89            cycle_tracker: HashMap::new(),
90            invocation_tracker: HashMap::new(),
91            touched_memory_addresses: 0,
92            gas: Some(gas),
93            exit_code: self.exit_code,
94        }
95    }
96
97    // Set the exit code which will be returned when the `ExecutionReport` is generated.
98    pub fn set_exit_code(&mut self, exit_code: u64) {
99        self.exit_code = exit_code;
100    }
101
102    /// Helper method to filter out opcode counts with zero values
103    fn filtered_opcode_counts(&self) -> impl Iterator<Item = (Opcode, u64)> + '_ {
104        self.opcode_counts
105            .iter()
106            .filter(|(_, &count)| count > 0)
107            .map(|(opcode, &count)| (opcode, count))
108    }
109
110    fn sum_total_complexity(&self) -> u64 {
111        self.filtered_opcode_counts()
112            .map(|(opcode, count)| {
113                get_complexity_mapping()[riscv_air_id_from_opcode(opcode)] * count
114            })
115            .sum::<u64>()
116            + self
117                .syscall_counts
118                .iter()
119                .map(|(syscall_code, count)| {
120                    if let Some(syscall_air_id) = syscall_code.as_air_id() {
121                        get_complexity_mapping()[syscall_air_id] * count
122                    } else {
123                        0
124                    }
125                })
126                .sum::<u64>()
127            + self
128                .deferred_syscall_counts
129                .iter()
130                .map(|(syscall_code, count)| {
131                    if let Some(syscall_air_id) = syscall_code.as_air_id() {
132                        get_complexity_mapping()[syscall_air_id] * count
133                    } else {
134                        0
135                    }
136                })
137                .sum::<u64>()
138            + self
139                .system_chips_counts
140                .iter()
141                .map(|(riscv_air_id, count)| get_complexity_mapping()[riscv_air_id] * count)
142                .sum::<u64>()
143    }
144
145    fn sum_total_trace_area(&self) -> u64 {
146        self.filtered_opcode_counts()
147            .map(|(opcode, count)| self.trace_cost_lookup[riscv_air_id_from_opcode(opcode)] * count)
148            .sum::<u64>()
149            + self
150                .syscall_counts
151                .iter()
152                .map(|(syscall_code, count)| {
153                    if let Some(syscall_air_id) = syscall_code.as_air_id() {
154                        self.trace_cost_lookup[syscall_air_id] * count
155                    } else {
156                        0
157                    }
158                })
159                .sum::<u64>()
160            + self
161                .deferred_syscall_counts
162                .iter()
163                .map(|(syscall_code, count)| {
164                    if let Some(syscall_air_id) = syscall_code.as_air_id() {
165                        self.trace_cost_lookup[syscall_air_id] * count
166                    } else {
167                        0
168                    }
169                })
170                .sum::<u64>()
171            + self
172                .system_chips_counts
173                .iter()
174                .map(|(riscv_air_id, count)| self.trace_cost_lookup[riscv_air_id] * count)
175                .sum::<u64>()
176    }
177
178    #[inline]
179    pub fn handle_mem_event(&mut self, addr: u64, clk: u64) {
180        // Round down to the nearest 8-byte aligned address.
181        let addr = if addr > 31 { addr & !0b111 } else { addr };
182
183        let is_external = self.syscall_sent;
184
185        let is_first_read_this_shard = self.shard_start_clk > clk;
186
187        let is_last_read_external = self.is_last_read_external.insert(addr, is_external);
188
189        self.local_mem_counts +=
190            (is_first_read_this_shard || (is_last_read_external && !is_external)) as u64;
191    }
192
193    #[inline]
194    pub fn handle_retained_syscall(&mut self, syscall_code: SyscallCode) {
195        if let Some(syscall_air_id) = syscall_code.as_air_id() {
196            let rows_per_event = syscall_air_id.rows_per_event() as u64;
197
198            self.syscall_counts[syscall_code] += rows_per_event;
199        }
200    }
201
202    #[inline]
203    pub fn add_global_init_and_finalize_counts(
204        &mut self,
205        final_registers: &[MemoryRecord; 32],
206        mut touched_addresses: HashSet<u64>,
207        hint_init_events_addrs: &HashSet<u64>,
208        memory_image_addrs: &[u64],
209    ) {
210        touched_addresses.extend(memory_image_addrs);
211
212        // Add init for registers
213        self.system_chips_counts[RiscvAirId::MemoryGlobalInit] += 32;
214
215        // Add finalize for registers
216        self.system_chips_counts[RiscvAirId::MemoryGlobalFinalize] +=
217            final_registers.iter().enumerate().filter(|(_, e)| e.timestamp != 0).count() as u64;
218
219        // Add memory init events
220        self.system_chips_counts[RiscvAirId::MemoryGlobalInit] +=
221            hint_init_events_addrs.len() as u64;
222
223        let memory_init_events = touched_addresses
224            .iter()
225            .filter(|addr| !memory_image_addrs.contains(*addr))
226            .filter(|addr| !hint_init_events_addrs.contains(*addr));
227        self.system_chips_counts[RiscvAirId::MemoryGlobalInit] += memory_init_events.count() as u64;
228
229        touched_addresses.extend(hint_init_events_addrs.clone());
230        self.system_chips_counts[RiscvAirId::MemoryGlobalFinalize] +=
231            touched_addresses.len() as u64;
232    }
233
234    /// Increment the trace area for the given instruction.
235    ///
236    /// # Arguments
237    ///
238    /// * `instruction`: The instruction that is being handled.
239    /// * `syscall_sent`: Whether a syscall was sent during this cycle.
240    /// * `bump_clk_high`: Whether the clk's top 24 bits incremented during this cycle.
241    /// * `is_load_x0`: Whether the instruction is a load of x0, if so the riscv air id is `LoadX0`.
242    ///
243    /// # Returns
244    ///
245    /// Whether the shard limit has been reached.
246    #[inline]
247    #[allow(clippy::fn_params_excessive_bools)]
248    pub fn handle_instruction(
249        &mut self,
250        instruction: &Instruction,
251        bump_clk_high: bool,
252        _is_load_x0: bool,
253        needs_state_bump: bool,
254    ) {
255        let touched_addresses: u64 = std::mem::take(&mut self.local_mem_counts);
256        let syscall_sent = std::mem::take(&mut self.syscall_sent);
257
258        // Increment for opcode
259        self.opcode_counts[instruction.opcode] += 1;
260
261        // Increment system chips
262        // Increment by if bump_clk_high is needed
263        let bump_clk_high_num_events = 32 * bump_clk_high as u64;
264        self.system_chips_counts[RiscvAirId::MemoryBump] += bump_clk_high_num_events;
265        self.system_chips_counts[RiscvAirId::MemoryLocal] += touched_addresses;
266        self.system_chips_counts[RiscvAirId::StateBump] += needs_state_bump as u64;
267        self.system_chips_counts[RiscvAirId::Global] += 2 * touched_addresses + syscall_sent as u64;
268
269        // Increment if the syscall is retained
270        self.system_chips_counts[RiscvAirId::SyscallCore] += syscall_sent as u64;
271    }
272
273    #[inline]
274    pub fn syscall_sent(&mut self, syscall_code: SyscallCode) {
275        self.syscall_sent = true;
276        if let Some(syscall_air_id) = syscall_code.as_air_id() {
277            let rows_per_event = syscall_air_id.rows_per_event() as u64;
278
279            self.deferred_syscall_counts[syscall_code] += rows_per_event;
280        }
281    }
282}
283
284/// Returns a mapping of `RiscvAirId` to their associated complexity codes.
285/// This provides the complexity cost for each AIR component in the system.
286#[must_use]
287pub fn get_complexity_mapping() -> EnumMap<RiscvAirId, u64> {
288    let mut mapping = EnumMap::<RiscvAirId, u64>::default();
289
290    // Core program and system components
291    mapping[RiscvAirId::Program] = 0;
292    mapping[RiscvAirId::SyscallCore] = 2;
293    mapping[RiscvAirId::SyscallPrecompile] = 2;
294
295    // SHA components
296    mapping[RiscvAirId::ShaExtend] = 80;
297    mapping[RiscvAirId::ShaExtendControl] = 21;
298    mapping[RiscvAirId::ShaCompress] = 300;
299    mapping[RiscvAirId::ShaCompressControl] = 21;
300
301    // Elliptic curve operations
302    mapping[RiscvAirId::EdAddAssign] = 792;
303    mapping[RiscvAirId::EdDecompress] = 755;
304
305    // Secp256k1 operations
306    mapping[RiscvAirId::Secp256k1Decompress] = 691;
307    mapping[RiscvAirId::Secp256k1AddAssign] = 918;
308    mapping[RiscvAirId::Secp256k1DoubleAssign] = 904;
309
310    // Secp256r1 operations
311    mapping[RiscvAirId::Secp256r1Decompress] = 691;
312    mapping[RiscvAirId::Secp256r1AddAssign] = 918;
313    mapping[RiscvAirId::Secp256r1DoubleAssign] = 904;
314
315    // Keccak operations
316    mapping[RiscvAirId::KeccakPermute] = 2859;
317    mapping[RiscvAirId::KeccakPermuteControl] = 331;
318
319    // Bn254 operations
320    mapping[RiscvAirId::Bn254AddAssign] = 918;
321    mapping[RiscvAirId::Bn254DoubleAssign] = 904;
322
323    // BLS12-381 operations
324    mapping[RiscvAirId::Bls12381AddAssign] = 1374;
325    mapping[RiscvAirId::Bls12381DoubleAssign] = 1356;
326    mapping[RiscvAirId::Bls12381Decompress] = 1237;
327
328    // Uint256 operations
329    mapping[RiscvAirId::Uint256MulMod] = 253;
330    mapping[RiscvAirId::Uint256Ops] = 297;
331    mapping[RiscvAirId::U256XU2048Mul] = 1197;
332
333    // Field operations
334    mapping[RiscvAirId::Bls12381FpOpAssign] = 317;
335    mapping[RiscvAirId::Bls12381Fp2AddSubAssign] = 615;
336    mapping[RiscvAirId::Bls12381Fp2MulAssign] = 994;
337    mapping[RiscvAirId::Bn254FpOpAssign] = 217;
338    mapping[RiscvAirId::Bn254Fp2AddSubAssign] = 415;
339    mapping[RiscvAirId::Bn254Fp2MulAssign] = 666;
340
341    // System operations
342    mapping[RiscvAirId::Mprotect] = 11;
343    mapping[RiscvAirId::Poseidon2] = 497;
344
345    // RISC-V instruction costs
346    mapping[RiscvAirId::DivRem] = 347;
347    mapping[RiscvAirId::Add] = 15;
348    mapping[RiscvAirId::Addi] = 14;
349    mapping[RiscvAirId::Addw] = 20;
350    mapping[RiscvAirId::Sub] = 15;
351    mapping[RiscvAirId::Subw] = 15;
352    mapping[RiscvAirId::Bitwise] = 19;
353    mapping[RiscvAirId::Mul] = 60;
354    mapping[RiscvAirId::ShiftRight] = 77;
355    mapping[RiscvAirId::ShiftLeft] = 68;
356    mapping[RiscvAirId::Lt] = 41;
357
358    // Memory operations
359    mapping[RiscvAirId::LoadByte] = 32;
360    mapping[RiscvAirId::LoadHalf] = 33;
361    mapping[RiscvAirId::LoadWord] = 33;
362    mapping[RiscvAirId::LoadDouble] = 24;
363    mapping[RiscvAirId::LoadX0] = 34;
364    mapping[RiscvAirId::StoreByte] = 32;
365    mapping[RiscvAirId::StoreHalf] = 27;
366    mapping[RiscvAirId::StoreWord] = 27;
367    mapping[RiscvAirId::StoreDouble] = 23;
368
369    // Control flow
370    mapping[RiscvAirId::UType] = 19;
371    mapping[RiscvAirId::Branch] = 49;
372    mapping[RiscvAirId::Jal] = 24;
373    mapping[RiscvAirId::Jalr] = 25;
374
375    // System components
376    mapping[RiscvAirId::InstructionDecode] = 160;
377    mapping[RiscvAirId::InstructionFetch] = 11;
378    mapping[RiscvAirId::SyscallInstrs] = 93;
379    mapping[RiscvAirId::MemoryBump] = 5;
380    mapping[RiscvAirId::PageProt] = 32;
381    mapping[RiscvAirId::PageProtLocal] = 1;
382    mapping[RiscvAirId::StateBump] = 8;
383    mapping[RiscvAirId::MemoryGlobalInit] = 31;
384    mapping[RiscvAirId::MemoryGlobalFinalize] = 31;
385    mapping[RiscvAirId::PageProtGlobalInit] = 26;
386    mapping[RiscvAirId::PageProtGlobalFinalize] = 25;
387    mapping[RiscvAirId::MemoryLocal] = 4;
388    mapping[RiscvAirId::Global] = 216;
389
390    // Memory types
391    mapping[RiscvAirId::Byte] = 0;
392    mapping[RiscvAirId::Range] = 0;
393
394    mapping
395}