triton_vm/table/
ram.rs

1use std::cmp::Ordering;
2
3use air::challenge_id::ChallengeId;
4use air::cross_table_argument::CrossTableArg;
5use air::cross_table_argument::LookupArg;
6use air::cross_table_argument::PermArg;
7use air::table::TableId;
8use air::table::ram::PADDING_INDICATOR;
9use air::table::ram::RamTable;
10use air::table_column::MasterAuxColumn;
11use air::table_column::MasterMainColumn;
12use arbitrary::Arbitrary;
13use itertools::Itertools;
14use ndarray::parallel::prelude::*;
15use ndarray::prelude::*;
16use num_traits::ConstOne;
17use num_traits::ConstZero;
18use num_traits::One;
19use num_traits::Zero;
20use serde::Deserialize;
21use serde::Serialize;
22use strum::EnumCount;
23use strum::IntoEnumIterator;
24use twenty_first::math::traits::FiniteField;
25use twenty_first::prelude::*;
26
27use crate::aet::AlgebraicExecutionTrace;
28use crate::challenges::Challenges;
29use crate::ndarray_helper::ROW_AXIS;
30use crate::ndarray_helper::contiguous_column_slices;
31use crate::ndarray_helper::horizontal_multi_slice_mut;
32use crate::profiler::profiler;
33use crate::table::TraceTable;
34
35type MainColumn = <RamTable as air::AIR>::MainColumn;
36type AuxColumn = <RamTable as air::AIR>::AuxColumn;
37
38#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
39pub struct RamTableCall {
40    pub clk: u32,
41    pub ram_pointer: BFieldElement,
42    pub ram_value: BFieldElement,
43    pub is_write: bool,
44}
45
46impl RamTableCall {
47    pub fn to_table_row(self) -> Array1<BFieldElement> {
48        let instruction_type = if self.is_write {
49            air::table::ram::INSTRUCTION_TYPE_WRITE
50        } else {
51            air::table::ram::INSTRUCTION_TYPE_READ
52        };
53
54        let mut row = Array1::zeros(MainColumn::COUNT);
55        row[MainColumn::CLK.main_index()] = self.clk.into();
56        row[MainColumn::InstructionType.main_index()] = instruction_type;
57        row[MainColumn::RamPointer.main_index()] = self.ram_pointer;
58        row[MainColumn::RamValue.main_index()] = self.ram_value;
59        row
60    }
61}
62
63impl TraceTable for RamTable {
64    type FillParam = ();
65    type FillReturnInfo = Vec<BFieldElement>;
66
67    fn fill(
68        mut ram_table: ArrayViewMut2<BFieldElement>,
69        aet: &AlgebraicExecutionTrace,
70        _: Self::FillParam,
71    ) -> Self::FillReturnInfo {
72        let mut ram_table = ram_table.slice_mut(s![0..aet.height_of_table(TableId::Ram), ..]);
73        let trace_iter = aet.ram_trace.rows().into_iter();
74
75        let sorted_rows =
76            trace_iter.sorted_by(|row_0, row_1| compare_rows(row_0.view(), row_1.view()));
77        for (row_index, row) in sorted_rows.enumerate() {
78            ram_table.row_mut(row_index).assign(&row);
79        }
80
81        let all_ram_pointers = ram_table.column(MainColumn::RamPointer.main_index());
82        let unique_ram_pointers = all_ram_pointers.iter().unique().copied().collect_vec();
83        let (bezout_0, bezout_1) =
84            bezout_coefficient_polynomials_coefficients(&unique_ram_pointers);
85
86        make_ram_table_consistent(&mut ram_table, bezout_0, bezout_1)
87    }
88
89    fn pad(mut main_table: ArrayViewMut2<BFieldElement>, table_len: usize) {
90        let last_row_index = table_len.saturating_sub(1);
91        let mut padding_row = main_table.row(last_row_index).to_owned();
92        padding_row[MainColumn::InstructionType.main_index()] = PADDING_INDICATOR;
93        if table_len == 0 {
94            padding_row[MainColumn::BezoutCoefficientPolynomialCoefficient1.main_index()] =
95                BFieldElement::ONE;
96        }
97
98        let mut padding_section = main_table.slice_mut(s![table_len.., ..]);
99        padding_section
100            .axis_iter_mut(ROW_AXIS)
101            .into_par_iter()
102            .for_each(|mut row| row.assign(&padding_row));
103    }
104
105    fn extend(
106        main_table: ArrayView2<BFieldElement>,
107        mut aux_table: ArrayViewMut2<XFieldElement>,
108        challenges: &Challenges,
109    ) {
110        profiler!(start "ram table");
111        assert_eq!(MainColumn::COUNT, main_table.ncols());
112        assert_eq!(AuxColumn::COUNT, aux_table.ncols());
113        assert_eq!(main_table.nrows(), aux_table.nrows());
114
115        let auxiliary_column_indices = AuxColumn::iter()
116            // RunningProductOfRAMP + FormalDerivative are constitute one
117            // slice and are populated by the same function
118            .filter(|column| *column != AuxColumn::FormalDerivative)
119            .map(|column| column.aux_index())
120            .collect_vec();
121        let auxiliary_column_slices = horizontal_multi_slice_mut(
122            aux_table.view_mut(),
123            &contiguous_column_slices(&auxiliary_column_indices),
124        );
125        let extension_functions = [
126            auxiliary_column_running_product_of_ramp_and_formal_derivative,
127            auxiliary_column_bezout_coefficient_0,
128            auxiliary_column_bezout_coefficient_1,
129            auxiliary_column_running_product_perm_arg,
130            auxiliary_column_clock_jump_difference_lookup_log_derivative,
131        ];
132        extension_functions
133            .into_par_iter()
134            .zip_eq(auxiliary_column_slices)
135            .for_each(|(generator, slice)| {
136                generator(main_table, challenges).move_into(slice);
137            });
138
139        profiler!(stop "ram table");
140    }
141}
142
143fn compare_rows(row_0: ArrayView1<BFieldElement>, row_1: ArrayView1<BFieldElement>) -> Ordering {
144    let ram_pointer_0 = row_0[MainColumn::RamPointer.main_index()].value();
145    let ram_pointer_1 = row_1[MainColumn::RamPointer.main_index()].value();
146    let compare_ram_pointers = ram_pointer_0.cmp(&ram_pointer_1);
147
148    let clk_0 = row_0[MainColumn::CLK.main_index()].value();
149    let clk_1 = row_1[MainColumn::CLK.main_index()].value();
150    let compare_clocks = clk_0.cmp(&clk_1);
151
152    compare_ram_pointers.then(compare_clocks)
153}
154
155/// Compute the
156/// [Bézout coefficients](https://en.wikipedia.org/wiki/B%C3%A9zout%27s_identity)
157/// of the polynomial with the given roots and its formal derivative.
158///
159/// All roots _must_ be unique. That is, the corresponding polynomial must be
160/// square free.
161#[doc(hidden)] // public for benchmarking purposes only
162pub fn bezout_coefficient_polynomials_coefficients(
163    unique_roots: &[BFieldElement],
164) -> (Vec<BFieldElement>, Vec<BFieldElement>) {
165    if unique_roots.is_empty() {
166        return (vec![], vec![]);
167    }
168
169    // The structure of the problem is exploited heavily to compute the Bézout
170    // coefficients as fast as possible. In the following paragraphs, let `rp`
171    // denote the polynomial with the given `unique_roots` as its roots, and
172    // `fd` the formal derivative of `rp`.
173    //
174    // The naïve approach is to perform the extended Euclidean algorithm (xgcd)
175    // on `rp` and `fd`. This has a time complexity in O(n^2) where `n` is the
176    // number of roots: for the given problem shape, the degrees `rp` and `fd`
177    // are `n` and `n-1`, respectively. Each step of the (x)gcd takes O(n) time
178    // and reduces the degree of the polynomials by one. For programs with a
179    // large number of different RAM accesses, `n` is large.
180    //
181    // The approach taken here is to exploit the structure of the problem.
182    // Concretely, since all roots of `rp` are unique, _i.e._, `rp` is square
183    // free, the gcd of `rp` and `fd` is 1. This implies `∀ r ∈ unique_roots:
184    // fd(r)·b(r) = 1`, where `b` is one of the Bézout coefficients. In other
185    // words, the evaluation of `fd` in `unique_roots` is the inverse of
186    // the evaluation of `b` in `unique_roots`. Furthermore, `b` is a polynomial
187    // of degree `n`, and therefore fully determined by the evaluations in
188    // `unique_roots`. Finally, the other Bézout coefficient `a` is determined
189    // by `a = (1 - fd·b) / rp`. In total, this allows computing the Bézout
190    // coefficients in O(n·(log n)^2) time.
191
192    debug_assert!(unique_roots.iter().all_unique());
193    let rp = Polynomial::par_zerofier(unique_roots);
194    let fd = rp.formal_derivative();
195    let fd_in_roots = fd.par_batch_evaluate(unique_roots);
196    let b_in_roots = BFieldElement::batch_inversion(fd_in_roots);
197    let b = Polynomial::par_interpolate(unique_roots, &b_in_roots);
198    let one_minus_fd_b = Polynomial::one() - fd.multiply(&b);
199    let a = one_minus_fd_b.clean_divide(rp);
200
201    let mut coefficients_0 = a.into_coefficients();
202    let mut coefficients_1 = b.into_coefficients();
203    coefficients_0.resize(unique_roots.len(), BFieldElement::ZERO);
204    coefficients_1.resize(unique_roots.len(), BFieldElement::ZERO);
205    (coefficients_0, coefficients_1)
206}
207
208/// - Set inverse of RAM pointer difference
209/// - Fill in the Bézout coefficients if the RAM pointer changes between two
210///   consecutive rows
211/// - Collect and return all clock jump differences
212fn make_ram_table_consistent(
213    ram_table: &mut ArrayViewMut2<BFieldElement>,
214    mut bezout_coefficient_polynomial_coefficients_0: Vec<BFieldElement>,
215    mut bezout_coefficient_polynomial_coefficients_1: Vec<BFieldElement>,
216) -> Vec<BFieldElement> {
217    if ram_table.nrows() == 0 {
218        assert_eq!(0, bezout_coefficient_polynomial_coefficients_0.len());
219        assert_eq!(0, bezout_coefficient_polynomial_coefficients_1.len());
220        return vec![];
221    }
222
223    let mut current_bcpc_0 = bezout_coefficient_polynomial_coefficients_0.pop().unwrap();
224    let mut current_bcpc_1 = bezout_coefficient_polynomial_coefficients_1.pop().unwrap();
225    ram_table.row_mut(0)[MainColumn::BezoutCoefficientPolynomialCoefficient0.main_index()] =
226        current_bcpc_0;
227    ram_table.row_mut(0)[MainColumn::BezoutCoefficientPolynomialCoefficient1.main_index()] =
228        current_bcpc_1;
229
230    let mut clock_jump_differences = vec![];
231    for row_idx in 0..ram_table.nrows() - 1 {
232        let (mut curr_row, mut next_row) =
233            ram_table.multi_slice_mut((s![row_idx, ..], s![row_idx + 1, ..]));
234
235        let ramp_diff = next_row[MainColumn::RamPointer.main_index()]
236            - curr_row[MainColumn::RamPointer.main_index()];
237        let clk_diff =
238            next_row[MainColumn::CLK.main_index()] - curr_row[MainColumn::CLK.main_index()];
239
240        if ramp_diff.is_zero() {
241            clock_jump_differences.push(clk_diff);
242        } else {
243            current_bcpc_0 = bezout_coefficient_polynomial_coefficients_0.pop().unwrap();
244            current_bcpc_1 = bezout_coefficient_polynomial_coefficients_1.pop().unwrap();
245        }
246
247        curr_row[MainColumn::InverseOfRampDifference.main_index()] = ramp_diff.inverse_or_zero();
248        next_row[MainColumn::BezoutCoefficientPolynomialCoefficient0.main_index()] = current_bcpc_0;
249        next_row[MainColumn::BezoutCoefficientPolynomialCoefficient1.main_index()] = current_bcpc_1;
250    }
251
252    assert_eq!(0, bezout_coefficient_polynomial_coefficients_0.len());
253    assert_eq!(0, bezout_coefficient_polynomial_coefficients_1.len());
254    clock_jump_differences
255}
256
257fn auxiliary_column_running_product_of_ramp_and_formal_derivative(
258    main_table: ArrayView2<BFieldElement>,
259    challenges: &Challenges,
260) -> Array2<XFieldElement> {
261    let bezout_indeterminate = challenges[ChallengeId::RamTableBezoutRelationIndeterminate];
262
263    let mut auxiliary_columns = Vec::with_capacity(2 * main_table.nrows());
264    let mut running_product_ram_pointer =
265        bezout_indeterminate - main_table.row(0)[MainColumn::RamPointer.main_index()];
266    let mut formal_derivative = xfe!(1);
267
268    auxiliary_columns.push(running_product_ram_pointer);
269    auxiliary_columns.push(formal_derivative);
270
271    for (previous_row, current_row) in main_table.rows().into_iter().tuple_windows() {
272        let instruction_type = current_row[MainColumn::InstructionType.main_index()];
273        let is_no_padding_row = instruction_type != PADDING_INDICATOR;
274
275        if is_no_padding_row {
276            let current_ram_pointer = current_row[MainColumn::RamPointer.main_index()];
277            let previous_ram_pointer = previous_row[MainColumn::RamPointer.main_index()];
278            if previous_ram_pointer != current_ram_pointer {
279                formal_derivative = (bezout_indeterminate - current_ram_pointer)
280                    * formal_derivative
281                    + running_product_ram_pointer;
282                running_product_ram_pointer *= bezout_indeterminate - current_ram_pointer;
283            }
284        }
285
286        auxiliary_columns.push(running_product_ram_pointer);
287        auxiliary_columns.push(formal_derivative);
288    }
289
290    Array2::from_shape_vec((main_table.nrows(), 2), auxiliary_columns).unwrap()
291}
292
293fn auxiliary_column_bezout_coefficient_0(
294    main_table: ArrayView2<BFieldElement>,
295    challenges: &Challenges,
296) -> Array2<XFieldElement> {
297    auxiliary_column_bezout_coefficient(
298        main_table,
299        challenges,
300        MainColumn::BezoutCoefficientPolynomialCoefficient0,
301    )
302}
303
304fn auxiliary_column_bezout_coefficient_1(
305    main_table: ArrayView2<BFieldElement>,
306    challenges: &Challenges,
307) -> Array2<XFieldElement> {
308    auxiliary_column_bezout_coefficient(
309        main_table,
310        challenges,
311        MainColumn::BezoutCoefficientPolynomialCoefficient1,
312    )
313}
314
315fn auxiliary_column_bezout_coefficient(
316    main_table: ArrayView2<BFieldElement>,
317    challenges: &Challenges,
318    bezout_cefficient_column: MainColumn,
319) -> Array2<XFieldElement> {
320    let bezout_indeterminate = challenges[ChallengeId::RamTableBezoutRelationIndeterminate];
321
322    let mut bezout_coefficient = main_table.row(0)[bezout_cefficient_column.main_index()].lift();
323    let mut auxiliary_column = Vec::with_capacity(main_table.nrows());
324    auxiliary_column.push(bezout_coefficient);
325
326    for (previous_row, current_row) in main_table.rows().into_iter().tuple_windows() {
327        if current_row[MainColumn::InstructionType.main_index()] == PADDING_INDICATOR {
328            break; // padding marks the end of the trace
329        }
330
331        let previous_ram_pointer = previous_row[MainColumn::RamPointer.main_index()];
332        let current_ram_pointer = current_row[MainColumn::RamPointer.main_index()];
333        if previous_ram_pointer != current_ram_pointer {
334            bezout_coefficient *= bezout_indeterminate;
335            bezout_coefficient += current_row[bezout_cefficient_column.main_index()];
336        }
337        auxiliary_column.push(bezout_coefficient);
338    }
339
340    // fill padding section
341    auxiliary_column.resize(main_table.nrows(), bezout_coefficient);
342    Array2::from_shape_vec((main_table.nrows(), 1), auxiliary_column).unwrap()
343}
344
345fn auxiliary_column_running_product_perm_arg(
346    main_table: ArrayView2<BFieldElement>,
347    challenges: &Challenges,
348) -> Array2<XFieldElement> {
349    let mut running_product_for_perm_arg = PermArg::default_initial();
350    let mut auxiliary_column = Vec::with_capacity(main_table.nrows());
351    for row in main_table.rows() {
352        let instruction_type = row[MainColumn::InstructionType.main_index()];
353        if instruction_type == PADDING_INDICATOR {
354            break; // padding marks the end of the trace
355        }
356
357        let clk = row[MainColumn::CLK.main_index()];
358        let current_ram_pointer = row[MainColumn::RamPointer.main_index()];
359        let ram_value = row[MainColumn::RamValue.main_index()];
360        let compressed_row = clk * challenges[ChallengeId::RamClkWeight]
361            + instruction_type * challenges[ChallengeId::RamInstructionTypeWeight]
362            + current_ram_pointer * challenges[ChallengeId::RamPointerWeight]
363            + ram_value * challenges[ChallengeId::RamValueWeight];
364        running_product_for_perm_arg *= challenges[ChallengeId::RamIndeterminate] - compressed_row;
365        auxiliary_column.push(running_product_for_perm_arg);
366    }
367
368    // fill padding section
369    auxiliary_column.resize(main_table.nrows(), running_product_for_perm_arg);
370    Array2::from_shape_vec((main_table.nrows(), 1), auxiliary_column).unwrap()
371}
372
373fn auxiliary_column_clock_jump_difference_lookup_log_derivative(
374    main_table: ArrayView2<BFieldElement>,
375    challenges: &Challenges,
376) -> Array2<XFieldElement> {
377    let indeterminate = challenges[ChallengeId::ClockJumpDifferenceLookupIndeterminate];
378
379    let mut cjd_lookup_log_derivative = LookupArg::default_initial();
380    let mut auxiliary_column = Vec::with_capacity(main_table.nrows());
381    auxiliary_column.push(cjd_lookup_log_derivative);
382
383    for (previous_row, current_row) in main_table.rows().into_iter().tuple_windows() {
384        if current_row[MainColumn::InstructionType.main_index()] == PADDING_INDICATOR {
385            break; // padding marks the end of the trace
386        }
387
388        let previous_ram_pointer = previous_row[MainColumn::RamPointer.main_index()];
389        let current_ram_pointer = current_row[MainColumn::RamPointer.main_index()];
390        if previous_ram_pointer == current_ram_pointer {
391            let previous_clock = previous_row[MainColumn::CLK.main_index()];
392            let current_clock = current_row[MainColumn::CLK.main_index()];
393            let clock_jump_difference = current_clock - previous_clock;
394            let log_derivative_summand = (indeterminate - clock_jump_difference).inverse();
395            cjd_lookup_log_derivative += log_derivative_summand;
396        }
397        auxiliary_column.push(cjd_lookup_log_derivative);
398    }
399
400    // fill padding section
401    auxiliary_column.resize(main_table.nrows(), cjd_lookup_log_derivative);
402    Array2::from_shape_vec((main_table.nrows(), 1), auxiliary_column).unwrap()
403}
404
405#[cfg(test)]
406#[cfg_attr(coverage_nightly, coverage(off))]
407pub(crate) mod tests {
408    use proptest::prelude::*;
409    use proptest_arbitrary_interop::arb;
410    use test_strategy::proptest;
411
412    use super::*;
413
414    #[proptest]
415    fn ram_table_call_can_be_converted_to_table_row(
416        #[strategy(arb())] ram_table_call: RamTableCall,
417    ) {
418        ram_table_call.to_table_row();
419    }
420
421    #[test]
422    fn bezout_coefficient_polynomials_of_empty_ram_table_are_default() {
423        let (a, b) = bezout_coefficient_polynomials_coefficients(&[]);
424        assert_eq!(a, vec![]);
425        assert_eq!(b, vec![]);
426    }
427
428    #[test]
429    fn bezout_coefficient_polynomials_are_as_expected() {
430        let rp = bfe_array![1, 2, 3];
431        let (a, b) = bezout_coefficient_polynomials_coefficients(&rp);
432
433        let expected_a = bfe_array![9, 0x7fff_ffff_7fff_fffc_u64, 0];
434        let expected_b = bfe_array![5, 0xffff_fffe_ffff_fffb_u64, 0x7fff_ffff_8000_0002_u64];
435
436        assert_eq!(expected_a, *a);
437        assert_eq!(expected_b, *b);
438    }
439
440    #[proptest]
441    fn bezout_coefficient_polynomials_agree_with_xgcd(
442        #[strategy(arb())]
443        #[filter(#ram_pointers.iter().all_unique())]
444        ram_pointers: Vec<BFieldElement>,
445    ) {
446        let (a, b) = bezout_coefficient_polynomials_coefficients(&ram_pointers);
447
448        let rp = Polynomial::zerofier(&ram_pointers);
449        let fd = rp.formal_derivative();
450        let (_, a_xgcd, b_xgcd) = Polynomial::xgcd(rp, fd);
451
452        let mut a_xgcd = a_xgcd.into_coefficients();
453        let mut b_xgcd = b_xgcd.into_coefficients();
454
455        a_xgcd.resize(ram_pointers.len(), BFieldElement::ZERO);
456        b_xgcd.resize(ram_pointers.len(), BFieldElement::ZERO);
457
458        prop_assert_eq!(a, a_xgcd);
459        prop_assert_eq!(b, b_xgcd);
460    }
461
462    #[proptest]
463    fn bezout_coefficients_are_actually_bezout_coefficients(
464        #[strategy(arb())]
465        #[filter(!#ram_pointers.is_empty())]
466        #[filter(#ram_pointers.iter().all_unique())]
467        ram_pointers: Vec<BFieldElement>,
468    ) {
469        let (a, b) = bezout_coefficient_polynomials_coefficients(&ram_pointers);
470
471        let rp = Polynomial::zerofier(&ram_pointers);
472        let fd = rp.formal_derivative();
473
474        let [a, b] = [a, b].map(Polynomial::new);
475        let gcd = rp * a + fd * b;
476        prop_assert_eq!(Polynomial::one(), gcd);
477    }
478}