winter_prover/matrix/col_matrix.rs
1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6use alloc::vec::Vec;
7use core::{iter::FusedIterator, slice};
8
9use crypto::{ElementHasher, VectorCommitment};
10use math::{fft, polynom, FieldElement};
11#[cfg(feature = "concurrent")]
12use utils::iterators::*;
13use utils::{batch_iter_mut, iter, iter_mut, uninit_vector};
14
15use crate::StarkDomain;
16
17// COLUMN-MAJOR MATRIX
18// ================================================================================================
19
20/// A two-dimensional matrix of field elements arranged in column-major order.
21///
22/// This struct is used as a backing type for many objects manipulated by the prover. The matrix
23/// itself does not assign any contextual meaning to the values stored in it. For example, columns
24/// may contain evaluations of polynomials, or polynomial coefficients, or really anything else.
25/// However, the matrix does expose a number of methods which make assumptions about the underlying
26/// data.
27///
28/// A matrix imposes the following restrictions on its content:
29/// - A matrix must consist of at least 1 column and at least 2 rows.
30/// - All columns must be of the same length.
31/// - Number of rows must be a power of two.
32#[derive(Debug, Clone)]
33pub struct ColMatrix<E: FieldElement> {
34 columns: Vec<Vec<E>>,
35}
36
37impl<E: FieldElement> ColMatrix<E> {
38 // CONSTRUCTOR
39 // --------------------------------------------------------------------------------------------
40 /// Returns a new [Matrix] instantiated with the data from the specified columns.
41 ///
42 /// # Panics
43 /// Panics if:
44 /// * The provided vector of columns is empty.
45 /// * Not all of the columns have the same number of elements.
46 /// * Number of rows is smaller than or equal to 1.
47 /// * Number of rows is not a power of two.
48 pub fn new(columns: Vec<Vec<E>>) -> Self {
49 assert!(!columns.is_empty(), "a matrix must contain at least one column");
50 let num_rows = columns[0].len();
51 assert!(num_rows > 1, "number of rows in a matrix must be greater than one");
52 assert!(num_rows.is_power_of_two(), "number of rows in a matrix must be a power of 2");
53 for column in columns.iter().skip(1) {
54 assert_eq!(column.len(), num_rows, "all matrix columns must have the same length");
55 }
56
57 Self { columns }
58 }
59
60 // PUBLIC ACCESSORS
61 // --------------------------------------------------------------------------------------------
62
63 /// Returns the number of columns in this matrix.
64 pub fn num_cols(&self) -> usize {
65 self.columns.len()
66 }
67
68 /// Returns the number of base field columns in this matrix.
69 ///
70 /// The number of base field columns is defined as the number of columns multiplied by the
71 /// extension degree of field elements contained in this matrix.
72 pub fn num_base_cols(&self) -> usize {
73 self.num_cols() * E::EXTENSION_DEGREE
74 }
75
76 /// Returns the number of rows in this matrix.
77 pub fn num_rows(&self) -> usize {
78 self.columns[0].len()
79 }
80
81 /// Returns the element located at the specified column and row indexes in this matrix.
82 ///
83 /// # Panics
84 /// Panics if either `col_idx` or `row_idx` are out of bounds for this matrix.
85 pub fn get(&self, col_idx: usize, row_idx: usize) -> E {
86 self.columns[col_idx][row_idx]
87 }
88
89 /// Returns base field elements located at the specified column and row indexes in this matrix.
90 ///
91 /// For STARK fields, `base_col_idx` is the same as `col_idx` used in `Self::get` method. For
92 /// extension fields, each column in the matrix is viewed as 2 or more columns in the base
93 /// field.
94 ///
95 /// Thus, for example, if we are in a degree 2 extension field, `base_col_idx = 0` would refer
96 /// to the first base element of the first column, `base_col_idx = 1` would refer to the second
97 /// base element of the first column, `base_col_idx = 2` would refer to the first base element
98 /// of the second column etc.
99 ///
100 /// # Panics
101 /// Panics if either `base_col_idx` or `row_idx` are out of bounds for this matrix.
102 pub fn get_base_element(&self, base_col_idx: usize, row_idx: usize) -> E::BaseField {
103 let (col_idx, elem_idx) =
104 (base_col_idx / E::EXTENSION_DEGREE, base_col_idx % E::EXTENSION_DEGREE);
105 self.columns[col_idx][row_idx].base_element(elem_idx)
106 }
107
108 /// Set the cell in this matrix at the specified column and row indexes to the provided value.
109 ///
110 /// # Panics
111 /// Panics if either `col_idx` or `row_idx` are out of bounds for this matrix.
112 pub fn set(&mut self, col_idx: usize, row_idx: usize, value: E) {
113 self.columns[col_idx][row_idx] = value;
114 }
115
116 /// Returns a reference to the column at the specified index.
117 pub fn get_column(&self, col_idx: usize) -> &[E] {
118 &self.columns[col_idx]
119 }
120
121 /// Returns a reference to the column at the specified index.
122 pub fn get_column_mut(&mut self, col_idx: usize) -> &mut [E] {
123 &mut self.columns[col_idx]
124 }
125
126 /// Copies values of all columns at the specified row into the specified row slice.
127 ///
128 /// # Panics
129 /// Panics if `row_idx` is out of bounds for this matrix.
130 pub fn read_row_into(&self, row_idx: usize, row: &mut [E]) {
131 for (column, value) in self.columns.iter().zip(row.iter_mut()) {
132 *value = column[row_idx];
133 }
134 }
135
136 /// Updates a row in this matrix at the specified index to the provided data.
137 ///
138 /// # Panics
139 /// Panics if `row_idx` is out of bounds for this matrix.
140 pub fn update_row(&mut self, row_idx: usize, row: &[E]) {
141 for (column, &value) in self.columns.iter_mut().zip(row) {
142 column[row_idx] = value;
143 }
144 }
145
146 /// Merges a column to the end of the matrix provided its length matches the matrix.
147 ///
148 /// # Panics
149 /// Panics if the column has a different length to other columns in the matrix.
150 pub fn merge_column(&mut self, column: Vec<E>) {
151 if let Some(first_column) = self.columns.first() {
152 assert_eq!(first_column.len(), column.len());
153 }
154 self.columns.push(column);
155 }
156
157 /// Removes a column of the matrix given its index.
158 ///
159 /// # Panics
160 /// Panics if the column index is out of range.
161 pub fn remove_column(&mut self, index: usize) -> Vec<E> {
162 assert!(index < self.num_cols(), "column index out of range");
163 self.columns.remove(index)
164 }
165
166 // ITERATION
167 // --------------------------------------------------------------------------------------------
168
169 /// Returns an iterator over the columns of this matrix.
170 pub fn columns(&self) -> ColumnIter<'_, E> {
171 ColumnIter::new(self)
172 }
173
174 /// Returns a mutable iterator over the columns of this matrix.
175 pub fn columns_mut(&mut self) -> ColumnIterMut<'_, E> {
176 ColumnIterMut::new(self)
177 }
178
179 // POLYNOMIAL METHODS
180 // --------------------------------------------------------------------------------------------
181
182 /// Interpolates columns of the matrix into polynomials in coefficient form and returns the
183 /// result.
184 ///
185 /// The interpolation is performed as follows:
186 /// * Each column of the matrix is interpreted as evaluations of degree `num_rows - 1`
187 /// polynomial over a subgroup of size `num_rows`.
188 /// * Then each column is interpolated using iFFT algorithm into a polynomial in coefficient
189 /// form.
190 /// * The resulting polynomials are returned as a single matrix where each column contains
191 /// coefficients of a degree `num_rows - 1` polynomial.
192 pub fn interpolate_columns(&self) -> Self {
193 let inv_twiddles = fft::get_inv_twiddles::<E::BaseField>(self.num_rows());
194 let columns = iter!(self.columns)
195 .map(|evaluations| {
196 let mut column = evaluations.clone();
197 fft::interpolate_poly(&mut column, &inv_twiddles);
198 column
199 })
200 .collect();
201 Self { columns }
202 }
203
204 /// Interpolates columns of the matrix into polynomials in coefficient form and returns the
205 /// result. The input matrix is consumed in the process.
206 ///
207 /// The interpolation is performed as follows:
208 /// * Each column of the matrix is interpreted as evaluations of degree `num_rows - 1`
209 /// polynomial over a subgroup of size `num_rows`.
210 /// * Then each column is interpolated (in place) using iFFT algorithm into a polynomial in
211 /// coefficient form.
212 /// * The resulting polynomials are returned as a single matrix where each column contains
213 /// coefficients of a degree `num_rows - 1` polynomial.
214 pub fn interpolate_columns_into(mut self) -> Self {
215 let inv_twiddles = fft::get_inv_twiddles::<E::BaseField>(self.num_rows());
216 iter_mut!(self.columns).for_each(|column| fft::interpolate_poly(column, &inv_twiddles));
217 self
218 }
219
220 /// Evaluates polynomials contained in the columns of this matrix over the specified domain
221 /// and returns the result.
222 ///
223 /// The evaluation is done as follows:
224 /// * Each column of the matrix is interpreted as coefficients of degree `num_rows - 1`
225 /// polynomial.
226 /// * These polynomials are evaluated over the LDE domain defined by the specified [StarkDomain]
227 /// using FFT algorithm. The domain specification includes the size of the subgroup as well as
228 /// the domain offset (to define a coset).
229 /// * The resulting evaluations are returned in a new Matrix.
230 pub fn evaluate_columns_over(&self, domain: &StarkDomain<E::BaseField>) -> Self {
231 let columns = iter!(self.columns)
232 .map(|poly| {
233 fft::evaluate_poly_with_offset(
234 poly,
235 domain.trace_twiddles(),
236 domain.offset(),
237 domain.trace_to_lde_blowup(),
238 )
239 })
240 .collect();
241 Self { columns }
242 }
243
244 /// Evaluates polynomials contained in the columns of this matrix at a single point `x`.
245 pub fn evaluate_columns_at<F>(&self, x: F) -> Vec<F>
246 where
247 F: FieldElement + From<E>,
248 {
249 iter!(self.columns).map(|p| polynom::eval(p, x)).collect()
250 }
251
252 // COMMITMENTS
253 // --------------------------------------------------------------------------------------------
254
255 /// Returns a commitment to this matrix.
256 ///
257 /// The commitment is built as follows:
258 /// * Each row of the matrix is hashed into a single digest of the specified hash function.
259 /// * The resulting vector of digests is committed to using the specified vector commitment
260 /// scheme.
261 /// * The resulting commitment is returned as the commitment to the entire matrix.
262 pub fn commit_to_rows<H, V>(&self) -> V
263 where
264 H: ElementHasher<BaseField = E::BaseField>,
265 V: VectorCommitment<H>,
266 {
267 // allocate vector to store row hashes
268 let mut row_hashes = unsafe { uninit_vector::<H::Digest>(self.num_rows()) };
269
270 // iterate though matrix rows, hashing each row; the hashing is done by first copying a
271 // row into row_buf to avoid heap allocations, and then by applying the hash function to
272 // the buffer.
273 batch_iter_mut!(
274 &mut row_hashes,
275 128, // min batch size
276 |batch: &mut [H::Digest], batch_offset: usize| {
277 let mut row_buf = vec![E::ZERO; self.num_cols()];
278 for (i, row_hash) in batch.iter_mut().enumerate() {
279 self.read_row_into(i + batch_offset, &mut row_buf);
280 *row_hash = H::hash_elements(&row_buf);
281 }
282 }
283 );
284
285 V::new(row_hashes).expect("failed to construct trace vector commitment")
286 }
287
288 // CONVERSIONS
289 // --------------------------------------------------------------------------------------------
290
291 /// Returns the columns of this matrix as a list of vectors.
292 ///
293 /// TODO: replace this with an iterator.
294 pub fn into_columns(self) -> Vec<Vec<E>> {
295 self.columns
296 }
297}
298
299// COLUMN ITERATOR
300// ================================================================================================
301
302/// Iterator over columns of [ColMatrix].
303pub struct ColumnIter<'a, E: FieldElement> {
304 matrix: Option<&'a ColMatrix<E>>,
305 cursor: usize,
306}
307
308impl<'a, E: FieldElement> ColumnIter<'a, E> {
309 pub fn new(matrix: &'a ColMatrix<E>) -> Self {
310 Self { matrix: Some(matrix), cursor: 0 }
311 }
312
313 pub fn empty() -> Self {
314 Self { matrix: None, cursor: 0 }
315 }
316}
317
318impl<'a, E: FieldElement> Iterator for ColumnIter<'a, E> {
319 type Item = &'a [E];
320
321 fn next(&mut self) -> Option<Self::Item> {
322 match self.matrix {
323 Some(matrix) => match matrix.num_cols() - self.cursor {
324 0 => None,
325 _ => {
326 let column = matrix.get_column(self.cursor);
327 self.cursor += 1;
328 Some(column)
329 },
330 },
331 None => None,
332 }
333 }
334}
335
336impl<E: FieldElement> ExactSizeIterator for ColumnIter<'_, E> {
337 fn len(&self) -> usize {
338 self.matrix.map(|matrix| matrix.num_cols()).unwrap_or_default()
339 }
340}
341
342impl<E: FieldElement> FusedIterator for ColumnIter<'_, E> {}
343
344impl<E: FieldElement> Default for ColumnIter<'_, E> {
345 fn default() -> Self {
346 Self::empty()
347 }
348}
349
350// MUTABLE COLUMN ITERATOR
351// ================================================================================================
352
353/// Iterator over mutable columns of [ColMatrix].
354pub struct ColumnIterMut<'a, E: FieldElement> {
355 matrix: &'a mut ColMatrix<E>,
356 cursor: usize,
357}
358
359impl<'a, E: FieldElement> ColumnIterMut<'a, E> {
360 pub fn new(matrix: &'a mut ColMatrix<E>) -> Self {
361 Self { matrix, cursor: 0 }
362 }
363}
364
365impl<'a, E: FieldElement> Iterator for ColumnIterMut<'a, E> {
366 type Item = &'a mut [E];
367
368 fn next(&mut self) -> Option<Self::Item> {
369 match self.matrix.num_cols() - self.cursor {
370 0 => None,
371 _ => {
372 let column = self.matrix.get_column_mut(self.cursor);
373 self.cursor += 1;
374
375 // this is needed to get around mutable iterator lifetime issues; this is safe
376 // because the iterator can never yield a reference to the same column twice
377 let p = column.as_ptr();
378 let len = column.len();
379 Some(unsafe { slice::from_raw_parts_mut(p as *mut E, len) })
380 },
381 }
382 }
383}
384
385impl<E: FieldElement> ExactSizeIterator for ColumnIterMut<'_, E> {
386 fn len(&self) -> usize {
387 self.matrix.num_cols()
388 }
389}
390
391impl<E: FieldElement> FusedIterator for ColumnIterMut<'_, E> {}