triton_vm/table/
op_stack.rs

1use std::cmp::Ordering;
2use std::collections::HashMap;
3use std::ops::Range;
4
5use air::challenge_id::ChallengeId;
6use air::cross_table_argument::CrossTableArg;
7use air::cross_table_argument::LookupArg;
8use air::cross_table_argument::PermArg;
9use air::table::TableId;
10use air::table::op_stack::OpStackTable;
11use air::table::op_stack::PADDING_VALUE;
12use air::table_column::MasterAuxColumn;
13use air::table_column::MasterMainColumn;
14use air::table_column::OpStackAuxColumn;
15use arbitrary::Arbitrary;
16use isa::op_stack::OpStackElement;
17use isa::op_stack::UnderflowIO;
18use itertools::Itertools;
19use ndarray::parallel::prelude::*;
20use ndarray::prelude::*;
21use strum::EnumCount;
22use strum::IntoEnumIterator;
23use twenty_first::math::traits::FiniteField;
24use twenty_first::prelude::*;
25
26use crate::aet::AlgebraicExecutionTrace;
27use crate::challenges::Challenges;
28use crate::ndarray_helper::ROW_AXIS;
29use crate::ndarray_helper::contiguous_column_slices;
30use crate::ndarray_helper::horizontal_multi_slice_mut;
31use crate::profiler::profiler;
32use crate::table::TraceTable;
33
34type MainColumn = <OpStackTable as air::AIR>::MainColumn;
35type AuxColumn = <OpStackTable as air::AIR>::AuxColumn;
36
37#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
38pub struct OpStackTableEntry {
39    pub clk: u32,
40    pub op_stack_pointer: BFieldElement,
41    pub underflow_io: UnderflowIO,
42}
43
44impl OpStackTableEntry {
45    pub fn new(clk: u32, op_stack_pointer: BFieldElement, underflow_io: UnderflowIO) -> Self {
46        Self {
47            clk,
48            op_stack_pointer,
49            underflow_io,
50        }
51    }
52
53    pub fn shrinks_stack(&self) -> bool {
54        self.underflow_io.shrinks_stack()
55    }
56
57    pub fn grows_stack(&self) -> bool {
58        self.underflow_io.grows_stack()
59    }
60
61    pub fn from_underflow_io_sequence(
62        clk: u32,
63        op_stack_pointer_after_sequence_execution: BFieldElement,
64        mut underflow_io_sequence: Vec<UnderflowIO>,
65    ) -> Vec<Self> {
66        UnderflowIO::canonicalize_sequence(&mut underflow_io_sequence);
67        assert!(UnderflowIO::is_uniform_sequence(&underflow_io_sequence));
68
69        let sequence_length: BFieldElement =
70            u32::try_from(underflow_io_sequence.len()).unwrap().into();
71        let mut op_stack_pointer = match UnderflowIO::is_writing_sequence(&underflow_io_sequence) {
72            true => op_stack_pointer_after_sequence_execution - sequence_length,
73            false => op_stack_pointer_after_sequence_execution + sequence_length,
74        };
75        let mut op_stack_table_entries = vec![];
76        for underflow_io in underflow_io_sequence {
77            if underflow_io.shrinks_stack() {
78                op_stack_pointer.decrement();
79            }
80            let op_stack_table_entry = Self::new(clk, op_stack_pointer, underflow_io);
81            op_stack_table_entries.push(op_stack_table_entry);
82            if underflow_io.grows_stack() {
83                op_stack_pointer.increment();
84            }
85        }
86        op_stack_table_entries
87    }
88
89    pub fn to_main_table_row(self) -> Array1<BFieldElement> {
90        let shrink_stack_indicator = if self.shrinks_stack() {
91            bfe!(1)
92        } else {
93            bfe!(0)
94        };
95
96        let mut row = Array1::zeros(MainColumn::COUNT);
97        row[MainColumn::CLK.main_index()] = self.clk.into();
98        row[MainColumn::IB1ShrinkStack.main_index()] = shrink_stack_indicator;
99        row[MainColumn::StackPointer.main_index()] = self.op_stack_pointer;
100        row[MainColumn::FirstUnderflowElement.main_index()] = self.underflow_io.payload();
101        row
102    }
103}
104
105fn auxiliary_column_running_product_permutation_argument(
106    main_table: ArrayView2<BFieldElement>,
107    challenges: &Challenges,
108) -> Array2<XFieldElement> {
109    let perm_arg_indeterminate = challenges[ChallengeId::OpStackIndeterminate];
110
111    let mut running_product = PermArg::default_initial();
112    let mut auxiliary_column = Vec::with_capacity(main_table.nrows());
113    for row in main_table.rows() {
114        if row[MainColumn::IB1ShrinkStack.main_index()] != PADDING_VALUE {
115            let compressed_row = row[MainColumn::CLK.main_index()]
116                * challenges[ChallengeId::OpStackClkWeight]
117                + row[MainColumn::IB1ShrinkStack.main_index()]
118                    * challenges[ChallengeId::OpStackIb1Weight]
119                + row[MainColumn::StackPointer.main_index()]
120                    * challenges[ChallengeId::OpStackPointerWeight]
121                + row[MainColumn::FirstUnderflowElement.main_index()]
122                    * challenges[ChallengeId::OpStackFirstUnderflowElementWeight];
123            running_product *= perm_arg_indeterminate - compressed_row;
124        }
125        auxiliary_column.push(running_product);
126    }
127    Array2::from_shape_vec((main_table.nrows(), 1), auxiliary_column).unwrap()
128}
129
130fn auxiliary_column_clock_jump_diff_lookup_log_derivative(
131    main_table: ArrayView2<BFieldElement>,
132    challenges: &Challenges,
133) -> Array2<XFieldElement> {
134    // - use memoization to avoid recomputing inverses
135    // - precompute common values through batch inversion
136    const PRECOMPUTE_INVERSES_OF: Range<u64> = 0..100;
137    let cjd_lookup_indeterminate = challenges[ChallengeId::ClockJumpDifferenceLookupIndeterminate];
138    let to_invert = PRECOMPUTE_INVERSES_OF
139        .map(|i| cjd_lookup_indeterminate - bfe!(i))
140        .collect_vec();
141    let inverses = XFieldElement::batch_inversion(to_invert);
142    let mut inverses_dictionary = PRECOMPUTE_INVERSES_OF
143        .zip_eq(inverses)
144        .map(|(i, inv)| (bfe!(i), inv))
145        .collect::<HashMap<_, _>>();
146
147    // populate auxiliary column using memoization
148    let mut cjd_lookup_log_derivative = LookupArg::default_initial();
149    let mut auxiliary_column = Vec::with_capacity(main_table.nrows());
150    auxiliary_column.push(cjd_lookup_log_derivative);
151    for (previous_row, current_row) in main_table.rows().into_iter().tuple_windows() {
152        if current_row[MainColumn::IB1ShrinkStack.main_index()] == PADDING_VALUE {
153            break;
154        };
155
156        let previous_stack_pointer = previous_row[MainColumn::StackPointer.main_index()];
157        let current_stack_pointer = current_row[MainColumn::StackPointer.main_index()];
158        if previous_stack_pointer == current_stack_pointer {
159            let previous_clock = previous_row[MainColumn::CLK.main_index()];
160            let current_clock = current_row[MainColumn::CLK.main_index()];
161            let clock_jump_difference = current_clock - previous_clock;
162            let &mut inverse = inverses_dictionary
163                .entry(clock_jump_difference)
164                .or_insert_with(|| (cjd_lookup_indeterminate - clock_jump_difference).inverse());
165            cjd_lookup_log_derivative += inverse;
166        }
167        auxiliary_column.push(cjd_lookup_log_derivative);
168    }
169
170    // fill padding section
171    auxiliary_column.resize(main_table.nrows(), cjd_lookup_log_derivative);
172    Array2::from_shape_vec((main_table.nrows(), 1), auxiliary_column).unwrap()
173}
174
175impl TraceTable for OpStackTable {
176    type FillParam = ();
177    type FillReturnInfo = Vec<BFieldElement>;
178
179    fn fill(
180        mut op_stack_table: ArrayViewMut2<BFieldElement>,
181        aet: &AlgebraicExecutionTrace,
182        _: Self::FillParam,
183    ) -> Vec<BFieldElement> {
184        let mut op_stack_table =
185            op_stack_table.slice_mut(s![0..aet.height_of_table(TableId::OpStack), ..]);
186        let trace_iter = aet.op_stack_underflow_trace.rows().into_iter();
187
188        let sorted_rows =
189            trace_iter.sorted_by(|row_0, row_1| compare_rows(row_0.view(), row_1.view()));
190        for (row_index, row) in sorted_rows.enumerate() {
191            op_stack_table.row_mut(row_index).assign(&row);
192        }
193
194        clock_jump_differences(op_stack_table.view())
195    }
196
197    fn pad(mut op_stack_table: ArrayViewMut2<BFieldElement>, op_stack_table_len: usize) {
198        let last_row_index = op_stack_table_len.saturating_sub(1);
199        let mut padding_row = op_stack_table.row(last_row_index).to_owned();
200        padding_row[MainColumn::IB1ShrinkStack.main_index()] = PADDING_VALUE;
201        if op_stack_table_len == 0 {
202            let first_stack_pointer = u32::try_from(OpStackElement::COUNT).unwrap().into();
203            padding_row[MainColumn::StackPointer.main_index()] = first_stack_pointer;
204        }
205
206        let mut padding_section = op_stack_table.slice_mut(s![op_stack_table_len.., ..]);
207        padding_section
208            .axis_iter_mut(ROW_AXIS)
209            .into_par_iter()
210            .for_each(|mut row| row.assign(&padding_row));
211    }
212
213    fn extend(
214        main_table: ArrayView2<BFieldElement>,
215        mut aux_table: ArrayViewMut2<XFieldElement>,
216        challenges: &Challenges,
217    ) {
218        profiler!(start "op stack table");
219        assert_eq!(MainColumn::COUNT, main_table.ncols());
220        assert_eq!(AuxColumn::COUNT, aux_table.ncols());
221        assert_eq!(main_table.nrows(), aux_table.nrows());
222
223        let auxiliary_column_indices = OpStackAuxColumn::iter()
224            .map(|column| column.aux_index())
225            .collect_vec();
226        let auxiliary_column_slices = horizontal_multi_slice_mut(
227            aux_table.view_mut(),
228            &contiguous_column_slices(&auxiliary_column_indices),
229        );
230        let extension_functions = [
231            auxiliary_column_running_product_permutation_argument,
232            auxiliary_column_clock_jump_diff_lookup_log_derivative,
233        ];
234
235        extension_functions
236            .into_par_iter()
237            .zip_eq(auxiliary_column_slices)
238            .for_each(|(generator, slice)| {
239                generator(main_table, challenges).move_into(slice);
240            });
241
242        profiler!(stop "op stack table");
243    }
244}
245
246fn compare_rows(row_0: ArrayView1<BFieldElement>, row_1: ArrayView1<BFieldElement>) -> Ordering {
247    let stack_pointer_0 = row_0[MainColumn::StackPointer.main_index()].value();
248    let stack_pointer_1 = row_1[MainColumn::StackPointer.main_index()].value();
249    let compare_stack_pointers = stack_pointer_0.cmp(&stack_pointer_1);
250
251    let clk_0 = row_0[MainColumn::CLK.main_index()].value();
252    let clk_1 = row_1[MainColumn::CLK.main_index()].value();
253    let compare_clocks = clk_0.cmp(&clk_1);
254
255    compare_stack_pointers.then(compare_clocks)
256}
257
258fn clock_jump_differences(op_stack_table: ArrayView2<BFieldElement>) -> Vec<BFieldElement> {
259    let mut clock_jump_differences = vec![];
260    for consecutive_rows in op_stack_table.axis_windows(ROW_AXIS, 2) {
261        let current_row = consecutive_rows.row(0);
262        let next_row = consecutive_rows.row(1);
263        let current_stack_pointer = current_row[MainColumn::StackPointer.main_index()];
264        let next_stack_pointer = next_row[MainColumn::StackPointer.main_index()];
265        if current_stack_pointer == next_stack_pointer {
266            let current_clk = current_row[MainColumn::CLK.main_index()];
267            let next_clk = next_row[MainColumn::CLK.main_index()];
268            let clk_difference = next_clk - current_clk;
269            clock_jump_differences.push(clk_difference);
270        }
271    }
272    clock_jump_differences
273}
274
275#[cfg(test)]
276#[cfg_attr(coverage_nightly, coverage(off))]
277pub(crate) mod tests {
278    use assert2::assert;
279    use isa::op_stack::OpStackElement;
280    use itertools::Itertools;
281    use proptest::collection::vec;
282    use proptest::prelude::*;
283    use proptest_arbitrary_interop::arb;
284    use test_strategy::proptest;
285
286    use super::*;
287
288    #[proptest]
289    fn op_stack_table_entry_either_shrinks_stack_or_grows_stack(
290        #[strategy(arb())] entry: OpStackTableEntry,
291    ) {
292        let shrinks_stack = entry.shrinks_stack();
293        let grows_stack = entry.grows_stack();
294        assert!(shrinks_stack ^ grows_stack);
295    }
296
297    #[proptest]
298    fn op_stack_pointer_in_sequence_of_op_stack_table_entries(
299        clk: u32,
300        #[strategy(OpStackElement::COUNT..1024)] stack_pointer: usize,
301        #[strategy(vec(arb(), ..OpStackElement::COUNT))] base_field_elements: Vec<BFieldElement>,
302        sequence_of_writes: bool,
303    ) {
304        let sequence_length = u64::try_from(base_field_elements.len()).unwrap();
305        let stack_pointer = u64::try_from(stack_pointer).unwrap();
306
307        let underflow_io_operation = match sequence_of_writes {
308            true => UnderflowIO::Write,
309            false => UnderflowIO::Read,
310        };
311        let underflow_io = base_field_elements
312            .into_iter()
313            .map(underflow_io_operation)
314            .collect();
315
316        let op_stack_pointer = stack_pointer.into();
317        let entries =
318            OpStackTableEntry::from_underflow_io_sequence(clk, op_stack_pointer, underflow_io);
319        let op_stack_pointers = entries
320            .iter()
321            .map(|entry| entry.op_stack_pointer.value())
322            .sorted()
323            .collect_vec();
324
325        let expected_stack_pointer_range = match sequence_of_writes {
326            true => stack_pointer - sequence_length..stack_pointer,
327            false => stack_pointer..stack_pointer + sequence_length,
328        };
329        let expected_op_stack_pointers = expected_stack_pointer_range.collect_vec();
330        prop_assert_eq!(expected_op_stack_pointers, op_stack_pointers);
331    }
332
333    #[proptest]
334    fn clk_stays_same_in_sequence_of_op_stack_table_entries(
335        clk: u32,
336        #[strategy(OpStackElement::COUNT..1024)] stack_pointer: usize,
337        #[strategy(vec(arb(), ..OpStackElement::COUNT))] base_field_elements: Vec<BFieldElement>,
338        sequence_of_writes: bool,
339    ) {
340        let underflow_io_operation = match sequence_of_writes {
341            true => UnderflowIO::Write,
342            false => UnderflowIO::Read,
343        };
344        let underflow_io = base_field_elements
345            .into_iter()
346            .map(underflow_io_operation)
347            .collect();
348
349        let op_stack_pointer = u64::try_from(stack_pointer).unwrap().into();
350        let entries =
351            OpStackTableEntry::from_underflow_io_sequence(clk, op_stack_pointer, underflow_io);
352        let clk_values = entries.iter().map(|entry| entry.clk).collect_vec();
353        let all_clk_values_are_clk = clk_values.iter().all(|&c| c == clk);
354        prop_assert!(all_clk_values_are_clk);
355    }
356
357    #[proptest]
358    fn compare_rows_with_unequal_stack_pointer_and_equal_clk(
359        stack_pointer_0: u64,
360        stack_pointer_1: u64,
361        clk: u64,
362    ) {
363        let mut row_0 = Array1::zeros(MainColumn::COUNT);
364        row_0[MainColumn::StackPointer.main_index()] = stack_pointer_0.into();
365        row_0[MainColumn::CLK.main_index()] = clk.into();
366
367        let mut row_1 = Array1::zeros(MainColumn::COUNT);
368        row_1[MainColumn::StackPointer.main_index()] = stack_pointer_1.into();
369        row_1[MainColumn::CLK.main_index()] = clk.into();
370
371        let stack_pointer_comparison = stack_pointer_0.cmp(&stack_pointer_1);
372        let row_comparison = compare_rows(row_0.view(), row_1.view());
373
374        prop_assert_eq!(stack_pointer_comparison, row_comparison);
375    }
376
377    #[proptest]
378    fn compare_rows_with_equal_stack_pointer_and_unequal_clk(
379        stack_pointer: u64,
380        clk_0: u64,
381        clk_1: u64,
382    ) {
383        let mut row_0 = Array1::zeros(MainColumn::COUNT);
384        row_0[MainColumn::StackPointer.main_index()] = stack_pointer.into();
385        row_0[MainColumn::CLK.main_index()] = clk_0.into();
386
387        let mut row_1 = Array1::zeros(MainColumn::COUNT);
388        row_1[MainColumn::StackPointer.main_index()] = stack_pointer.into();
389        row_1[MainColumn::CLK.main_index()] = clk_1.into();
390
391        let clk_comparison = clk_0.cmp(&clk_1);
392        let row_comparison = compare_rows(row_0.view(), row_1.view());
393
394        prop_assert_eq!(clk_comparison, row_comparison);
395    }
396}