1use std::cmp::Ordering;
2use std::cmp::max;
3
4use air::challenge_id::ChallengeId;
5use air::cross_table_argument::CrossTableArg;
6use air::cross_table_argument::LookupArg;
7use air::table::u32::U32Table;
8use air::table_column::MasterAuxColumn;
9use air::table_column::MasterMainColumn;
10use arbitrary::Arbitrary;
11use isa::instruction::Instruction;
12use ndarray::Array1;
13use ndarray::Array2;
14use ndarray::ArrayView2;
15use ndarray::ArrayViewMut2;
16use ndarray::parallel::prelude::*;
17use ndarray::s;
18use num_traits::One;
19use num_traits::Zero;
20use strum::EnumCount;
21use twenty_first::prelude::*;
22
23use crate::aet::AlgebraicExecutionTrace;
24use crate::challenges::Challenges;
25use crate::ndarray_helper::ROW_AXIS;
26use crate::profiler::profiler;
27use crate::table::TraceTable;
28
29type MainColumn = <U32Table as air::AIR>::MainColumn;
30type AuxColumn = <U32Table as air::AIR>::AuxColumn;
31
32#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
34pub struct U32TableEntry {
35 pub instruction: Instruction,
36 pub left_operand: BFieldElement,
37 pub right_operand: BFieldElement,
38}
39
40impl U32TableEntry {
41 pub fn new<L, R>(instruction: Instruction, left_operand: L, right_operand: R) -> Self
42 where
43 L: Into<BFieldElement>,
44 R: Into<BFieldElement>,
45 {
46 Self {
47 instruction,
48 left_operand: left_operand.into(),
49 right_operand: right_operand.into(),
50 }
51 }
52
53 pub(crate) fn table_height_contribution(&self) -> u32 {
55 let lhs = self.left_operand.value();
56 let rhs = self.right_operand.value();
57 let dominant_operand = match self.instruction {
58 Instruction::Pow => rhs, _ => max(lhs, rhs),
60 };
61 match dominant_operand {
62 0 => 1,
63 _ => 2 + dominant_operand.ilog2(),
64 }
65 }
66}
67
68impl PartialOrd for U32TableEntry {
69 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
70 Some(self.cmp(other))
71 }
72}
73
74impl Ord for U32TableEntry {
75 fn cmp(&self, other: &Self) -> Ordering {
76 let Self {
78 instruction: self_instruction,
79 left_operand: self_left_operand,
80 right_operand: self_right_operand,
81 } = *self;
82 let Self {
83 instruction: other_instruction,
84 left_operand: other_left_operand,
85 right_operand: other_right_operand,
86 } = *other;
87
88 let instruction_cmp = self_instruction.opcode().cmp(&other_instruction.opcode());
92 let left_operand_cmp = self_left_operand.value().cmp(&other_left_operand.value());
93 let right_operand_cmp = self_right_operand.value().cmp(&other_right_operand.value());
94
95 instruction_cmp
96 .then(left_operand_cmp)
97 .then(right_operand_cmp)
98 }
99}
100
101impl TraceTable for U32Table {
102 type FillParam = ();
103 type FillReturnInfo = ();
104
105 fn fill(mut u32_table: ArrayViewMut2<BFieldElement>, aet: &AlgebraicExecutionTrace, _: ()) {
106 let mut next_section_start = 0;
107 for (&u32_table_entry, &multiplicity) in &aet.u32_entries {
108 let mut first_row = Array2::zeros([1, MainColumn::COUNT]);
109 first_row[[0, MainColumn::CopyFlag.main_index()]] = bfe!(1);
110 first_row[[0, MainColumn::Bits.main_index()]] = bfe!(0);
111 first_row[[0, MainColumn::BitsMinus33Inv.main_index()]] = bfe!(-33).inverse();
112 first_row[[0, MainColumn::CI.main_index()]] = u32_table_entry.instruction.opcode_b();
113 first_row[[0, MainColumn::LHS.main_index()]] = u32_table_entry.left_operand;
114 first_row[[0, MainColumn::RHS.main_index()]] = u32_table_entry.right_operand;
115 first_row[[0, MainColumn::LookupMultiplicity.main_index()]] = multiplicity.into();
116 let u32_section = u32_section_next_row(first_row);
117
118 let next_section_end = next_section_start + u32_section.nrows();
119 u32_table
120 .slice_mut(s![next_section_start..next_section_end, ..])
121 .assign(&u32_section);
122 next_section_start = next_section_end;
123 }
124 }
125
126 fn pad(mut main_table: ArrayViewMut2<BFieldElement>, table_len: usize) {
127 let mut padding_row = Array1::zeros([MainColumn::COUNT]);
128 padding_row[[MainColumn::CI.main_index()]] = Instruction::Split.opcode_b();
129 padding_row[[MainColumn::BitsMinus33Inv.main_index()]] = bfe!(-33).inverse();
130
131 if table_len > 0 {
132 let last_row = main_table.row(table_len - 1);
133 padding_row[[MainColumn::CI.main_index()]] = last_row[MainColumn::CI.main_index()];
134 padding_row[[MainColumn::LHS.main_index()]] = last_row[MainColumn::LHS.main_index()];
135 padding_row[[MainColumn::LhsInv.main_index()]] =
136 last_row[MainColumn::LhsInv.main_index()];
137 padding_row[[MainColumn::Result.main_index()]] =
138 last_row[MainColumn::Result.main_index()];
139
140 if padding_row[[MainColumn::CI.main_index()]] == Instruction::Lt.opcode_b() {
145 padding_row[[MainColumn::Result.main_index()]] = bfe!(2);
146 }
147 }
148
149 main_table
150 .slice_mut(s![table_len.., ..])
151 .axis_iter_mut(ROW_AXIS)
152 .into_par_iter()
153 .for_each(|mut row| row.assign(&padding_row));
154 }
155
156 fn extend(
157 main_table: ArrayView2<BFieldElement>,
158 mut aux_table: ArrayViewMut2<XFieldElement>,
159 challenges: &Challenges,
160 ) {
161 profiler!(start "u32 table");
162 assert_eq!(MainColumn::COUNT, main_table.ncols());
163 assert_eq!(AuxColumn::COUNT, aux_table.ncols());
164 assert_eq!(main_table.nrows(), aux_table.nrows());
165
166 let ci_weight = challenges[ChallengeId::U32CiWeight];
167 let lhs_weight = challenges[ChallengeId::U32LhsWeight];
168 let rhs_weight = challenges[ChallengeId::U32RhsWeight];
169 let result_weight = challenges[ChallengeId::U32ResultWeight];
170 let lookup_indeterminate = challenges[ChallengeId::U32Indeterminate];
171
172 let mut running_sum_log_derivative = LookupArg::default_initial();
173 for row_idx in 0..main_table.nrows() {
174 let current_row = main_table.row(row_idx);
175 if current_row[MainColumn::CopyFlag.main_index()].is_one() {
176 let lookup_multiplicity = current_row[MainColumn::LookupMultiplicity.main_index()];
177 let compressed_row = ci_weight * current_row[MainColumn::CI.main_index()]
178 + lhs_weight * current_row[MainColumn::LHS.main_index()]
179 + rhs_weight * current_row[MainColumn::RHS.main_index()]
180 + result_weight * current_row[MainColumn::Result.main_index()];
181 running_sum_log_derivative +=
182 lookup_multiplicity * (lookup_indeterminate - compressed_row).inverse();
183 }
184
185 let mut auxiliary_row = aux_table.row_mut(row_idx);
186 auxiliary_row[AuxColumn::LookupServerLogDerivative.aux_index()] =
187 running_sum_log_derivative;
188 }
189 profiler!(stop "u32 table");
190 }
191}
192
193fn u32_section_next_row(mut section: Array2<BFieldElement>) -> Array2<BFieldElement> {
194 let row_idx = section.nrows() - 1;
195 let current_instruction: Instruction = section[[row_idx, MainColumn::CI.main_index()]]
196 .value()
197 .try_into()
198 .expect("Unknown instruction");
199
200 if (section[[row_idx, MainColumn::LHS.main_index()]].is_zero()
202 || current_instruction == Instruction::Pow)
203 && section[[row_idx, MainColumn::RHS.main_index()]].is_zero()
204 {
205 section[[row_idx, MainColumn::Result.main_index()]] = match current_instruction {
206 Instruction::Split => bfe!(0),
207 Instruction::Lt => bfe!(2),
208 Instruction::And => bfe!(0),
209 Instruction::Log2Floor => bfe!(-1),
210 Instruction::Pow => bfe!(1),
211 Instruction::PopCount => bfe!(0),
212 _ => panic!("Must be u32 instruction, not {current_instruction}."),
213 };
214
215 let both_operands_are_0 = section[[row_idx, MainColumn::Bits.main_index()]].is_zero();
219 if current_instruction == Instruction::Lt && both_operands_are_0 {
220 section[[row_idx, MainColumn::Result.main_index()]] = bfe!(0);
221 }
222
223 let lhs_inv_or_0 = section[[row_idx, MainColumn::LHS.main_index()]].inverse_or_zero();
226 section[[row_idx, MainColumn::LhsInv.main_index()]] = lhs_inv_or_0;
227
228 return section;
229 }
230
231 let lhs_lsb = bfe!(section[[row_idx, MainColumn::LHS.main_index()]].value() % 2);
232 let rhs_lsb = bfe!(section[[row_idx, MainColumn::RHS.main_index()]].value() % 2);
233 let mut next_row = section.row(row_idx).to_owned();
234 next_row[MainColumn::CopyFlag.main_index()] = bfe!(0);
235 next_row[MainColumn::Bits.main_index()] += bfe!(1);
236 next_row[MainColumn::BitsMinus33Inv.main_index()] =
237 (next_row[MainColumn::Bits.main_index()] - bfe!(33)).inverse();
238 next_row[MainColumn::LHS.main_index()] = match current_instruction == Instruction::Pow {
239 true => section[[row_idx, MainColumn::LHS.main_index()]],
240 false => (section[[row_idx, MainColumn::LHS.main_index()]] - lhs_lsb) / bfe!(2),
241 };
242 next_row[MainColumn::RHS.main_index()] =
243 (section[[row_idx, MainColumn::RHS.main_index()]] - rhs_lsb) / bfe!(2);
244 next_row[MainColumn::LookupMultiplicity.main_index()] = bfe!(0);
245
246 section.push_row(next_row.view()).unwrap();
247 section = u32_section_next_row(section);
248 let (mut row, next_row) = section.multi_slice_mut((s![row_idx, ..], s![row_idx + 1, ..]));
249
250 row[MainColumn::LhsInv.main_index()] = row[MainColumn::LHS.main_index()].inverse_or_zero();
251 row[MainColumn::RhsInv.main_index()] = row[MainColumn::RHS.main_index()].inverse_or_zero();
252
253 let next_row_result = next_row[MainColumn::Result.main_index()];
254 row[MainColumn::Result.main_index()] = match current_instruction {
255 Instruction::Split => next_row_result,
256 Instruction::Lt => {
257 match (
258 next_row_result.value(),
259 lhs_lsb.value(),
260 rhs_lsb.value(),
261 row[MainColumn::CopyFlag.main_index()].value(),
262 ) {
263 (0 | 1, _, _, _) => next_row_result, (2, 0, 1, _) => bfe!(1), (2, 1, 0, _) => bfe!(0), (2, _, _, 1) => bfe!(0), (2, _, _, 0) => bfe!(2), _ => panic!("Invalid state"),
269 }
270 }
271 Instruction::And => bfe!(2) * next_row_result + lhs_lsb * rhs_lsb,
272 Instruction::Log2Floor => {
273 if row[MainColumn::LHS.main_index()].is_zero() {
274 bfe!(-1)
275 } else if !next_row[MainColumn::LHS.main_index()].is_zero() {
276 next_row_result
277 } else {
278 row[MainColumn::Bits.main_index()]
280 }
281 }
282 Instruction::Pow => match rhs_lsb.is_zero() {
283 true => next_row_result * next_row_result,
284 false => next_row_result * next_row_result * row[MainColumn::LHS.main_index()],
285 },
286 Instruction::PopCount => next_row_result + lhs_lsb,
287 _ => panic!("Must be u32 instruction, not {current_instruction}."),
288 };
289
290 section
291}