Skip to main content

sp1_core_machine/global/
mod.rs

1use std::{
2    borrow::Borrow,
3    mem::{transmute, MaybeUninit},
4};
5
6use rayon::iter::{
7    IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelBridge,
8    ParallelIterator,
9};
10use rayon_scan::ScanParallelIterator;
11use slop_air::{Air, BaseAir, PairBuilder};
12use slop_algebra::PrimeField32;
13use slop_matrix::Matrix;
14use sp1_core_executor::{
15    events::{ByteLookupEvent, ByteRecord, GlobalInteractionEvent},
16    ExecutionRecord, Program,
17};
18use sp1_hypercube::{
19    air::{AirInteraction, InteractionScope, MachineAir, SP1AirBuilder},
20    septic_curve::{SepticCurve, SepticCurveComplete},
21    septic_digest::SepticDigest,
22    septic_extension::SepticExtension,
23    InteractionKind,
24};
25use std::borrow::BorrowMut;
26
27use crate::{
28    operations::{GlobalAccumulationOperation, GlobalInteractionOperation},
29    utils::{indices_arr, next_multiple_of_32},
30};
31use sp1_derive::AlignedBorrow;
32
33const NUM_GLOBAL_COLS: usize = size_of::<GlobalCols<u8>>();
34
35/// Creates the column map for the CPU.
36const fn make_col_map() -> GlobalCols<usize> {
37    let indices_arr = indices_arr::<NUM_GLOBAL_COLS>();
38    unsafe { transmute::<[usize; NUM_GLOBAL_COLS], GlobalCols<usize>>(indices_arr) }
39}
40
41const GLOBAL_COL_MAP: GlobalCols<usize> = make_col_map();
42
43pub const GLOBAL_INITIAL_DIGEST_POS: usize = GLOBAL_COL_MAP.accumulation.initial_digest[0].0[0];
44
45const GLOBAL_OFFSET_POS: usize = GLOBAL_COL_MAP.interaction.offset;
46
47pub const GLOBAL_INITIAL_DIGEST_POS_COPY: usize = 213;
48
49pub const GLOBAL_OFFSET_POS_COPY: usize = 204;
50
51#[repr(C)]
52pub struct Ghost {
53    pub v: [usize; GLOBAL_INITIAL_DIGEST_POS_COPY],
54}
55
56#[derive(Default)]
57pub struct GlobalChip;
58
59#[derive(AlignedBorrow)]
60#[repr(C)]
61pub struct GlobalCols<T: Copy> {
62    pub message: [T; 8],
63    pub kind: T,
64    pub message_0_16bit_limb: T,
65    pub message_0_8bit_limb: T,
66    pub interaction: GlobalInteractionOperation<T>,
67    pub is_real: T,
68    pub is_receive: T,
69    pub is_send: T,
70    pub index: T,
71    pub accumulation: GlobalAccumulationOperation<T>,
72}
73
74impl<F: PrimeField32> MachineAir<F> for GlobalChip {
75    type Record = ExecutionRecord;
76
77    type Program = Program;
78
79    fn name(&self) -> &'static str {
80        debug_assert_eq!(GLOBAL_INITIAL_DIGEST_POS_COPY, GLOBAL_INITIAL_DIGEST_POS);
81        debug_assert_eq!(GLOBAL_OFFSET_POS_COPY, GLOBAL_OFFSET_POS);
82        "Global"
83    }
84
85    fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
86        let events = &input.global_interaction_events;
87
88        let chunk_size = std::cmp::max(events.len() / num_cpus::get(), 1);
89
90        let blu_batches = events
91            .chunks(chunk_size)
92            .par_bridge()
93            .map(|events| {
94                let mut blu: Vec<ByteLookupEvent> = Vec::new();
95                let mut row = [F::zero(); NUM_GLOBAL_COLS];
96                let cols: &mut GlobalCols<F> = row.as_mut_slice().borrow_mut();
97                events.iter().for_each(|event| {
98                    let message0_16bit_limb = (event.message[0] & 0xffff) as u16;
99                    let message0_8bit_limb = ((event.message[0] >> 16) & 0xff) as u8;
100                    blu.add_u16_range_check(message0_16bit_limb);
101                    blu.add_u16_range_check(event.message[7] as u16);
102                    blu.add_u8_range_check(message0_8bit_limb, 0);
103                    blu.add_bit_range_check(event.kind as u16, 6);
104                    if !input.global_dependencies_opt {
105                        cols.interaction.populate(
106                            &mut blu,
107                            event.message,
108                            event.is_receive,
109                            true,
110                            event.kind,
111                        );
112                    }
113                });
114                blu
115            })
116            .collect::<Vec<_>>();
117
118        output.add_byte_lookup_events(blu_batches.into_iter().flatten().collect());
119        output.public_values.global_count = events.len() as u32;
120    }
121
122    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
123        let events = &input.global_interaction_events;
124        let nb_rows = events.len();
125        let size_log2 = input.fixed_log2_rows::<F, _>(self);
126        let padded_nb_rows = next_multiple_of_32(nb_rows, size_log2);
127
128        Some(padded_nb_rows)
129    }
130
131    fn generate_trace_into(
132        &self,
133        input: &ExecutionRecord,
134        output: &mut ExecutionRecord,
135        buffer: &mut [MaybeUninit<F>],
136    ) {
137        let events = &input.global_interaction_events;
138
139        let nb_rows = events.len();
140
141        let padded_nb_rows = <GlobalChip as MachineAir<F>>::num_rows(self, input).unwrap();
142        let chunk_size = std::cmp::max(nb_rows / num_cpus::get(), 0) + 1;
143
144        let buffer_ptr = buffer.as_mut_ptr() as *mut F;
145        let values = unsafe {
146            core::slice::from_raw_parts_mut(buffer_ptr, padded_nb_rows * NUM_GLOBAL_COLS)
147        };
148        let mut chunks = values[..nb_rows * NUM_GLOBAL_COLS]
149            .chunks_mut(chunk_size * NUM_GLOBAL_COLS)
150            .collect::<Vec<_>>();
151
152        let point_chunks = chunks
153            .par_iter_mut()
154            .enumerate()
155            .map(|(i, rows)| {
156                let mut point_chunks = Vec::with_capacity(chunk_size * NUM_GLOBAL_COLS + 1);
157                if i == 0 {
158                    point_chunks.push(SepticCurveComplete::Affine(SepticDigest::<F>::zero().0));
159                }
160                let mut blu = Vec::new();
161                rows.chunks_mut(NUM_GLOBAL_COLS).enumerate().for_each(|(j, row)| {
162                    let idx = i * chunk_size + j;
163                    let cols: &mut GlobalCols<F> = row.borrow_mut();
164                    let event: &GlobalInteractionEvent = &events[idx];
165                    cols.message = event.message.map(F::from_canonical_u32);
166                    cols.kind = F::from_canonical_u8(event.kind);
167                    cols.index = F::from_canonical_u32(idx as u32);
168                    cols.interaction.populate(
169                        &mut blu,
170                        event.message,
171                        event.is_receive,
172                        true,
173                        event.kind,
174                    );
175                    cols.is_real = F::one();
176                    if event.is_receive {
177                        cols.is_receive = F::one();
178                        cols.is_send = F::zero();
179                    } else {
180                        cols.is_receive = F::zero();
181                        cols.is_send = F::one();
182                    }
183                    cols.message_0_16bit_limb =
184                        F::from_canonical_u16((event.message[0] & 0xffff) as u16);
185                    cols.message_0_8bit_limb =
186                        F::from_canonical_u8(((event.message[0] >> 16) & 0xff) as u8);
187                    point_chunks.push(SepticCurveComplete::Affine(SepticCurve {
188                        x: SepticExtension(cols.interaction.x_coordinate.0),
189                        y: SepticExtension(cols.interaction.y_coordinate.0),
190                    }));
191                });
192                point_chunks
193            })
194            .collect::<Vec<_>>();
195
196        let points = point_chunks.into_iter().flatten().collect::<Vec<_>>();
197        let cumulative_sum = points
198            .into_par_iter()
199            .with_min_len(1 << 15)
200            .scan(|a, b| *a + *b, SepticCurveComplete::Infinity)
201            .collect::<Vec<SepticCurveComplete<F>>>();
202
203        let final_digest = match cumulative_sum.last() {
204            Some(digest) => digest.point(),
205            None => SepticDigest::<F>::zero().0,
206        };
207
208        let mut global_sum = input.global_cumulative_sum.lock().unwrap();
209        *global_sum = SepticDigest(SepticCurve::convert(final_digest, |x| F::as_canonical_u32(&x)));
210
211        output.global_interaction_event_count = nb_rows as u32;
212
213        let start_digest = SepticDigest::<F>::zero().0;
214        let dummy = SepticCurve::<F>::dummy();
215        let start_digest_plus_dummy = start_digest.add_incomplete(dummy);
216
217        let chunk_size = std::cmp::max(padded_nb_rows / num_cpus::get(), 0) + 1;
218        values.chunks_mut(chunk_size * NUM_GLOBAL_COLS).enumerate().par_bridge().for_each(
219            |(i, rows)| {
220                rows.chunks_mut(NUM_GLOBAL_COLS).enumerate().for_each(|(j, row)| {
221                    let idx = i * chunk_size + j;
222                    if idx >= nb_rows {
223                        unsafe {
224                            core::ptr::write_bytes(row.as_mut_ptr(), 0, NUM_GLOBAL_COLS);
225                        }
226                    }
227                    let cols: &mut GlobalCols<F> = row.borrow_mut();
228                    if idx < nb_rows {
229                        cols.accumulation.populate_real(&cumulative_sum[idx..idx + 2]);
230                    } else {
231                        cols.interaction.populate_dummy();
232                        cols.accumulation.populate_dummy(start_digest, start_digest_plus_dummy);
233                    }
234                });
235            },
236        );
237    }
238
239    fn included(&self, _: &Self::Record) -> bool {
240        true
241    }
242}
243
244impl<F> BaseAir<F> for GlobalChip {
245    fn width(&self) -> usize {
246        NUM_GLOBAL_COLS
247    }
248}
249
250impl<AB> Air<AB> for GlobalChip
251where
252    AB: SP1AirBuilder + PairBuilder,
253{
254    fn eval(&self, builder: &mut AB) {
255        let main = builder.main();
256        let local = main.row_slice(0);
257        let local: &GlobalCols<AB::Var> = (*local).borrow();
258
259        // Constrain that `local.is_real` is boolean.
260        builder.assert_bool(local.is_real);
261
262        // Receive the arguments, which consists of 8 message columns, `is_send`, `is_receive`, and
263        // `kind`. In MemoryGlobal, MemoryLocal, Syscall chips, `is_send`, `is_receive`,
264        // `kind` are sent with correct constant values. For a global send interaction,
265        // `is_send = 1` and `is_receive = 0` are used. For a global receive interaction,
266        // `is_send = 0` and `is_receive = 1` are used. For a memory global interaction,
267        // `kind = InteractionKind::Memory` is used. For a syscall global interaction, `kind
268        // = InteractionKind::Syscall` is used. Therefore, `is_send`, `is_receive` are
269        // already known to be boolean, and `kind` is also known to be a `u8` value.
270        builder.receive(
271            AirInteraction::new(
272                vec![
273                    local.message[0].into(),
274                    local.message[1].into(),
275                    local.message[2].into(),
276                    local.message[3].into(),
277                    local.message[4].into(),
278                    local.message[5].into(),
279                    local.message[6].into(),
280                    local.message[7].into(),
281                    local.is_send.into(),
282                    local.is_receive.into(),
283                    local.kind.into(),
284                ],
285                local.is_real.into(),
286                InteractionKind::Global,
287            ),
288            InteractionScope::Local,
289        );
290
291        // Evaluate the interaction.
292        GlobalInteractionOperation::<AB::F>::eval_single_digest(
293            builder,
294            local.message.map(Into::into),
295            local.interaction,
296            local.is_receive.into(),
297            local.is_send.into(),
298            local.is_real,
299            local.kind,
300            [local.message_0_16bit_limb, local.message_0_8bit_limb],
301        );
302
303        // Evaluate the accumulation.
304        GlobalAccumulationOperation::<AB::F>::eval_accumulation(
305            builder,
306            local.interaction,
307            local.is_real,
308            local.index,
309            local.accumulation,
310        );
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    #![allow(clippy::print_stdout)]
317
318    use std::sync::Arc;
319
320    use super::*;
321    use crate::io::SP1Stdin;
322    use crate::programs::tests::*;
323    use crate::utils::generate_records;
324
325    use slop_matrix::dense::RowMajorMatrix;
326    use sp1_core_executor::{ExecutionRecord, SP1CoreOpts};
327    use sp1_hypercube::air::MachineAir;
328    use sp1_primitives::SP1Field;
329
330    #[test]
331    #[allow(clippy::uninlined_format_args)]
332    fn test_global_generate_trace() {
333        let program = simple_program();
334        let (records, _) = generate_records::<SP1Field>(
335            Arc::new(program),
336            SP1Stdin::new(),
337            SP1CoreOpts::default(),
338            [0; 4],
339        )
340        .unwrap();
341
342        // Use the last record which should contain global events
343        let shard = records.into_iter().last().unwrap();
344
345        let chip: GlobalChip = GlobalChip;
346
347        let trace: RowMajorMatrix<SP1Field> =
348            chip.generate_trace(&shard, &mut ExecutionRecord::default());
349        println!("{:?}", trace.values);
350
351        for mem_event in shard.global_memory_finalize_events {
352            println!("{mem_event:?}");
353        }
354    }
355}