sp1_core_machine/control_flow/branch/
trace.rs1use std::borrow::BorrowMut;
2
3use hashbrown::HashMap;
4use itertools::Itertools;
5use p3_field::PrimeField32;
6use p3_matrix::dense::RowMajorMatrix;
7use rayon::iter::{ParallelBridge, ParallelIterator};
8use sp1_core_executor::{
9 events::{BranchEvent, ByteLookupEvent, ByteRecord},
10 ExecutionRecord, Opcode, Program,
11};
12use sp1_stark::air::MachineAir;
13
14use crate::utils::{next_power_of_two, zeroed_f_vec};
15
16use super::{BranchChip, BranchColumns, NUM_BRANCH_COLS};
17
18impl<F: PrimeField32> MachineAir<F> for BranchChip {
19 type Record = ExecutionRecord;
20
21 type Program = Program;
22
23 fn name(&self) -> String {
24 "Branch".to_string()
25 }
26
27 fn generate_trace(
28 &self,
29 input: &ExecutionRecord,
30 output: &mut ExecutionRecord,
31 ) -> RowMajorMatrix<F> {
32 let chunk_size = std::cmp::max((input.branch_events.len()) / num_cpus::get(), 1);
33 let nb_rows = input.branch_events.len();
34 let size_log2 = input.fixed_log2_rows::<F, _>(self);
35 let padded_nb_rows = next_power_of_two(nb_rows, size_log2);
36 let mut values = zeroed_f_vec(padded_nb_rows * NUM_BRANCH_COLS);
37
38 let blu_events = values
39 .chunks_mut(chunk_size * NUM_BRANCH_COLS)
40 .enumerate()
41 .par_bridge()
42 .map(|(i, rows)| {
43 let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
44 rows.chunks_mut(NUM_BRANCH_COLS).enumerate().for_each(|(j, row)| {
45 let idx = i * chunk_size + j;
46 let cols: &mut BranchColumns<F> = row.borrow_mut();
47
48 if idx < input.branch_events.len() {
49 let event = &input.branch_events[idx];
50 self.event_to_row(event, cols, &mut blu);
51 }
52 });
53 blu
54 })
55 .collect::<Vec<_>>();
56
57 output.add_byte_lookup_events_from_maps(blu_events.iter().collect_vec());
58
59 RowMajorMatrix::new(values, NUM_BRANCH_COLS)
61 }
62
63 fn included(&self, shard: &Self::Record) -> bool {
64 if let Some(shape) = shard.shape.as_ref() {
65 shape.included::<F, _>(self)
66 } else {
67 !shard.branch_events.is_empty()
68 }
69 }
70
71 fn local_only(&self) -> bool {
72 true
73 }
74}
75
76impl BranchChip {
77 fn event_to_row<F: PrimeField32>(
79 &self,
80 event: &BranchEvent,
81 cols: &mut BranchColumns<F>,
82 blu: &mut HashMap<ByteLookupEvent, usize>,
83 ) {
84 cols.is_beq = F::from_bool(matches!(event.opcode, Opcode::BEQ));
85 cols.is_bne = F::from_bool(matches!(event.opcode, Opcode::BNE));
86 cols.is_blt = F::from_bool(matches!(event.opcode, Opcode::BLT));
87 cols.is_bge = F::from_bool(matches!(event.opcode, Opcode::BGE));
88 cols.is_bltu = F::from_bool(matches!(event.opcode, Opcode::BLTU));
89 cols.is_bgeu = F::from_bool(matches!(event.opcode, Opcode::BGEU));
90
91 cols.op_a_value = event.a.into();
92 cols.op_b_value = event.b.into();
93 cols.op_c_value = event.c.into();
94 cols.op_a_0 = F::from_bool(event.op_a_0);
95
96 let a_eq_b = event.a == event.b;
97
98 let use_signed_comparison = matches!(event.opcode, Opcode::BLT | Opcode::BGE);
99
100 let a_lt_b = if use_signed_comparison {
101 (event.a as i32) < (event.b as i32)
102 } else {
103 event.a < event.b
104 };
105 let a_gt_b = if use_signed_comparison {
106 (event.a as i32) > (event.b as i32)
107 } else {
108 event.a > event.b
109 };
110
111 cols.a_eq_b = F::from_bool(a_eq_b);
112 cols.a_lt_b = F::from_bool(a_lt_b);
113 cols.a_gt_b = F::from_bool(a_gt_b);
114
115 let branching = match event.opcode {
116 Opcode::BEQ => a_eq_b,
117 Opcode::BNE => !a_eq_b,
118 Opcode::BLT | Opcode::BLTU => a_lt_b,
119 Opcode::BGE | Opcode::BGEU => a_eq_b || a_gt_b,
120 _ => unreachable!(),
121 };
122
123 cols.pc = event.pc.into();
124 cols.next_pc = event.next_pc.into();
125 cols.pc_range_checker.populate(cols.pc, blu);
126 cols.next_pc_range_checker.populate(cols.next_pc, blu);
127
128 if branching {
129 cols.is_branching = F::one();
130 } else {
131 cols.not_branching = F::one();
132 }
133 }
134}