1use std::{
2 array,
3 cell::UnsafeCell,
4 mem::MaybeUninit,
5 ops::{Add, AddAssign},
6 sync::Arc,
7};
8
9use serde::{Deserialize, Serialize};
10use slop_algebra::{AbstractField, Field, PrimeField32};
11use sp1_hypercube::{air::SP1AirBuilder, InteractionKind, MachineRecord, PROOF_MAX_NUM_PVS};
12
13use crate::{
14 instruction::{HintBitsInstr, HintExt2FeltsInstr, HintInstr},
15 public_values::RecursionPublicValues,
16 ExtFeltEvent, Instruction, Poseidon2LinearLayerEvent, Poseidon2SBoxEvent, PrefixSumChecksEvent,
17};
18
19use super::{
20 BaseAluEvent, CommitPublicValuesEvent, ExtAluEvent, MemEvent, Poseidon2Event, RecursionProgram,
21 SelectEvent,
22};
23
24#[derive(Clone, Default, Debug)]
25pub struct ExecutionRecord<F> {
26 pub program: Arc<RecursionProgram<F>>,
27 pub index: u32,
29
30 pub base_alu_events: Vec<BaseAluEvent<F>>,
31 pub ext_alu_events: Vec<ExtAluEvent<F>>,
32 pub mem_const_count: usize,
33 pub mem_var_events: Vec<MemEvent<F>>,
34 pub public_values: RecursionPublicValues<F>,
36
37 pub ext_felt_conversion_events: Vec<ExtFeltEvent<F>>,
38 pub poseidon2_events: Vec<Poseidon2Event<F>>,
39 pub poseidon2_linear_layer_events: Vec<Poseidon2LinearLayerEvent<F>>,
40 pub poseidon2_sbox_events: Vec<Poseidon2SBoxEvent<F>>,
41 pub select_events: Vec<SelectEvent<F>>,
42 pub prefix_sum_checks_events: Vec<PrefixSumChecksEvent<F>>,
43 pub commit_pv_hash_events: Vec<CommitPublicValuesEvent<F>>,
44}
45
46#[derive(Debug)]
47pub struct UnsafeRecord<F> {
48 pub base_alu_events: Vec<MaybeUninit<UnsafeCell<BaseAluEvent<F>>>>,
49 pub ext_alu_events: Vec<MaybeUninit<UnsafeCell<ExtAluEvent<F>>>>,
50 pub mem_const_count: usize,
52 pub mem_var_events: Vec<MaybeUninit<UnsafeCell<MemEvent<F>>>>,
53 pub public_values: MaybeUninit<UnsafeCell<RecursionPublicValues<F>>>,
55
56 pub ext_felt_conversion_events: Vec<MaybeUninit<UnsafeCell<ExtFeltEvent<F>>>>,
57 pub poseidon2_events: Vec<MaybeUninit<UnsafeCell<Poseidon2Event<F>>>>,
58 pub poseidon2_linear_layer_events: Vec<MaybeUninit<UnsafeCell<Poseidon2LinearLayerEvent<F>>>>,
59 pub poseidon2_sbox_events: Vec<MaybeUninit<UnsafeCell<Poseidon2SBoxEvent<F>>>>,
60 pub select_events: Vec<MaybeUninit<UnsafeCell<SelectEvent<F>>>>,
61 pub prefix_sum_checks_events: Vec<MaybeUninit<UnsafeCell<PrefixSumChecksEvent<F>>>>,
62 pub commit_pv_hash_events: Vec<MaybeUninit<UnsafeCell<CommitPublicValuesEvent<F>>>>,
63}
64
65impl<F> UnsafeRecord<F> {
66 pub unsafe fn into_record(
71 self,
72 program: Arc<RecursionProgram<F>>,
73 index: u32,
74 ) -> ExecutionRecord<F> {
75 #[allow(clippy::missing_transmute_annotations)]
77 ExecutionRecord {
78 program,
79 index,
80 base_alu_events: std::mem::transmute(self.base_alu_events),
81 ext_alu_events: std::mem::transmute(self.ext_alu_events),
82 mem_const_count: self.mem_const_count,
83 mem_var_events: std::mem::transmute(self.mem_var_events),
84 public_values: self.public_values.assume_init().into_inner(),
85 ext_felt_conversion_events: std::mem::transmute(self.ext_felt_conversion_events),
86 poseidon2_events: std::mem::transmute(self.poseidon2_events),
87 poseidon2_linear_layer_events: std::mem::transmute(self.poseidon2_linear_layer_events),
88 poseidon2_sbox_events: std::mem::transmute(self.poseidon2_sbox_events),
89 select_events: std::mem::transmute(self.select_events),
90 prefix_sum_checks_events: std::mem::transmute(self.prefix_sum_checks_events),
91 commit_pv_hash_events: std::mem::transmute(self.commit_pv_hash_events),
92 }
93 }
94
95 pub fn new(event_counts: RecursionAirEventCount) -> Self
96 where
97 F: Field,
98 {
99 #[inline]
100 fn create_uninit_vec<T>(len: usize) -> Vec<MaybeUninit<T>> {
101 let mut vec = Vec::with_capacity(len);
102 unsafe { vec.set_len(len) };
105 vec
106 }
107
108 Self {
109 base_alu_events: create_uninit_vec(event_counts.base_alu_events),
110 ext_alu_events: create_uninit_vec(event_counts.ext_alu_events),
111 mem_const_count: event_counts.mem_const_events,
112 mem_var_events: create_uninit_vec(event_counts.mem_var_events),
113 public_values: MaybeUninit::uninit(),
114 ext_felt_conversion_events: create_uninit_vec(event_counts.ext_felt_conversion_events),
115 poseidon2_events: create_uninit_vec(event_counts.poseidon2_wide_events),
116 poseidon2_linear_layer_events: create_uninit_vec(
117 event_counts.poseidon2_linear_layer_events,
118 ),
119 poseidon2_sbox_events: create_uninit_vec(event_counts.poseidon2_sbox_events),
120 select_events: create_uninit_vec(event_counts.select_events),
121 prefix_sum_checks_events: create_uninit_vec(event_counts.prefix_sum_checks_events),
122 commit_pv_hash_events: create_uninit_vec(event_counts.commit_pv_hash_events),
123 }
124 }
125}
126
127unsafe impl<F> Sync for UnsafeRecord<F> {}
128
129impl<F: PrimeField32> MachineRecord for ExecutionRecord<F> {
130 fn stats(&self) -> hashbrown::HashMap<String, usize> {
131 [
132 ("base_alu_events", self.base_alu_events.len()),
133 ("ext_alu_events", self.ext_alu_events.len()),
134 ("mem_const_count", self.mem_const_count),
135 ("mem_var_events", self.mem_var_events.len()),
136 ("ext_felt_conversion_events", self.ext_felt_conversion_events.len()),
137 ("poseidon2_events", self.poseidon2_events.len()),
138 ("poseidon2_linear_layer_events", self.poseidon2_linear_layer_events.len()),
139 ("poseidon2_sbox_events", self.poseidon2_sbox_events.len()),
140 ("select_events", self.select_events.len()),
141 ("prefix_sum_checks_events", self.prefix_sum_checks_events.len()),
142 ("commit_pv_hash_events", self.commit_pv_hash_events.len()),
143 ]
144 .into_iter()
145 .map(|(k, v)| (k.to_owned(), v))
146 .collect()
147 }
148
149 fn append(&mut self, other: &mut Self) {
150 let Self {
152 program: _,
153 index: _,
154 base_alu_events,
155 ext_alu_events,
156 mem_const_count,
157 mem_var_events,
158 public_values: _,
159 ext_felt_conversion_events,
160 poseidon2_events,
161 poseidon2_linear_layer_events,
162 poseidon2_sbox_events,
163 select_events,
164 prefix_sum_checks_events,
165 commit_pv_hash_events,
166 } = self;
167 base_alu_events.append(&mut other.base_alu_events);
168 ext_alu_events.append(&mut other.ext_alu_events);
169 *mem_const_count += other.mem_const_count;
170 mem_var_events.append(&mut other.mem_var_events);
171 ext_felt_conversion_events.append(&mut other.ext_felt_conversion_events);
172 poseidon2_events.append(&mut other.poseidon2_events);
173 poseidon2_linear_layer_events.append(&mut other.poseidon2_linear_layer_events);
174 poseidon2_sbox_events.append(&mut other.poseidon2_sbox_events);
175 select_events.append(&mut other.select_events);
176 prefix_sum_checks_events.append(&mut other.prefix_sum_checks_events);
177 commit_pv_hash_events.append(&mut other.commit_pv_hash_events);
178 }
179
180 fn public_values<T: AbstractField>(&self) -> Vec<T> {
181 let pv_elms = self.public_values.as_array();
182
183 let ret: [T; PROOF_MAX_NUM_PVS] = array::from_fn(|i| {
184 if i < pv_elms.len() {
185 T::from_canonical_u32(pv_elms[i].as_canonical_u32())
186 } else {
187 T::zero()
188 }
189 });
190
191 ret.to_vec()
192 }
193
194 fn eval_public_values<AB: SP1AirBuilder>(_builder: &mut AB) {}
196
197 fn interactions_in_public_values() -> Vec<InteractionKind> {
198 vec![]
199 }
200}
201
202impl<F: Field> ExecutionRecord<F> {
203 pub fn compute_event_counts<'a>(
204 instrs: impl Iterator<Item = &'a Instruction<F>> + 'a,
205 ) -> RecursionAirEventCount {
206 instrs.fold(RecursionAirEventCount::default(), Add::add)
207 }
208}
209
210#[derive(Default, Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
211pub struct RecursionAirEventCount {
212 pub mem_const_events: usize,
213 pub mem_var_events: usize,
214 pub base_alu_events: usize,
215 pub ext_alu_events: usize,
216 pub ext_felt_conversion_events: usize,
217 pub poseidon2_wide_events: usize,
218 pub poseidon2_linear_layer_events: usize,
219 pub poseidon2_sbox_events: usize,
220 pub select_events: usize,
221 pub prefix_sum_checks_events: usize,
222 pub commit_pv_hash_events: usize,
223}
224
225impl<F> AddAssign<&Instruction<F>> for RecursionAirEventCount {
226 #[inline]
227 fn add_assign(&mut self, rhs: &Instruction<F>) {
228 match rhs {
229 Instruction::BaseAlu(_) => self.base_alu_events += 1,
230 Instruction::ExtAlu(_) => self.ext_alu_events += 1,
231 Instruction::ExtFelt(_) => self.ext_felt_conversion_events += 1,
232 Instruction::Mem(_) => self.mem_const_events += 1,
233 Instruction::Poseidon2(_) => self.poseidon2_wide_events += 1,
234 Instruction::Poseidon2LinearLayer(_) => self.poseidon2_linear_layer_events += 1,
235 Instruction::Poseidon2SBox(_) => self.poseidon2_sbox_events += 1,
236 Instruction::Select(_) => self.select_events += 1,
237 Instruction::Hint(HintInstr { output_addrs_mults })
238 | Instruction::HintBits(HintBitsInstr {
239 output_addrs_mults,
240 input_addr: _, }) => self.mem_var_events += output_addrs_mults.len(),
242 Instruction::HintExt2Felts(HintExt2FeltsInstr {
243 output_addrs_mults,
244 input_addr: _, }) => self.mem_var_events += output_addrs_mults.len(),
246 Instruction::PrefixSumChecks(instr) => {
247 self.prefix_sum_checks_events += instr.addrs.x1.len()
248 }
249 Instruction::HintAddCurve(instr) => {
250 self.mem_var_events += instr.output_x_addrs_mults.len();
251 self.mem_var_events += instr.output_y_addrs_mults.len();
252 }
253 Instruction::CommitPublicValues(_) => self.commit_pv_hash_events += 1,
254 Instruction::Print(_) | Instruction::DebugBacktrace(_) => {}
255 }
256 }
257}
258
259impl<F> Add<&Instruction<F>> for RecursionAirEventCount {
260 type Output = Self;
261
262 #[inline]
263 fn add(mut self, rhs: &Instruction<F>) -> Self::Output {
264 self += rhs;
265 self
266 }
267}