sp1_recursion_core/runtime/
record.rs

1use std::{array, ops::Add, sync::Arc};
2
3use p3_field::{AbstractField, Field, PrimeField32};
4use sp1_stark::{air::MachineAir, MachineRecord, SP1CoreOpts, PROOF_MAX_NUM_PVS};
5
6use super::{
7    machine::RecursionAirEventCount, BaseAluEvent, BatchFRIEvent, CommitPublicValuesEvent,
8    ExpReverseBitsEvent, ExtAluEvent, FriFoldEvent, MemEvent, Poseidon2Event, RecursionProgram,
9    RecursionPublicValues, SelectEvent,
10};
11
12#[derive(Clone, Default, Debug)]
13pub struct ExecutionRecord<F> {
14    pub program: Arc<RecursionProgram<F>>,
15    /// The index of the shard.
16    pub index: u32,
17
18    pub base_alu_events: Vec<BaseAluEvent<F>>,
19    pub ext_alu_events: Vec<ExtAluEvent<F>>,
20    pub mem_const_count: usize,
21    pub mem_var_events: Vec<MemEvent<F>>,
22    /// The public values.
23    pub public_values: RecursionPublicValues<F>,
24
25    pub poseidon2_events: Vec<Poseidon2Event<F>>,
26    pub select_events: Vec<SelectEvent<F>>,
27    pub exp_reverse_bits_len_events: Vec<ExpReverseBitsEvent<F>>,
28    pub fri_fold_events: Vec<FriFoldEvent<F>>,
29    pub batch_fri_events: Vec<BatchFRIEvent<F>>,
30    pub commit_pv_hash_events: Vec<CommitPublicValuesEvent<F>>,
31}
32
33impl<F: PrimeField32> MachineRecord for ExecutionRecord<F> {
34    type Config = SP1CoreOpts;
35
36    fn stats(&self) -> hashbrown::HashMap<String, usize> {
37        [
38            ("base_alu_events", self.base_alu_events.len()),
39            ("ext_alu_events", self.ext_alu_events.len()),
40            ("mem_const_count", self.mem_const_count),
41            ("mem_var_events", self.mem_var_events.len()),
42            ("poseidon2_events", self.poseidon2_events.len()),
43            ("select_events", self.select_events.len()),
44            ("exp_reverse_bits_len_events", self.exp_reverse_bits_len_events.len()),
45            ("fri_fold_events", self.fri_fold_events.len()),
46            ("batch_fri_events", self.batch_fri_events.len()),
47            ("commit_pv_hash_events", self.commit_pv_hash_events.len()),
48        ]
49        .into_iter()
50        .map(|(k, v)| (k.to_owned(), v))
51        .collect()
52    }
53
54    fn append(&mut self, other: &mut Self) {
55        // Exhaustive destructuring for refactoring purposes.
56        let Self {
57            program: _,
58            index: _,
59            base_alu_events,
60            ext_alu_events,
61            mem_const_count,
62            mem_var_events,
63            public_values: _,
64            poseidon2_events,
65            select_events,
66            exp_reverse_bits_len_events,
67            fri_fold_events,
68            batch_fri_events,
69            commit_pv_hash_events,
70        } = self;
71        base_alu_events.append(&mut other.base_alu_events);
72        ext_alu_events.append(&mut other.ext_alu_events);
73        *mem_const_count += other.mem_const_count;
74        mem_var_events.append(&mut other.mem_var_events);
75        poseidon2_events.append(&mut other.poseidon2_events);
76        select_events.append(&mut other.select_events);
77        exp_reverse_bits_len_events.append(&mut other.exp_reverse_bits_len_events);
78        fri_fold_events.append(&mut other.fri_fold_events);
79        batch_fri_events.append(&mut other.batch_fri_events);
80        commit_pv_hash_events.append(&mut other.commit_pv_hash_events);
81    }
82
83    fn public_values<T: AbstractField>(&self) -> Vec<T> {
84        let pv_elms = self.public_values.as_array();
85
86        let ret: [T; PROOF_MAX_NUM_PVS] = array::from_fn(|i| {
87            if i < pv_elms.len() {
88                T::from_canonical_u32(pv_elms[i].as_canonical_u32())
89            } else {
90                T::zero()
91            }
92        });
93
94        ret.to_vec()
95    }
96}
97
98impl<F: Field> ExecutionRecord<F> {
99    #[inline]
100    pub fn fixed_log2_rows<A: MachineAir<F>>(&self, air: &A) -> Option<usize> {
101        self.program.fixed_log2_rows(air)
102    }
103
104    pub fn preallocate(&mut self) {
105        let event_counts =
106            self.program.inner.iter().fold(RecursionAirEventCount::default(), Add::add);
107        self.poseidon2_events.reserve(event_counts.poseidon2_wide_events);
108        self.mem_var_events.reserve(event_counts.mem_var_events);
109        self.base_alu_events.reserve(event_counts.base_alu_events);
110        self.ext_alu_events.reserve(event_counts.ext_alu_events);
111        self.exp_reverse_bits_len_events.reserve(event_counts.exp_reverse_bits_len_events);
112        self.select_events.reserve(event_counts.select_events);
113    }
114}