Skip to main content

sp1_core_machine/control_flow/branch/
trace.rs

1use std::{borrow::BorrowMut, mem::MaybeUninit};
2
3use hashbrown::HashMap;
4use itertools::Itertools;
5use rayon::iter::{ParallelBridge, ParallelIterator};
6use slop_air::BaseAir;
7use slop_algebra::PrimeField32;
8use sp1_core_executor::{
9    events::{BranchEvent, ByteLookupEvent, ByteRecord},
10    ExecutionRecord, Opcode, Program,
11};
12use sp1_hypercube::air::MachineAir;
13use struct_reflection::StructReflectionHelper;
14
15use crate::{utils::next_multiple_of_32, TrustMode, UserMode};
16
17use super::{BranchChip, BranchColumns};
18
19impl<F: PrimeField32, M: TrustMode> MachineAir<F> for BranchChip<M> {
20    type Record = ExecutionRecord;
21
22    type Program = Program;
23
24    fn name(&self) -> &'static str {
25        if M::IS_TRUSTED {
26            "Branch"
27        } else {
28            "BranchUser"
29        }
30    }
31
32    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
33        if input.program.enable_untrusted_programs == M::IS_TRUSTED {
34            return Some(0);
35        }
36        let nb_rows =
37            next_multiple_of_32(input.branch_events.len(), input.fixed_log2_rows::<F, _>(self));
38        Some(nb_rows)
39    }
40
41    fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
42        if input.program.enable_untrusted_programs == M::IS_TRUSTED {
43            return;
44        }
45
46        let chunk_size = std::cmp::max((input.branch_events.len()) / num_cpus::get(), 1);
47        let width = <BranchChip<M> as BaseAir<F>>::width(self);
48
49        let blu_batches = input
50            .branch_events
51            .chunks(chunk_size)
52            .par_bridge()
53            .map(|events| {
54                let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
55                events.iter().for_each(|event| {
56                    let mut row = vec![F::zero(); width];
57                    let cols: &mut BranchColumns<F, M> = row.as_mut_slice().borrow_mut();
58
59                    self.event_to_row(&event.0, cols, &mut blu);
60                    cols.state.populate(&mut blu, event.0.clk, event.0.pc);
61                    cols.adapter.populate(&mut blu, event.1);
62                });
63                blu
64            })
65            .collect::<Vec<_>>();
66
67        output.add_byte_lookup_events_from_maps(blu_batches.iter().collect_vec());
68    }
69
70    fn generate_trace_into(
71        &self,
72        input: &ExecutionRecord,
73        _output: &mut ExecutionRecord,
74        buffer: &mut [MaybeUninit<F>],
75    ) {
76        if input.program.enable_untrusted_programs == M::IS_TRUSTED {
77            return;
78        }
79
80        // Generate the rows for the trace.
81        let chunk_size = std::cmp::max(input.branch_events.len() / num_cpus::get(), 1);
82        let padded_nb_rows = <BranchChip<M> as MachineAir<F>>::num_rows(self, input).unwrap();
83        let width = <BranchChip<M> as BaseAir<F>>::width(self);
84
85        let num_event_rows = input.branch_events.len();
86
87        unsafe {
88            let padding_start = num_event_rows * width;
89            let padding_size = (padded_nb_rows - num_event_rows) * width;
90            if padding_size > 0 {
91                core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
92            }
93        }
94
95        let buffer_ptr = buffer.as_mut_ptr() as *mut F;
96        let values = unsafe { core::slice::from_raw_parts_mut(buffer_ptr, num_event_rows * width) };
97
98        values.chunks_mut(chunk_size * width).enumerate().par_bridge().for_each(|(i, rows)| {
99            let mut blu = Vec::new();
100            rows.chunks_mut(width).enumerate().for_each(|(j, row)| {
101                let idx = i * chunk_size + j;
102                let cols: &mut BranchColumns<F, M> = row.borrow_mut();
103
104                if idx < input.branch_events.len() {
105                    let event = input.branch_events[idx];
106                    self.event_to_row(&event.0, cols, &mut blu);
107                    cols.state.populate(&mut blu, event.0.clk, event.0.pc);
108                    cols.adapter.populate(&mut blu, event.1);
109                    if !M::IS_TRUSTED {
110                        let cols: &mut BranchColumns<F, UserMode> = row.borrow_mut();
111                        cols.adapter_cols.is_trusted = F::from_bool(!event.1.is_untrusted);
112                    }
113                }
114            });
115        });
116    }
117
118    fn included(&self, shard: &Self::Record) -> bool {
119        if let Some(shape) = shard.shape.as_ref() {
120            shape.included::<F, _>(self)
121        } else {
122            !shard.branch_events.is_empty()
123                && (M::IS_TRUSTED != shard.program.enable_untrusted_programs)
124        }
125    }
126
127    fn column_names(&self) -> Vec<String> {
128        BranchColumns::<F, M>::struct_reflection().unwrap()
129    }
130}
131
132impl<M: TrustMode> BranchChip<M> {
133    /// Create a row from an event.
134    fn event_to_row<F: PrimeField32>(
135        &self,
136        event: &BranchEvent,
137        cols: &mut BranchColumns<F, M>,
138        blu: &mut impl ByteRecord,
139    ) {
140        cols.is_beq = F::from_bool(matches!(event.opcode, Opcode::BEQ));
141        cols.is_bne = F::from_bool(matches!(event.opcode, Opcode::BNE));
142        cols.is_blt = F::from_bool(matches!(event.opcode, Opcode::BLT));
143        cols.is_bge = F::from_bool(matches!(event.opcode, Opcode::BGE));
144        cols.is_bltu = F::from_bool(matches!(event.opcode, Opcode::BLTU));
145        cols.is_bgeu = F::from_bool(matches!(event.opcode, Opcode::BGEU));
146
147        let a_eq_b = event.a == event.b;
148
149        let use_signed_comparison = matches!(event.opcode, Opcode::BLT | Opcode::BGE);
150
151        let a_lt_b = if use_signed_comparison {
152            (event.a as i64) < (event.b as i64)
153        } else {
154            event.a < event.b
155        };
156
157        let branching = match event.opcode {
158            Opcode::BEQ => a_eq_b,
159            Opcode::BNE => !a_eq_b,
160            Opcode::BLT | Opcode::BLTU => a_lt_b,
161            Opcode::BGE | Opcode::BGEU => !a_lt_b,
162            _ => unreachable!(),
163        };
164
165        cols.compare_operation.populate_signed(
166            blu,
167            a_lt_b as u64,
168            event.a,
169            event.b,
170            use_signed_comparison,
171        );
172
173        cols.next_pc = [
174            F::from_canonical_u16((event.next_pc & 0xFFFF) as u16),
175            F::from_canonical_u16(((event.next_pc >> 16) & 0xFFFF) as u16),
176            F::from_canonical_u16(((event.next_pc >> 32) & 0xFFFF) as u16),
177        ];
178        blu.add_bit_range_check((event.next_pc & 0xFFFF) as u16 / 4, 14);
179        blu.add_u16_range_checks_field(&cols.next_pc[1..3]);
180
181        if branching {
182            cols.is_branching = F::one();
183        } else {
184            cols.is_branching = F::zero();
185        }
186    }
187}