Skip to main content

sp1_core_machine/control_flow/trap/
exec.rs

1use hashbrown::HashMap;
2use itertools::Itertools;
3use rayon::iter::{ParallelBridge, ParallelIterator};
4use sp1_core_executor::{
5    events::{ByteLookupEvent, ByteRecord},
6    ByteOpcode, ExecutionRecord, Program, CLK_INC,
7};
8use sp1_derive::AlignedBorrow;
9use sp1_hypercube::air::MachineAir;
10use sp1_primitives::consts::{PROT_EXEC, PROT_FAILURE_EXEC, PROT_FAILURE_READ, PROT_READ};
11use std::borrow::{Borrow, BorrowMut};
12use std::mem::{size_of, MaybeUninit};
13
14use crate::{
15    adapter::state::{CPUState, CPUStateInput},
16    air::{SP1CoreAirBuilder, SP1Operation},
17    operations::{PageProtOperation, TrapOperation},
18    utils::next_multiple_of_32,
19};
20use slop_air::{Air, AirBuilder, BaseAir};
21use slop_algebra::{AbstractField, PrimeField32};
22use slop_matrix::Matrix;
23#[cfg(feature = "mprotect")]
24use sp1_hypercube::addr_to_limbs;
25
26/// The number of main trace columns for `TrapExecChip`.
27pub const NUM_TRAP_EXEC_COLS: usize = size_of::<TrapExecColumns<u8>>();
28
29#[derive(AlignedBorrow, Default, Debug, Clone, Copy)]
30#[repr(C)]
31pub struct TrapExecColumns<T> {
32    /// The current shard, timestamp, program counter of the CPU.
33    pub state: CPUState<T>,
34
35    /// The operation to get the page permission.
36    pub page_prot_operation: PageProtOperation<T>,
37
38    /// The operation to handle the trap.
39    pub trap_operation: TrapOperation<T>,
40
41    /// Addresses for the trap context. Should be removed after GKR supports public values.
42    pub addresses: [[T; 3]; 3],
43
44    /// Whether or not `PROT_EXEC` failed.
45    pub prot_exec_fail: T,
46
47    /// Whether or not `PROT_READ` failed.
48    pub prot_read_fail: T,
49
50    /// The trap code.
51    pub trap_code: T,
52
53    /// Whether or not the current row is a real row.
54    pub is_real: T,
55}
56
57#[derive(Default)]
58pub struct TrapExecChip;
59
60impl<F> BaseAir<F> for TrapExecChip {
61    fn width(&self) -> usize {
62        NUM_TRAP_EXEC_COLS
63    }
64}
65
66impl<F: PrimeField32> MachineAir<F> for TrapExecChip {
67    type Record = ExecutionRecord;
68
69    type Program = Program;
70
71    fn name(&self) -> &'static str {
72        "TrapExec"
73    }
74
75    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
76        let nb_rows =
77            next_multiple_of_32(input.trap_exec_events.len(), input.fixed_log2_rows::<F, _>(self));
78        Some(nb_rows)
79    }
80
81    fn generate_trace_into(
82        &self,
83        input: &ExecutionRecord,
84        output: &mut ExecutionRecord,
85        buffer: &mut [MaybeUninit<F>],
86    ) {
87        let chunk_size = std::cmp::max((input.trap_exec_events.len()) / num_cpus::get(), 1);
88        let padded_nb_rows = <TrapExecChip as MachineAir<F>>::num_rows(self, input).unwrap();
89        let width = <TrapExecChip as BaseAir<F>>::width(self);
90        let num_event_rows = input.trap_exec_events.len();
91
92        unsafe {
93            let padding_start = num_event_rows * width;
94            let padding_size = (padded_nb_rows - num_event_rows) * width;
95            if padding_size > 0 {
96                core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
97            }
98        }
99
100        let buffer_ptr = buffer.as_mut_ptr() as *mut F;
101        let values = unsafe { core::slice::from_raw_parts_mut(buffer_ptr, padded_nb_rows * width) };
102
103        let blu_events = values
104            .chunks_mut(chunk_size * width)
105            .enumerate()
106            .par_bridge()
107            .map(|(i, rows)| {
108                let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
109                rows.chunks_mut(width).enumerate().for_each(|(j, row)| {
110                    let idx = i * chunk_size + j;
111                    let cols: &mut TrapExecColumns<F> = row.borrow_mut();
112
113                    if idx < input.trap_exec_events.len() {
114                        let event = &input.trap_exec_events[idx];
115                        cols.state.populate(&mut blu, event.clk, event.pc);
116                        cols.page_prot_operation.populate(
117                            &mut blu,
118                            event.pc,
119                            event.clk,
120                            &event.page_prot_record,
121                        );
122                        cols.trap_operation.populate(&mut blu, event.trap_result);
123                        let perm = event.page_prot_record.page_prot;
124                        cols.trap_code = F::from_canonical_u64(event.trap_result.code_record.value);
125                        cols.prot_read_fail = F::from_bool((perm & PROT_READ) == 0);
126                        cols.prot_exec_fail = F::from_bool((perm & PROT_EXEC) == 0);
127                        blu.add_byte_lookup_event(ByteLookupEvent {
128                            opcode: ByteOpcode::AND,
129                            a: (perm & (PROT_READ | PROT_EXEC)) as u16,
130                            b: perm,
131                            c: (PROT_READ | PROT_EXEC),
132                        });
133                        #[cfg(feature = "mprotect")]
134                        for i in 0..3 {
135                            cols.addresses[i] = addr_to_limbs(input.public_values.trap_context[i]);
136                        }
137                        blu.add_u16_range_check((event.pc & 0xFFFF) as u16);
138                        blu.add_u16_range_check(((event.pc >> 16) & 0xFFFF) as u16);
139                        blu.add_u16_range_check(((event.pc >> 32) & 0xFFFF) as u16);
140                        cols.is_real = F::one();
141                    }
142                });
143                blu
144            })
145            .collect::<Vec<_>>();
146
147        output.add_byte_lookup_events_from_maps(blu_events.iter().collect_vec());
148    }
149
150    fn included(&self, shard: &Self::Record) -> bool {
151        if let Some(shape) = shard.shape.as_ref() {
152            shape.included::<F, _>(self)
153        } else {
154            !shard.trap_exec_events.is_empty()
155        }
156    }
157}
158
159impl<AB> Air<AB> for TrapExecChip
160where
161    AB: SP1CoreAirBuilder,
162    AB::Var: Sized,
163{
164    #[inline(never)]
165    fn eval(&self, builder: &mut AB) {
166        let main = builder.main();
167        let local = main.row_slice(0);
168        let local: &TrapExecColumns<AB::Var> = (*local).borrow();
169
170        // Check that `is_real` is boolean.
171        builder.assert_bool(local.is_real);
172
173        // Range check that the `pc` are all valid u16 limbs.
174        builder.slice_range_check_u16(&local.state.pc, local.is_real);
175
176        #[cfg(not(feature = "mprotect"))]
177        builder.assert_zero(local.is_real);
178
179        // Read the currently set page permissions.
180        PageProtOperation::<AB::F>::eval(
181            builder,
182            local.state.clk_high::<AB>(),
183            local.state.clk_low::<AB>(),
184            &local.state.pc.map(Into::into),
185            local.page_prot_operation,
186            local.is_real.into(),
187        );
188
189        // Check that `prot_exec_fail` and `prot_read_fail` are boolean flags.
190        builder.assert_bool(local.prot_exec_fail);
191        builder.assert_bool(local.prot_read_fail);
192        // At least one of the permissions must fail.
193        builder.when(local.is_real).assert_zero(
194            (AB::Expr::one() - local.prot_exec_fail) * (AB::Expr::one() - local.prot_read_fail),
195        );
196
197        // Check the flags with an `OR` lookup.
198        builder.send_byte(
199            AB::Expr::from_canonical_u8(ByteOpcode::AND as u8),
200            AB::Expr::from_canonical_u8(PROT_EXEC) * (AB::Expr::one() - local.prot_exec_fail)
201                + AB::Expr::from_canonical_u8(PROT_READ) * (AB::Expr::one() - local.prot_read_fail),
202            local.page_prot_operation.page_prot_access.prev_prot_bitmap.into(),
203            AB::Expr::from_canonical_u8(PROT_EXEC | PROT_READ),
204            local.is_real.into(),
205        );
206
207        // If `PROT_EXEC` fails, the trap code is `PROT_FAILURE_EXEC`.
208        builder
209            .when(local.prot_exec_fail)
210            .assert_eq(local.trap_code, AB::Expr::from_canonical_u64(PROT_FAILURE_EXEC));
211
212        // If `PROT_EXEC` succeeds but `PROT_READ` fails, the trap code is `PROT_FAILURE_READ`.
213        builder
214            .when_not(local.prot_exec_fail)
215            .when(local.prot_read_fail)
216            .assert_eq(local.trap_code, AB::Expr::from_canonical_u64(PROT_FAILURE_READ));
217
218        let next_pc = TrapOperation::<AB::F>::eval(
219            builder,
220            local.trap_operation,
221            local.state.clk_high::<AB>(),
222            local.state.clk_low::<AB>(),
223            local.trap_code.into(),
224            local.state.pc.map(Into::into),
225            local.addresses,
226            local.is_real.into(),
227        );
228
229        // Constrain the state of the CPU.
230        // The `next_pc` is constrained by the AIR.
231        // The clock is incremented by `8`.
232        <CPUState<AB::F> as SP1Operation<AB>>::eval(
233            builder,
234            CPUStateInput::new(
235                local.state,
236                [next_pc[0].into(), next_pc[1].into(), next_pc[2].into()],
237                AB::Expr::from_canonical_u32(CLK_INC),
238                local.is_real.into(),
239            ),
240        );
241    }
242}