1pub mod instruction;
2mod memory;
3mod opcode;
4mod program;
5mod record;
6
7use backtrace::Backtrace as Trace;
9pub use instruction::Instruction;
10use instruction::{FieldEltType, HintBitsInstr, HintExt2FeltsInstr, HintInstr, PrintInstr};
11use memory::*;
12pub use opcode::*;
13pub use program::*;
14pub use record::*;
15
16use std::{
17 array,
18 borrow::Borrow,
19 collections::VecDeque,
20 fmt::Debug,
21 io::{stdout, Write},
22 iter::zip,
23 marker::PhantomData,
24 sync::Arc,
25};
26
27use hashbrown::HashMap;
28use itertools::Itertools;
29use p3_field::{AbstractField, ExtensionField, PrimeField32};
30use p3_poseidon2::{Poseidon2, Poseidon2ExternalMatrixGeneral};
31use p3_symmetric::{CryptographicPermutation, Permutation};
32use p3_util::reverse_bits_len;
33use thiserror::Error;
34
35use sp1_recursion_core::air::{Block, RECURSIVE_PROOF_NUM_PV_ELTS};
36
37use crate::*;
39
40pub const HEAP_PTR: i32 = -4;
42pub const HEAP_START_ADDRESS: usize = STACK_SIZE + 4;
43
44pub const STACK_SIZE: usize = 1 << 24;
45pub const MEMORY_SIZE: usize = 1 << 28;
46
47pub const PERMUTATION_WIDTH: usize = 16;
49pub const POSEIDON2_SBOX_DEGREE: u64 = 7;
50pub const HASH_RATE: usize = 8;
51
52pub const DIGEST_SIZE: usize = 8;
55
56pub const NUM_BITS: usize = 31;
57
58pub const D: usize = 4;
59
60#[derive(Debug, Clone, Default)]
61pub struct CycleTrackerEntry {
62 pub span_entered: bool,
63 pub span_enter_cycle: usize,
64 pub cumulative_cycles: usize,
65}
66
67pub struct Runtime<'a, F: PrimeField32, EF: ExtensionField<F>, Diffusion> {
71 pub timestamp: usize,
72
73 pub nb_poseidons: usize,
74
75 pub nb_wide_poseidons: usize,
76
77 pub nb_bit_decompositions: usize,
78
79 pub nb_ext_ops: usize,
80
81 pub nb_base_ops: usize,
82
83 pub nb_memory_ops: usize,
84
85 pub nb_branch_ops: usize,
86
87 pub nb_exp_reverse_bits: usize,
88
89 pub nb_fri_fold: usize,
90
91 pub nb_print_f: usize,
92
93 pub nb_print_e: usize,
94
95 pub clk: F,
97
98 pub pc: F,
100
101 pub program: Arc<RecursionProgram<F>>,
103
104 pub memory: MemVecMap<F>,
106
107 pub record: ExecutionRecord<F>,
109
110 pub witness_stream: VecDeque<Block<F>>,
111
112 pub cycle_tracker: HashMap<String, CycleTrackerEntry>,
113
114 pub debug_stdout: Box<dyn Write + 'a>,
116
117 perm: Option<
119 Poseidon2<
120 F,
121 Poseidon2ExternalMatrixGeneral,
122 Diffusion,
123 PERMUTATION_WIDTH,
124 POSEIDON2_SBOX_DEGREE,
125 >,
126 >,
127
128 _marker_ef: PhantomData<EF>,
129
130 _marker_diffusion: PhantomData<Diffusion>,
131}
132
133#[derive(Error, Debug)]
134pub enum RuntimeError<F: Debug, EF: Debug> {
135 #[error(
136 "attempted to perform base field division {in1:?}/{in2:?} \
137 from instruction {instr:?} at pc {pc:?}\nnearest pc with backtrace:\n{trace:?}"
138 )]
139 DivFOutOfDomain {
140 in1: F,
141 in2: F,
142 instr: BaseAluInstr<F>,
143 pc: usize,
144 trace: Option<(usize, Trace)>,
145 },
146 #[error(
147 "attempted to perform extension field division {in1:?}/{in2:?} \
148 from instruction {instr:?} at pc {pc:?}\nnearest pc with backtrace:\n{trace:?}"
149 )]
150 DivEOutOfDomain {
151 in1: EF,
152 in2: EF,
153 instr: ExtAluInstr<F>,
154 pc: usize,
155 trace: Option<(usize, Trace)>,
156 },
157 #[error("failed to print to `debug_stdout`: {0}")]
158 DebugPrint(#[from] std::io::Error),
159 #[error("attempted to read from empty witness stream")]
160 EmptyWitnessStream,
161}
162
163impl<'a, F: PrimeField32, EF: ExtensionField<F>, Diffusion> Runtime<'a, F, EF, Diffusion>
164where
165 Poseidon2<
166 F,
167 Poseidon2ExternalMatrixGeneral,
168 Diffusion,
169 PERMUTATION_WIDTH,
170 POSEIDON2_SBOX_DEGREE,
171 >: CryptographicPermutation<[F; PERMUTATION_WIDTH]>,
172{
173 pub fn new(
174 program: Arc<RecursionProgram<F>>,
175 perm: Poseidon2<
176 F,
177 Poseidon2ExternalMatrixGeneral,
178 Diffusion,
179 PERMUTATION_WIDTH,
180 POSEIDON2_SBOX_DEGREE,
181 >,
182 ) -> Self {
183 let record = ExecutionRecord::<F> { program: program.clone(), ..Default::default() };
184 let memory = Memory::with_capacity(program.total_memory);
185 Self {
186 timestamp: 0,
187 nb_poseidons: 0,
188 nb_wide_poseidons: 0,
189 nb_bit_decompositions: 0,
190 nb_exp_reverse_bits: 0,
191 nb_ext_ops: 0,
192 nb_base_ops: 0,
193 nb_memory_ops: 0,
194 nb_branch_ops: 0,
195 nb_fri_fold: 0,
196 nb_print_f: 0,
197 nb_print_e: 0,
198 clk: F::zero(),
199 program,
200 pc: F::zero(),
201 memory,
202 record,
203 witness_stream: VecDeque::new(),
204 cycle_tracker: HashMap::new(),
205 debug_stdout: Box::new(stdout()),
206 perm: Some(perm),
207 _marker_ef: PhantomData,
208 _marker_diffusion: PhantomData,
209 }
210 }
211
212 pub fn print_stats(&self) {
213 tracing::debug!("Total Cycles: {}", self.timestamp);
214 tracing::debug!("Poseidon Skinny Operations: {}", self.nb_poseidons);
215 tracing::debug!("Poseidon Wide Operations: {}", self.nb_wide_poseidons);
216 tracing::debug!("Exp Reverse Bits Operations: {}", self.nb_exp_reverse_bits);
217 tracing::debug!("FriFold Operations: {}", self.nb_fri_fold);
218 tracing::debug!("Field Operations: {}", self.nb_base_ops);
219 tracing::debug!("Extension Operations: {}", self.nb_ext_ops);
220 tracing::debug!("Memory Operations: {}", self.nb_memory_ops);
221 tracing::debug!("Branch Operations: {}", self.nb_branch_ops);
222 for (name, entry) in self.cycle_tracker.iter().sorted_by_key(|(name, _)| *name) {
223 tracing::debug!("> {}: {}", name, entry.cumulative_cycles);
224 }
225 }
226
227 fn nearest_pc_backtrace(&mut self) -> Option<(usize, Trace)> {
228 let trap_pc = self.pc.as_canonical_u32() as usize;
229 let trace = self.program.traces[trap_pc].clone();
230 if let Some(mut trace) = trace {
231 trace.resolve();
232 Some((trap_pc, trace))
233 } else {
234 (0..trap_pc)
235 .rev()
236 .filter_map(|nearby_pc| {
237 let mut trace = self.program.traces.get(nearby_pc)?.clone()?;
238 trace.resolve();
239 Some((nearby_pc, trace))
240 })
241 .next()
242 }
243 }
244
245 pub fn run(&mut self) -> Result<(), RuntimeError<F, EF>> {
247 let early_exit_ts = std::env::var("RECURSION_EARLY_EXIT_TS")
248 .map_or(usize::MAX, |ts: String| ts.parse().unwrap());
249 while self.pc < F::from_canonical_u32(self.program.instructions.len() as u32) {
250 let idx = self.pc.as_canonical_u32() as usize;
251 let instruction = self.program.instructions[idx].clone();
252
253 let next_clk = self.clk + F::from_canonical_u32(4);
254 let next_pc = self.pc + F::one();
255 match instruction {
256 Instruction::BaseAlu(instr @ BaseAluInstr { opcode, mult, addrs }) => {
257 self.nb_base_ops += 1;
258 let in1 = self.memory.mr(addrs.in1).val[0];
259 let in2 = self.memory.mr(addrs.in2).val[0];
260 let out = match opcode {
262 BaseAluOpcode::AddF => in1 + in2,
263 BaseAluOpcode::SubF => in1 - in2,
264 BaseAluOpcode::MulF => in1 * in2,
265 BaseAluOpcode::DivF => match in1.try_div(in2) {
266 Some(x) => x,
267 None => {
268 if in1.is_zero() {
271 AbstractField::one()
272 } else {
273 return Err(RuntimeError::DivFOutOfDomain {
274 in1,
275 in2,
276 instr,
277 pc: self.pc.as_canonical_u32() as usize,
278 trace: self.nearest_pc_backtrace(),
279 });
280 }
281 }
282 },
283 };
284 self.memory.mw(addrs.out, Block::from(out), mult);
285 self.record.base_alu_events.push(BaseAluEvent { out, in1, in2 });
286 }
287 Instruction::ExtAlu(instr @ ExtAluInstr { opcode, mult, addrs }) => {
288 self.nb_ext_ops += 1;
289 let in1 = self.memory.mr(addrs.in1).val;
290 let in2 = self.memory.mr(addrs.in2).val;
291 let in1_ef = EF::from_base_slice(&in1.0);
293 let in2_ef = EF::from_base_slice(&in2.0);
294 let out_ef = match opcode {
295 ExtAluOpcode::AddE => in1_ef + in2_ef,
296 ExtAluOpcode::SubE => in1_ef - in2_ef,
297 ExtAluOpcode::MulE => in1_ef * in2_ef,
298 ExtAluOpcode::DivE => match in1_ef.try_div(in2_ef) {
299 Some(x) => x,
300 None => {
301 if in1_ef.is_zero() {
304 AbstractField::one()
305 } else {
306 return Err(RuntimeError::DivEOutOfDomain {
307 in1: in1_ef,
308 in2: in2_ef,
309 instr,
310 pc: self.pc.as_canonical_u32() as usize,
311 trace: self.nearest_pc_backtrace(),
312 });
313 }
314 }
315 },
316 };
317 let out = Block::from(out_ef.as_base_slice());
318 self.memory.mw(addrs.out, out, mult);
319 self.record.ext_alu_events.push(ExtAluEvent { out, in1, in2 });
320 }
321 Instruction::Mem(MemInstr {
322 addrs: MemIo { inner: addr },
323 vals: MemIo { inner: val },
324 mult,
325 kind,
326 }) => {
327 self.nb_memory_ops += 1;
328 match kind {
329 MemAccessKind::Read => {
330 let mem_entry = self.memory.mr_mult(addr, mult);
331 assert_eq!(
332 mem_entry.val, val,
333 "stored memory value should be the specified value"
334 );
335 }
336 MemAccessKind::Write => drop(self.memory.mw(addr, val, mult)),
337 }
338 self.record.mem_const_count += 1;
339 }
340 Instruction::Poseidon2(instr) => {
341 let Poseidon2Instr { addrs: Poseidon2Io { input, output }, mults } = *instr;
342 self.nb_poseidons += 1;
343 let in_vals = std::array::from_fn(|i| self.memory.mr(input[i]).val[0]);
344 let perm_output = self.perm.as_ref().unwrap().permute(in_vals);
345
346 perm_output.iter().zip(output).zip(mults).for_each(|((&val, addr), mult)| {
347 self.memory.mw(addr, Block::from(val), mult);
348 });
349 self.record
350 .poseidon2_events
351 .push(Poseidon2Event { input: in_vals, output: perm_output });
352 }
353 Instruction::ExpReverseBitsLen(ExpReverseBitsInstr {
354 addrs: ExpReverseBitsIo { base, exp, result },
355 mult,
356 }) => {
357 self.nb_exp_reverse_bits += 1;
358 let base_val = self.memory.mr(base).val[0];
359 let exp_bits: Vec<_> =
360 exp.iter().map(|bit| self.memory.mr(*bit).val[0]).collect();
361 let exp_val = exp_bits
362 .iter()
363 .enumerate()
364 .fold(0, |acc, (i, &val)| acc + val.as_canonical_u32() * (1 << i));
365 let out =
366 base_val.exp_u64(reverse_bits_len(exp_val as usize, exp_bits.len()) as u64);
367 self.memory.mw(result, Block::from(out), mult);
368 self.record.exp_reverse_bits_len_events.push(ExpReverseBitsEvent {
369 result: out,
370 base: base_val,
371 exp: exp_bits,
372 });
373 }
374 Instruction::HintBits(HintBitsInstr { output_addrs_mults, input_addr }) => {
375 self.nb_bit_decompositions += 1;
376 let num = self.memory.mr_mult(input_addr, F::zero()).val[0].as_canonical_u32();
377 let bits = (0..output_addrs_mults.len())
379 .map(|i| Block::from(F::from_canonical_u32((num >> i) & 1)))
380 .collect::<Vec<_>>();
381 for (bit, (addr, mult)) in bits.into_iter().zip(output_addrs_mults) {
383 self.memory.mw(addr, bit, mult);
384 self.record.mem_var_events.push(MemEvent { inner: bit });
385 }
386 }
387
388 Instruction::FriFold(instr) => {
389 let FriFoldInstr {
390 base_single_addrs,
391 ext_single_addrs,
392 ext_vec_addrs,
393 alpha_pow_mults,
394 ro_mults,
395 } = *instr;
396 self.nb_fri_fold += 1;
397 let x = self.memory.mr(base_single_addrs.x).val[0];
398 let z = self.memory.mr(ext_single_addrs.z).val;
399 let z: EF = z.ext();
400 let alpha = self.memory.mr(ext_single_addrs.alpha).val;
401 let alpha: EF = alpha.ext();
402 let mat_opening = ext_vec_addrs
403 .mat_opening
404 .iter()
405 .map(|addr| self.memory.mr(*addr).val)
406 .collect_vec();
407 let ps_at_z = ext_vec_addrs
408 .ps_at_z
409 .iter()
410 .map(|addr| self.memory.mr(*addr).val)
411 .collect_vec();
412
413 for m in 0..ps_at_z.len() {
414 let p_at_x = mat_opening[m];
417 let p_at_x: EF = p_at_x.ext();
418 let p_at_z = ps_at_z[m];
419 let p_at_z: EF = p_at_z.ext();
420
421 let quotient = (-p_at_z + p_at_x) / (-z + x);
423
424 let alpha_pow: EF =
426 self.memory.mr(ext_vec_addrs.alpha_pow_input[m]).val.ext();
427
428 let ro: EF = self.memory.mr(ext_vec_addrs.ro_input[m]).val.ext();
429
430 let new_ro = ro + alpha_pow * quotient;
431 let new_alpha_pow = alpha_pow * alpha;
432
433 let _ = self.memory.mw(
434 ext_vec_addrs.ro_output[m],
435 Block::from(new_ro.as_base_slice()),
436 ro_mults[m],
437 );
438
439 let _ = self.memory.mw(
440 ext_vec_addrs.alpha_pow_output[m],
441 Block::from(new_alpha_pow.as_base_slice()),
442 alpha_pow_mults[m],
443 );
444
445 self.record.fri_fold_events.push(FriFoldEvent {
446 base_single: FriFoldBaseIo { x },
447 ext_single: FriFoldExtSingleIo {
448 z: Block::from(z.as_base_slice()),
449 alpha: Block::from(alpha.as_base_slice()),
450 },
451 ext_vec: FriFoldExtVecIo {
452 mat_opening: Block::from(p_at_x.as_base_slice()),
453 ps_at_z: Block::from(p_at_z.as_base_slice()),
454 alpha_pow_input: Block::from(alpha_pow.as_base_slice()),
455 ro_input: Block::from(ro.as_base_slice()),
456 alpha_pow_output: Block::from(new_alpha_pow.as_base_slice()),
457 ro_output: Block::from(new_ro.as_base_slice()),
458 },
459 });
460 }
461 }
462
463 Instruction::CommitPublicValues(instr) => {
464 let pv_addrs = instr.pv_addrs.to_vec();
465 let pv_values: [F; RECURSIVE_PROOF_NUM_PV_ELTS] =
466 array::from_fn(|i| self.memory.mr(pv_addrs[i]).val[0]);
467 self.record.public_values = *pv_values.as_slice().borrow();
468 self.record
469 .commit_pv_hash_events
470 .push(CommitPublicValuesEvent { public_values: self.record.public_values });
471 }
472
473 Instruction::Print(PrintInstr { field_elt_type, addr }) => match field_elt_type {
474 FieldEltType::Base => {
475 self.nb_print_f += 1;
476 let f = self.memory.mr_mult(addr, F::zero()).val[0];
477 writeln!(self.debug_stdout, "PRINTF={f}")
478 }
479 FieldEltType::Extension => {
480 self.nb_print_e += 1;
481 let ef = self.memory.mr_mult(addr, F::zero()).val;
482 writeln!(self.debug_stdout, "PRINTEF={ef:?}")
483 }
484 }
485 .map_err(RuntimeError::DebugPrint)?,
486 Instruction::HintExt2Felts(HintExt2FeltsInstr {
487 output_addrs_mults,
488 input_addr,
489 }) => {
490 self.nb_bit_decompositions += 1;
491 let fs = self.memory.mr_mult(input_addr, F::zero()).val;
492 for (f, (addr, mult)) in fs.into_iter().zip(output_addrs_mults) {
494 let felt = Block::from(f);
495 self.memory.mw(addr, felt, mult);
496 self.record.mem_var_events.push(MemEvent { inner: felt });
497 }
498 }
499 Instruction::Hint(HintInstr { output_addrs_mults }) => {
500 if self.witness_stream.len() < output_addrs_mults.len() {
502 return Err(RuntimeError::EmptyWitnessStream);
503 }
504 let witness = self.witness_stream.drain(0..output_addrs_mults.len());
505 for ((addr, mult), val) in zip(output_addrs_mults, witness) {
506 self.memory.mw(addr, val, mult);
508 self.record.mem_var_events.push(MemEvent { inner: val });
509 }
510 }
511 }
512
513 self.pc = next_pc;
514 self.clk = next_clk;
515 self.timestamp += 1;
516
517 if self.timestamp >= early_exit_ts {
518 break;
519 }
520 }
521 Ok(())
522 }
523}