sp1_core_machine/control_flow/jump/
trace.rs

1use std::borrow::BorrowMut;
2
3use hashbrown::HashMap;
4use itertools::Itertools;
5use p3_field::PrimeField32;
6use p3_matrix::dense::RowMajorMatrix;
7use rayon::iter::{ParallelBridge, ParallelIterator};
8use sp1_core_executor::{
9    events::{ByteLookupEvent, ByteRecord, JumpEvent},
10    ExecutionRecord, Opcode, Program,
11};
12use sp1_stark::air::MachineAir;
13
14use crate::utils::{next_power_of_two, zeroed_f_vec};
15
16use super::{JumpChip, JumpColumns, NUM_JUMP_COLS};
17
18impl<F: PrimeField32> MachineAir<F> for JumpChip {
19    type Record = ExecutionRecord;
20
21    type Program = Program;
22
23    fn name(&self) -> String {
24        "Jump".to_string()
25    }
26
27    fn generate_trace(
28        &self,
29        input: &ExecutionRecord,
30        output: &mut ExecutionRecord,
31    ) -> RowMajorMatrix<F> {
32        let chunk_size = std::cmp::max((input.jump_events.len()) / num_cpus::get(), 1);
33        let nb_rows = input.jump_events.len();
34        let size_log2 = input.fixed_log2_rows::<F, _>(self);
35        let padded_nb_rows = next_power_of_two(nb_rows, size_log2);
36        let mut values = zeroed_f_vec(padded_nb_rows * NUM_JUMP_COLS);
37
38        let blu_events = values
39            .chunks_mut(chunk_size * NUM_JUMP_COLS)
40            .enumerate()
41            .par_bridge()
42            .map(|(i, rows)| {
43                let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
44                rows.chunks_mut(NUM_JUMP_COLS).enumerate().for_each(|(j, row)| {
45                    let idx = i * chunk_size + j;
46                    let cols: &mut JumpColumns<F> = row.borrow_mut();
47
48                    if idx < input.jump_events.len() {
49                        let event = &input.jump_events[idx];
50                        self.event_to_row(event, cols, &mut blu);
51                    }
52                });
53                blu
54            })
55            .collect::<Vec<_>>();
56
57        output.add_byte_lookup_events_from_maps(blu_events.iter().collect_vec());
58
59        // Convert the trace to a row major matrix.
60        RowMajorMatrix::new(values, NUM_JUMP_COLS)
61    }
62
63    fn included(&self, shard: &Self::Record) -> bool {
64        if let Some(shape) = shard.shape.as_ref() {
65            shape.included::<F, _>(self)
66        } else {
67            !shard.jump_events.is_empty()
68        }
69    }
70
71    fn local_only(&self) -> bool {
72        true
73    }
74}
75
76impl JumpChip {
77    /// Create a row from an event.
78    fn event_to_row<F: PrimeField32>(
79        &self,
80        event: &JumpEvent,
81        cols: &mut JumpColumns<F>,
82        blu: &mut HashMap<ByteLookupEvent, usize>,
83    ) {
84        cols.is_jal = F::from_bool(matches!(event.opcode, Opcode::JAL));
85        cols.is_jalr = F::from_bool(matches!(event.opcode, Opcode::JALR));
86
87        cols.op_a_value = event.a.into();
88        cols.op_b_value = event.b.into();
89        cols.op_c_value = event.c.into();
90        cols.op_a_0 = F::from_bool(event.op_a_0);
91
92        cols.op_a_range_checker.populate(cols.op_a_value, blu);
93
94        cols.pc = event.pc.into();
95        cols.pc_range_checker.populate(cols.pc, blu);
96
97        let next_pc = match event.opcode {
98            Opcode::JAL => event.pc.wrapping_add(event.b),
99            Opcode::JALR => event.b.wrapping_add(event.c),
100            _ => unreachable!(),
101        };
102
103        cols.next_pc = next_pc.into();
104        cols.next_pc_range_checker.populate(cols.next_pc, blu);
105    }
106}