triton_vm/table/
u32.rs

1use std::cmp::Ordering;
2use std::cmp::max;
3
4use air::challenge_id::ChallengeId;
5use air::cross_table_argument::CrossTableArg;
6use air::cross_table_argument::LookupArg;
7use air::table::u32::U32Table;
8use air::table_column::MasterAuxColumn;
9use air::table_column::MasterMainColumn;
10use arbitrary::Arbitrary;
11use isa::instruction::Instruction;
12use ndarray::Array1;
13use ndarray::Array2;
14use ndarray::ArrayView2;
15use ndarray::ArrayViewMut2;
16use ndarray::parallel::prelude::*;
17use ndarray::s;
18use num_traits::One;
19use num_traits::Zero;
20use strum::EnumCount;
21use twenty_first::prelude::*;
22
23use crate::aet::AlgebraicExecutionTrace;
24use crate::challenges::Challenges;
25use crate::ndarray_helper::ROW_AXIS;
26use crate::profiler::profiler;
27use crate::table::TraceTable;
28
29type MainColumn = <U32Table as air::AIR>::MainColumn;
30type AuxColumn = <U32Table as air::AIR>::AuxColumn;
31
32/// An executed u32 instruction as well as its operands.
33#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
34pub struct U32TableEntry {
35    pub instruction: Instruction,
36    pub left_operand: BFieldElement,
37    pub right_operand: BFieldElement,
38}
39
40impl U32TableEntry {
41    pub fn new<L, R>(instruction: Instruction, left_operand: L, right_operand: R) -> Self
42    where
43        L: Into<BFieldElement>,
44        R: Into<BFieldElement>,
45    {
46        Self {
47            instruction,
48            left_operand: left_operand.into(),
49            right_operand: right_operand.into(),
50        }
51    }
52
53    /// The number of rows this entry contributes to the U32 Table.
54    pub(crate) fn table_height_contribution(&self) -> u32 {
55        let lhs = self.left_operand.value();
56        let rhs = self.right_operand.value();
57        let dominant_operand = match self.instruction {
58            Instruction::Pow => rhs, // left operand doesn't change across rows
59            _ => max(lhs, rhs),
60        };
61        match dominant_operand {
62            0 => 1,
63            _ => 2 + dominant_operand.ilog2(),
64        }
65    }
66}
67
68impl PartialOrd for U32TableEntry {
69    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
70        Some(self.cmp(other))
71    }
72}
73
74impl Ord for U32TableEntry {
75    fn cmp(&self, other: &Self) -> Ordering {
76        // destructure to get compilation errors if fields change
77        let Self {
78            instruction: self_instruction,
79            left_operand: self_left_operand,
80            right_operand: self_right_operand,
81        } = *self;
82        let Self {
83            instruction: other_instruction,
84            left_operand: other_left_operand,
85            right_operand: other_right_operand,
86        } = *other;
87
88        // Even though field elements (like `BFieldElement`) do not have a
89        // natural ordering, the operands of any valid `Self` are `u32`s, which
90        // _do_ have a natural ordering.
91        let instruction_cmp = self_instruction.opcode().cmp(&other_instruction.opcode());
92        let left_operand_cmp = self_left_operand.value().cmp(&other_left_operand.value());
93        let right_operand_cmp = self_right_operand.value().cmp(&other_right_operand.value());
94
95        instruction_cmp
96            .then(left_operand_cmp)
97            .then(right_operand_cmp)
98    }
99}
100
101impl TraceTable for U32Table {
102    type FillParam = ();
103    type FillReturnInfo = ();
104
105    fn fill(mut u32_table: ArrayViewMut2<BFieldElement>, aet: &AlgebraicExecutionTrace, _: ()) {
106        let mut next_section_start = 0;
107        for (&u32_table_entry, &multiplicity) in &aet.u32_entries {
108            let mut first_row = Array2::zeros([1, MainColumn::COUNT]);
109            first_row[[0, MainColumn::CopyFlag.main_index()]] = bfe!(1);
110            first_row[[0, MainColumn::Bits.main_index()]] = bfe!(0);
111            first_row[[0, MainColumn::BitsMinus33Inv.main_index()]] = bfe!(-33).inverse();
112            first_row[[0, MainColumn::CI.main_index()]] = u32_table_entry.instruction.opcode_b();
113            first_row[[0, MainColumn::LHS.main_index()]] = u32_table_entry.left_operand;
114            first_row[[0, MainColumn::RHS.main_index()]] = u32_table_entry.right_operand;
115            first_row[[0, MainColumn::LookupMultiplicity.main_index()]] = multiplicity.into();
116            let u32_section = u32_section_next_row(first_row);
117
118            let next_section_end = next_section_start + u32_section.nrows();
119            u32_table
120                .slice_mut(s![next_section_start..next_section_end, ..])
121                .assign(&u32_section);
122            next_section_start = next_section_end;
123        }
124    }
125
126    fn pad(mut main_table: ArrayViewMut2<BFieldElement>, table_len: usize) {
127        let mut padding_row = Array1::zeros([MainColumn::COUNT]);
128        padding_row[[MainColumn::CI.main_index()]] = Instruction::Split.opcode_b();
129        padding_row[[MainColumn::BitsMinus33Inv.main_index()]] = bfe!(-33).inverse();
130
131        if table_len > 0 {
132            let last_row = main_table.row(table_len - 1);
133            padding_row[[MainColumn::CI.main_index()]] = last_row[MainColumn::CI.main_index()];
134            padding_row[[MainColumn::LHS.main_index()]] = last_row[MainColumn::LHS.main_index()];
135            padding_row[[MainColumn::LhsInv.main_index()]] =
136                last_row[MainColumn::LhsInv.main_index()];
137            padding_row[[MainColumn::Result.main_index()]] =
138                last_row[MainColumn::Result.main_index()];
139
140            // In the edge case that the last non-padding row comes from
141            // executing instruction `lt` on operands 0 and 0, the `Result`
142            // column is 0. For the padding section, where the `CopyFlag` is
143            // always 0, the `Result` needs to be set to 2 instead.
144            if padding_row[[MainColumn::CI.main_index()]] == Instruction::Lt.opcode_b() {
145                padding_row[[MainColumn::Result.main_index()]] = bfe!(2);
146            }
147        }
148
149        main_table
150            .slice_mut(s![table_len.., ..])
151            .axis_iter_mut(ROW_AXIS)
152            .into_par_iter()
153            .for_each(|mut row| row.assign(&padding_row));
154    }
155
156    fn extend(
157        main_table: ArrayView2<BFieldElement>,
158        mut aux_table: ArrayViewMut2<XFieldElement>,
159        challenges: &Challenges,
160    ) {
161        profiler!(start "u32 table");
162        assert_eq!(MainColumn::COUNT, main_table.ncols());
163        assert_eq!(AuxColumn::COUNT, aux_table.ncols());
164        assert_eq!(main_table.nrows(), aux_table.nrows());
165
166        let ci_weight = challenges[ChallengeId::U32CiWeight];
167        let lhs_weight = challenges[ChallengeId::U32LhsWeight];
168        let rhs_weight = challenges[ChallengeId::U32RhsWeight];
169        let result_weight = challenges[ChallengeId::U32ResultWeight];
170        let lookup_indeterminate = challenges[ChallengeId::U32Indeterminate];
171
172        let mut running_sum_log_derivative = LookupArg::default_initial();
173        for row_idx in 0..main_table.nrows() {
174            let current_row = main_table.row(row_idx);
175            if current_row[MainColumn::CopyFlag.main_index()].is_one() {
176                let lookup_multiplicity = current_row[MainColumn::LookupMultiplicity.main_index()];
177                let compressed_row = ci_weight * current_row[MainColumn::CI.main_index()]
178                    + lhs_weight * current_row[MainColumn::LHS.main_index()]
179                    + rhs_weight * current_row[MainColumn::RHS.main_index()]
180                    + result_weight * current_row[MainColumn::Result.main_index()];
181                running_sum_log_derivative +=
182                    lookup_multiplicity * (lookup_indeterminate - compressed_row).inverse();
183            }
184
185            let mut auxiliary_row = aux_table.row_mut(row_idx);
186            auxiliary_row[AuxColumn::LookupServerLogDerivative.aux_index()] =
187                running_sum_log_derivative;
188        }
189        profiler!(stop "u32 table");
190    }
191}
192
193fn u32_section_next_row(mut section: Array2<BFieldElement>) -> Array2<BFieldElement> {
194    let row_idx = section.nrows() - 1;
195    let current_instruction: Instruction = section[[row_idx, MainColumn::CI.main_index()]]
196        .value()
197        .try_into()
198        .expect("Unknown instruction");
199
200    // Is the last row in this section reached?
201    if (section[[row_idx, MainColumn::LHS.main_index()]].is_zero()
202        || current_instruction == Instruction::Pow)
203        && section[[row_idx, MainColumn::RHS.main_index()]].is_zero()
204    {
205        section[[row_idx, MainColumn::Result.main_index()]] = match current_instruction {
206            Instruction::Split => bfe!(0),
207            Instruction::Lt => bfe!(2),
208            Instruction::And => bfe!(0),
209            Instruction::Log2Floor => bfe!(-1),
210            Instruction::Pow => bfe!(1),
211            Instruction::PopCount => bfe!(0),
212            _ => panic!("Must be u32 instruction, not {current_instruction}."),
213        };
214
215        // If instruction `lt` is executed on operands 0 and 0, the result is
216        // known to be 0. The edge case can be reliably detected by checking
217        // whether column `Bits` is 0.
218        let both_operands_are_0 = section[[row_idx, MainColumn::Bits.main_index()]].is_zero();
219        if current_instruction == Instruction::Lt && both_operands_are_0 {
220            section[[row_idx, MainColumn::Result.main_index()]] = bfe!(0);
221        }
222
223        // The right hand side is guaranteed to be 0. However, if the current
224        // instruction is `pow`, then the left hand side might be non-zero.
225        let lhs_inv_or_0 = section[[row_idx, MainColumn::LHS.main_index()]].inverse_or_zero();
226        section[[row_idx, MainColumn::LhsInv.main_index()]] = lhs_inv_or_0;
227
228        return section;
229    }
230
231    let lhs_lsb = bfe!(section[[row_idx, MainColumn::LHS.main_index()]].value() % 2);
232    let rhs_lsb = bfe!(section[[row_idx, MainColumn::RHS.main_index()]].value() % 2);
233    let mut next_row = section.row(row_idx).to_owned();
234    next_row[MainColumn::CopyFlag.main_index()] = bfe!(0);
235    next_row[MainColumn::Bits.main_index()] += bfe!(1);
236    next_row[MainColumn::BitsMinus33Inv.main_index()] =
237        (next_row[MainColumn::Bits.main_index()] - bfe!(33)).inverse();
238    next_row[MainColumn::LHS.main_index()] = match current_instruction == Instruction::Pow {
239        true => section[[row_idx, MainColumn::LHS.main_index()]],
240        false => (section[[row_idx, MainColumn::LHS.main_index()]] - lhs_lsb) / bfe!(2),
241    };
242    next_row[MainColumn::RHS.main_index()] =
243        (section[[row_idx, MainColumn::RHS.main_index()]] - rhs_lsb) / bfe!(2);
244    next_row[MainColumn::LookupMultiplicity.main_index()] = bfe!(0);
245
246    section.push_row(next_row.view()).unwrap();
247    section = u32_section_next_row(section);
248    let (mut row, next_row) = section.multi_slice_mut((s![row_idx, ..], s![row_idx + 1, ..]));
249
250    row[MainColumn::LhsInv.main_index()] = row[MainColumn::LHS.main_index()].inverse_or_zero();
251    row[MainColumn::RhsInv.main_index()] = row[MainColumn::RHS.main_index()].inverse_or_zero();
252
253    let next_row_result = next_row[MainColumn::Result.main_index()];
254    row[MainColumn::Result.main_index()] = match current_instruction {
255        Instruction::Split => next_row_result,
256        Instruction::Lt => {
257            match (
258                next_row_result.value(),
259                lhs_lsb.value(),
260                rhs_lsb.value(),
261                row[MainColumn::CopyFlag.main_index()].value(),
262            ) {
263                (0 | 1, _, _, _) => next_row_result, // result already known
264                (2, 0, 1, _) => bfe!(1),             // LHS < RHS
265                (2, 1, 0, _) => bfe!(0),             // LHS > RHS
266                (2, _, _, 1) => bfe!(0),             // LHS == RHS
267                (2, _, _, 0) => bfe!(2),             // result still unknown
268                _ => panic!("Invalid state"),
269            }
270        }
271        Instruction::And => bfe!(2) * next_row_result + lhs_lsb * rhs_lsb,
272        Instruction::Log2Floor => {
273            if row[MainColumn::LHS.main_index()].is_zero() {
274                bfe!(-1)
275            } else if !next_row[MainColumn::LHS.main_index()].is_zero() {
276                next_row_result
277            } else {
278                // LHS != 0 && LHS' == 0
279                row[MainColumn::Bits.main_index()]
280            }
281        }
282        Instruction::Pow => match rhs_lsb.is_zero() {
283            true => next_row_result * next_row_result,
284            false => next_row_result * next_row_result * row[MainColumn::LHS.main_index()],
285        },
286        Instruction::PopCount => next_row_result + lhs_lsb,
287        _ => panic!("Must be u32 instruction, not {current_instruction}."),
288    };
289
290    section
291}