sp1_core_machine/control_flow/branch/
trace.rs1use std::{borrow::BorrowMut, mem::MaybeUninit};
2
3use hashbrown::HashMap;
4use itertools::Itertools;
5use rayon::iter::{ParallelBridge, ParallelIterator};
6use slop_air::BaseAir;
7use slop_algebra::PrimeField32;
8use sp1_core_executor::{
9 events::{BranchEvent, ByteLookupEvent, ByteRecord},
10 ExecutionRecord, Opcode, Program,
11};
12use sp1_hypercube::air::MachineAir;
13use struct_reflection::StructReflectionHelper;
14
15use crate::{utils::next_multiple_of_32, TrustMode, UserMode};
16
17use super::{BranchChip, BranchColumns};
18
19impl<F: PrimeField32, M: TrustMode> MachineAir<F> for BranchChip<M> {
20 type Record = ExecutionRecord;
21
22 type Program = Program;
23
24 fn name(&self) -> &'static str {
25 if M::IS_TRUSTED {
26 "Branch"
27 } else {
28 "BranchUser"
29 }
30 }
31
32 fn num_rows(&self, input: &Self::Record) -> Option<usize> {
33 if input.program.enable_untrusted_programs == M::IS_TRUSTED {
34 return Some(0);
35 }
36 let nb_rows =
37 next_multiple_of_32(input.branch_events.len(), input.fixed_log2_rows::<F, _>(self));
38 Some(nb_rows)
39 }
40
41 fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
42 if input.program.enable_untrusted_programs == M::IS_TRUSTED {
43 return;
44 }
45
46 let chunk_size = std::cmp::max((input.branch_events.len()) / num_cpus::get(), 1);
47 let width = <BranchChip<M> as BaseAir<F>>::width(self);
48
49 let blu_batches = input
50 .branch_events
51 .chunks(chunk_size)
52 .par_bridge()
53 .map(|events| {
54 let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
55 events.iter().for_each(|event| {
56 let mut row = vec![F::zero(); width];
57 let cols: &mut BranchColumns<F, M> = row.as_mut_slice().borrow_mut();
58
59 self.event_to_row(&event.0, cols, &mut blu);
60 cols.state.populate(&mut blu, event.0.clk, event.0.pc);
61 cols.adapter.populate(&mut blu, event.1);
62 });
63 blu
64 })
65 .collect::<Vec<_>>();
66
67 output.add_byte_lookup_events_from_maps(blu_batches.iter().collect_vec());
68 }
69
70 fn generate_trace_into(
71 &self,
72 input: &ExecutionRecord,
73 _output: &mut ExecutionRecord,
74 buffer: &mut [MaybeUninit<F>],
75 ) {
76 if input.program.enable_untrusted_programs == M::IS_TRUSTED {
77 return;
78 }
79
80 let chunk_size = std::cmp::max(input.branch_events.len() / num_cpus::get(), 1);
82 let padded_nb_rows = <BranchChip<M> as MachineAir<F>>::num_rows(self, input).unwrap();
83 let width = <BranchChip<M> as BaseAir<F>>::width(self);
84
85 let num_event_rows = input.branch_events.len();
86
87 unsafe {
88 let padding_start = num_event_rows * width;
89 let padding_size = (padded_nb_rows - num_event_rows) * width;
90 if padding_size > 0 {
91 core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
92 }
93 }
94
95 let buffer_ptr = buffer.as_mut_ptr() as *mut F;
96 let values = unsafe { core::slice::from_raw_parts_mut(buffer_ptr, num_event_rows * width) };
97
98 values.chunks_mut(chunk_size * width).enumerate().par_bridge().for_each(|(i, rows)| {
99 let mut blu = Vec::new();
100 rows.chunks_mut(width).enumerate().for_each(|(j, row)| {
101 let idx = i * chunk_size + j;
102 let cols: &mut BranchColumns<F, M> = row.borrow_mut();
103
104 if idx < input.branch_events.len() {
105 let event = input.branch_events[idx];
106 self.event_to_row(&event.0, cols, &mut blu);
107 cols.state.populate(&mut blu, event.0.clk, event.0.pc);
108 cols.adapter.populate(&mut blu, event.1);
109 if !M::IS_TRUSTED {
110 let cols: &mut BranchColumns<F, UserMode> = row.borrow_mut();
111 cols.adapter_cols.is_trusted = F::from_bool(!event.1.is_untrusted);
112 }
113 }
114 });
115 });
116 }
117
118 fn included(&self, shard: &Self::Record) -> bool {
119 if let Some(shape) = shard.shape.as_ref() {
120 shape.included::<F, _>(self)
121 } else {
122 !shard.branch_events.is_empty()
123 && (M::IS_TRUSTED != shard.program.enable_untrusted_programs)
124 }
125 }
126
127 fn column_names(&self) -> Vec<String> {
128 BranchColumns::<F, M>::struct_reflection().unwrap()
129 }
130}
131
132impl<M: TrustMode> BranchChip<M> {
133 fn event_to_row<F: PrimeField32>(
135 &self,
136 event: &BranchEvent,
137 cols: &mut BranchColumns<F, M>,
138 blu: &mut impl ByteRecord,
139 ) {
140 cols.is_beq = F::from_bool(matches!(event.opcode, Opcode::BEQ));
141 cols.is_bne = F::from_bool(matches!(event.opcode, Opcode::BNE));
142 cols.is_blt = F::from_bool(matches!(event.opcode, Opcode::BLT));
143 cols.is_bge = F::from_bool(matches!(event.opcode, Opcode::BGE));
144 cols.is_bltu = F::from_bool(matches!(event.opcode, Opcode::BLTU));
145 cols.is_bgeu = F::from_bool(matches!(event.opcode, Opcode::BGEU));
146
147 let a_eq_b = event.a == event.b;
148
149 let use_signed_comparison = matches!(event.opcode, Opcode::BLT | Opcode::BGE);
150
151 let a_lt_b = if use_signed_comparison {
152 (event.a as i64) < (event.b as i64)
153 } else {
154 event.a < event.b
155 };
156
157 let branching = match event.opcode {
158 Opcode::BEQ => a_eq_b,
159 Opcode::BNE => !a_eq_b,
160 Opcode::BLT | Opcode::BLTU => a_lt_b,
161 Opcode::BGE | Opcode::BGEU => !a_lt_b,
162 _ => unreachable!(),
163 };
164
165 cols.compare_operation.populate_signed(
166 blu,
167 a_lt_b as u64,
168 event.a,
169 event.b,
170 use_signed_comparison,
171 );
172
173 cols.next_pc = [
174 F::from_canonical_u16((event.next_pc & 0xFFFF) as u16),
175 F::from_canonical_u16(((event.next_pc >> 16) & 0xFFFF) as u16),
176 F::from_canonical_u16(((event.next_pc >> 32) & 0xFFFF) as u16),
177 ];
178 blu.add_bit_range_check((event.next_pc & 0xFFFF) as u16 / 4, 14);
179 blu.add_u16_range_checks_field(&cols.next_pc[1..3]);
180
181 if branching {
182 cols.is_branching = F::one();
183 } else {
184 cols.is_branching = F::zero();
185 }
186 }
187}