sp1_core_machine/syscall/instructions/
trace.rs

1use std::borrow::BorrowMut;
2
3use hashbrown::HashMap;
4use itertools::Itertools;
5use p3_field::PrimeField32;
6use p3_matrix::dense::RowMajorMatrix;
7use rayon::iter::{ParallelBridge, ParallelIterator};
8use sp1_core_executor::{
9    events::{ByteLookupEvent, ByteRecord, MemoryRecordEnum, SyscallEvent},
10    syscalls::SyscallCode,
11    ExecutionRecord, Program,
12};
13use sp1_stark::air::MachineAir;
14
15use crate::utils::{next_power_of_two, zeroed_f_vec};
16
17use super::{
18    columns::{SyscallInstrColumns, NUM_SYSCALL_INSTR_COLS},
19    SyscallInstrsChip,
20};
21
22impl<F: PrimeField32> MachineAir<F> for SyscallInstrsChip {
23    type Record = ExecutionRecord;
24
25    type Program = Program;
26
27    fn name(&self) -> String {
28        "SyscallInstrs".to_string()
29    }
30
31    fn generate_trace(
32        &self,
33        input: &ExecutionRecord,
34        output: &mut ExecutionRecord,
35    ) -> RowMajorMatrix<F> {
36        let chunk_size = std::cmp::max((input.syscall_events.len()) / num_cpus::get(), 1);
37        let nb_rows = input.syscall_events.len();
38        let size_log2 = input.fixed_log2_rows::<F, _>(self);
39        let padded_nb_rows = next_power_of_two(nb_rows, size_log2);
40        let mut values = zeroed_f_vec(padded_nb_rows * NUM_SYSCALL_INSTR_COLS);
41
42        let blu_events = values
43            .chunks_mut(chunk_size * NUM_SYSCALL_INSTR_COLS)
44            .enumerate()
45            .par_bridge()
46            .map(|(i, rows)| {
47                let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
48                rows.chunks_mut(NUM_SYSCALL_INSTR_COLS).enumerate().for_each(|(j, row)| {
49                    let idx = i * chunk_size + j;
50                    let cols: &mut SyscallInstrColumns<F> = row.borrow_mut();
51
52                    if idx < input.syscall_events.len() {
53                        let event = &input.syscall_events[idx];
54                        self.event_to_row(event, cols, &mut blu);
55                    }
56                });
57                blu
58            })
59            .collect::<Vec<_>>();
60
61        output.add_byte_lookup_events_from_maps(blu_events.iter().collect_vec());
62
63        // Convert the trace to a row major matrix.
64        RowMajorMatrix::new(values, NUM_SYSCALL_INSTR_COLS)
65    }
66
67    fn included(&self, shard: &Self::Record) -> bool {
68        if let Some(shape) = shard.shape.as_ref() {
69            shape.included::<F, _>(self)
70        } else {
71            !shard.syscall_events.is_empty()
72        }
73    }
74}
75
76impl SyscallInstrsChip {
77    fn event_to_row<F: PrimeField32>(
78        &self,
79        event: &SyscallEvent,
80        cols: &mut SyscallInstrColumns<F>,
81        blu: &mut impl ByteRecord,
82    ) {
83        cols.is_real = F::one();
84        cols.pc = F::from_canonical_u32(event.pc);
85        cols.next_pc = F::from_canonical_u32(event.next_pc);
86        cols.shard = F::from_canonical_u32(event.shard);
87        cols.clk = F::from_canonical_u32(event.clk);
88
89        cols.op_a_access.populate(MemoryRecordEnum::Write(event.a_record), blu);
90        cols.op_b_value = event.arg1.into();
91        cols.op_c_value = event.arg2.into();
92
93        let syscall_id = cols.op_a_access.prev_value[0];
94        let num_cycles = cols.op_a_access.prev_value[2];
95
96        cols.num_extra_cycles = num_cycles;
97        cols.is_halt =
98            F::from_bool(syscall_id == F::from_canonical_u32(SyscallCode::HALT.syscall_id()));
99
100        // Populate `is_enter_unconstrained`.
101        cols.is_enter_unconstrained.populate_from_field_element(
102            syscall_id - F::from_canonical_u32(SyscallCode::ENTER_UNCONSTRAINED.syscall_id()),
103        );
104
105        // Populate `is_hint_len`.
106        cols.is_hint_len.populate_from_field_element(
107            syscall_id - F::from_canonical_u32(SyscallCode::HINT_LEN.syscall_id()),
108        );
109
110        // Populate `is_halt`.
111        cols.is_halt_check.populate_from_field_element(
112            syscall_id - F::from_canonical_u32(SyscallCode::HALT.syscall_id()),
113        );
114
115        // Populate `is_commit`.
116        cols.is_commit.populate_from_field_element(
117            syscall_id - F::from_canonical_u32(SyscallCode::COMMIT.syscall_id()),
118        );
119
120        // Populate `is_commit_deferred_proofs`.
121        cols.is_commit_deferred_proofs.populate_from_field_element(
122            syscall_id - F::from_canonical_u32(SyscallCode::COMMIT_DEFERRED_PROOFS.syscall_id()),
123        );
124
125        // If the syscall is `COMMIT` or `COMMIT_DEFERRED_PROOFS`, set the index bitmap and
126        // digest word.
127        if syscall_id == F::from_canonical_u32(SyscallCode::COMMIT.syscall_id()) ||
128            syscall_id == F::from_canonical_u32(SyscallCode::COMMIT_DEFERRED_PROOFS.syscall_id())
129        {
130            let digest_idx = cols.op_b_value.to_u32() as usize;
131            cols.index_bitmap[digest_idx] = F::one();
132        }
133
134        // For halt and commit deferred proofs syscalls, we need to baby bear range check one of
135        // it's operands.
136        if cols.is_halt == F::one() {
137            cols.operand_to_check = event.arg1.into();
138            cols.operand_range_check_cols.populate(cols.operand_to_check, blu);
139            cols.ecall_range_check_operand = F::one();
140        }
141
142        if syscall_id == F::from_canonical_u32(SyscallCode::COMMIT_DEFERRED_PROOFS.syscall_id()) {
143            cols.operand_to_check = event.arg2.into();
144            cols.operand_range_check_cols.populate(cols.operand_to_check, blu);
145            cols.ecall_range_check_operand = F::one();
146        }
147    }
148}