1use hashbrown::HashMap;
2use itertools::Itertools;
3use slop_air::{Air, BaseAir};
4use slop_algebra::{AbstractField, PrimeField32};
5use slop_matrix::Matrix;
6use slop_maybe_rayon::prelude::{
7 IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
8};
9use sp1_core_executor::{
10 events::{
11 ByteLookupEvent, ByteRecord, InstructionFetchEvent, MemInstrEvent, MemoryAccessPosition,
12 MemoryRecordEnum,
13 },
14 ByteOpcode, ExecutionRecord, Program,
15};
16use sp1_derive::AlignedBorrow;
17use sp1_hypercube::air::MachineAir;
18use sp1_primitives::consts::{PROT_EXEC, PROT_READ, PROT_WRITE};
19use std::{
20 borrow::{Borrow, BorrowMut},
21 mem::MaybeUninit,
22};
23
24use crate::{air::SP1CoreAirBuilder, operations::PageProtOperation, utils::next_multiple_of_32};
25
26const BITMASK_CLEAR_LOWEST_THREE_BITS: u64 = 0xFFFFFFFFFFFFFFF8;
28
29pub const NUM_PAGE_PROT_ENTRIES_PER_ROW: usize = 4;
30pub(crate) const NUM_PAGE_PROT_COLS: usize = size_of::<PageProtCols<u8>>();
31
32#[derive(AlignedBorrow, Clone, Copy)]
33#[repr(C)]
34pub struct SinglePageProtCols<T: Copy> {
35 pub clk_high: T,
37 pub clk_low: T,
38
39 pub addr: [T; 3],
41
42 pub permissions: T,
44
45 pub is_real: T,
47
48 pub page_prot_op: PageProtOperation<T>,
50}
51
52#[derive(AlignedBorrow, Clone, Copy)]
53#[repr(C)]
54pub struct PageProtCols<T: Copy> {
55 page_prot_entries: [SinglePageProtCols<T>; NUM_PAGE_PROT_ENTRIES_PER_ROW],
56}
57
58#[derive(Default)]
59pub struct PageProtChip;
60
61impl<F> BaseAir<F> for PageProtChip {
62 fn width(&self) -> usize {
63 NUM_PAGE_PROT_COLS
64 }
65}
66
67fn nb_rows(count: usize) -> usize {
68 if NUM_PAGE_PROT_ENTRIES_PER_ROW > 1 {
69 count.div_ceil(NUM_PAGE_PROT_ENTRIES_PER_ROW)
70 } else {
71 count
72 }
73}
74
75impl<F: PrimeField32> MachineAir<F> for PageProtChip {
76 type Record = ExecutionRecord;
77
78 type Program = Program;
79
80 fn name(&self) -> &'static str {
81 "PageProt"
82 }
83
84 fn num_rows(&self, input: &Self::Record) -> Option<usize> {
85 let mut count = 0;
86 if input.public_values.is_untrusted_programs_enabled == 1 {
87 count = input.memory_load_byte_events.len()
88 + input.memory_store_byte_events.len()
89 + input.memory_load_word_events.len()
90 + input.memory_store_word_events.len()
91 + input.memory_load_double_events.len()
92 + input.memory_store_double_events.len()
93 + input.memory_load_half_events.len()
94 + input.memory_store_half_events.len()
95 + input.memory_load_x0_events.len()
96 + input.instruction_fetch_events.len();
97 }
98
99 let nb_rows = nb_rows(count);
100 let size_log2 = input.fixed_log2_rows::<F, _>(self);
101 Some(next_multiple_of_32(nb_rows, size_log2))
102 }
103
104 fn generate_trace_into(
105 &self,
106 input: &ExecutionRecord,
107 output: &mut ExecutionRecord,
108 buffer: &mut [MaybeUninit<F>],
109 ) {
110 let mut events = vec![];
111
112 if input.public_values.is_untrusted_programs_enabled == 1 {
113 events = input
114 .memory_load_byte_events
115 .iter()
116 .map(|e| Self::generate_page_prot_event(&e.0, true, false, false))
117 .chain(
118 input
119 .memory_store_byte_events
120 .iter()
121 .map(|e| Self::generate_page_prot_event(&e.0, false, true, false)),
122 )
123 .chain(
124 input
125 .memory_load_word_events
126 .iter()
127 .map(|e| Self::generate_page_prot_event(&e.0, true, false, false)),
128 )
129 .chain(
130 input
131 .memory_store_word_events
132 .iter()
133 .map(|e| Self::generate_page_prot_event(&e.0, false, true, false)),
134 )
135 .chain(
136 input
137 .memory_load_double_events
138 .iter()
139 .map(|e| Self::generate_page_prot_event(&e.0, true, false, false)),
140 )
141 .chain(
142 input
143 .memory_store_double_events
144 .iter()
145 .map(|e| Self::generate_page_prot_event(&e.0, false, true, false)),
146 )
147 .chain(
148 input
149 .memory_load_half_events
150 .iter()
151 .map(|e| Self::generate_page_prot_event(&e.0, true, false, false)),
152 )
153 .chain(
154 input
155 .memory_store_half_events
156 .iter()
157 .map(|e| Self::generate_page_prot_event(&e.0, false, true, false)),
158 )
159 .chain(
160 input
161 .memory_load_x0_events
162 .iter()
163 .map(|e| Self::generate_page_prot_event(&e.0, true, false, false)),
164 )
165 .chain(input.instruction_fetch_events.iter().map(|e| {
166 let (mem_access, _) = e.1.untrusted_instruction.unwrap();
167 Self::generate_fetch_instruction_page_prot_event(
168 &e.0, mem_access, true, false, true,
169 )
170 }))
171 .collect_vec();
172 }
173
174 let nb_rows = nb_rows(events.len());
175 let padded_nb_rows = <PageProtChip as MachineAir<F>>::num_rows(self, input).unwrap();
176
177 unsafe {
178 let padding_start = nb_rows * NUM_PAGE_PROT_COLS;
179 let padding_size = (padded_nb_rows - nb_rows) * NUM_PAGE_PROT_COLS;
180 if padding_size > 0 {
181 core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
182 }
183 }
184
185 let buffer_ptr = buffer.as_mut_ptr() as *mut F;
186 let values =
187 unsafe { core::slice::from_raw_parts_mut(buffer_ptr, nb_rows * NUM_PAGE_PROT_COLS) };
188
189 let chunk_size = std::cmp::max(nb_rows / num_cpus::get(), 0) + 1;
190
191 let mut chunks = values[..nb_rows * NUM_PAGE_PROT_COLS]
192 .chunks_mut(chunk_size * NUM_PAGE_PROT_COLS)
193 .collect::<Vec<_>>();
194
195 let blu_events = chunks
196 .par_iter_mut()
197 .enumerate()
198 .map(|(i, rows)| {
199 let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
200
201 rows.chunks_mut(NUM_PAGE_PROT_COLS).enumerate().for_each(|(j, row)| {
202 unsafe {
203 core::ptr::write_bytes(row.as_mut_ptr(), 0, NUM_PAGE_PROT_COLS);
204 }
205 let idx = (i * chunk_size + j) * NUM_PAGE_PROT_ENTRIES_PER_ROW;
206 let cols: &mut PageProtCols<F> = row.borrow_mut();
207
208 for k in 0..NUM_PAGE_PROT_ENTRIES_PER_ROW {
209 let cols = &mut cols.page_prot_entries[k];
210 if idx + k < events.len() {
211 let event = &events[idx + k];
212 self.event_to_row(event, cols, &mut blu);
213 }
214 }
215 });
216 blu
217 })
218 .collect::<Vec<_>>();
219
220 output.add_byte_lookup_events_from_maps(blu_events.iter().collect_vec());
221 }
222
223 fn included(&self, shard: &Self::Record) -> bool {
224 if let Some(shape) = shard.shape.as_ref() {
225 shape.included::<F, _>(self)
226 } else {
227 (shard.memory_load_byte_events.len()
228 + shard.memory_store_byte_events.len()
229 + shard.memory_load_word_events.len()
230 + shard.memory_store_word_events.len()
231 + shard.memory_load_double_events.len()
232 + shard.memory_store_double_events.len()
233 + shard.memory_load_half_events.len()
234 + shard.memory_store_half_events.len()
235 + shard.memory_load_x0_events.len()
236 + shard.instruction_fetch_events.len()
237 > 0)
238 && shard.program.enable_untrusted_programs
239 }
240 }
241}
242
243impl<AB> Air<AB> for PageProtChip
244where
245 AB: SP1CoreAirBuilder,
246 AB::Var: Sized,
247 AB::F: PrimeField32,
248{
249 fn eval(&self, builder: &mut AB) {
250 let main = builder.main();
251 let local = main.row_slice(0);
252 let local: &PageProtCols<AB::Var> = (*local).borrow();
253
254 builder.assert_eq(
255 builder.extract_public_values().is_untrusted_programs_enabled,
256 AB::Expr::one(),
257 );
258
259 for local in local.page_prot_entries.iter() {
260 builder.assert_bool(local.is_real);
262
263 #[cfg(not(feature = "mprotect"))]
264 builder.assert_zero(local.is_real);
265
266 builder.send_byte(
268 AB::Expr::from_canonical_u8(ByteOpcode::AND as u8),
269 local.permissions,
270 local.permissions,
271 local.page_prot_op.page_prot_access.prev_prot_bitmap,
272 local.is_real,
273 );
274
275 builder.receive_page_prot(
277 local.clk_high,
278 local.clk_low,
279 &local.addr.map(Into::into),
280 local.permissions,
281 local.is_real,
282 );
283
284 PageProtOperation::<AB::F>::eval(
286 builder,
287 local.clk_high.into(),
288 local.clk_low.into(),
289 &local.addr.map(Into::into),
290 local.page_prot_op,
291 local.is_real.into(),
292 );
293 }
294 }
295}
296
297struct PageProtEvent {
298 clk: u64,
299 addr: u64,
300 is_read: bool,
301 is_write: bool,
302 is_executable: bool,
303 mem_access: MemoryRecordEnum,
304}
305
306impl PageProtChip {
307 fn generate_page_prot_event(
308 mem_instr_event: &MemInstrEvent,
309 is_read: bool,
310 is_write: bool,
311 is_executable: bool,
312 ) -> PageProtEvent {
313 PageProtEvent {
314 clk: mem_instr_event.clk + MemoryAccessPosition::Memory as u64,
315 addr: (mem_instr_event.b.wrapping_add(mem_instr_event.c)
316 & BITMASK_CLEAR_LOWEST_THREE_BITS) as u64,
317 is_read,
318 is_write,
319 is_executable,
320 mem_access: mem_instr_event.mem_access,
321 }
322 }
323
324 fn generate_fetch_instruction_page_prot_event(
325 untrusted_program_event: &InstructionFetchEvent,
326 memory_record_enum: MemoryRecordEnum,
327 is_read: bool,
328 is_write: bool,
329 is_executable: bool,
330 ) -> PageProtEvent {
331 PageProtEvent {
332 clk: untrusted_program_event.clk,
333 addr: (untrusted_program_event.pc & BITMASK_CLEAR_LOWEST_THREE_BITS) as u64,
334 is_read,
335 is_write,
336 is_executable,
337 mem_access: memory_record_enum,
338 }
339 }
340
341 fn event_to_row<F: PrimeField32>(
342 &self,
343 event: &PageProtEvent,
344 cols: &mut SinglePageProtCols<F>,
345 blu: &mut HashMap<ByteLookupEvent, usize>,
346 ) {
347 cols.clk_high = F::from_canonical_u32((event.clk >> 24) as u32);
348 cols.clk_low = F::from_canonical_u32((event.clk & 0xFFFFFF) as u32);
349
350 let mut perm: u8 = 0;
351 perm += (event.is_read as u8) * PROT_READ;
352 perm += (event.is_write as u8) * PROT_WRITE;
353 perm += (event.is_executable as u8) * PROT_EXEC;
354
355 let set_perm = event.mem_access.previous_page_prot_record().unwrap().page_prot;
356
357 blu.add_byte_lookup_event(ByteLookupEvent {
358 opcode: ByteOpcode::AND,
359 a: perm as u16,
360 b: perm,
361 c: set_perm,
362 });
363
364 cols.permissions = F::from_canonical_u8(perm);
365 cols.is_real = F::one();
366
367 cols.addr = [
368 F::from_canonical_u64(event.addr & 0xFFFF),
369 F::from_canonical_u64((event.addr >> 16) & 0xFFFF),
370 F::from_canonical_u64((event.addr >> 32) & 0xFFFF),
371 ];
372
373 cols.page_prot_op.populate(
374 blu,
375 event.addr,
376 event.clk,
377 &event.mem_access.previous_page_prot_record().unwrap(),
378 );
379 }
380}