1use std::{borrow::Borrow, mem::transmute};
2
3use p3_air::{Air, BaseAir, PairBuilder};
4use p3_field::PrimeField32;
5use p3_matrix::{dense::RowMajorMatrix, Matrix};
6use rayon::iter::{
7 IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelBridge,
8 ParallelIterator,
9};
10use rayon_scan::ScanParallelIterator;
11use sp1_core_executor::{
12 events::{ByteLookupEvent, ByteRecord, GlobalInteractionEvent},
13 ExecutionRecord, Program,
14};
15use sp1_stark::{
16 air::{AirInteraction, InteractionScope, MachineAir, SP1AirBuilder},
17 septic_curve::{SepticCurve, SepticCurveComplete},
18 septic_digest::SepticDigest,
19 septic_extension::{SepticBlock, SepticExtension},
20 InteractionKind,
21};
22use std::borrow::BorrowMut;
23
24use crate::{
25 operations::{GlobalAccumulationOperation, GlobalInteractionOperation},
26 utils::{indices_arr, next_power_of_two, zeroed_f_vec},
27};
28use sp1_derive::AlignedBorrow;
29
30const NUM_GLOBAL_COLS: usize = size_of::<GlobalCols<u8>>();
31
32const fn make_col_map() -> GlobalCols<usize> {
34 let indices_arr = indices_arr::<NUM_GLOBAL_COLS>();
35 unsafe { transmute::<[usize; NUM_GLOBAL_COLS], GlobalCols<usize>>(indices_arr) }
36}
37
38const GLOBAL_COL_MAP: GlobalCols<usize> = make_col_map();
39
40pub const GLOBAL_INITIAL_DIGEST_POS: usize = GLOBAL_COL_MAP.accumulation.initial_digest[0].0[0];
41
42pub const GLOBAL_INITIAL_DIGEST_POS_COPY: usize = 377;
43
44#[repr(C)]
45pub struct Ghost {
46 pub v: [usize; GLOBAL_INITIAL_DIGEST_POS_COPY],
47}
48
49#[derive(Default)]
50pub struct GlobalChip;
51
52#[derive(AlignedBorrow)]
53#[repr(C)]
54pub struct GlobalCols<T: Copy> {
55 pub message: [T; 7],
56 pub kind: T,
57 pub interaction: GlobalInteractionOperation<T>,
58 pub is_receive: T,
59 pub is_send: T,
60 pub is_real: T,
61 pub accumulation: GlobalAccumulationOperation<T, 1>,
62}
63
64impl<F: PrimeField32> MachineAir<F> for GlobalChip {
65 type Record = ExecutionRecord;
66
67 type Program = Program;
68
69 fn name(&self) -> String {
70 assert_eq!(GLOBAL_INITIAL_DIGEST_POS_COPY, GLOBAL_INITIAL_DIGEST_POS);
71 "Global".to_string()
72 }
73
74 fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
75 let events = &input.global_interaction_events;
76
77 let chunk_size = std::cmp::max(events.len() / num_cpus::get(), 1);
78
79 let blu_batches = events
80 .chunks(chunk_size)
81 .par_bridge()
82 .map(|events| {
83 let mut blu: Vec<ByteLookupEvent> = Vec::new();
84 events.iter().for_each(|event| {
85 blu.add_u16_range_check(event.message[0].try_into().unwrap());
86 });
87 blu
88 })
89 .collect::<Vec<_>>();
90
91 output.add_byte_lookup_events(blu_batches.into_iter().flatten().collect());
92 }
93
94 fn num_rows(&self, input: &Self::Record) -> Option<usize> {
95 let events = &input.global_interaction_events;
96 let nb_rows = events.len();
97 let size_log2 = input.fixed_log2_rows::<F, _>(self);
98 let padded_nb_rows = next_power_of_two(nb_rows, size_log2);
99 Some(padded_nb_rows)
100 }
101
102 fn generate_trace(&self, input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
103 let events = &input.global_interaction_events;
104
105 let nb_rows = events.len();
106 let padded_nb_rows = <GlobalChip as MachineAir<F>>::num_rows(self, input).unwrap();
107 let mut values = zeroed_f_vec(padded_nb_rows * NUM_GLOBAL_COLS);
108 let chunk_size = std::cmp::max(nb_rows / num_cpus::get(), 0) + 1;
109
110 let mut chunks = values[..nb_rows * NUM_GLOBAL_COLS]
111 .chunks_mut(chunk_size * NUM_GLOBAL_COLS)
112 .collect::<Vec<_>>();
113
114 let point_chunks = chunks
115 .par_iter_mut()
116 .enumerate()
117 .map(|(i, rows)| {
118 let mut point_chunks = Vec::with_capacity(chunk_size * NUM_GLOBAL_COLS + 1);
119 if i == 0 {
120 point_chunks.push(SepticCurveComplete::Affine(SepticDigest::<F>::zero().0));
121 }
122 rows.chunks_mut(NUM_GLOBAL_COLS).enumerate().for_each(|(j, row)| {
123 let idx = i * chunk_size + j;
124 let cols: &mut GlobalCols<F> = row.borrow_mut();
125 let event: &GlobalInteractionEvent = &events[idx];
126 cols.message = event.message.map(F::from_canonical_u32);
127 cols.kind = F::from_canonical_u8(event.kind);
128 cols.interaction.populate(
129 SepticBlock(event.message),
130 event.is_receive,
131 true,
132 event.kind,
133 );
134 cols.is_real = F::one();
135 if event.is_receive {
136 cols.is_receive = F::one();
137 } else {
138 cols.is_send = F::one();
139 }
140 point_chunks.push(SepticCurveComplete::Affine(SepticCurve {
141 x: SepticExtension(cols.interaction.x_coordinate.0),
142 y: SepticExtension(cols.interaction.y_coordinate.0),
143 }));
144 });
145 point_chunks
146 })
147 .collect::<Vec<_>>();
148
149 let points = point_chunks.into_iter().flatten().collect::<Vec<_>>();
150 let cumulative_sum = points
151 .into_par_iter()
152 .with_min_len(1 << 15)
153 .scan(|a, b| *a + *b, SepticCurveComplete::Infinity)
154 .collect::<Vec<SepticCurveComplete<F>>>();
155
156 let final_digest = match cumulative_sum.last() {
157 Some(digest) => digest.point(),
158 None => SepticCurve::<F>::dummy(),
159 };
160 let dummy = SepticCurve::<F>::dummy();
161 let final_sum_checker = SepticCurve::<F>::sum_checker_x(final_digest, dummy, final_digest);
162
163 let chunk_size = std::cmp::max(padded_nb_rows / num_cpus::get(), 0) + 1;
164 values.chunks_mut(chunk_size * NUM_GLOBAL_COLS).enumerate().par_bridge().for_each(
165 |(i, rows)| {
166 rows.chunks_mut(NUM_GLOBAL_COLS).enumerate().for_each(|(j, row)| {
167 let idx = i * chunk_size + j;
168 let cols: &mut GlobalCols<F> = row.borrow_mut();
169 if idx < nb_rows {
170 cols.accumulation.populate_real(
171 &cumulative_sum[idx..idx + 2],
172 final_digest,
173 final_sum_checker,
174 );
175 } else {
176 cols.interaction.populate_dummy();
177 cols.accumulation.populate_dummy(final_digest, final_sum_checker);
178 }
179 });
180 },
181 );
182
183 RowMajorMatrix::new(values, NUM_GLOBAL_COLS)
184 }
185
186 fn included(&self, _: &Self::Record) -> bool {
187 true
188 }
189
190 fn commit_scope(&self) -> InteractionScope {
191 InteractionScope::Global
192 }
193}
194
195impl<F> BaseAir<F> for GlobalChip {
196 fn width(&self) -> usize {
197 NUM_GLOBAL_COLS
198 }
199}
200
201impl<AB> Air<AB> for GlobalChip
202where
203 AB: SP1AirBuilder + PairBuilder,
204{
205 fn eval(&self, builder: &mut AB) {
206 let main = builder.main();
207 let local = main.row_slice(0);
208 let local: &GlobalCols<AB::Var> = (*local).borrow();
209 let next = main.row_slice(1);
210 let next: &GlobalCols<AB::Var> = (*next).borrow();
211
212 builder.receive(
222 AirInteraction::new(
223 vec![
224 local.message[0].into(),
225 local.message[1].into(),
226 local.message[2].into(),
227 local.message[3].into(),
228 local.message[4].into(),
229 local.message[5].into(),
230 local.message[6].into(),
231 local.is_send.into(),
232 local.is_receive.into(),
233 local.kind.into(),
234 ],
235 local.is_real.into(),
236 InteractionKind::Global,
237 ),
238 InteractionScope::Local,
239 );
240
241 GlobalInteractionOperation::<AB::F>::eval_single_digest(
243 builder,
244 local.message.map(Into::into),
245 local.interaction,
246 local.is_receive.into(),
247 local.is_send.into(),
248 local.is_real,
249 local.kind,
250 );
251
252 GlobalAccumulationOperation::<AB::F, 1>::eval_accumulation(
254 builder,
255 [local.interaction],
256 [local.is_real],
257 [next.is_real],
258 local.accumulation,
259 next.accumulation,
260 );
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 #![allow(clippy::print_stdout)]
267
268 use super::*;
269 use crate::programs::tests::*;
270 use p3_baby_bear::BabyBear;
271 use p3_matrix::dense::RowMajorMatrix;
272 use sp1_core_executor::{ExecutionRecord, Executor};
273 use sp1_stark::{air::MachineAir, SP1CoreOpts};
274
275 #[test]
276 fn test_global_generate_trace() {
277 let program = simple_program();
278 let mut runtime = Executor::new(program, SP1CoreOpts::default());
279 runtime.run().unwrap();
280 let shard = runtime.records[0].clone();
281
282 let chip: GlobalChip = GlobalChip;
283
284 let trace: RowMajorMatrix<BabyBear> =
285 chip.generate_trace(&shard, &mut ExecutionRecord::default());
286 println!("{:?}", trace.values);
287
288 for mem_event in shard.global_memory_finalize_events {
289 println!("{mem_event:?}");
290 }
291 }
292}