1use std::cmp::Ordering;
2use std::collections::HashMap;
3use std::ops::Range;
4
5use air::challenge_id::ChallengeId;
6use air::cross_table_argument::CrossTableArg;
7use air::cross_table_argument::LookupArg;
8use air::cross_table_argument::PermArg;
9use air::table::op_stack::OpStackTable;
10use air::table::op_stack::PADDING_VALUE;
11use air::table::TableId;
12use air::table_column::MasterAuxColumn;
13use air::table_column::MasterMainColumn;
14use air::table_column::OpStackAuxColumn;
15use arbitrary::Arbitrary;
16use isa::op_stack::OpStackElement;
17use isa::op_stack::UnderflowIO;
18use itertools::Itertools;
19use ndarray::parallel::prelude::*;
20use ndarray::prelude::*;
21use strum::EnumCount;
22use strum::IntoEnumIterator;
23use twenty_first::math::traits::FiniteField;
24use twenty_first::prelude::*;
25
26use crate::aet::AlgebraicExecutionTrace;
27use crate::challenges::Challenges;
28use crate::ndarray_helper::contiguous_column_slices;
29use crate::ndarray_helper::horizontal_multi_slice_mut;
30use crate::profiler::profiler;
31use crate::table::TraceTable;
32
33type MainColumn = <OpStackTable as air::AIR>::MainColumn;
34type AuxColumn = <OpStackTable as air::AIR>::AuxColumn;
35
36#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
37pub struct OpStackTableEntry {
38 pub clk: u32,
39 pub op_stack_pointer: BFieldElement,
40 pub underflow_io: UnderflowIO,
41}
42
43impl OpStackTableEntry {
44 pub fn new(clk: u32, op_stack_pointer: BFieldElement, underflow_io: UnderflowIO) -> Self {
45 Self {
46 clk,
47 op_stack_pointer,
48 underflow_io,
49 }
50 }
51
52 pub fn shrinks_stack(&self) -> bool {
53 self.underflow_io.shrinks_stack()
54 }
55
56 pub fn grows_stack(&self) -> bool {
57 self.underflow_io.grows_stack()
58 }
59
60 pub fn from_underflow_io_sequence(
61 clk: u32,
62 op_stack_pointer_after_sequence_execution: BFieldElement,
63 mut underflow_io_sequence: Vec<UnderflowIO>,
64 ) -> Vec<Self> {
65 UnderflowIO::canonicalize_sequence(&mut underflow_io_sequence);
66 assert!(UnderflowIO::is_uniform_sequence(&underflow_io_sequence));
67
68 let sequence_length: BFieldElement =
69 u32::try_from(underflow_io_sequence.len()).unwrap().into();
70 let mut op_stack_pointer = match UnderflowIO::is_writing_sequence(&underflow_io_sequence) {
71 true => op_stack_pointer_after_sequence_execution - sequence_length,
72 false => op_stack_pointer_after_sequence_execution + sequence_length,
73 };
74 let mut op_stack_table_entries = vec![];
75 for underflow_io in underflow_io_sequence {
76 if underflow_io.shrinks_stack() {
77 op_stack_pointer.decrement();
78 }
79 let op_stack_table_entry = Self::new(clk, op_stack_pointer, underflow_io);
80 op_stack_table_entries.push(op_stack_table_entry);
81 if underflow_io.grows_stack() {
82 op_stack_pointer.increment();
83 }
84 }
85 op_stack_table_entries
86 }
87
88 pub fn to_main_table_row(self) -> Array1<BFieldElement> {
89 let shrink_stack_indicator = if self.shrinks_stack() {
90 bfe!(1)
91 } else {
92 bfe!(0)
93 };
94
95 let mut row = Array1::zeros(MainColumn::COUNT);
96 row[MainColumn::CLK.main_index()] = self.clk.into();
97 row[MainColumn::IB1ShrinkStack.main_index()] = shrink_stack_indicator;
98 row[MainColumn::StackPointer.main_index()] = self.op_stack_pointer;
99 row[MainColumn::FirstUnderflowElement.main_index()] = self.underflow_io.payload();
100 row
101 }
102}
103
104fn auxiliary_column_running_product_permutation_argument(
105 main_table: ArrayView2<BFieldElement>,
106 challenges: &Challenges,
107) -> Array2<XFieldElement> {
108 let perm_arg_indeterminate = challenges[ChallengeId::OpStackIndeterminate];
109
110 let mut running_product = PermArg::default_initial();
111 let mut auxiliary_column = Vec::with_capacity(main_table.nrows());
112 for row in main_table.rows() {
113 if row[MainColumn::IB1ShrinkStack.main_index()] != PADDING_VALUE {
114 let compressed_row = row[MainColumn::CLK.main_index()]
115 * challenges[ChallengeId::OpStackClkWeight]
116 + row[MainColumn::IB1ShrinkStack.main_index()]
117 * challenges[ChallengeId::OpStackIb1Weight]
118 + row[MainColumn::StackPointer.main_index()]
119 * challenges[ChallengeId::OpStackPointerWeight]
120 + row[MainColumn::FirstUnderflowElement.main_index()]
121 * challenges[ChallengeId::OpStackFirstUnderflowElementWeight];
122 running_product *= perm_arg_indeterminate - compressed_row;
123 }
124 auxiliary_column.push(running_product);
125 }
126 Array2::from_shape_vec((main_table.nrows(), 1), auxiliary_column).unwrap()
127}
128
129fn auxiliary_column_clock_jump_diff_lookup_log_derivative(
130 main_table: ArrayView2<BFieldElement>,
131 challenges: &Challenges,
132) -> Array2<XFieldElement> {
133 const PRECOMPUTE_INVERSES_OF: Range<u64> = 0..100;
136 let cjd_lookup_indeterminate = challenges[ChallengeId::ClockJumpDifferenceLookupIndeterminate];
137 let to_invert = PRECOMPUTE_INVERSES_OF
138 .map(|i| cjd_lookup_indeterminate - bfe!(i))
139 .collect_vec();
140 let inverses = XFieldElement::batch_inversion(to_invert);
141 let mut inverses_dictionary = PRECOMPUTE_INVERSES_OF
142 .zip_eq(inverses)
143 .map(|(i, inv)| (bfe!(i), inv))
144 .collect::<HashMap<_, _>>();
145
146 let mut cjd_lookup_log_derivative = LookupArg::default_initial();
148 let mut auxiliary_column = Vec::with_capacity(main_table.nrows());
149 auxiliary_column.push(cjd_lookup_log_derivative);
150 for (previous_row, current_row) in main_table.rows().into_iter().tuple_windows() {
151 if current_row[MainColumn::IB1ShrinkStack.main_index()] == PADDING_VALUE {
152 break;
153 };
154
155 let previous_stack_pointer = previous_row[MainColumn::StackPointer.main_index()];
156 let current_stack_pointer = current_row[MainColumn::StackPointer.main_index()];
157 if previous_stack_pointer == current_stack_pointer {
158 let previous_clock = previous_row[MainColumn::CLK.main_index()];
159 let current_clock = current_row[MainColumn::CLK.main_index()];
160 let clock_jump_difference = current_clock - previous_clock;
161 let &mut inverse = inverses_dictionary
162 .entry(clock_jump_difference)
163 .or_insert_with(|| (cjd_lookup_indeterminate - clock_jump_difference).inverse());
164 cjd_lookup_log_derivative += inverse;
165 }
166 auxiliary_column.push(cjd_lookup_log_derivative);
167 }
168
169 auxiliary_column.resize(main_table.nrows(), cjd_lookup_log_derivative);
171 Array2::from_shape_vec((main_table.nrows(), 1), auxiliary_column).unwrap()
172}
173
174impl TraceTable for OpStackTable {
175 type FillParam = ();
176 type FillReturnInfo = Vec<BFieldElement>;
177
178 fn fill(
179 mut op_stack_table: ArrayViewMut2<BFieldElement>,
180 aet: &AlgebraicExecutionTrace,
181 _: Self::FillParam,
182 ) -> Vec<BFieldElement> {
183 let mut op_stack_table =
184 op_stack_table.slice_mut(s![0..aet.height_of_table(TableId::OpStack), ..]);
185 let trace_iter = aet.op_stack_underflow_trace.rows().into_iter();
186
187 let sorted_rows =
188 trace_iter.sorted_by(|row_0, row_1| compare_rows(row_0.view(), row_1.view()));
189 for (row_index, row) in sorted_rows.enumerate() {
190 op_stack_table.row_mut(row_index).assign(&row);
191 }
192
193 clock_jump_differences(op_stack_table.view())
194 }
195
196 fn pad(mut op_stack_table: ArrayViewMut2<BFieldElement>, op_stack_table_len: usize) {
197 let last_row_index = op_stack_table_len.saturating_sub(1);
198 let mut padding_row = op_stack_table.row(last_row_index).to_owned();
199 padding_row[MainColumn::IB1ShrinkStack.main_index()] = PADDING_VALUE;
200 if op_stack_table_len == 0 {
201 let first_stack_pointer = u32::try_from(OpStackElement::COUNT).unwrap().into();
202 padding_row[MainColumn::StackPointer.main_index()] = first_stack_pointer;
203 }
204
205 let mut padding_section = op_stack_table.slice_mut(s![op_stack_table_len.., ..]);
206 padding_section
207 .axis_iter_mut(Axis(0))
208 .into_par_iter()
209 .for_each(|mut row| row.assign(&padding_row));
210 }
211
212 fn extend(
213 main_table: ArrayView2<BFieldElement>,
214 mut aux_table: ArrayViewMut2<XFieldElement>,
215 challenges: &Challenges,
216 ) {
217 profiler!(start "op stack table");
218 assert_eq!(MainColumn::COUNT, main_table.ncols());
219 assert_eq!(AuxColumn::COUNT, aux_table.ncols());
220 assert_eq!(main_table.nrows(), aux_table.nrows());
221
222 let auxiliary_column_indices = OpStackAuxColumn::iter()
223 .map(|column| column.aux_index())
224 .collect_vec();
225 let auxiliary_column_slices = horizontal_multi_slice_mut(
226 aux_table.view_mut(),
227 &contiguous_column_slices(&auxiliary_column_indices),
228 );
229 let extension_functions = [
230 auxiliary_column_running_product_permutation_argument,
231 auxiliary_column_clock_jump_diff_lookup_log_derivative,
232 ];
233
234 extension_functions
235 .into_par_iter()
236 .zip_eq(auxiliary_column_slices)
237 .for_each(|(generator, slice)| {
238 generator(main_table, challenges).move_into(slice);
239 });
240
241 profiler!(stop "op stack table");
242 }
243}
244
245fn compare_rows(row_0: ArrayView1<BFieldElement>, row_1: ArrayView1<BFieldElement>) -> Ordering {
246 let stack_pointer_0 = row_0[MainColumn::StackPointer.main_index()].value();
247 let stack_pointer_1 = row_1[MainColumn::StackPointer.main_index()].value();
248 let compare_stack_pointers = stack_pointer_0.cmp(&stack_pointer_1);
249
250 let clk_0 = row_0[MainColumn::CLK.main_index()].value();
251 let clk_1 = row_1[MainColumn::CLK.main_index()].value();
252 let compare_clocks = clk_0.cmp(&clk_1);
253
254 compare_stack_pointers.then(compare_clocks)
255}
256
257fn clock_jump_differences(op_stack_table: ArrayView2<BFieldElement>) -> Vec<BFieldElement> {
258 let mut clock_jump_differences = vec![];
259 for consecutive_rows in op_stack_table.axis_windows(Axis(0), 2) {
260 let current_row = consecutive_rows.row(0);
261 let next_row = consecutive_rows.row(1);
262 let current_stack_pointer = current_row[MainColumn::StackPointer.main_index()];
263 let next_stack_pointer = next_row[MainColumn::StackPointer.main_index()];
264 if current_stack_pointer == next_stack_pointer {
265 let current_clk = current_row[MainColumn::CLK.main_index()];
266 let next_clk = next_row[MainColumn::CLK.main_index()];
267 let clk_difference = next_clk - current_clk;
268 clock_jump_differences.push(clk_difference);
269 }
270 }
271 clock_jump_differences
272}
273
274#[cfg(test)]
275pub(crate) mod tests {
276 use assert2::assert;
277 use isa::op_stack::OpStackElement;
278 use itertools::Itertools;
279 use proptest::collection::vec;
280 use proptest::prelude::*;
281 use proptest_arbitrary_interop::arb;
282 use test_strategy::proptest;
283
284 use super::*;
285
286 #[proptest]
287 fn op_stack_table_entry_either_shrinks_stack_or_grows_stack(
288 #[strategy(arb())] entry: OpStackTableEntry,
289 ) {
290 let shrinks_stack = entry.shrinks_stack();
291 let grows_stack = entry.grows_stack();
292 assert!(shrinks_stack ^ grows_stack);
293 }
294
295 #[proptest]
296 fn op_stack_pointer_in_sequence_of_op_stack_table_entries(
297 clk: u32,
298 #[strategy(OpStackElement::COUNT..1024)] stack_pointer: usize,
299 #[strategy(vec(arb(), ..OpStackElement::COUNT))] base_field_elements: Vec<BFieldElement>,
300 sequence_of_writes: bool,
301 ) {
302 let sequence_length = u64::try_from(base_field_elements.len()).unwrap();
303 let stack_pointer = u64::try_from(stack_pointer).unwrap();
304
305 let underflow_io_operation = match sequence_of_writes {
306 true => UnderflowIO::Write,
307 false => UnderflowIO::Read,
308 };
309 let underflow_io = base_field_elements
310 .into_iter()
311 .map(underflow_io_operation)
312 .collect();
313
314 let op_stack_pointer = stack_pointer.into();
315 let entries =
316 OpStackTableEntry::from_underflow_io_sequence(clk, op_stack_pointer, underflow_io);
317 let op_stack_pointers = entries
318 .iter()
319 .map(|entry| entry.op_stack_pointer.value())
320 .sorted()
321 .collect_vec();
322
323 let expected_stack_pointer_range = match sequence_of_writes {
324 true => stack_pointer - sequence_length..stack_pointer,
325 false => stack_pointer..stack_pointer + sequence_length,
326 };
327 let expected_op_stack_pointers = expected_stack_pointer_range.collect_vec();
328 prop_assert_eq!(expected_op_stack_pointers, op_stack_pointers);
329 }
330
331 #[proptest]
332 fn clk_stays_same_in_sequence_of_op_stack_table_entries(
333 clk: u32,
334 #[strategy(OpStackElement::COUNT..1024)] stack_pointer: usize,
335 #[strategy(vec(arb(), ..OpStackElement::COUNT))] base_field_elements: Vec<BFieldElement>,
336 sequence_of_writes: bool,
337 ) {
338 let underflow_io_operation = match sequence_of_writes {
339 true => UnderflowIO::Write,
340 false => UnderflowIO::Read,
341 };
342 let underflow_io = base_field_elements
343 .into_iter()
344 .map(underflow_io_operation)
345 .collect();
346
347 let op_stack_pointer = u64::try_from(stack_pointer).unwrap().into();
348 let entries =
349 OpStackTableEntry::from_underflow_io_sequence(clk, op_stack_pointer, underflow_io);
350 let clk_values = entries.iter().map(|entry| entry.clk).collect_vec();
351 let all_clk_values_are_clk = clk_values.iter().all(|&c| c == clk);
352 prop_assert!(all_clk_values_are_clk);
353 }
354
355 #[proptest]
356 fn compare_rows_with_unequal_stack_pointer_and_equal_clk(
357 stack_pointer_0: u64,
358 stack_pointer_1: u64,
359 clk: u64,
360 ) {
361 let mut row_0 = Array1::zeros(MainColumn::COUNT);
362 row_0[MainColumn::StackPointer.main_index()] = stack_pointer_0.into();
363 row_0[MainColumn::CLK.main_index()] = clk.into();
364
365 let mut row_1 = Array1::zeros(MainColumn::COUNT);
366 row_1[MainColumn::StackPointer.main_index()] = stack_pointer_1.into();
367 row_1[MainColumn::CLK.main_index()] = clk.into();
368
369 let stack_pointer_comparison = stack_pointer_0.cmp(&stack_pointer_1);
370 let row_comparison = compare_rows(row_0.view(), row_1.view());
371
372 prop_assert_eq!(stack_pointer_comparison, row_comparison);
373 }
374
375 #[proptest]
376 fn compare_rows_with_equal_stack_pointer_and_unequal_clk(
377 stack_pointer: u64,
378 clk_0: u64,
379 clk_1: u64,
380 ) {
381 let mut row_0 = Array1::zeros(MainColumn::COUNT);
382 row_0[MainColumn::StackPointer.main_index()] = stack_pointer.into();
383 row_0[MainColumn::CLK.main_index()] = clk_0.into();
384
385 let mut row_1 = Array1::zeros(MainColumn::COUNT);
386 row_1[MainColumn::StackPointer.main_index()] = stack_pointer.into();
387 row_1[MainColumn::CLK.main_index()] = clk_1.into();
388
389 let clk_comparison = clk_0.cmp(&clk_1);
390 let row_comparison = compare_rows(row_0.view(), row_1.view());
391
392 prop_assert_eq!(clk_comparison, row_comparison);
393 }
394}