1use sp1_hypercube::air::PROOF_NONCE_NUM_WORDS;
2use sp1_jit::MinimalTrace;
3use std::sync::Arc;
4
5use crate::{
6 events::{MemoryReadRecord, MemoryWriteRecord},
7 vm::{
8 gas::ReportGenerator,
9 results::{
10 AluResult, BranchResult, CycleResult, JumpResult, LoadResult, MaybeImmediate,
11 StoreResult, UTypeResult,
12 },
13 syscall::SyscallRuntime,
14 CoreVM,
15 },
16 ExecutionError, ExecutionReport, Instruction, Opcode, Program, Register, SP1CoreOpts,
17 SyscallCode,
18};
19
20pub struct GasEstimatingVM<'a> {
22 pub core: CoreVM<'a>,
24 pub gas_calculator: ReportGenerator,
26 pub hint_lens_idx: usize,
28}
29
30impl GasEstimatingVM<'_> {
31 pub fn execute(&mut self) -> Result<ExecutionReport, ExecutionError> {
33 if self.core.is_done() {
34 return Ok(self.gas_calculator.generate_report());
35 }
36
37 loop {
38 match self.execute_instruction()? {
39 CycleResult::Done(false) => {}
40 CycleResult::TraceEnd | CycleResult::ShardBoundary | CycleResult::Done(true) => {
41 return Ok(self.gas_calculator.generate_report());
42 }
43 }
44 }
45 }
46
47 pub fn execute_instruction(&mut self) -> Result<CycleResult, ExecutionError> {
49 let instruction = self.core.fetch();
50 if instruction.is_none() {
51 unreachable!("Fetching the next instruction failed");
52 }
53
54 let instruction = unsafe { *instruction.unwrap_unchecked() };
56
57 match &instruction.opcode {
58 Opcode::ADD
59 | Opcode::ADDI
60 | Opcode::SUB
61 | Opcode::XOR
62 | Opcode::OR
63 | Opcode::AND
64 | Opcode::SLL
65 | Opcode::SLLW
66 | Opcode::SRL
67 | Opcode::SRA
68 | Opcode::SRLW
69 | Opcode::SRAW
70 | Opcode::SLT
71 | Opcode::SLTU
72 | Opcode::MUL
73 | Opcode::MULHU
74 | Opcode::MULHSU
75 | Opcode::MULH
76 | Opcode::MULW
77 | Opcode::DIVU
78 | Opcode::REMU
79 | Opcode::DIV
80 | Opcode::REM
81 | Opcode::DIVW
82 | Opcode::ADDW
83 | Opcode::SUBW
84 | Opcode::DIVUW
85 | Opcode::REMUW
86 | Opcode::REMW => {
87 self.execute_alu(&instruction);
88 }
89 Opcode::LB
90 | Opcode::LBU
91 | Opcode::LH
92 | Opcode::LHU
93 | Opcode::LW
94 | Opcode::LWU
95 | Opcode::LD => self.execute_load(&instruction)?,
96 Opcode::SB | Opcode::SH | Opcode::SW | Opcode::SD => {
97 self.execute_store(&instruction)?;
98 }
99 Opcode::JAL | Opcode::JALR => {
100 self.execute_jump(&instruction);
101 }
102 Opcode::BEQ | Opcode::BNE | Opcode::BLT | Opcode::BGE | Opcode::BLTU | Opcode::BGEU => {
103 self.execute_branch(&instruction);
104 }
105 Opcode::LUI | Opcode::AUIPC => {
106 self.execute_utype(&instruction);
107 }
108 Opcode::ECALL => self.execute_ecall(&instruction)?,
109 Opcode::EBREAK | Opcode::UNIMP => {
110 unreachable!("Invalid opcode for `execute_instruction`: {:?}", instruction.opcode)
111 }
112 }
113
114 Ok(self.core.advance())
115 }
116}
117
118impl<'a> GasEstimatingVM<'a> {
119 pub fn new<T: MinimalTrace>(
121 trace: &'a T,
122 program: Arc<Program>,
123 proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
124 opts: SP1CoreOpts,
125 ) -> Self {
126 Self {
127 core: CoreVM::new(trace, program, opts, proof_nonce),
128 hint_lens_idx: 0,
129 gas_calculator: ReportGenerator::new(trace.clk_start()),
130 }
131 }
132
133 pub fn execute_load(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
140 let LoadResult { addr, rd, mr_record, rr_record, rw_record, rs1, .. } =
141 self.core.execute_load(instruction)?;
142
143 self.gas_calculator.handle_instruction(
144 instruction,
145 self.core.needs_bump_clk_high(),
146 rd == Register::X0,
147 self.core.needs_state_bump(instruction),
148 );
149
150 self.gas_calculator.handle_mem_event(addr, mr_record.prev_timestamp);
151 self.gas_calculator.handle_mem_event(rs1 as u64, rr_record.prev_timestamp);
152 self.gas_calculator.handle_mem_event(rd as u64, rw_record.prev_timestamp);
153
154 Ok(())
155 }
156
157 pub fn execute_store(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
164 let StoreResult { addr, mw_record, rs1_record, rs2_record, rs1, rs2, .. } =
165 self.core.execute_store(instruction)?;
166
167 self.gas_calculator.handle_instruction(
168 instruction,
169 self.core.needs_bump_clk_high(),
170 false, self.core.needs_state_bump(instruction),
172 );
173
174 self.gas_calculator.handle_mem_event(addr, mw_record.prev_timestamp);
175 self.gas_calculator.handle_mem_event(rs1 as u64, rs1_record.prev_timestamp);
176 self.gas_calculator.handle_mem_event(rs2 as u64, rs2_record.prev_timestamp);
177
178 Ok(())
179 }
180
181 #[inline]
183 pub fn execute_alu(&mut self, instruction: &Instruction) {
184 let AluResult { rd, rw_record, rs1, rs2, .. } = self.core.execute_alu(instruction);
185
186 self.gas_calculator.handle_mem_event(rd as u64, rw_record.prev_timestamp);
187
188 if let MaybeImmediate::Register(register, record) = rs1 {
189 self.gas_calculator.handle_mem_event(register as u64, record.prev_timestamp);
190 }
191
192 if let MaybeImmediate::Register(register, record) = rs2 {
193 self.gas_calculator.handle_mem_event(register as u64, record.prev_timestamp);
194 }
195
196 self.gas_calculator.handle_instruction(
197 instruction,
198 self.core.needs_bump_clk_high(),
199 false, self.core.needs_state_bump(instruction),
201 );
202 }
203
204 #[inline]
206 pub fn execute_jump(&mut self, instruction: &Instruction) {
207 let JumpResult { rd, rd_record, rs1, .. } = self.core.execute_jump(instruction);
208
209 self.gas_calculator.handle_mem_event(rd as u64, rd_record.prev_timestamp);
210
211 if let MaybeImmediate::Register(register, record) = rs1 {
212 self.gas_calculator.handle_mem_event(register as u64, record.prev_timestamp);
213 }
214
215 self.gas_calculator.handle_instruction(
216 instruction,
217 self.core.needs_bump_clk_high(),
218 false, self.core.needs_state_bump(instruction),
220 );
221 }
222
223 #[inline]
225 pub fn execute_branch(&mut self, instruction: &Instruction) {
226 let BranchResult { rs1, a_record, rs2, b_record, .. } =
227 self.core.execute_branch(instruction);
228
229 self.gas_calculator.handle_mem_event(rs1 as u64, a_record.prev_timestamp);
230 self.gas_calculator.handle_mem_event(rs2 as u64, b_record.prev_timestamp);
231
232 self.gas_calculator.handle_instruction(
233 instruction,
234 self.core.needs_bump_clk_high(),
235 false, self.core.needs_state_bump(instruction),
237 );
238 }
239
240 #[inline]
242 pub fn execute_utype(&mut self, instruction: &Instruction) {
243 let UTypeResult { rd, rw_record, .. } = self.core.execute_utype(instruction);
244
245 self.gas_calculator.handle_mem_event(rd as u64, rw_record.prev_timestamp);
246
247 self.gas_calculator.handle_instruction(
248 instruction,
249 self.core.needs_bump_clk_high(),
250 false, self.core.needs_state_bump(instruction),
252 );
253 }
254
255 #[inline]
257 pub fn execute_ecall(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
258 let code = self.core.read_code();
259
260 let result = CoreVM::execute_ecall(self, instruction, code)?;
261
262 if code == SyscallCode::HINT_LEN {
263 self.hint_lens_idx += 1;
264 }
265
266 if code == SyscallCode::HALT {
267 self.gas_calculator.set_exit_code(result.b);
268 }
269
270 if code.should_send() == 1 {
271 if self.core.is_retained_syscall(code) {
272 self.gas_calculator.handle_retained_syscall(code);
273 } else {
274 self.gas_calculator.syscall_sent(code);
275 }
276 }
277
278 self.gas_calculator.handle_instruction(
279 instruction,
280 self.core.needs_bump_clk_high(),
281 false, self.core.needs_state_bump(instruction),
283 );
284
285 Ok(())
286 }
287}
288
289impl<'a> SyscallRuntime<'a> for GasEstimatingVM<'a> {
290 const TRACING: bool = false;
291
292 fn core(&self) -> &CoreVM<'a> {
293 &self.core
294 }
295
296 fn core_mut(&mut self) -> &mut CoreVM<'a> {
297 &mut self.core
298 }
299
300 fn mr(&mut self, addr: u64) -> MemoryReadRecord {
301 let record = SyscallRuntime::mr(self.core_mut(), addr);
302
303 self.gas_calculator.handle_mem_event(addr, record.prev_timestamp);
304
305 record
306 }
307
308 fn mw_slice(&mut self, addr: u64, len: usize) -> Vec<MemoryWriteRecord> {
309 let records = SyscallRuntime::mw_slice(self.core_mut(), addr, len);
310
311 for (i, record) in records.iter().enumerate() {
312 self.gas_calculator.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
313 }
314
315 records
316 }
317
318 fn mr_slice(&mut self, addr: u64, len: usize) -> Vec<MemoryReadRecord> {
319 let records = SyscallRuntime::mr_slice(self.core_mut(), addr, len);
320
321 for (i, record) in records.iter().enumerate() {
322 self.gas_calculator.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
323 }
324
325 records
326 }
327
328 fn rr(&mut self, register: usize) -> MemoryReadRecord {
329 let record = SyscallRuntime::rr(self.core_mut(), register);
330
331 self.gas_calculator.handle_mem_event(register as u64, record.prev_timestamp);
332
333 record
334 }
335
336 fn mw(&mut self, addr: u64) -> MemoryWriteRecord {
337 let record = SyscallRuntime::mw(self.core_mut(), addr);
338
339 self.gas_calculator.handle_mem_event(addr, record.prev_timestamp);
340
341 record
342 }
343}