1use std::cmp::Ordering;
2use std::collections::HashMap;
3use std::ops::Range;
4
5use air::challenge_id::ChallengeId;
6use air::cross_table_argument::CrossTableArg;
7use air::cross_table_argument::LookupArg;
8use air::cross_table_argument::PermArg;
9use air::table::TableId;
10use air::table::op_stack::OpStackTable;
11use air::table::op_stack::PADDING_VALUE;
12use air::table_column::MasterAuxColumn;
13use air::table_column::MasterMainColumn;
14use air::table_column::OpStackAuxColumn;
15use arbitrary::Arbitrary;
16use isa::op_stack::OpStackElement;
17use isa::op_stack::UnderflowIO;
18use itertools::Itertools;
19use ndarray::parallel::prelude::*;
20use ndarray::prelude::*;
21use strum::EnumCount;
22use strum::IntoEnumIterator;
23use twenty_first::math::traits::FiniteField;
24use twenty_first::prelude::*;
25
26use crate::aet::AlgebraicExecutionTrace;
27use crate::challenges::Challenges;
28use crate::ndarray_helper::ROW_AXIS;
29use crate::ndarray_helper::contiguous_column_slices;
30use crate::ndarray_helper::horizontal_multi_slice_mut;
31use crate::profiler::profiler;
32use crate::table::TraceTable;
33
34type MainColumn = <OpStackTable as air::AIR>::MainColumn;
35type AuxColumn = <OpStackTable as air::AIR>::AuxColumn;
36
37#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
38pub struct OpStackTableEntry {
39 pub clk: u32,
40 pub op_stack_pointer: BFieldElement,
41 pub underflow_io: UnderflowIO,
42}
43
44impl OpStackTableEntry {
45 pub fn new(clk: u32, op_stack_pointer: BFieldElement, underflow_io: UnderflowIO) -> Self {
46 Self {
47 clk,
48 op_stack_pointer,
49 underflow_io,
50 }
51 }
52
53 pub fn shrinks_stack(&self) -> bool {
54 self.underflow_io.shrinks_stack()
55 }
56
57 pub fn grows_stack(&self) -> bool {
58 self.underflow_io.grows_stack()
59 }
60
61 pub fn from_underflow_io_sequence(
62 clk: u32,
63 op_stack_pointer_after_sequence_execution: BFieldElement,
64 mut underflow_io_sequence: Vec<UnderflowIO>,
65 ) -> Vec<Self> {
66 UnderflowIO::canonicalize_sequence(&mut underflow_io_sequence);
67 assert!(UnderflowIO::is_uniform_sequence(&underflow_io_sequence));
68
69 let sequence_length: BFieldElement =
70 u32::try_from(underflow_io_sequence.len()).unwrap().into();
71 let mut op_stack_pointer = match UnderflowIO::is_writing_sequence(&underflow_io_sequence) {
72 true => op_stack_pointer_after_sequence_execution - sequence_length,
73 false => op_stack_pointer_after_sequence_execution + sequence_length,
74 };
75 let mut op_stack_table_entries = vec![];
76 for underflow_io in underflow_io_sequence {
77 if underflow_io.shrinks_stack() {
78 op_stack_pointer.decrement();
79 }
80 let op_stack_table_entry = Self::new(clk, op_stack_pointer, underflow_io);
81 op_stack_table_entries.push(op_stack_table_entry);
82 if underflow_io.grows_stack() {
83 op_stack_pointer.increment();
84 }
85 }
86 op_stack_table_entries
87 }
88
89 pub fn to_main_table_row(self) -> Array1<BFieldElement> {
90 let shrink_stack_indicator = if self.shrinks_stack() {
91 bfe!(1)
92 } else {
93 bfe!(0)
94 };
95
96 let mut row = Array1::zeros(MainColumn::COUNT);
97 row[MainColumn::CLK.main_index()] = self.clk.into();
98 row[MainColumn::IB1ShrinkStack.main_index()] = shrink_stack_indicator;
99 row[MainColumn::StackPointer.main_index()] = self.op_stack_pointer;
100 row[MainColumn::FirstUnderflowElement.main_index()] = self.underflow_io.payload();
101 row
102 }
103}
104
105fn auxiliary_column_running_product_permutation_argument(
106 main_table: ArrayView2<BFieldElement>,
107 challenges: &Challenges,
108) -> Array2<XFieldElement> {
109 let perm_arg_indeterminate = challenges[ChallengeId::OpStackIndeterminate];
110
111 let mut running_product = PermArg::default_initial();
112 let mut auxiliary_column = Vec::with_capacity(main_table.nrows());
113 for row in main_table.rows() {
114 if row[MainColumn::IB1ShrinkStack.main_index()] != PADDING_VALUE {
115 let compressed_row = row[MainColumn::CLK.main_index()]
116 * challenges[ChallengeId::OpStackClkWeight]
117 + row[MainColumn::IB1ShrinkStack.main_index()]
118 * challenges[ChallengeId::OpStackIb1Weight]
119 + row[MainColumn::StackPointer.main_index()]
120 * challenges[ChallengeId::OpStackPointerWeight]
121 + row[MainColumn::FirstUnderflowElement.main_index()]
122 * challenges[ChallengeId::OpStackFirstUnderflowElementWeight];
123 running_product *= perm_arg_indeterminate - compressed_row;
124 }
125 auxiliary_column.push(running_product);
126 }
127 Array2::from_shape_vec((main_table.nrows(), 1), auxiliary_column).unwrap()
128}
129
130fn auxiliary_column_clock_jump_diff_lookup_log_derivative(
131 main_table: ArrayView2<BFieldElement>,
132 challenges: &Challenges,
133) -> Array2<XFieldElement> {
134 const PRECOMPUTE_INVERSES_OF: Range<u64> = 0..100;
137 let cjd_lookup_indeterminate = challenges[ChallengeId::ClockJumpDifferenceLookupIndeterminate];
138 let to_invert = PRECOMPUTE_INVERSES_OF
139 .map(|i| cjd_lookup_indeterminate - bfe!(i))
140 .collect_vec();
141 let inverses = XFieldElement::batch_inversion(to_invert);
142 let mut inverses_dictionary = PRECOMPUTE_INVERSES_OF
143 .zip_eq(inverses)
144 .map(|(i, inv)| (bfe!(i), inv))
145 .collect::<HashMap<_, _>>();
146
147 let mut cjd_lookup_log_derivative = LookupArg::default_initial();
149 let mut auxiliary_column = Vec::with_capacity(main_table.nrows());
150 auxiliary_column.push(cjd_lookup_log_derivative);
151 for (previous_row, current_row) in main_table.rows().into_iter().tuple_windows() {
152 if current_row[MainColumn::IB1ShrinkStack.main_index()] == PADDING_VALUE {
153 break;
154 };
155
156 let previous_stack_pointer = previous_row[MainColumn::StackPointer.main_index()];
157 let current_stack_pointer = current_row[MainColumn::StackPointer.main_index()];
158 if previous_stack_pointer == current_stack_pointer {
159 let previous_clock = previous_row[MainColumn::CLK.main_index()];
160 let current_clock = current_row[MainColumn::CLK.main_index()];
161 let clock_jump_difference = current_clock - previous_clock;
162 let &mut inverse = inverses_dictionary
163 .entry(clock_jump_difference)
164 .or_insert_with(|| (cjd_lookup_indeterminate - clock_jump_difference).inverse());
165 cjd_lookup_log_derivative += inverse;
166 }
167 auxiliary_column.push(cjd_lookup_log_derivative);
168 }
169
170 auxiliary_column.resize(main_table.nrows(), cjd_lookup_log_derivative);
172 Array2::from_shape_vec((main_table.nrows(), 1), auxiliary_column).unwrap()
173}
174
175impl TraceTable for OpStackTable {
176 type FillParam = ();
177 type FillReturnInfo = Vec<BFieldElement>;
178
179 fn fill(
180 mut op_stack_table: ArrayViewMut2<BFieldElement>,
181 aet: &AlgebraicExecutionTrace,
182 _: Self::FillParam,
183 ) -> Vec<BFieldElement> {
184 let mut op_stack_table =
185 op_stack_table.slice_mut(s![0..aet.height_of_table(TableId::OpStack), ..]);
186 let trace_iter = aet.op_stack_underflow_trace.rows().into_iter();
187
188 let sorted_rows =
189 trace_iter.sorted_by(|row_0, row_1| compare_rows(row_0.view(), row_1.view()));
190 for (row_index, row) in sorted_rows.enumerate() {
191 op_stack_table.row_mut(row_index).assign(&row);
192 }
193
194 clock_jump_differences(op_stack_table.view())
195 }
196
197 fn pad(mut op_stack_table: ArrayViewMut2<BFieldElement>, op_stack_table_len: usize) {
198 let last_row_index = op_stack_table_len.saturating_sub(1);
199 let mut padding_row = op_stack_table.row(last_row_index).to_owned();
200 padding_row[MainColumn::IB1ShrinkStack.main_index()] = PADDING_VALUE;
201 if op_stack_table_len == 0 {
202 let first_stack_pointer = u32::try_from(OpStackElement::COUNT).unwrap().into();
203 padding_row[MainColumn::StackPointer.main_index()] = first_stack_pointer;
204 }
205
206 let mut padding_section = op_stack_table.slice_mut(s![op_stack_table_len.., ..]);
207 padding_section
208 .axis_iter_mut(ROW_AXIS)
209 .into_par_iter()
210 .for_each(|mut row| row.assign(&padding_row));
211 }
212
213 fn extend(
214 main_table: ArrayView2<BFieldElement>,
215 mut aux_table: ArrayViewMut2<XFieldElement>,
216 challenges: &Challenges,
217 ) {
218 profiler!(start "op stack table");
219 assert_eq!(MainColumn::COUNT, main_table.ncols());
220 assert_eq!(AuxColumn::COUNT, aux_table.ncols());
221 assert_eq!(main_table.nrows(), aux_table.nrows());
222
223 let auxiliary_column_indices = OpStackAuxColumn::iter()
224 .map(|column| column.aux_index())
225 .collect_vec();
226 let auxiliary_column_slices = horizontal_multi_slice_mut(
227 aux_table.view_mut(),
228 &contiguous_column_slices(&auxiliary_column_indices),
229 );
230 let extension_functions = [
231 auxiliary_column_running_product_permutation_argument,
232 auxiliary_column_clock_jump_diff_lookup_log_derivative,
233 ];
234
235 extension_functions
236 .into_par_iter()
237 .zip_eq(auxiliary_column_slices)
238 .for_each(|(generator, slice)| {
239 generator(main_table, challenges).move_into(slice);
240 });
241
242 profiler!(stop "op stack table");
243 }
244}
245
246fn compare_rows(row_0: ArrayView1<BFieldElement>, row_1: ArrayView1<BFieldElement>) -> Ordering {
247 let stack_pointer_0 = row_0[MainColumn::StackPointer.main_index()].value();
248 let stack_pointer_1 = row_1[MainColumn::StackPointer.main_index()].value();
249 let compare_stack_pointers = stack_pointer_0.cmp(&stack_pointer_1);
250
251 let clk_0 = row_0[MainColumn::CLK.main_index()].value();
252 let clk_1 = row_1[MainColumn::CLK.main_index()].value();
253 let compare_clocks = clk_0.cmp(&clk_1);
254
255 compare_stack_pointers.then(compare_clocks)
256}
257
258fn clock_jump_differences(op_stack_table: ArrayView2<BFieldElement>) -> Vec<BFieldElement> {
259 let mut clock_jump_differences = vec![];
260 for consecutive_rows in op_stack_table.axis_windows(ROW_AXIS, 2) {
261 let current_row = consecutive_rows.row(0);
262 let next_row = consecutive_rows.row(1);
263 let current_stack_pointer = current_row[MainColumn::StackPointer.main_index()];
264 let next_stack_pointer = next_row[MainColumn::StackPointer.main_index()];
265 if current_stack_pointer == next_stack_pointer {
266 let current_clk = current_row[MainColumn::CLK.main_index()];
267 let next_clk = next_row[MainColumn::CLK.main_index()];
268 let clk_difference = next_clk - current_clk;
269 clock_jump_differences.push(clk_difference);
270 }
271 }
272 clock_jump_differences
273}
274
275#[cfg(test)]
276#[cfg_attr(coverage_nightly, coverage(off))]
277pub(crate) mod tests {
278 use assert2::assert;
279 use isa::op_stack::OpStackElement;
280 use itertools::Itertools;
281 use proptest::collection::vec;
282 use proptest::prelude::*;
283 use proptest_arbitrary_interop::arb;
284 use test_strategy::proptest;
285
286 use super::*;
287
288 #[proptest]
289 fn op_stack_table_entry_either_shrinks_stack_or_grows_stack(
290 #[strategy(arb())] entry: OpStackTableEntry,
291 ) {
292 let shrinks_stack = entry.shrinks_stack();
293 let grows_stack = entry.grows_stack();
294 assert!(shrinks_stack ^ grows_stack);
295 }
296
297 #[proptest]
298 fn op_stack_pointer_in_sequence_of_op_stack_table_entries(
299 clk: u32,
300 #[strategy(OpStackElement::COUNT..1024)] stack_pointer: usize,
301 #[strategy(vec(arb(), ..OpStackElement::COUNT))] base_field_elements: Vec<BFieldElement>,
302 sequence_of_writes: bool,
303 ) {
304 let sequence_length = u64::try_from(base_field_elements.len()).unwrap();
305 let stack_pointer = u64::try_from(stack_pointer).unwrap();
306
307 let underflow_io_operation = match sequence_of_writes {
308 true => UnderflowIO::Write,
309 false => UnderflowIO::Read,
310 };
311 let underflow_io = base_field_elements
312 .into_iter()
313 .map(underflow_io_operation)
314 .collect();
315
316 let op_stack_pointer = stack_pointer.into();
317 let entries =
318 OpStackTableEntry::from_underflow_io_sequence(clk, op_stack_pointer, underflow_io);
319 let op_stack_pointers = entries
320 .iter()
321 .map(|entry| entry.op_stack_pointer.value())
322 .sorted()
323 .collect_vec();
324
325 let expected_stack_pointer_range = match sequence_of_writes {
326 true => stack_pointer - sequence_length..stack_pointer,
327 false => stack_pointer..stack_pointer + sequence_length,
328 };
329 let expected_op_stack_pointers = expected_stack_pointer_range.collect_vec();
330 prop_assert_eq!(expected_op_stack_pointers, op_stack_pointers);
331 }
332
333 #[proptest]
334 fn clk_stays_same_in_sequence_of_op_stack_table_entries(
335 clk: u32,
336 #[strategy(OpStackElement::COUNT..1024)] stack_pointer: usize,
337 #[strategy(vec(arb(), ..OpStackElement::COUNT))] base_field_elements: Vec<BFieldElement>,
338 sequence_of_writes: bool,
339 ) {
340 let underflow_io_operation = match sequence_of_writes {
341 true => UnderflowIO::Write,
342 false => UnderflowIO::Read,
343 };
344 let underflow_io = base_field_elements
345 .into_iter()
346 .map(underflow_io_operation)
347 .collect();
348
349 let op_stack_pointer = u64::try_from(stack_pointer).unwrap().into();
350 let entries =
351 OpStackTableEntry::from_underflow_io_sequence(clk, op_stack_pointer, underflow_io);
352 let clk_values = entries.iter().map(|entry| entry.clk).collect_vec();
353 let all_clk_values_are_clk = clk_values.iter().all(|&c| c == clk);
354 prop_assert!(all_clk_values_are_clk);
355 }
356
357 #[proptest]
358 fn compare_rows_with_unequal_stack_pointer_and_equal_clk(
359 stack_pointer_0: u64,
360 stack_pointer_1: u64,
361 clk: u64,
362 ) {
363 let mut row_0 = Array1::zeros(MainColumn::COUNT);
364 row_0[MainColumn::StackPointer.main_index()] = stack_pointer_0.into();
365 row_0[MainColumn::CLK.main_index()] = clk.into();
366
367 let mut row_1 = Array1::zeros(MainColumn::COUNT);
368 row_1[MainColumn::StackPointer.main_index()] = stack_pointer_1.into();
369 row_1[MainColumn::CLK.main_index()] = clk.into();
370
371 let stack_pointer_comparison = stack_pointer_0.cmp(&stack_pointer_1);
372 let row_comparison = compare_rows(row_0.view(), row_1.view());
373
374 prop_assert_eq!(stack_pointer_comparison, row_comparison);
375 }
376
377 #[proptest]
378 fn compare_rows_with_equal_stack_pointer_and_unequal_clk(
379 stack_pointer: u64,
380 clk_0: u64,
381 clk_1: u64,
382 ) {
383 let mut row_0 = Array1::zeros(MainColumn::COUNT);
384 row_0[MainColumn::StackPointer.main_index()] = stack_pointer.into();
385 row_0[MainColumn::CLK.main_index()] = clk_0.into();
386
387 let mut row_1 = Array1::zeros(MainColumn::COUNT);
388 row_1[MainColumn::StackPointer.main_index()] = stack_pointer.into();
389 row_1[MainColumn::CLK.main_index()] = clk_1.into();
390
391 let clk_comparison = clk_0.cmp(&clk_1);
392 let row_comparison = compare_rows(row_0.view(), row_1.view());
393
394 prop_assert_eq!(clk_comparison, row_comparison);
395 }
396}