sp1_core_machine/global/
mod.rs

1use std::{borrow::Borrow, mem::transmute};
2
3use p3_air::{Air, BaseAir, PairBuilder};
4use p3_field::PrimeField32;
5use p3_matrix::{dense::RowMajorMatrix, Matrix};
6use rayon::iter::{
7    IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelBridge,
8    ParallelIterator,
9};
10use rayon_scan::ScanParallelIterator;
11use sp1_core_executor::{
12    events::{ByteLookupEvent, ByteRecord, GlobalInteractionEvent},
13    ExecutionRecord, Program,
14};
15use sp1_stark::{
16    air::{AirInteraction, InteractionScope, MachineAir, SP1AirBuilder},
17    septic_curve::{SepticCurve, SepticCurveComplete},
18    septic_digest::SepticDigest,
19    septic_extension::{SepticBlock, SepticExtension},
20    InteractionKind,
21};
22use std::borrow::BorrowMut;
23
24use crate::{
25    operations::{GlobalAccumulationOperation, GlobalInteractionOperation},
26    utils::{indices_arr, next_power_of_two, zeroed_f_vec},
27};
28use sp1_derive::AlignedBorrow;
29
30const NUM_GLOBAL_COLS: usize = size_of::<GlobalCols<u8>>();
31
32/// Creates the column map for the CPU.
33const fn make_col_map() -> GlobalCols<usize> {
34    let indices_arr = indices_arr::<NUM_GLOBAL_COLS>();
35    unsafe { transmute::<[usize; NUM_GLOBAL_COLS], GlobalCols<usize>>(indices_arr) }
36}
37
38const GLOBAL_COL_MAP: GlobalCols<usize> = make_col_map();
39
40pub const GLOBAL_INITIAL_DIGEST_POS: usize = GLOBAL_COL_MAP.accumulation.initial_digest[0].0[0];
41
42pub const GLOBAL_INITIAL_DIGEST_POS_COPY: usize = 377;
43
44#[repr(C)]
45pub struct Ghost {
46    pub v: [usize; GLOBAL_INITIAL_DIGEST_POS_COPY],
47}
48
49#[derive(Default)]
50pub struct GlobalChip;
51
52#[derive(AlignedBorrow)]
53#[repr(C)]
54pub struct GlobalCols<T: Copy> {
55    pub message: [T; 7],
56    pub kind: T,
57    pub interaction: GlobalInteractionOperation<T>,
58    pub is_receive: T,
59    pub is_send: T,
60    pub is_real: T,
61    pub accumulation: GlobalAccumulationOperation<T, 1>,
62}
63
64impl<F: PrimeField32> MachineAir<F> for GlobalChip {
65    type Record = ExecutionRecord;
66
67    type Program = Program;
68
69    fn name(&self) -> String {
70        assert_eq!(GLOBAL_INITIAL_DIGEST_POS_COPY, GLOBAL_INITIAL_DIGEST_POS);
71        "Global".to_string()
72    }
73
74    fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
75        let events = &input.global_interaction_events;
76
77        let chunk_size = std::cmp::max(events.len() / num_cpus::get(), 1);
78
79        let blu_batches = events
80            .chunks(chunk_size)
81            .par_bridge()
82            .map(|events| {
83                let mut blu: Vec<ByteLookupEvent> = Vec::new();
84                events.iter().for_each(|event| {
85                    blu.add_u16_range_check(event.message[0].try_into().unwrap());
86                });
87                blu
88            })
89            .collect::<Vec<_>>();
90
91        output.add_byte_lookup_events(blu_batches.into_iter().flatten().collect());
92    }
93
94    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
95        let events = &input.global_interaction_events;
96        let nb_rows = events.len();
97        let size_log2 = input.fixed_log2_rows::<F, _>(self);
98        let padded_nb_rows = next_power_of_two(nb_rows, size_log2);
99        Some(padded_nb_rows)
100    }
101
102    fn generate_trace(&self, input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
103        let events = &input.global_interaction_events;
104
105        let nb_rows = events.len();
106        let padded_nb_rows = <GlobalChip as MachineAir<F>>::num_rows(self, input).unwrap();
107        let mut values = zeroed_f_vec(padded_nb_rows * NUM_GLOBAL_COLS);
108        let chunk_size = std::cmp::max(nb_rows / num_cpus::get(), 0) + 1;
109
110        let mut chunks = values[..nb_rows * NUM_GLOBAL_COLS]
111            .chunks_mut(chunk_size * NUM_GLOBAL_COLS)
112            .collect::<Vec<_>>();
113
114        let point_chunks = chunks
115            .par_iter_mut()
116            .enumerate()
117            .map(|(i, rows)| {
118                let mut point_chunks = Vec::with_capacity(chunk_size * NUM_GLOBAL_COLS + 1);
119                if i == 0 {
120                    point_chunks.push(SepticCurveComplete::Affine(SepticDigest::<F>::zero().0));
121                }
122                rows.chunks_mut(NUM_GLOBAL_COLS).enumerate().for_each(|(j, row)| {
123                    let idx = i * chunk_size + j;
124                    let cols: &mut GlobalCols<F> = row.borrow_mut();
125                    let event: &GlobalInteractionEvent = &events[idx];
126                    cols.message = event.message.map(F::from_canonical_u32);
127                    cols.kind = F::from_canonical_u8(event.kind);
128                    cols.interaction.populate(
129                        SepticBlock(event.message),
130                        event.is_receive,
131                        true,
132                        event.kind,
133                    );
134                    cols.is_real = F::one();
135                    if event.is_receive {
136                        cols.is_receive = F::one();
137                    } else {
138                        cols.is_send = F::one();
139                    }
140                    point_chunks.push(SepticCurveComplete::Affine(SepticCurve {
141                        x: SepticExtension(cols.interaction.x_coordinate.0),
142                        y: SepticExtension(cols.interaction.y_coordinate.0),
143                    }));
144                });
145                point_chunks
146            })
147            .collect::<Vec<_>>();
148
149        let points = point_chunks.into_iter().flatten().collect::<Vec<_>>();
150        let cumulative_sum = points
151            .into_par_iter()
152            .with_min_len(1 << 15)
153            .scan(|a, b| *a + *b, SepticCurveComplete::Infinity)
154            .collect::<Vec<SepticCurveComplete<F>>>();
155
156        let final_digest = match cumulative_sum.last() {
157            Some(digest) => digest.point(),
158            None => SepticCurve::<F>::dummy(),
159        };
160        let dummy = SepticCurve::<F>::dummy();
161        let final_sum_checker = SepticCurve::<F>::sum_checker_x(final_digest, dummy, final_digest);
162
163        let chunk_size = std::cmp::max(padded_nb_rows / num_cpus::get(), 0) + 1;
164        values.chunks_mut(chunk_size * NUM_GLOBAL_COLS).enumerate().par_bridge().for_each(
165            |(i, rows)| {
166                rows.chunks_mut(NUM_GLOBAL_COLS).enumerate().for_each(|(j, row)| {
167                    let idx = i * chunk_size + j;
168                    let cols: &mut GlobalCols<F> = row.borrow_mut();
169                    if idx < nb_rows {
170                        cols.accumulation.populate_real(
171                            &cumulative_sum[idx..idx + 2],
172                            final_digest,
173                            final_sum_checker,
174                        );
175                    } else {
176                        cols.interaction.populate_dummy();
177                        cols.accumulation.populate_dummy(final_digest, final_sum_checker);
178                    }
179                });
180            },
181        );
182
183        RowMajorMatrix::new(values, NUM_GLOBAL_COLS)
184    }
185
186    fn included(&self, _: &Self::Record) -> bool {
187        true
188    }
189
190    fn commit_scope(&self) -> InteractionScope {
191        InteractionScope::Global
192    }
193}
194
195impl<F> BaseAir<F> for GlobalChip {
196    fn width(&self) -> usize {
197        NUM_GLOBAL_COLS
198    }
199}
200
201impl<AB> Air<AB> for GlobalChip
202where
203    AB: SP1AirBuilder + PairBuilder,
204{
205    fn eval(&self, builder: &mut AB) {
206        let main = builder.main();
207        let local = main.row_slice(0);
208        let local: &GlobalCols<AB::Var> = (*local).borrow();
209        let next = main.row_slice(1);
210        let next: &GlobalCols<AB::Var> = (*next).borrow();
211
212        // Receive the arguments, which consists of 7 message columns, `is_send`, `is_receive`, and
213        // `kind`. In MemoryGlobal, MemoryLocal, Syscall chips, `is_send`, `is_receive`,
214        // `kind` are sent with correct constant values. For a global send interaction,
215        // `is_send = 1` and `is_receive = 0` are used. For a global receive interaction,
216        // `is_send = 0` and `is_receive = 1` are used. For a memory global interaction,
217        // `kind = InteractionKind::Memory` is used. For a syscall global interaction, `kind
218        // = InteractionKind::Syscall` is used. Therefore, `is_send`, `is_receive` are
219        // already known to be boolean, and `kind` is also known to be a `u8` value.
220        // Note that `local.is_real` is constrained to be boolean in `eval_single_digest`.
221        builder.receive(
222            AirInteraction::new(
223                vec![
224                    local.message[0].into(),
225                    local.message[1].into(),
226                    local.message[2].into(),
227                    local.message[3].into(),
228                    local.message[4].into(),
229                    local.message[5].into(),
230                    local.message[6].into(),
231                    local.is_send.into(),
232                    local.is_receive.into(),
233                    local.kind.into(),
234                ],
235                local.is_real.into(),
236                InteractionKind::Global,
237            ),
238            InteractionScope::Local,
239        );
240
241        // Evaluate the interaction.
242        GlobalInteractionOperation::<AB::F>::eval_single_digest(
243            builder,
244            local.message.map(Into::into),
245            local.interaction,
246            local.is_receive.into(),
247            local.is_send.into(),
248            local.is_real,
249            local.kind,
250        );
251
252        // Evaluate the accumulation.
253        GlobalAccumulationOperation::<AB::F, 1>::eval_accumulation(
254            builder,
255            [local.interaction],
256            [local.is_real],
257            [next.is_real],
258            local.accumulation,
259            next.accumulation,
260        );
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    #![allow(clippy::print_stdout)]
267
268    use super::*;
269    use crate::programs::tests::*;
270    use p3_baby_bear::BabyBear;
271    use p3_matrix::dense::RowMajorMatrix;
272    use sp1_core_executor::{ExecutionRecord, Executor};
273    use sp1_stark::{air::MachineAir, SP1CoreOpts};
274
275    #[test]
276    fn test_global_generate_trace() {
277        let program = simple_program();
278        let mut runtime = Executor::new(program, SP1CoreOpts::default());
279        runtime.run().unwrap();
280        let shard = runtime.records[0].clone();
281
282        let chip: GlobalChip = GlobalChip;
283
284        let trace: RowMajorMatrix<BabyBear> =
285            chip.generate_trace(&shard, &mut ExecutionRecord::default());
286        println!("{:?}", trace.values);
287
288        for mem_event in shard.global_memory_finalize_events {
289            println!("{mem_event:?}");
290        }
291    }
292}