1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

use super::{ColMatrix, Segment};
use crate::StarkDomain;
use crypto::{ElementHasher, MerkleTree};
use math::{fft, log2, FieldElement, StarkField};
use utils::collections::Vec;
use utils::{batch_iter_mut, flatten_vector_elements, uninit_vector};

#[cfg(feature = "concurrent")]
use utils::iterators::*;

// ROW-MAJOR MATRIX
// ================================================================================================

/// A two-dimensional matrix of field elements arranged in row-major order.
///
/// The matrix is represented as a single vector of base field elements for the field defined by E
/// type parameter. The first `row_width` base field elements represent the first row of the matrix,
/// the next `row_width` base field elements represent the second row, and so on.
///
/// When rows are returned via the [RowMatrix::row()] method, base field elements are grouped
/// together as appropriate to form elements in E.
///
/// In some cases, rows may be padded with extra elements. The number of elements which are
/// accessible via the [RowMatrix::row()] method is specified by the `elements_per_row` member.
#[derive(Clone, Debug)]
pub struct RowMatrix<E: FieldElement> {
    /// Field elements stored in the matrix.
    data: Vec<E::BaseField>,
    /// Total number of base field elements stored in a single row.
    row_width: usize,
    /// Number of field elements in a single row accessible via the [RowMatrix::row()] method. This
    /// must be equal to or smaller than `row_width`.
    elements_per_row: usize,
}

impl<E: FieldElement> RowMatrix<E> {
    // CONSTRUCTORS
    // --------------------------------------------------------------------------------------------

    /// Returns a new [RowMatrix] constructed by evaluating the provided polynomials over the
    /// domain defined by the specified blowup factor.
    ///
    /// The provided `polys` matrix is assumed to contain polynomials in coefficient form (one
    /// polynomial per column). Columns in the returned matrix will contain evaluations of the
    /// corresponding polynomials over the domain defined by polynomial size (i.e., number of rows
    /// in the `polys` matrix) and the `blowup_factor`.
    ///
    /// To improve performance, polynomials are evaluated in batches specified by the `N` type
    /// parameter. Minimum batch size is 1.
    pub fn evaluate_polys<const N: usize>(polys: &ColMatrix<E>, blowup_factor: usize) -> Self {
        assert!(N > 0, "batch size N must be greater than zero");

        // pre-compute offsets for each row
        let poly_size = polys.num_rows();
        let offsets = get_offsets::<E>(poly_size, blowup_factor, E::BaseField::GENERATOR);

        // compute twiddles for polynomial evaluation
        let twiddles = fft::get_twiddles::<E::BaseField>(polys.num_rows());

        // build matrix segments by evaluating all polynomials
        let segments = build_segments::<E, N>(polys, &twiddles, &offsets);

        // transpose data in individual segments into a single row-major matrix
        Self::from_segments(segments, polys.num_base_cols())
    }

    /// Returns a new [RowMatrix] constructed by evaluating the provided polynomials over the
    /// specified [StarkDomain].
    ///
    /// The provided `polys` matrix is assumed to contain polynomials in coefficient form (one
    /// polynomial per column). Columns in the returned matrix will contain evaluations of the
    /// corresponding polynomials over the LDE domain defined by the provided [StarkDomain].
    ///
    /// To improve performance, polynomials are evaluated in batches specified by the `N` type
    /// parameter. Minimum batch size is 1.
    pub fn evaluate_polys_over<const N: usize>(
        polys: &ColMatrix<E>,
        domain: &StarkDomain<E::BaseField>,
    ) -> Self {
        assert!(N > 0, "batch size N must be greater than zero");

        // pre-compute offsets for each row
        let poly_size = polys.num_rows();
        let offsets = get_offsets::<E>(poly_size, domain.trace_to_lde_blowup(), domain.offset());

        // build matrix segments by evaluating all polynomials
        let segments = build_segments::<E, N>(polys, domain.trace_twiddles(), &offsets);

        // transpose data in individual segments into a single row-major matrix
        Self::from_segments(segments, polys.num_base_cols())
    }

    /// Returns a new [RowMatrix] instantiated from the specified matrix segments.
    ///
    /// `elements_per_row` specifies how many base field elements are considered to form a single
    /// row in the matrix.
    ///
    /// # Panics
    /// Panics if
    /// - `segments` is an empty vector.
    /// - `elements_per_row` is greater than the row width implied by the number of segments and
    ///   `N` type parameter.
    pub fn from_segments<const N: usize>(
        segments: Vec<Segment<E::BaseField, N>>,
        elements_per_row: usize,
    ) -> Self {
        assert!(N > 0, "batch size N must be greater than zero");
        assert!(!segments.is_empty(), "a list of segments cannot be empty");

        // compute the size of each row
        let row_width = segments.len() * N;
        assert!(
            elements_per_row <= row_width,
            "elements per row cannot exceed {row_width}, but was {elements_per_row}"
        );

        // transpose the segments into a single vector of arrays
        let result = transpose(segments);

        // flatten the result to be a simple vector of elements and return
        RowMatrix {
            data: flatten_vector_elements(result),
            row_width,
            elements_per_row,
        }
    }

    // PUBLIC ACCESSORS
    // --------------------------------------------------------------------------------------------

    /// Returns the number of columns in this matrix.
    pub fn num_cols(&self) -> usize {
        self.elements_per_row / E::EXTENSION_DEGREE
    }

    /// Returns the number of rows in this matrix.
    pub fn num_rows(&self) -> usize {
        self.data.len() / self.row_width
    }

    /// Returns the element located at the specified column and row indexes in this matrix.
    ///
    /// # Panics
    /// Panics if either `col_idx` or `row_idx` are out of bounds for this matrix.
    pub fn get(&self, col_idx: usize, row_idx: usize) -> E {
        self.row(row_idx)[col_idx]
    }

    /// Returns a reference to a row at the specified index in this matrix.
    ///
    /// # Panics
    /// Panics if the specified row index is out of bounds.
    pub fn row(&self, row_idx: usize) -> &[E] {
        assert!(row_idx < self.num_rows());
        let start = row_idx * self.row_width;
        E::slice_from_base_elements(&self.data[start..start + self.elements_per_row])
    }

    /// Returns the data in this matrix as a slice of field elements.
    pub fn data(&self) -> &[E::BaseField] {
        &self.data
    }

    // COMMITMENTS
    // --------------------------------------------------------------------------------------------

    /// Returns a commitment to this matrix.
    ///
    /// The commitment is built as follows:
    /// * Each row of the matrix is hashed into a single digest of the specified hash function.
    /// * The resulting values are used to build a binary Merkle tree such that each row digest
    ///   becomes a leaf in the tree. Thus, the number of leaves in the tree is equal to the
    ///   number of rows in the matrix.
    /// * The resulting Merkle tree is returned as the commitment to the entire matrix.
    pub fn commit_to_rows<H>(&self) -> MerkleTree<H>
    where
        H: ElementHasher<BaseField = E::BaseField>,
    {
        // allocate vector to store row hashes
        let mut row_hashes = unsafe { uninit_vector::<H::Digest>(self.num_rows()) };

        // iterate though matrix rows, hashing each row
        batch_iter_mut!(
            &mut row_hashes,
            128, // min batch size
            |batch: &mut [H::Digest], batch_offset: usize| {
                for (i, row_hash) in batch.iter_mut().enumerate() {
                    *row_hash = H::hash_elements(self.row(batch_offset + i));
                }
            }
        );

        // build Merkle tree out of hashed rows
        MerkleTree::new(row_hashes).expect("failed to construct trace Merkle tree")
    }
}

// HELPER FUNCTIONS
// ================================================================================================

/// Returns a vector of offsets for an evaluation defined by the specified polynomial size, blowup
/// factor and domain offset.
///
/// When `concurrent` feature is enabled, offsets are computed in multiple threads.
fn get_offsets<E: FieldElement>(
    poly_size: usize,
    blowup_factor: usize,
    domain_offset: E::BaseField,
) -> Vec<E::BaseField> {
    let domain_size = poly_size * blowup_factor;
    let g = E::BaseField::get_root_of_unity(log2(domain_size));

    // allocate memory to hold the offsets
    let mut offsets = unsafe { uninit_vector(domain_size) };

    // define a closure to compute offsets for a given chunk of the result; the number of chunks
    // is defined by the blowup factor. for example, for blowup factor = 2, the number of chunks
    // will be 2, for blowup factor = 8, the number of chunks will be 8 etc.
    let compute_offsets = |(chunk_idx, chunk): (usize, &mut [E::BaseField])| {
        let idx = fft::permute_index(blowup_factor, chunk_idx) as u64;
        let offset = g.exp_vartime(idx.into()) * domain_offset;
        let mut factor = E::BaseField::ONE;
        for res in chunk.iter_mut() {
            *res = factor;
            factor *= offset;
        }
    };

    // compute offsets for each chunk using either parallel or regular iterators

    #[cfg(not(feature = "concurrent"))]
    offsets
        .chunks_mut(poly_size)
        .enumerate()
        .for_each(compute_offsets);

    #[cfg(feature = "concurrent")]
    offsets
        .par_chunks_mut(poly_size)
        .enumerate()
        .for_each(compute_offsets);

    offsets
}

/// Returns matrix segments constructed by evaluating polynomials in the specified matrix over the
/// domain defined by twiddles and offsets.
fn build_segments<E: FieldElement, const N: usize>(
    polys: &ColMatrix<E>,
    twiddles: &[E::BaseField],
    offsets: &[E::BaseField],
) -> Vec<Segment<E::BaseField, N>> {
    assert!(N > 0, "batch size N must be greater than zero");
    debug_assert_eq!(polys.num_rows(), twiddles.len() * 2);
    debug_assert_eq!(offsets.len() % polys.num_rows(), 0);

    let num_segments = if polys.num_base_cols() % N == 0 {
        polys.num_base_cols() / N
    } else {
        polys.num_base_cols() / N + 1
    };

    (0..num_segments)
        .map(|i| Segment::new(polys, i * N, offsets, twiddles))
        .collect()
}

/// Transposes a vector of segments into a single vector of fixed-size arrays.
///
/// When `concurrent` feature is enabled, transposition is performed in multiple threads.
fn transpose<B: StarkField, const N: usize>(mut segments: Vec<Segment<B, N>>) -> Vec<[B; N]> {
    let num_rows = segments[0].num_rows();
    let num_segs = segments.len();
    let result_len = num_rows * num_segs;

    // if there is only one segment, there is nothing to transpose as it is already in row
    // major form
    if segments.len() == 1 {
        return segments.remove(0).into_data();
    }

    // allocate memory to hold the transposed result;
    // TODO: investigate transposing in-place
    let mut result = unsafe { uninit_vector::<[B; N]>(result_len) };

    // determine number of batches in which transposition will be preformed; if `concurrent`
    // feature is not enabled, the number of batches will always be 1
    let num_batches = get_num_batches(result_len);
    let rows_per_batch = num_rows / num_batches;

    // define a closure for transposing a given batch
    let transpose_batch = |(batch_idx, batch): (usize, &mut [[B; N]])| {
        let row_offset = batch_idx * rows_per_batch;
        for i in 0..rows_per_batch {
            let row_idx = i + row_offset;
            for j in 0..num_segs {
                let v = &segments[j].data()[row_idx];
                batch[i * num_segs + j].copy_from_slice(v);
            }
        }
    };

    // call the closure either once (for single-threaded transposition) or in a parallel
    // iterator (for multi-threaded transposition)

    #[cfg(not(feature = "concurrent"))]
    transpose_batch((0, &mut result));

    #[cfg(feature = "concurrent")]
    result
        .par_chunks_mut(result_len / num_batches)
        .enumerate()
        .for_each(transpose_batch);

    result
}

#[cfg(not(feature = "concurrent"))]
fn get_num_batches(_input_size: usize) -> usize {
    1
}

#[cfg(feature = "concurrent")]
fn get_num_batches(input_size: usize) -> usize {
    if input_size < 1024 {
        return 1;
    }
    utils::rayon::current_num_threads().next_power_of_two() * 2
}