Skip to main content

sp1_core_machine/syscall/
chip.rs

1use crate::{air::WordAirBuilder, utils::next_multiple_of_32, TrustMode};
2use core::fmt;
3use itertools::Itertools;
4use slop_air::{Air, BaseAir};
5use slop_algebra::{AbstractField, PrimeField32};
6use slop_matrix::Matrix;
7use slop_maybe_rayon::prelude::{IndexedParallelIterator, ParallelIterator, ParallelSliceMut};
8use sp1_core_executor::{
9    events::{ByteRecord, GlobalInteractionEvent, SyscallEvent},
10    ExecutionRecord, Program, SupervisorMode, TrapError, UserMode,
11};
12use sp1_derive::AlignedBorrow;
13use sp1_hypercube::{
14    air::{AirInteraction, InteractionScope, MachineAir, SP1AirBuilder},
15    InteractionKind,
16};
17use std::{
18    borrow::{Borrow, BorrowMut},
19    marker::PhantomData,
20    mem::{size_of, MaybeUninit},
21};
22use struct_reflection::{StructReflection, StructReflectionHelper};
23
24/// The number of main trace columns for `SyscallChip` in supervisor mode.
25pub const NUM_SYSCALL_COLS_SUPERVISOR: usize = size_of::<SyscallCols<u8, SupervisorMode>>();
26/// The number of main trace columns for `SyscallChip` in user mode.
27pub const NUM_SYSCALL_COLS_USER: usize = size_of::<SyscallCols<u8, UserMode>>();
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
30pub enum SyscallShardKind {
31    Core,
32    Precompile,
33}
34
35/// A chip that stores the syscall invocations.
36pub struct SyscallChip<M: TrustMode> {
37    shard_kind: SyscallShardKind,
38    _phantom: PhantomData<M>,
39}
40
41impl<M: TrustMode> SyscallChip<M> {
42    pub const fn new(shard_kind: SyscallShardKind) -> Self {
43        Self { shard_kind, _phantom: std::marker::PhantomData }
44    }
45
46    pub const fn core() -> Self {
47        Self::new(SyscallShardKind::Core)
48    }
49
50    pub const fn precompile() -> Self {
51        Self::new(SyscallShardKind::Precompile)
52    }
53
54    pub fn shard_kind(&self) -> SyscallShardKind {
55        self.shard_kind
56    }
57}
58
59/// The column layout for the chip.
60#[derive(AlignedBorrow, Clone, Copy, StructReflection)]
61#[repr(C)]
62pub struct SyscallCols<T: Copy, M: TrustMode> {
63    /// The high bits of the clk of the syscall.
64    pub clk_high: T,
65
66    /// The low bits of clk of the syscall.
67    pub clk_low: T,
68
69    /// The syscall_id of the syscall.
70    pub syscall_id: T,
71
72    /// The arg1.
73    pub arg1: [T; 3],
74
75    /// The arg2.
76    pub arg2: [T; 3],
77
78    pub is_real: T,
79
80    /// The trap code of the syscall.
81    pub trap_code: M::TrapCodeCols<T>,
82}
83
84impl<F: PrimeField32, M: TrustMode> MachineAir<F> for SyscallChip<M> {
85    type Record = ExecutionRecord;
86
87    type Program = Program;
88
89    fn name(&self) -> &'static str {
90        if M::IS_TRUSTED {
91            match self.shard_kind {
92                SyscallShardKind::Core => "SyscallCore",
93                SyscallShardKind::Precompile => "SyscallPrecompile",
94            }
95        } else {
96            match self.shard_kind {
97                SyscallShardKind::Core => "SyscallCoreUser",
98                SyscallShardKind::Precompile => "SyscallPrecompileUser",
99            }
100        }
101    }
102
103    fn generate_dependencies(&self, input: &ExecutionRecord, output: &mut ExecutionRecord) {
104        if input.program.enable_untrusted_programs == M::IS_TRUSTED {
105            return;
106        }
107        let events = match self.shard_kind {
108            SyscallShardKind::Core => &input
109                .syscall_events
110                .iter()
111                .map(|(event, _)| event)
112                .filter(|e| e.should_send)
113                .copied()
114                .collect::<Vec<_>>(),
115            SyscallShardKind::Precompile => &input
116                .precompile_events
117                .all_events()
118                .map(|(event, _)| event.to_owned())
119                .collect::<Vec<_>>(),
120        };
121
122        let events = events
123            .iter()
124            .filter(|e| e.should_send)
125            .map(|event| {
126                let trap_code =
127                    if let Some(TrapError::PagePermissionViolation(code)) = event.trap_error {
128                        code as u8
129                    } else {
130                        0
131                    };
132
133                let mut blu = Vec::new();
134                blu.add_u8_range_checks(&[event.syscall_id as u8, trap_code]);
135                blu.add_u16_range_checks(&[(event.arg1 & 0xFFFF) as u16]);
136                if !M::IS_TRUSTED {
137                    blu.add_u16_range_checks(&[((event.arg1 >> 16) & 0xFFFF) as u16]);
138                }
139
140                let global_event = GlobalInteractionEvent {
141                    message: [
142                        (event.clk >> 24) as u32,
143                        (event.clk & 0xFFFFFF) as u32,
144                        event.syscall_id + (1 << 8) * (event.arg1 & 0xFFFF) as u32,
145                        ((event.arg1 >> 16) & 0xFFFF) as u32 + ((trap_code as u32) << 16),
146                        ((event.arg1 >> 32) & 0xFFFF) as u32,
147                        (event.arg2 & 0xFFFF) as u32,
148                        ((event.arg2 >> 16) & 0xFFFF) as u32,
149                        ((event.arg2 >> 32) & 0xFFFF) as u32,
150                    ],
151                    is_receive: self.shard_kind == SyscallShardKind::Precompile,
152                    kind: InteractionKind::Syscall as u8,
153                };
154                output.add_byte_lookup_events(blu);
155                global_event
156            })
157            .collect_vec();
158        output.global_interaction_events.extend(events);
159    }
160
161    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
162        if input.program.enable_untrusted_programs == M::IS_TRUSTED {
163            return Some(0);
164        }
165        let events = match self.shard_kind {
166            SyscallShardKind::Core => &input
167                .syscall_events
168                .iter()
169                .map(|(event, _)| event)
170                .filter(|e| e.should_send)
171                .copied()
172                .collect::<Vec<_>>(),
173            SyscallShardKind::Precompile => &input
174                .precompile_events
175                .all_events()
176                .map(|(event, _)| event.to_owned())
177                .collect::<Vec<_>>(),
178        };
179        let nb_rows = events.len();
180        let size_log2 = input.fixed_log2_rows::<F, _>(self);
181        let padded_nb_rows = next_multiple_of_32(nb_rows, size_log2);
182        Some(padded_nb_rows)
183    }
184
185    fn generate_trace_into(
186        &self,
187        input: &ExecutionRecord,
188        _output: &mut ExecutionRecord,
189        buffer: &mut [MaybeUninit<F>],
190    ) {
191        if input.program.enable_untrusted_programs == M::IS_TRUSTED {
192            return;
193        }
194        let row_fn = |syscall_event: &SyscallEvent, cols: &mut SyscallCols<F, M>| {
195            cols.clk_high = F::from_canonical_u32((syscall_event.clk >> 24) as u32);
196            cols.clk_low = F::from_canonical_u32((syscall_event.clk & 0xFFFFFF) as u32);
197            cols.syscall_id = F::from_canonical_u32(syscall_event.syscall_code.syscall_id());
198            cols.arg1 = [
199                F::from_canonical_u64((syscall_event.arg1 & 0xFFFF) as u64),
200                F::from_canonical_u64(((syscall_event.arg1 >> 16) & 0xFFFF) as u64),
201                F::from_canonical_u64(((syscall_event.arg1 >> 32) & 0xFFFF) as u64),
202            ];
203            cols.arg2 = [
204                F::from_canonical_u64((syscall_event.arg2 & 0xFFFF) as u64),
205                F::from_canonical_u64(((syscall_event.arg2 >> 16) & 0xFFFF) as u64),
206                F::from_canonical_u64(((syscall_event.arg2 >> 32) & 0xFFFF) as u64),
207            ];
208
209            cols.is_real = F::one();
210        };
211
212        let padded_nb_rows = <SyscallChip<M> as MachineAir<F>>::num_rows(self, input).unwrap();
213        let width = <Self as BaseAir<F>>::width(self);
214
215        // Get event slice based on shard kind
216        let events: Vec<&SyscallEvent> = match self.shard_kind {
217            SyscallShardKind::Core => input
218                .syscall_events
219                .iter()
220                .map(|(event, _)| event)
221                .filter(|e| e.should_send)
222                .collect(),
223            SyscallShardKind::Precompile => {
224                input.precompile_events.all_events().map(|(event, _)| event).collect()
225            }
226        };
227
228        let num_event_rows = events.len();
229
230        unsafe {
231            let padding_start = num_event_rows * width;
232            let padding_size = (padded_nb_rows - num_event_rows) * width;
233            if padding_size > 0 {
234                core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
235            }
236        }
237
238        let buffer_ptr = buffer.as_mut_ptr() as *mut F;
239        let values = unsafe { core::slice::from_raw_parts_mut(buffer_ptr, num_event_rows * width) };
240
241        values.par_chunks_mut(width).enumerate().for_each(|(idx, row)| {
242            if idx < events.len() {
243                let cols: &mut SyscallCols<F, M> = row.borrow_mut();
244                row_fn(events[idx], cols);
245                if !M::IS_TRUSTED {
246                    let cols: &mut SyscallCols<F, UserMode> = row.borrow_mut();
247                    let trap_code = if let Some(TrapError::PagePermissionViolation(code)) =
248                        events[idx].trap_error
249                    {
250                        code
251                    } else {
252                        0
253                    };
254                    cols.trap_code.trap_code = F::from_canonical_u64(trap_code);
255                }
256            }
257        });
258    }
259
260    fn included(&self, shard: &Self::Record) -> bool {
261        if shard.program.enable_untrusted_programs == M::IS_TRUSTED {
262            return false;
263        }
264        if let Some(shape) = shard.shape.as_ref() {
265            shape.included::<F, _>(self)
266        } else {
267            match self.shard_kind {
268                SyscallShardKind::Core => {
269                    shard
270                        .syscall_events
271                        .iter()
272                        .map(|(event, _)| event)
273                        .filter(|e| e.should_send)
274                        .take(1)
275                        .count()
276                        > 0
277                }
278                SyscallShardKind::Precompile => {
279                    !shard.precompile_events.is_empty()
280                        && !shard.contains_cpu()
281                        && shard.global_memory_initialize_events.is_empty()
282                        && shard.global_memory_finalize_events.is_empty()
283                        && shard.global_page_prot_initialize_events.is_empty()
284                        && shard.global_page_prot_finalize_events.is_empty()
285                }
286            }
287        }
288    }
289
290    fn column_names(&self) -> Vec<String> {
291        SyscallCols::<F, M>::struct_reflection().unwrap()
292    }
293}
294
295impl<AB, M: TrustMode> Air<AB> for SyscallChip<M>
296where
297    AB: SP1AirBuilder,
298{
299    fn eval(&self, builder: &mut AB) {
300        let main = builder.main();
301        let local = main.row_slice(0);
302        let local: &SyscallCols<AB::Var, M> = (*local).borrow();
303
304        #[cfg(feature = "mprotect")]
305        builder.assert_eq(
306            builder.extract_public_values().is_untrusted_programs_enabled,
307            AB::Expr::from_bool(!M::IS_TRUSTED),
308        );
309
310        let mut trap_code = AB::Expr::zero();
311        if !M::IS_TRUSTED {
312            let local = main.row_slice(0);
313            let local: &SyscallCols<AB::Var, UserMode> = (*local).borrow();
314
315            #[cfg(not(feature = "mprotect"))]
316            builder.assert_zero(local.is_real);
317
318            trap_code = local.trap_code.trap_code.into();
319        }
320
321        // Constrain that `local.is_real` is boolean.
322        builder.assert_bool(local.is_real);
323
324        builder.assert_eq(
325            local.is_real * local.is_real * local.is_real,
326            local.is_real * local.is_real * local.is_real,
327        );
328
329        // Constrain that the syscall id and trap code is 8 bits.
330        builder.slice_range_check_u8(&[local.syscall_id.into(), trap_code.clone()], local.is_real);
331        // Constrain that the arg1[0] is 16 bits.
332        builder.slice_range_check_u16(&[local.arg1[0]], local.is_real);
333
334        if !M::IS_TRUSTED {
335            // Constrain that the arg1[1] is 16 bits.
336            builder.slice_range_check_u16(&[local.arg1[1]], local.is_real);
337        }
338
339        #[cfg(not(feature = "mprotect"))]
340        let arg4: AB::Expr = local.arg1[1].into().clone();
341        #[cfg(feature = "mprotect")]
342        let arg4: AB::Expr =
343            local.arg1[1].into().clone() + trap_code.clone() * AB::F::from_canonical_u32(1 << 16);
344
345        match self.shard_kind {
346            SyscallShardKind::Core => {
347                builder.receive_syscall(
348                    local.clk_high,
349                    local.clk_low,
350                    local.syscall_id,
351                    trap_code.clone(),
352                    local.arg1.map(Into::into),
353                    local.arg2.map(Into::into),
354                    local.is_real,
355                    InteractionScope::Local,
356                );
357
358                // Send the "send interaction" to the global table.
359                builder.send(
360                    AirInteraction::new(
361                        vec![
362                            local.clk_high.into(),
363                            local.clk_low.into(),
364                            local.syscall_id + local.arg1[0] * AB::F::from_canonical_u32(1 << 8),
365                            arg4,
366                            local.arg1[2].into(),
367                            local.arg2[0].into(),
368                            local.arg2[1].into(),
369                            local.arg2[2].into(),
370                            AB::Expr::one(),
371                            AB::Expr::zero(),
372                            AB::Expr::from_canonical_u8(InteractionKind::Syscall as u8),
373                        ],
374                        local.is_real.into(),
375                        InteractionKind::Global,
376                    ),
377                    InteractionScope::Local,
378                );
379            }
380            SyscallShardKind::Precompile => {
381                builder.send_syscall(
382                    local.clk_high,
383                    local.clk_low,
384                    local.syscall_id,
385                    trap_code.clone(),
386                    local.arg1.map(Into::into),
387                    local.arg2.map(Into::into),
388                    local.is_real,
389                    InteractionScope::Local,
390                );
391
392                // Send the "receive interaction" to the global table.
393                builder.send(
394                    AirInteraction::new(
395                        vec![
396                            local.clk_high.into(),
397                            local.clk_low.into(),
398                            local.syscall_id + local.arg1[0] * AB::F::from_canonical_u32(1 << 8),
399                            arg4,
400                            local.arg1[2].into(),
401                            local.arg2[0].into(),
402                            local.arg2[1].into(),
403                            local.arg2[2].into(),
404                            AB::Expr::zero(),
405                            AB::Expr::one(),
406                            AB::Expr::from_canonical_u8(InteractionKind::Syscall as u8),
407                        ],
408                        local.is_real.into(),
409                        InteractionKind::Global,
410                    ),
411                    InteractionScope::Local,
412                );
413            }
414        }
415    }
416}
417
418impl<F, M: TrustMode> BaseAir<F> for SyscallChip<M> {
419    fn width(&self) -> usize {
420        if M::IS_TRUSTED {
421            NUM_SYSCALL_COLS_SUPERVISOR
422        } else {
423            NUM_SYSCALL_COLS_USER
424        }
425    }
426}
427
428impl fmt::Display for SyscallShardKind {
429    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
430        match self {
431            SyscallShardKind::Core => write!(f, "Core"),
432            SyscallShardKind::Precompile => write!(f, "Precompile"),
433        }
434    }
435}