Skip to main content

sp1_core_machine/memory/
page_prot.rs

1use hashbrown::HashMap;
2use itertools::Itertools;
3use slop_air::{Air, BaseAir};
4use slop_algebra::{AbstractField, PrimeField32};
5use slop_matrix::Matrix;
6use slop_maybe_rayon::prelude::{
7    IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
8};
9use sp1_core_executor::{
10    events::{
11        ByteLookupEvent, ByteRecord, InstructionFetchEvent, MemInstrEvent, MemoryAccessPosition,
12        MemoryRecordEnum,
13    },
14    ByteOpcode, ExecutionRecord, Program,
15};
16use sp1_derive::AlignedBorrow;
17use sp1_hypercube::air::MachineAir;
18use sp1_primitives::consts::{PROT_EXEC, PROT_READ, PROT_WRITE};
19use std::{
20    borrow::{Borrow, BorrowMut},
21    mem::MaybeUninit,
22};
23
24use crate::{air::SP1CoreAirBuilder, operations::PageProtOperation, utils::next_multiple_of_32};
25
26// Used to ensure address is aligned to page, clears out lowest 3 bits
27const BITMASK_CLEAR_LOWEST_THREE_BITS: u64 = 0xFFFFFFFFFFFFFFF8;
28
29pub const NUM_PAGE_PROT_ENTRIES_PER_ROW: usize = 4;
30pub(crate) const NUM_PAGE_PROT_COLS: usize = size_of::<PageProtCols<u8>>();
31
32#[derive(AlignedBorrow, Clone, Copy)]
33#[repr(C)]
34pub struct SinglePageProtCols<T: Copy> {
35    /// The clock of the memory access.
36    pub clk_high: T,
37    pub clk_low: T,
38
39    /// The address of the memory access.
40    pub addr: [T; 3],
41
42    /// The permissions of the page.
43    pub permissions: T,
44
45    /// Whether or not the row is a real row or a padding row.
46    pub is_real: T,
47
48    /// The page prot operation.
49    pub page_prot_op: PageProtOperation<T>,
50}
51
52#[derive(AlignedBorrow, Clone, Copy)]
53#[repr(C)]
54pub struct PageProtCols<T: Copy> {
55    page_prot_entries: [SinglePageProtCols<T>; NUM_PAGE_PROT_ENTRIES_PER_ROW],
56}
57
58#[derive(Default)]
59pub struct PageProtChip;
60
61impl<F> BaseAir<F> for PageProtChip {
62    fn width(&self) -> usize {
63        NUM_PAGE_PROT_COLS
64    }
65}
66
67fn nb_rows(count: usize) -> usize {
68    if NUM_PAGE_PROT_ENTRIES_PER_ROW > 1 {
69        count.div_ceil(NUM_PAGE_PROT_ENTRIES_PER_ROW)
70    } else {
71        count
72    }
73}
74
75impl<F: PrimeField32> MachineAir<F> for PageProtChip {
76    type Record = ExecutionRecord;
77
78    type Program = Program;
79
80    fn name(&self) -> &'static str {
81        "PageProt"
82    }
83
84    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
85        let mut count = 0;
86        if input.public_values.is_untrusted_programs_enabled == 1 {
87            count = input.memory_load_byte_events.len()
88                + input.memory_store_byte_events.len()
89                + input.memory_load_word_events.len()
90                + input.memory_store_word_events.len()
91                + input.memory_load_double_events.len()
92                + input.memory_store_double_events.len()
93                + input.memory_load_half_events.len()
94                + input.memory_store_half_events.len()
95                + input.memory_load_x0_events.len()
96                + input.instruction_fetch_events.len();
97        }
98
99        let nb_rows = nb_rows(count);
100        let size_log2 = input.fixed_log2_rows::<F, _>(self);
101        Some(next_multiple_of_32(nb_rows, size_log2))
102    }
103
104    fn generate_trace_into(
105        &self,
106        input: &ExecutionRecord,
107        output: &mut ExecutionRecord,
108        buffer: &mut [MaybeUninit<F>],
109    ) {
110        let mut events = vec![];
111
112        if input.public_values.is_untrusted_programs_enabled == 1 {
113            events = input
114                .memory_load_byte_events
115                .iter()
116                .map(|e| Self::generate_page_prot_event(&e.0, true, false, false))
117                .chain(
118                    input
119                        .memory_store_byte_events
120                        .iter()
121                        .map(|e| Self::generate_page_prot_event(&e.0, false, true, false)),
122                )
123                .chain(
124                    input
125                        .memory_load_word_events
126                        .iter()
127                        .map(|e| Self::generate_page_prot_event(&e.0, true, false, false)),
128                )
129                .chain(
130                    input
131                        .memory_store_word_events
132                        .iter()
133                        .map(|e| Self::generate_page_prot_event(&e.0, false, true, false)),
134                )
135                .chain(
136                    input
137                        .memory_load_double_events
138                        .iter()
139                        .map(|e| Self::generate_page_prot_event(&e.0, true, false, false)),
140                )
141                .chain(
142                    input
143                        .memory_store_double_events
144                        .iter()
145                        .map(|e| Self::generate_page_prot_event(&e.0, false, true, false)),
146                )
147                .chain(
148                    input
149                        .memory_load_half_events
150                        .iter()
151                        .map(|e| Self::generate_page_prot_event(&e.0, true, false, false)),
152                )
153                .chain(
154                    input
155                        .memory_store_half_events
156                        .iter()
157                        .map(|e| Self::generate_page_prot_event(&e.0, false, true, false)),
158                )
159                .chain(
160                    input
161                        .memory_load_x0_events
162                        .iter()
163                        .map(|e| Self::generate_page_prot_event(&e.0, true, false, false)),
164                )
165                .chain(input.instruction_fetch_events.iter().map(|e| {
166                    let (mem_access, _) = e.1.untrusted_instruction.unwrap();
167                    Self::generate_fetch_instruction_page_prot_event(
168                        &e.0, mem_access, true, false, true,
169                    )
170                }))
171                .collect_vec();
172        }
173
174        let nb_rows = nb_rows(events.len());
175        let padded_nb_rows = <PageProtChip as MachineAir<F>>::num_rows(self, input).unwrap();
176
177        unsafe {
178            let padding_start = nb_rows * NUM_PAGE_PROT_COLS;
179            let padding_size = (padded_nb_rows - nb_rows) * NUM_PAGE_PROT_COLS;
180            if padding_size > 0 {
181                core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
182            }
183        }
184
185        let buffer_ptr = buffer.as_mut_ptr() as *mut F;
186        let values =
187            unsafe { core::slice::from_raw_parts_mut(buffer_ptr, nb_rows * NUM_PAGE_PROT_COLS) };
188
189        let chunk_size = std::cmp::max(nb_rows / num_cpus::get(), 0) + 1;
190
191        let mut chunks = values[..nb_rows * NUM_PAGE_PROT_COLS]
192            .chunks_mut(chunk_size * NUM_PAGE_PROT_COLS)
193            .collect::<Vec<_>>();
194
195        let blu_events = chunks
196            .par_iter_mut()
197            .enumerate()
198            .map(|(i, rows)| {
199                let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
200
201                rows.chunks_mut(NUM_PAGE_PROT_COLS).enumerate().for_each(|(j, row)| {
202                    unsafe {
203                        core::ptr::write_bytes(row.as_mut_ptr(), 0, NUM_PAGE_PROT_COLS);
204                    }
205                    let idx = (i * chunk_size + j) * NUM_PAGE_PROT_ENTRIES_PER_ROW;
206                    let cols: &mut PageProtCols<F> = row.borrow_mut();
207
208                    for k in 0..NUM_PAGE_PROT_ENTRIES_PER_ROW {
209                        let cols = &mut cols.page_prot_entries[k];
210                        if idx + k < events.len() {
211                            let event = &events[idx + k];
212                            self.event_to_row(event, cols, &mut blu);
213                        }
214                    }
215                });
216                blu
217            })
218            .collect::<Vec<_>>();
219
220        output.add_byte_lookup_events_from_maps(blu_events.iter().collect_vec());
221    }
222
223    fn included(&self, shard: &Self::Record) -> bool {
224        if let Some(shape) = shard.shape.as_ref() {
225            shape.included::<F, _>(self)
226        } else {
227            (shard.memory_load_byte_events.len()
228                + shard.memory_store_byte_events.len()
229                + shard.memory_load_word_events.len()
230                + shard.memory_store_word_events.len()
231                + shard.memory_load_double_events.len()
232                + shard.memory_store_double_events.len()
233                + shard.memory_load_half_events.len()
234                + shard.memory_store_half_events.len()
235                + shard.memory_load_x0_events.len()
236                + shard.instruction_fetch_events.len()
237                > 0)
238                && shard.program.enable_untrusted_programs
239        }
240    }
241}
242
243impl<AB> Air<AB> for PageProtChip
244where
245    AB: SP1CoreAirBuilder,
246    AB::Var: Sized,
247    AB::F: PrimeField32,
248{
249    fn eval(&self, builder: &mut AB) {
250        let main = builder.main();
251        let local = main.row_slice(0);
252        let local: &PageProtCols<AB::Var> = (*local).borrow();
253
254        builder.assert_eq(
255            builder.extract_public_values().is_untrusted_programs_enabled,
256            AB::Expr::one(),
257        );
258
259        for local in local.page_prot_entries.iter() {
260            // Assert that `is_real` is boolean.
261            builder.assert_bool(local.is_real);
262
263            #[cfg(not(feature = "mprotect"))]
264            builder.assert_zero(local.is_real);
265
266            // Ensure requested permission matches the set permission.
267            builder.send_byte(
268                AB::Expr::from_canonical_u8(ByteOpcode::AND as u8),
269                local.permissions,
270                local.permissions,
271                local.page_prot_op.page_prot_access.prev_prot_bitmap,
272                local.is_real,
273            );
274
275            // Receive the page prot access.
276            builder.receive_page_prot(
277                local.clk_high,
278                local.clk_low,
279                &local.addr.map(Into::into),
280                local.permissions,
281                local.is_real,
282            );
283
284            // Read the currently set page permissions.
285            PageProtOperation::<AB::F>::eval(
286                builder,
287                local.clk_high.into(),
288                local.clk_low.into(),
289                &local.addr.map(Into::into),
290                local.page_prot_op,
291                local.is_real.into(),
292            );
293        }
294    }
295}
296
297struct PageProtEvent {
298    clk: u64,
299    addr: u64,
300    is_read: bool,
301    is_write: bool,
302    is_executable: bool,
303    mem_access: MemoryRecordEnum,
304}
305
306impl PageProtChip {
307    fn generate_page_prot_event(
308        mem_instr_event: &MemInstrEvent,
309        is_read: bool,
310        is_write: bool,
311        is_executable: bool,
312    ) -> PageProtEvent {
313        PageProtEvent {
314            clk: mem_instr_event.clk + MemoryAccessPosition::Memory as u64,
315            addr: (mem_instr_event.b.wrapping_add(mem_instr_event.c)
316                & BITMASK_CLEAR_LOWEST_THREE_BITS) as u64,
317            is_read,
318            is_write,
319            is_executable,
320            mem_access: mem_instr_event.mem_access,
321        }
322    }
323
324    fn generate_fetch_instruction_page_prot_event(
325        untrusted_program_event: &InstructionFetchEvent,
326        memory_record_enum: MemoryRecordEnum,
327        is_read: bool,
328        is_write: bool,
329        is_executable: bool,
330    ) -> PageProtEvent {
331        PageProtEvent {
332            clk: untrusted_program_event.clk,
333            addr: (untrusted_program_event.pc & BITMASK_CLEAR_LOWEST_THREE_BITS) as u64,
334            is_read,
335            is_write,
336            is_executable,
337            mem_access: memory_record_enum,
338        }
339    }
340
341    fn event_to_row<F: PrimeField32>(
342        &self,
343        event: &PageProtEvent,
344        cols: &mut SinglePageProtCols<F>,
345        blu: &mut HashMap<ByteLookupEvent, usize>,
346    ) {
347        cols.clk_high = F::from_canonical_u32((event.clk >> 24) as u32);
348        cols.clk_low = F::from_canonical_u32((event.clk & 0xFFFFFF) as u32);
349
350        let mut perm: u8 = 0;
351        perm += (event.is_read as u8) * PROT_READ;
352        perm += (event.is_write as u8) * PROT_WRITE;
353        perm += (event.is_executable as u8) * PROT_EXEC;
354
355        let set_perm = event.mem_access.previous_page_prot_record().unwrap().page_prot;
356
357        blu.add_byte_lookup_event(ByteLookupEvent {
358            opcode: ByteOpcode::AND,
359            a: perm as u16,
360            b: perm,
361            c: set_perm,
362        });
363
364        cols.permissions = F::from_canonical_u8(perm);
365        cols.is_real = F::one();
366
367        cols.addr = [
368            F::from_canonical_u64(event.addr & 0xFFFF),
369            F::from_canonical_u64((event.addr >> 16) & 0xFFFF),
370            F::from_canonical_u64((event.addr >> 32) & 0xFFFF),
371        ];
372
373        cols.page_prot_op.populate(
374            blu,
375            event.addr,
376            event.clk,
377            &event.mem_access.previous_page_prot_record().unwrap(),
378        );
379    }
380}