Skip to main content

sp1_recursion_executor/
record.rs

1use std::{
2    array,
3    cell::UnsafeCell,
4    mem::MaybeUninit,
5    ops::{Add, AddAssign},
6    sync::Arc,
7};
8
9use serde::{Deserialize, Serialize};
10use slop_algebra::{AbstractField, Field, PrimeField32};
11use sp1_hypercube::{air::SP1AirBuilder, InteractionKind, MachineRecord, PROOF_MAX_NUM_PVS};
12
13use crate::{
14    instruction::{HintBitsInstr, HintExt2FeltsInstr, HintInstr},
15    public_values::RecursionPublicValues,
16    ExtFeltEvent, Instruction, Poseidon2LinearLayerEvent, Poseidon2SBoxEvent, PrefixSumChecksEvent,
17};
18
19use super::{
20    BaseAluEvent, CommitPublicValuesEvent, ExtAluEvent, MemEvent, Poseidon2Event, RecursionProgram,
21    SelectEvent,
22};
23
24#[derive(Clone, Default, Debug)]
25pub struct ExecutionRecord<F> {
26    pub program: Arc<RecursionProgram<F>>,
27    /// The index of the shard.
28    pub index: u32,
29
30    pub base_alu_events: Vec<BaseAluEvent<F>>,
31    pub ext_alu_events: Vec<ExtAluEvent<F>>,
32    pub mem_const_count: usize,
33    pub mem_var_events: Vec<MemEvent<F>>,
34    /// The public values.
35    pub public_values: RecursionPublicValues<F>,
36
37    pub ext_felt_conversion_events: Vec<ExtFeltEvent<F>>,
38    pub poseidon2_events: Vec<Poseidon2Event<F>>,
39    pub poseidon2_linear_layer_events: Vec<Poseidon2LinearLayerEvent<F>>,
40    pub poseidon2_sbox_events: Vec<Poseidon2SBoxEvent<F>>,
41    pub select_events: Vec<SelectEvent<F>>,
42    pub prefix_sum_checks_events: Vec<PrefixSumChecksEvent<F>>,
43    pub commit_pv_hash_events: Vec<CommitPublicValuesEvent<F>>,
44}
45
46#[derive(Debug)]
47pub struct UnsafeRecord<F> {
48    pub base_alu_events: Vec<MaybeUninit<UnsafeCell<BaseAluEvent<F>>>>,
49    pub ext_alu_events: Vec<MaybeUninit<UnsafeCell<ExtAluEvent<F>>>>,
50    // Can be computed by the analysis step.
51    pub mem_const_count: usize,
52    pub mem_var_events: Vec<MaybeUninit<UnsafeCell<MemEvent<F>>>>,
53    /// The public values.
54    pub public_values: MaybeUninit<UnsafeCell<RecursionPublicValues<F>>>,
55
56    pub ext_felt_conversion_events: Vec<MaybeUninit<UnsafeCell<ExtFeltEvent<F>>>>,
57    pub poseidon2_events: Vec<MaybeUninit<UnsafeCell<Poseidon2Event<F>>>>,
58    pub poseidon2_linear_layer_events: Vec<MaybeUninit<UnsafeCell<Poseidon2LinearLayerEvent<F>>>>,
59    pub poseidon2_sbox_events: Vec<MaybeUninit<UnsafeCell<Poseidon2SBoxEvent<F>>>>,
60    pub select_events: Vec<MaybeUninit<UnsafeCell<SelectEvent<F>>>>,
61    pub prefix_sum_checks_events: Vec<MaybeUninit<UnsafeCell<PrefixSumChecksEvent<F>>>>,
62    pub commit_pv_hash_events: Vec<MaybeUninit<UnsafeCell<CommitPublicValuesEvent<F>>>>,
63}
64
65impl<F> UnsafeRecord<F> {
66    /// # Safety
67    ///
68    /// The caller must ensure that the `UnsafeRecord` is fully initialized, this is
69    /// done by the executor.
70    pub unsafe fn into_record(
71        self,
72        program: Arc<RecursionProgram<F>>,
73        index: u32,
74    ) -> ExecutionRecord<F> {
75        // SAFETY: `T` and `MaybeUninit<UnsafeCell<T>>` have the same memory layout.
76        #[allow(clippy::missing_transmute_annotations)]
77        ExecutionRecord {
78            program,
79            index,
80            base_alu_events: std::mem::transmute(self.base_alu_events),
81            ext_alu_events: std::mem::transmute(self.ext_alu_events),
82            mem_const_count: self.mem_const_count,
83            mem_var_events: std::mem::transmute(self.mem_var_events),
84            public_values: self.public_values.assume_init().into_inner(),
85            ext_felt_conversion_events: std::mem::transmute(self.ext_felt_conversion_events),
86            poseidon2_events: std::mem::transmute(self.poseidon2_events),
87            poseidon2_linear_layer_events: std::mem::transmute(self.poseidon2_linear_layer_events),
88            poseidon2_sbox_events: std::mem::transmute(self.poseidon2_sbox_events),
89            select_events: std::mem::transmute(self.select_events),
90            prefix_sum_checks_events: std::mem::transmute(self.prefix_sum_checks_events),
91            commit_pv_hash_events: std::mem::transmute(self.commit_pv_hash_events),
92        }
93    }
94
95    pub fn new(event_counts: RecursionAirEventCount) -> Self
96    where
97        F: Field,
98    {
99        #[inline]
100        fn create_uninit_vec<T>(len: usize) -> Vec<MaybeUninit<T>> {
101            let mut vec = Vec::with_capacity(len);
102            // SAFETY: The vector has enough capacity to hold the elements as we just allocated it,
103            // and the type `T` is `MaybeUninit` which implies that an "uninitialized" value is OK.
104            unsafe { vec.set_len(len) };
105            vec
106        }
107
108        Self {
109            base_alu_events: create_uninit_vec(event_counts.base_alu_events),
110            ext_alu_events: create_uninit_vec(event_counts.ext_alu_events),
111            mem_const_count: event_counts.mem_const_events,
112            mem_var_events: create_uninit_vec(event_counts.mem_var_events),
113            public_values: MaybeUninit::uninit(),
114            ext_felt_conversion_events: create_uninit_vec(event_counts.ext_felt_conversion_events),
115            poseidon2_events: create_uninit_vec(event_counts.poseidon2_wide_events),
116            poseidon2_linear_layer_events: create_uninit_vec(
117                event_counts.poseidon2_linear_layer_events,
118            ),
119            poseidon2_sbox_events: create_uninit_vec(event_counts.poseidon2_sbox_events),
120            select_events: create_uninit_vec(event_counts.select_events),
121            prefix_sum_checks_events: create_uninit_vec(event_counts.prefix_sum_checks_events),
122            commit_pv_hash_events: create_uninit_vec(event_counts.commit_pv_hash_events),
123        }
124    }
125}
126
127unsafe impl<F> Sync for UnsafeRecord<F> {}
128
129impl<F: PrimeField32> MachineRecord for ExecutionRecord<F> {
130    fn stats(&self) -> hashbrown::HashMap<String, usize> {
131        [
132            ("base_alu_events", self.base_alu_events.len()),
133            ("ext_alu_events", self.ext_alu_events.len()),
134            ("mem_const_count", self.mem_const_count),
135            ("mem_var_events", self.mem_var_events.len()),
136            ("ext_felt_conversion_events", self.ext_felt_conversion_events.len()),
137            ("poseidon2_events", self.poseidon2_events.len()),
138            ("poseidon2_linear_layer_events", self.poseidon2_linear_layer_events.len()),
139            ("poseidon2_sbox_events", self.poseidon2_sbox_events.len()),
140            ("select_events", self.select_events.len()),
141            ("prefix_sum_checks_events", self.prefix_sum_checks_events.len()),
142            ("commit_pv_hash_events", self.commit_pv_hash_events.len()),
143        ]
144        .into_iter()
145        .map(|(k, v)| (k.to_owned(), v))
146        .collect()
147    }
148
149    fn append(&mut self, other: &mut Self) {
150        // Exhaustive destructuring for refactoring purposes.
151        let Self {
152            program: _,
153            index: _,
154            base_alu_events,
155            ext_alu_events,
156            mem_const_count,
157            mem_var_events,
158            public_values: _,
159            ext_felt_conversion_events,
160            poseidon2_events,
161            poseidon2_linear_layer_events,
162            poseidon2_sbox_events,
163            select_events,
164            prefix_sum_checks_events,
165            commit_pv_hash_events,
166        } = self;
167        base_alu_events.append(&mut other.base_alu_events);
168        ext_alu_events.append(&mut other.ext_alu_events);
169        *mem_const_count += other.mem_const_count;
170        mem_var_events.append(&mut other.mem_var_events);
171        ext_felt_conversion_events.append(&mut other.ext_felt_conversion_events);
172        poseidon2_events.append(&mut other.poseidon2_events);
173        poseidon2_linear_layer_events.append(&mut other.poseidon2_linear_layer_events);
174        poseidon2_sbox_events.append(&mut other.poseidon2_sbox_events);
175        select_events.append(&mut other.select_events);
176        prefix_sum_checks_events.append(&mut other.prefix_sum_checks_events);
177        commit_pv_hash_events.append(&mut other.commit_pv_hash_events);
178    }
179
180    fn public_values<T: AbstractField>(&self) -> Vec<T> {
181        let pv_elms = self.public_values.as_array();
182
183        let ret: [T; PROOF_MAX_NUM_PVS] = array::from_fn(|i| {
184            if i < pv_elms.len() {
185                T::from_canonical_u32(pv_elms[i].as_canonical_u32())
186            } else {
187                T::zero()
188            }
189        });
190
191        ret.to_vec()
192    }
193
194    // No public value constraints for recursion public values.
195    fn eval_public_values<AB: SP1AirBuilder>(_builder: &mut AB) {}
196
197    fn interactions_in_public_values() -> Vec<InteractionKind> {
198        vec![]
199    }
200}
201
202impl<F: Field> ExecutionRecord<F> {
203    pub fn compute_event_counts<'a>(
204        instrs: impl Iterator<Item = &'a Instruction<F>> + 'a,
205    ) -> RecursionAirEventCount {
206        instrs.fold(RecursionAirEventCount::default(), Add::add)
207    }
208}
209
210#[derive(Default, Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
211pub struct RecursionAirEventCount {
212    pub mem_const_events: usize,
213    pub mem_var_events: usize,
214    pub base_alu_events: usize,
215    pub ext_alu_events: usize,
216    pub ext_felt_conversion_events: usize,
217    pub poseidon2_wide_events: usize,
218    pub poseidon2_linear_layer_events: usize,
219    pub poseidon2_sbox_events: usize,
220    pub select_events: usize,
221    pub prefix_sum_checks_events: usize,
222    pub commit_pv_hash_events: usize,
223}
224
225impl<F> AddAssign<&Instruction<F>> for RecursionAirEventCount {
226    #[inline]
227    fn add_assign(&mut self, rhs: &Instruction<F>) {
228        match rhs {
229            Instruction::BaseAlu(_) => self.base_alu_events += 1,
230            Instruction::ExtAlu(_) => self.ext_alu_events += 1,
231            Instruction::ExtFelt(_) => self.ext_felt_conversion_events += 1,
232            Instruction::Mem(_) => self.mem_const_events += 1,
233            Instruction::Poseidon2(_) => self.poseidon2_wide_events += 1,
234            Instruction::Poseidon2LinearLayer(_) => self.poseidon2_linear_layer_events += 1,
235            Instruction::Poseidon2SBox(_) => self.poseidon2_sbox_events += 1,
236            Instruction::Select(_) => self.select_events += 1,
237            Instruction::Hint(HintInstr { output_addrs_mults })
238            | Instruction::HintBits(HintBitsInstr {
239                output_addrs_mults,
240                input_addr: _, // No receive interaction for the hint operation
241            }) => self.mem_var_events += output_addrs_mults.len(),
242            Instruction::HintExt2Felts(HintExt2FeltsInstr {
243                output_addrs_mults,
244                input_addr: _, // No receive interaction for the hint operation
245            }) => self.mem_var_events += output_addrs_mults.len(),
246            Instruction::PrefixSumChecks(instr) => {
247                self.prefix_sum_checks_events += instr.addrs.x1.len()
248            }
249            Instruction::HintAddCurve(instr) => {
250                self.mem_var_events += instr.output_x_addrs_mults.len();
251                self.mem_var_events += instr.output_y_addrs_mults.len();
252            }
253            Instruction::CommitPublicValues(_) => self.commit_pv_hash_events += 1,
254            Instruction::Print(_) | Instruction::DebugBacktrace(_) => {}
255        }
256    }
257}
258
259impl<F> Add<&Instruction<F>> for RecursionAirEventCount {
260    type Output = Self;
261
262    #[inline]
263    fn add(mut self, rhs: &Instruction<F>) -> Self::Output {
264        self += rhs;
265        self
266    }
267}