sp1_recursion_core_v2/runtime/
mod.rs

1pub mod instruction;
2mod memory;
3mod opcode;
4mod program;
5mod record;
6
7// Avoid triggering annoying branch of thiserror derive macro.
8use backtrace::Backtrace as Trace;
9pub use instruction::Instruction;
10use instruction::{FieldEltType, HintBitsInstr, HintExt2FeltsInstr, HintInstr, PrintInstr};
11use memory::*;
12pub use opcode::*;
13pub use program::*;
14pub use record::*;
15
16use std::{
17    array,
18    borrow::Borrow,
19    collections::VecDeque,
20    fmt::Debug,
21    io::{stdout, Write},
22    iter::zip,
23    marker::PhantomData,
24    sync::Arc,
25};
26
27use hashbrown::HashMap;
28use itertools::Itertools;
29use p3_field::{AbstractField, ExtensionField, PrimeField32};
30use p3_poseidon2::{Poseidon2, Poseidon2ExternalMatrixGeneral};
31use p3_symmetric::{CryptographicPermutation, Permutation};
32use p3_util::reverse_bits_len;
33use thiserror::Error;
34
35use sp1_recursion_core::air::{Block, RECURSIVE_PROOF_NUM_PV_ELTS};
36
37/// TODO expand glob import once things are organized enough
38use crate::*;
39
40/// The heap pointer address.
41pub const HEAP_PTR: i32 = -4;
42pub const HEAP_START_ADDRESS: usize = STACK_SIZE + 4;
43
44pub const STACK_SIZE: usize = 1 << 24;
45pub const MEMORY_SIZE: usize = 1 << 28;
46
47/// The width of the Poseidon2 permutation.
48pub const PERMUTATION_WIDTH: usize = 16;
49pub const POSEIDON2_SBOX_DEGREE: u64 = 7;
50pub const HASH_RATE: usize = 8;
51
52/// The current verifier implementation assumes that we are using a 256-bit hash with 32-bit
53/// elements.
54pub const DIGEST_SIZE: usize = 8;
55
56pub const NUM_BITS: usize = 31;
57
58pub const D: usize = 4;
59
60#[derive(Debug, Clone, Default)]
61pub struct CycleTrackerEntry {
62    pub span_entered: bool,
63    pub span_enter_cycle: usize,
64    pub cumulative_cycles: usize,
65}
66
67/// TODO fully document.
68/// Taken from [`sp1_recursion_core::runtime::Runtime`].
69/// Many missing things (compared to the old `Runtime`) will need to be implemented.
70pub struct Runtime<'a, F: PrimeField32, EF: ExtensionField<F>, Diffusion> {
71    pub timestamp: usize,
72
73    pub nb_poseidons: usize,
74
75    pub nb_wide_poseidons: usize,
76
77    pub nb_bit_decompositions: usize,
78
79    pub nb_ext_ops: usize,
80
81    pub nb_base_ops: usize,
82
83    pub nb_memory_ops: usize,
84
85    pub nb_branch_ops: usize,
86
87    pub nb_exp_reverse_bits: usize,
88
89    pub nb_fri_fold: usize,
90
91    pub nb_print_f: usize,
92
93    pub nb_print_e: usize,
94
95    /// The current clock.
96    pub clk: F,
97
98    /// The program counter.
99    pub pc: F,
100
101    /// The program.
102    pub program: Arc<RecursionProgram<F>>,
103
104    /// Memory. From canonical usize of an Address to a MemoryEntry.
105    pub memory: MemVecMap<F>,
106
107    /// The execution record.
108    pub record: ExecutionRecord<F>,
109
110    pub witness_stream: VecDeque<Block<F>>,
111
112    pub cycle_tracker: HashMap<String, CycleTrackerEntry>,
113
114    /// The stream that print statements write to.
115    pub debug_stdout: Box<dyn Write + 'a>,
116
117    /// Entries for dealing with the Poseidon2 hash state.
118    perm: Option<
119        Poseidon2<
120            F,
121            Poseidon2ExternalMatrixGeneral,
122            Diffusion,
123            PERMUTATION_WIDTH,
124            POSEIDON2_SBOX_DEGREE,
125        >,
126    >,
127
128    _marker_ef: PhantomData<EF>,
129
130    _marker_diffusion: PhantomData<Diffusion>,
131}
132
133#[derive(Error, Debug)]
134pub enum RuntimeError<F: Debug, EF: Debug> {
135    #[error(
136        "attempted to perform base field division {in1:?}/{in2:?} \
137        from instruction {instr:?} at pc {pc:?}\nnearest pc with backtrace:\n{trace:?}"
138    )]
139    DivFOutOfDomain {
140        in1: F,
141        in2: F,
142        instr: BaseAluInstr<F>,
143        pc: usize,
144        trace: Option<(usize, Trace)>,
145    },
146    #[error(
147        "attempted to perform extension field division {in1:?}/{in2:?} \
148        from instruction {instr:?} at pc {pc:?}\nnearest pc with backtrace:\n{trace:?}"
149    )]
150    DivEOutOfDomain {
151        in1: EF,
152        in2: EF,
153        instr: ExtAluInstr<F>,
154        pc: usize,
155        trace: Option<(usize, Trace)>,
156    },
157    #[error("failed to print to `debug_stdout`: {0}")]
158    DebugPrint(#[from] std::io::Error),
159    #[error("attempted to read from empty witness stream")]
160    EmptyWitnessStream,
161}
162
163impl<'a, F: PrimeField32, EF: ExtensionField<F>, Diffusion> Runtime<'a, F, EF, Diffusion>
164where
165    Poseidon2<
166        F,
167        Poseidon2ExternalMatrixGeneral,
168        Diffusion,
169        PERMUTATION_WIDTH,
170        POSEIDON2_SBOX_DEGREE,
171    >: CryptographicPermutation<[F; PERMUTATION_WIDTH]>,
172{
173    pub fn new(
174        program: Arc<RecursionProgram<F>>,
175        perm: Poseidon2<
176            F,
177            Poseidon2ExternalMatrixGeneral,
178            Diffusion,
179            PERMUTATION_WIDTH,
180            POSEIDON2_SBOX_DEGREE,
181        >,
182    ) -> Self {
183        let record = ExecutionRecord::<F> { program: program.clone(), ..Default::default() };
184        let memory = Memory::with_capacity(program.total_memory);
185        Self {
186            timestamp: 0,
187            nb_poseidons: 0,
188            nb_wide_poseidons: 0,
189            nb_bit_decompositions: 0,
190            nb_exp_reverse_bits: 0,
191            nb_ext_ops: 0,
192            nb_base_ops: 0,
193            nb_memory_ops: 0,
194            nb_branch_ops: 0,
195            nb_fri_fold: 0,
196            nb_print_f: 0,
197            nb_print_e: 0,
198            clk: F::zero(),
199            program,
200            pc: F::zero(),
201            memory,
202            record,
203            witness_stream: VecDeque::new(),
204            cycle_tracker: HashMap::new(),
205            debug_stdout: Box::new(stdout()),
206            perm: Some(perm),
207            _marker_ef: PhantomData,
208            _marker_diffusion: PhantomData,
209        }
210    }
211
212    pub fn print_stats(&self) {
213        tracing::debug!("Total Cycles: {}", self.timestamp);
214        tracing::debug!("Poseidon Skinny Operations: {}", self.nb_poseidons);
215        tracing::debug!("Poseidon Wide Operations: {}", self.nb_wide_poseidons);
216        tracing::debug!("Exp Reverse Bits Operations: {}", self.nb_exp_reverse_bits);
217        tracing::debug!("FriFold Operations: {}", self.nb_fri_fold);
218        tracing::debug!("Field Operations: {}", self.nb_base_ops);
219        tracing::debug!("Extension Operations: {}", self.nb_ext_ops);
220        tracing::debug!("Memory Operations: {}", self.nb_memory_ops);
221        tracing::debug!("Branch Operations: {}", self.nb_branch_ops);
222        for (name, entry) in self.cycle_tracker.iter().sorted_by_key(|(name, _)| *name) {
223            tracing::debug!("> {}: {}", name, entry.cumulative_cycles);
224        }
225    }
226
227    fn nearest_pc_backtrace(&mut self) -> Option<(usize, Trace)> {
228        let trap_pc = self.pc.as_canonical_u32() as usize;
229        let trace = self.program.traces[trap_pc].clone();
230        if let Some(mut trace) = trace {
231            trace.resolve();
232            Some((trap_pc, trace))
233        } else {
234            (0..trap_pc)
235                .rev()
236                .filter_map(|nearby_pc| {
237                    let mut trace = self.program.traces.get(nearby_pc)?.clone()?;
238                    trace.resolve();
239                    Some((nearby_pc, trace))
240                })
241                .next()
242        }
243    }
244
245    /// Compare to [sp1_recursion_core::runtime::Runtime::run].
246    pub fn run(&mut self) -> Result<(), RuntimeError<F, EF>> {
247        let early_exit_ts = std::env::var("RECURSION_EARLY_EXIT_TS")
248            .map_or(usize::MAX, |ts: String| ts.parse().unwrap());
249        while self.pc < F::from_canonical_u32(self.program.instructions.len() as u32) {
250            let idx = self.pc.as_canonical_u32() as usize;
251            let instruction = self.program.instructions[idx].clone();
252
253            let next_clk = self.clk + F::from_canonical_u32(4);
254            let next_pc = self.pc + F::one();
255            match instruction {
256                Instruction::BaseAlu(instr @ BaseAluInstr { opcode, mult, addrs }) => {
257                    self.nb_base_ops += 1;
258                    let in1 = self.memory.mr(addrs.in1).val[0];
259                    let in2 = self.memory.mr(addrs.in2).val[0];
260                    // Do the computation.
261                    let out = match opcode {
262                        BaseAluOpcode::AddF => in1 + in2,
263                        BaseAluOpcode::SubF => in1 - in2,
264                        BaseAluOpcode::MulF => in1 * in2,
265                        BaseAluOpcode::DivF => match in1.try_div(in2) {
266                            Some(x) => x,
267                            None => {
268                                // Check for division exceptions and error. Note that 0/0 is defined
269                                // to be 1.
270                                if in1.is_zero() {
271                                    AbstractField::one()
272                                } else {
273                                    return Err(RuntimeError::DivFOutOfDomain {
274                                        in1,
275                                        in2,
276                                        instr,
277                                        pc: self.pc.as_canonical_u32() as usize,
278                                        trace: self.nearest_pc_backtrace(),
279                                    });
280                                }
281                            }
282                        },
283                    };
284                    self.memory.mw(addrs.out, Block::from(out), mult);
285                    self.record.base_alu_events.push(BaseAluEvent { out, in1, in2 });
286                }
287                Instruction::ExtAlu(instr @ ExtAluInstr { opcode, mult, addrs }) => {
288                    self.nb_ext_ops += 1;
289                    let in1 = self.memory.mr(addrs.in1).val;
290                    let in2 = self.memory.mr(addrs.in2).val;
291                    // Do the computation.
292                    let in1_ef = EF::from_base_slice(&in1.0);
293                    let in2_ef = EF::from_base_slice(&in2.0);
294                    let out_ef = match opcode {
295                        ExtAluOpcode::AddE => in1_ef + in2_ef,
296                        ExtAluOpcode::SubE => in1_ef - in2_ef,
297                        ExtAluOpcode::MulE => in1_ef * in2_ef,
298                        ExtAluOpcode::DivE => match in1_ef.try_div(in2_ef) {
299                            Some(x) => x,
300                            None => {
301                                // Check for division exceptions and error. Note that 0/0 is defined
302                                // to be 1.
303                                if in1_ef.is_zero() {
304                                    AbstractField::one()
305                                } else {
306                                    return Err(RuntimeError::DivEOutOfDomain {
307                                        in1: in1_ef,
308                                        in2: in2_ef,
309                                        instr,
310                                        pc: self.pc.as_canonical_u32() as usize,
311                                        trace: self.nearest_pc_backtrace(),
312                                    });
313                                }
314                            }
315                        },
316                    };
317                    let out = Block::from(out_ef.as_base_slice());
318                    self.memory.mw(addrs.out, out, mult);
319                    self.record.ext_alu_events.push(ExtAluEvent { out, in1, in2 });
320                }
321                Instruction::Mem(MemInstr {
322                    addrs: MemIo { inner: addr },
323                    vals: MemIo { inner: val },
324                    mult,
325                    kind,
326                }) => {
327                    self.nb_memory_ops += 1;
328                    match kind {
329                        MemAccessKind::Read => {
330                            let mem_entry = self.memory.mr_mult(addr, mult);
331                            assert_eq!(
332                                mem_entry.val, val,
333                                "stored memory value should be the specified value"
334                            );
335                        }
336                        MemAccessKind::Write => drop(self.memory.mw(addr, val, mult)),
337                    }
338                    self.record.mem_const_count += 1;
339                }
340                Instruction::Poseidon2(instr) => {
341                    let Poseidon2Instr { addrs: Poseidon2Io { input, output }, mults } = *instr;
342                    self.nb_poseidons += 1;
343                    let in_vals = std::array::from_fn(|i| self.memory.mr(input[i]).val[0]);
344                    let perm_output = self.perm.as_ref().unwrap().permute(in_vals);
345
346                    perm_output.iter().zip(output).zip(mults).for_each(|((&val, addr), mult)| {
347                        self.memory.mw(addr, Block::from(val), mult);
348                    });
349                    self.record
350                        .poseidon2_events
351                        .push(Poseidon2Event { input: in_vals, output: perm_output });
352                }
353                Instruction::ExpReverseBitsLen(ExpReverseBitsInstr {
354                    addrs: ExpReverseBitsIo { base, exp, result },
355                    mult,
356                }) => {
357                    self.nb_exp_reverse_bits += 1;
358                    let base_val = self.memory.mr(base).val[0];
359                    let exp_bits: Vec<_> =
360                        exp.iter().map(|bit| self.memory.mr(*bit).val[0]).collect();
361                    let exp_val = exp_bits
362                        .iter()
363                        .enumerate()
364                        .fold(0, |acc, (i, &val)| acc + val.as_canonical_u32() * (1 << i));
365                    let out =
366                        base_val.exp_u64(reverse_bits_len(exp_val as usize, exp_bits.len()) as u64);
367                    self.memory.mw(result, Block::from(out), mult);
368                    self.record.exp_reverse_bits_len_events.push(ExpReverseBitsEvent {
369                        result: out,
370                        base: base_val,
371                        exp: exp_bits,
372                    });
373                }
374                Instruction::HintBits(HintBitsInstr { output_addrs_mults, input_addr }) => {
375                    self.nb_bit_decompositions += 1;
376                    let num = self.memory.mr_mult(input_addr, F::zero()).val[0].as_canonical_u32();
377                    // Decompose the num into LE bits.
378                    let bits = (0..output_addrs_mults.len())
379                        .map(|i| Block::from(F::from_canonical_u32((num >> i) & 1)))
380                        .collect::<Vec<_>>();
381                    // Write the bits to the array at dst.
382                    for (bit, (addr, mult)) in bits.into_iter().zip(output_addrs_mults) {
383                        self.memory.mw(addr, bit, mult);
384                        self.record.mem_var_events.push(MemEvent { inner: bit });
385                    }
386                }
387
388                Instruction::FriFold(instr) => {
389                    let FriFoldInstr {
390                        base_single_addrs,
391                        ext_single_addrs,
392                        ext_vec_addrs,
393                        alpha_pow_mults,
394                        ro_mults,
395                    } = *instr;
396                    self.nb_fri_fold += 1;
397                    let x = self.memory.mr(base_single_addrs.x).val[0];
398                    let z = self.memory.mr(ext_single_addrs.z).val;
399                    let z: EF = z.ext();
400                    let alpha = self.memory.mr(ext_single_addrs.alpha).val;
401                    let alpha: EF = alpha.ext();
402                    let mat_opening = ext_vec_addrs
403                        .mat_opening
404                        .iter()
405                        .map(|addr| self.memory.mr(*addr).val)
406                        .collect_vec();
407                    let ps_at_z = ext_vec_addrs
408                        .ps_at_z
409                        .iter()
410                        .map(|addr| self.memory.mr(*addr).val)
411                        .collect_vec();
412
413                    for m in 0..ps_at_z.len() {
414                        // let m = F::from_canonical_u32(m);
415                        // Get the opening values.
416                        let p_at_x = mat_opening[m];
417                        let p_at_x: EF = p_at_x.ext();
418                        let p_at_z = ps_at_z[m];
419                        let p_at_z: EF = p_at_z.ext();
420
421                        // Calculate the quotient and update the values
422                        let quotient = (-p_at_z + p_at_x) / (-z + x);
423
424                        // First we peek to get the current value.
425                        let alpha_pow: EF =
426                            self.memory.mr(ext_vec_addrs.alpha_pow_input[m]).val.ext();
427
428                        let ro: EF = self.memory.mr(ext_vec_addrs.ro_input[m]).val.ext();
429
430                        let new_ro = ro + alpha_pow * quotient;
431                        let new_alpha_pow = alpha_pow * alpha;
432
433                        let _ = self.memory.mw(
434                            ext_vec_addrs.ro_output[m],
435                            Block::from(new_ro.as_base_slice()),
436                            ro_mults[m],
437                        );
438
439                        let _ = self.memory.mw(
440                            ext_vec_addrs.alpha_pow_output[m],
441                            Block::from(new_alpha_pow.as_base_slice()),
442                            alpha_pow_mults[m],
443                        );
444
445                        self.record.fri_fold_events.push(FriFoldEvent {
446                            base_single: FriFoldBaseIo { x },
447                            ext_single: FriFoldExtSingleIo {
448                                z: Block::from(z.as_base_slice()),
449                                alpha: Block::from(alpha.as_base_slice()),
450                            },
451                            ext_vec: FriFoldExtVecIo {
452                                mat_opening: Block::from(p_at_x.as_base_slice()),
453                                ps_at_z: Block::from(p_at_z.as_base_slice()),
454                                alpha_pow_input: Block::from(alpha_pow.as_base_slice()),
455                                ro_input: Block::from(ro.as_base_slice()),
456                                alpha_pow_output: Block::from(new_alpha_pow.as_base_slice()),
457                                ro_output: Block::from(new_ro.as_base_slice()),
458                            },
459                        });
460                    }
461                }
462
463                Instruction::CommitPublicValues(instr) => {
464                    let pv_addrs = instr.pv_addrs.to_vec();
465                    let pv_values: [F; RECURSIVE_PROOF_NUM_PV_ELTS] =
466                        array::from_fn(|i| self.memory.mr(pv_addrs[i]).val[0]);
467                    self.record.public_values = *pv_values.as_slice().borrow();
468                    self.record
469                        .commit_pv_hash_events
470                        .push(CommitPublicValuesEvent { public_values: self.record.public_values });
471                }
472
473                Instruction::Print(PrintInstr { field_elt_type, addr }) => match field_elt_type {
474                    FieldEltType::Base => {
475                        self.nb_print_f += 1;
476                        let f = self.memory.mr_mult(addr, F::zero()).val[0];
477                        writeln!(self.debug_stdout, "PRINTF={f}")
478                    }
479                    FieldEltType::Extension => {
480                        self.nb_print_e += 1;
481                        let ef = self.memory.mr_mult(addr, F::zero()).val;
482                        writeln!(self.debug_stdout, "PRINTEF={ef:?}")
483                    }
484                }
485                .map_err(RuntimeError::DebugPrint)?,
486                Instruction::HintExt2Felts(HintExt2FeltsInstr {
487                    output_addrs_mults,
488                    input_addr,
489                }) => {
490                    self.nb_bit_decompositions += 1;
491                    let fs = self.memory.mr_mult(input_addr, F::zero()).val;
492                    // Write the bits to the array at dst.
493                    for (f, (addr, mult)) in fs.into_iter().zip(output_addrs_mults) {
494                        let felt = Block::from(f);
495                        self.memory.mw(addr, felt, mult);
496                        self.record.mem_var_events.push(MemEvent { inner: felt });
497                    }
498                }
499                Instruction::Hint(HintInstr { output_addrs_mults }) => {
500                    // Check that enough Blocks can be read, so `drain` does not panic.
501                    if self.witness_stream.len() < output_addrs_mults.len() {
502                        return Err(RuntimeError::EmptyWitnessStream);
503                    }
504                    let witness = self.witness_stream.drain(0..output_addrs_mults.len());
505                    for ((addr, mult), val) in zip(output_addrs_mults, witness) {
506                        // Inline [`Self::mw`] to mutably borrow multiple fields of `self`.
507                        self.memory.mw(addr, val, mult);
508                        self.record.mem_var_events.push(MemEvent { inner: val });
509                    }
510                }
511            }
512
513            self.pc = next_pc;
514            self.clk = next_clk;
515            self.timestamp += 1;
516
517            if self.timestamp >= early_exit_ts {
518                break;
519            }
520        }
521        Ok(())
522    }
523}