triton_vm/table/
op_stack.rs

1use std::cmp::Ordering;
2use std::collections::HashMap;
3use std::ops::Range;
4
5use air::challenge_id::ChallengeId;
6use air::cross_table_argument::CrossTableArg;
7use air::cross_table_argument::LookupArg;
8use air::cross_table_argument::PermArg;
9use air::table::op_stack::OpStackTable;
10use air::table::op_stack::PADDING_VALUE;
11use air::table::TableId;
12use air::table_column::MasterAuxColumn;
13use air::table_column::MasterMainColumn;
14use air::table_column::OpStackAuxColumn;
15use arbitrary::Arbitrary;
16use isa::op_stack::OpStackElement;
17use isa::op_stack::UnderflowIO;
18use itertools::Itertools;
19use ndarray::parallel::prelude::*;
20use ndarray::prelude::*;
21use strum::EnumCount;
22use strum::IntoEnumIterator;
23use twenty_first::math::traits::FiniteField;
24use twenty_first::prelude::*;
25
26use crate::aet::AlgebraicExecutionTrace;
27use crate::challenges::Challenges;
28use crate::ndarray_helper::contiguous_column_slices;
29use crate::ndarray_helper::horizontal_multi_slice_mut;
30use crate::profiler::profiler;
31use crate::table::TraceTable;
32
33type MainColumn = <OpStackTable as air::AIR>::MainColumn;
34type AuxColumn = <OpStackTable as air::AIR>::AuxColumn;
35
36#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
37pub struct OpStackTableEntry {
38    pub clk: u32,
39    pub op_stack_pointer: BFieldElement,
40    pub underflow_io: UnderflowIO,
41}
42
43impl OpStackTableEntry {
44    pub fn new(clk: u32, op_stack_pointer: BFieldElement, underflow_io: UnderflowIO) -> Self {
45        Self {
46            clk,
47            op_stack_pointer,
48            underflow_io,
49        }
50    }
51
52    pub fn shrinks_stack(&self) -> bool {
53        self.underflow_io.shrinks_stack()
54    }
55
56    pub fn grows_stack(&self) -> bool {
57        self.underflow_io.grows_stack()
58    }
59
60    pub fn from_underflow_io_sequence(
61        clk: u32,
62        op_stack_pointer_after_sequence_execution: BFieldElement,
63        mut underflow_io_sequence: Vec<UnderflowIO>,
64    ) -> Vec<Self> {
65        UnderflowIO::canonicalize_sequence(&mut underflow_io_sequence);
66        assert!(UnderflowIO::is_uniform_sequence(&underflow_io_sequence));
67
68        let sequence_length: BFieldElement =
69            u32::try_from(underflow_io_sequence.len()).unwrap().into();
70        let mut op_stack_pointer = match UnderflowIO::is_writing_sequence(&underflow_io_sequence) {
71            true => op_stack_pointer_after_sequence_execution - sequence_length,
72            false => op_stack_pointer_after_sequence_execution + sequence_length,
73        };
74        let mut op_stack_table_entries = vec![];
75        for underflow_io in underflow_io_sequence {
76            if underflow_io.shrinks_stack() {
77                op_stack_pointer.decrement();
78            }
79            let op_stack_table_entry = Self::new(clk, op_stack_pointer, underflow_io);
80            op_stack_table_entries.push(op_stack_table_entry);
81            if underflow_io.grows_stack() {
82                op_stack_pointer.increment();
83            }
84        }
85        op_stack_table_entries
86    }
87
88    pub fn to_main_table_row(self) -> Array1<BFieldElement> {
89        let shrink_stack_indicator = if self.shrinks_stack() {
90            bfe!(1)
91        } else {
92            bfe!(0)
93        };
94
95        let mut row = Array1::zeros(MainColumn::COUNT);
96        row[MainColumn::CLK.main_index()] = self.clk.into();
97        row[MainColumn::IB1ShrinkStack.main_index()] = shrink_stack_indicator;
98        row[MainColumn::StackPointer.main_index()] = self.op_stack_pointer;
99        row[MainColumn::FirstUnderflowElement.main_index()] = self.underflow_io.payload();
100        row
101    }
102}
103
104fn auxiliary_column_running_product_permutation_argument(
105    main_table: ArrayView2<BFieldElement>,
106    challenges: &Challenges,
107) -> Array2<XFieldElement> {
108    let perm_arg_indeterminate = challenges[ChallengeId::OpStackIndeterminate];
109
110    let mut running_product = PermArg::default_initial();
111    let mut auxiliary_column = Vec::with_capacity(main_table.nrows());
112    for row in main_table.rows() {
113        if row[MainColumn::IB1ShrinkStack.main_index()] != PADDING_VALUE {
114            let compressed_row = row[MainColumn::CLK.main_index()]
115                * challenges[ChallengeId::OpStackClkWeight]
116                + row[MainColumn::IB1ShrinkStack.main_index()]
117                    * challenges[ChallengeId::OpStackIb1Weight]
118                + row[MainColumn::StackPointer.main_index()]
119                    * challenges[ChallengeId::OpStackPointerWeight]
120                + row[MainColumn::FirstUnderflowElement.main_index()]
121                    * challenges[ChallengeId::OpStackFirstUnderflowElementWeight];
122            running_product *= perm_arg_indeterminate - compressed_row;
123        }
124        auxiliary_column.push(running_product);
125    }
126    Array2::from_shape_vec((main_table.nrows(), 1), auxiliary_column).unwrap()
127}
128
129fn auxiliary_column_clock_jump_diff_lookup_log_derivative(
130    main_table: ArrayView2<BFieldElement>,
131    challenges: &Challenges,
132) -> Array2<XFieldElement> {
133    // - use memoization to avoid recomputing inverses
134    // - precompute common values through batch inversion
135    const PRECOMPUTE_INVERSES_OF: Range<u64> = 0..100;
136    let cjd_lookup_indeterminate = challenges[ChallengeId::ClockJumpDifferenceLookupIndeterminate];
137    let to_invert = PRECOMPUTE_INVERSES_OF
138        .map(|i| cjd_lookup_indeterminate - bfe!(i))
139        .collect_vec();
140    let inverses = XFieldElement::batch_inversion(to_invert);
141    let mut inverses_dictionary = PRECOMPUTE_INVERSES_OF
142        .zip_eq(inverses)
143        .map(|(i, inv)| (bfe!(i), inv))
144        .collect::<HashMap<_, _>>();
145
146    // populate auxiliary column using memoization
147    let mut cjd_lookup_log_derivative = LookupArg::default_initial();
148    let mut auxiliary_column = Vec::with_capacity(main_table.nrows());
149    auxiliary_column.push(cjd_lookup_log_derivative);
150    for (previous_row, current_row) in main_table.rows().into_iter().tuple_windows() {
151        if current_row[MainColumn::IB1ShrinkStack.main_index()] == PADDING_VALUE {
152            break;
153        };
154
155        let previous_stack_pointer = previous_row[MainColumn::StackPointer.main_index()];
156        let current_stack_pointer = current_row[MainColumn::StackPointer.main_index()];
157        if previous_stack_pointer == current_stack_pointer {
158            let previous_clock = previous_row[MainColumn::CLK.main_index()];
159            let current_clock = current_row[MainColumn::CLK.main_index()];
160            let clock_jump_difference = current_clock - previous_clock;
161            let &mut inverse = inverses_dictionary
162                .entry(clock_jump_difference)
163                .or_insert_with(|| (cjd_lookup_indeterminate - clock_jump_difference).inverse());
164            cjd_lookup_log_derivative += inverse;
165        }
166        auxiliary_column.push(cjd_lookup_log_derivative);
167    }
168
169    // fill padding section
170    auxiliary_column.resize(main_table.nrows(), cjd_lookup_log_derivative);
171    Array2::from_shape_vec((main_table.nrows(), 1), auxiliary_column).unwrap()
172}
173
174impl TraceTable for OpStackTable {
175    type FillParam = ();
176    type FillReturnInfo = Vec<BFieldElement>;
177
178    fn fill(
179        mut op_stack_table: ArrayViewMut2<BFieldElement>,
180        aet: &AlgebraicExecutionTrace,
181        _: Self::FillParam,
182    ) -> Vec<BFieldElement> {
183        let mut op_stack_table =
184            op_stack_table.slice_mut(s![0..aet.height_of_table(TableId::OpStack), ..]);
185        let trace_iter = aet.op_stack_underflow_trace.rows().into_iter();
186
187        let sorted_rows =
188            trace_iter.sorted_by(|row_0, row_1| compare_rows(row_0.view(), row_1.view()));
189        for (row_index, row) in sorted_rows.enumerate() {
190            op_stack_table.row_mut(row_index).assign(&row);
191        }
192
193        clock_jump_differences(op_stack_table.view())
194    }
195
196    fn pad(mut op_stack_table: ArrayViewMut2<BFieldElement>, op_stack_table_len: usize) {
197        let last_row_index = op_stack_table_len.saturating_sub(1);
198        let mut padding_row = op_stack_table.row(last_row_index).to_owned();
199        padding_row[MainColumn::IB1ShrinkStack.main_index()] = PADDING_VALUE;
200        if op_stack_table_len == 0 {
201            let first_stack_pointer = u32::try_from(OpStackElement::COUNT).unwrap().into();
202            padding_row[MainColumn::StackPointer.main_index()] = first_stack_pointer;
203        }
204
205        let mut padding_section = op_stack_table.slice_mut(s![op_stack_table_len.., ..]);
206        padding_section
207            .axis_iter_mut(Axis(0))
208            .into_par_iter()
209            .for_each(|mut row| row.assign(&padding_row));
210    }
211
212    fn extend(
213        main_table: ArrayView2<BFieldElement>,
214        mut aux_table: ArrayViewMut2<XFieldElement>,
215        challenges: &Challenges,
216    ) {
217        profiler!(start "op stack table");
218        assert_eq!(MainColumn::COUNT, main_table.ncols());
219        assert_eq!(AuxColumn::COUNT, aux_table.ncols());
220        assert_eq!(main_table.nrows(), aux_table.nrows());
221
222        let auxiliary_column_indices = OpStackAuxColumn::iter()
223            .map(|column| column.aux_index())
224            .collect_vec();
225        let auxiliary_column_slices = horizontal_multi_slice_mut(
226            aux_table.view_mut(),
227            &contiguous_column_slices(&auxiliary_column_indices),
228        );
229        let extension_functions = [
230            auxiliary_column_running_product_permutation_argument,
231            auxiliary_column_clock_jump_diff_lookup_log_derivative,
232        ];
233
234        extension_functions
235            .into_par_iter()
236            .zip_eq(auxiliary_column_slices)
237            .for_each(|(generator, slice)| {
238                generator(main_table, challenges).move_into(slice);
239            });
240
241        profiler!(stop "op stack table");
242    }
243}
244
245fn compare_rows(row_0: ArrayView1<BFieldElement>, row_1: ArrayView1<BFieldElement>) -> Ordering {
246    let stack_pointer_0 = row_0[MainColumn::StackPointer.main_index()].value();
247    let stack_pointer_1 = row_1[MainColumn::StackPointer.main_index()].value();
248    let compare_stack_pointers = stack_pointer_0.cmp(&stack_pointer_1);
249
250    let clk_0 = row_0[MainColumn::CLK.main_index()].value();
251    let clk_1 = row_1[MainColumn::CLK.main_index()].value();
252    let compare_clocks = clk_0.cmp(&clk_1);
253
254    compare_stack_pointers.then(compare_clocks)
255}
256
257fn clock_jump_differences(op_stack_table: ArrayView2<BFieldElement>) -> Vec<BFieldElement> {
258    let mut clock_jump_differences = vec![];
259    for consecutive_rows in op_stack_table.axis_windows(Axis(0), 2) {
260        let current_row = consecutive_rows.row(0);
261        let next_row = consecutive_rows.row(1);
262        let current_stack_pointer = current_row[MainColumn::StackPointer.main_index()];
263        let next_stack_pointer = next_row[MainColumn::StackPointer.main_index()];
264        if current_stack_pointer == next_stack_pointer {
265            let current_clk = current_row[MainColumn::CLK.main_index()];
266            let next_clk = next_row[MainColumn::CLK.main_index()];
267            let clk_difference = next_clk - current_clk;
268            clock_jump_differences.push(clk_difference);
269        }
270    }
271    clock_jump_differences
272}
273
274#[cfg(test)]
275pub(crate) mod tests {
276    use assert2::assert;
277    use isa::op_stack::OpStackElement;
278    use itertools::Itertools;
279    use proptest::collection::vec;
280    use proptest::prelude::*;
281    use proptest_arbitrary_interop::arb;
282    use test_strategy::proptest;
283
284    use super::*;
285
286    #[proptest]
287    fn op_stack_table_entry_either_shrinks_stack_or_grows_stack(
288        #[strategy(arb())] entry: OpStackTableEntry,
289    ) {
290        let shrinks_stack = entry.shrinks_stack();
291        let grows_stack = entry.grows_stack();
292        assert!(shrinks_stack ^ grows_stack);
293    }
294
295    #[proptest]
296    fn op_stack_pointer_in_sequence_of_op_stack_table_entries(
297        clk: u32,
298        #[strategy(OpStackElement::COUNT..1024)] stack_pointer: usize,
299        #[strategy(vec(arb(), ..OpStackElement::COUNT))] base_field_elements: Vec<BFieldElement>,
300        sequence_of_writes: bool,
301    ) {
302        let sequence_length = u64::try_from(base_field_elements.len()).unwrap();
303        let stack_pointer = u64::try_from(stack_pointer).unwrap();
304
305        let underflow_io_operation = match sequence_of_writes {
306            true => UnderflowIO::Write,
307            false => UnderflowIO::Read,
308        };
309        let underflow_io = base_field_elements
310            .into_iter()
311            .map(underflow_io_operation)
312            .collect();
313
314        let op_stack_pointer = stack_pointer.into();
315        let entries =
316            OpStackTableEntry::from_underflow_io_sequence(clk, op_stack_pointer, underflow_io);
317        let op_stack_pointers = entries
318            .iter()
319            .map(|entry| entry.op_stack_pointer.value())
320            .sorted()
321            .collect_vec();
322
323        let expected_stack_pointer_range = match sequence_of_writes {
324            true => stack_pointer - sequence_length..stack_pointer,
325            false => stack_pointer..stack_pointer + sequence_length,
326        };
327        let expected_op_stack_pointers = expected_stack_pointer_range.collect_vec();
328        prop_assert_eq!(expected_op_stack_pointers, op_stack_pointers);
329    }
330
331    #[proptest]
332    fn clk_stays_same_in_sequence_of_op_stack_table_entries(
333        clk: u32,
334        #[strategy(OpStackElement::COUNT..1024)] stack_pointer: usize,
335        #[strategy(vec(arb(), ..OpStackElement::COUNT))] base_field_elements: Vec<BFieldElement>,
336        sequence_of_writes: bool,
337    ) {
338        let underflow_io_operation = match sequence_of_writes {
339            true => UnderflowIO::Write,
340            false => UnderflowIO::Read,
341        };
342        let underflow_io = base_field_elements
343            .into_iter()
344            .map(underflow_io_operation)
345            .collect();
346
347        let op_stack_pointer = u64::try_from(stack_pointer).unwrap().into();
348        let entries =
349            OpStackTableEntry::from_underflow_io_sequence(clk, op_stack_pointer, underflow_io);
350        let clk_values = entries.iter().map(|entry| entry.clk).collect_vec();
351        let all_clk_values_are_clk = clk_values.iter().all(|&c| c == clk);
352        prop_assert!(all_clk_values_are_clk);
353    }
354
355    #[proptest]
356    fn compare_rows_with_unequal_stack_pointer_and_equal_clk(
357        stack_pointer_0: u64,
358        stack_pointer_1: u64,
359        clk: u64,
360    ) {
361        let mut row_0 = Array1::zeros(MainColumn::COUNT);
362        row_0[MainColumn::StackPointer.main_index()] = stack_pointer_0.into();
363        row_0[MainColumn::CLK.main_index()] = clk.into();
364
365        let mut row_1 = Array1::zeros(MainColumn::COUNT);
366        row_1[MainColumn::StackPointer.main_index()] = stack_pointer_1.into();
367        row_1[MainColumn::CLK.main_index()] = clk.into();
368
369        let stack_pointer_comparison = stack_pointer_0.cmp(&stack_pointer_1);
370        let row_comparison = compare_rows(row_0.view(), row_1.view());
371
372        prop_assert_eq!(stack_pointer_comparison, row_comparison);
373    }
374
375    #[proptest]
376    fn compare_rows_with_equal_stack_pointer_and_unequal_clk(
377        stack_pointer: u64,
378        clk_0: u64,
379        clk_1: u64,
380    ) {
381        let mut row_0 = Array1::zeros(MainColumn::COUNT);
382        row_0[MainColumn::StackPointer.main_index()] = stack_pointer.into();
383        row_0[MainColumn::CLK.main_index()] = clk_0.into();
384
385        let mut row_1 = Array1::zeros(MainColumn::COUNT);
386        row_1[MainColumn::StackPointer.main_index()] = stack_pointer.into();
387        row_1[MainColumn::CLK.main_index()] = clk_1.into();
388
389        let clk_comparison = clk_0.cmp(&clk_1);
390        let row_comparison = compare_rows(row_0.view(), row_1.view());
391
392        prop_assert_eq!(clk_comparison, row_comparison);
393    }
394}