winter_prover/matrix/
segments.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6use alloc::vec::Vec;
7use core::ops::Deref;
8
9use math::{fft::fft_inputs::FftInputs, FieldElement, StarkField};
10#[cfg(feature = "concurrent")]
11use utils::iterators::*;
12use utils::uninit_vector;
13
14use super::ColMatrix;
15
16// CONSTANTS
17// ================================================================================================
18
19/// Segments with domain sizes under this number will be evaluated in a single thread.
20const MIN_CONCURRENT_SIZE: usize = 1024;
21
22// SEGMENT OF ROW-MAJOR MATRIX
23// ================================================================================================
24
25/// A set of columns of a matrix stored in row-major form.
26///
27/// The rows are stored in a single vector where each element is an array of size `N`. A segment
28/// can store [StarkField] elements only, but can be instantiated from a [Matrix] of any extension
29/// of the specified [StarkField]. In such a case, extension field elements are decomposed into
30/// base field elements and then added to the segment.
31#[derive(Clone, Debug)]
32pub struct Segment<B: StarkField, const N: usize> {
33    data: Vec<[B; N]>,
34}
35
36impl<B: StarkField, const N: usize> Segment<B, N> {
37    // CONSTRUCTORS
38    // --------------------------------------------------------------------------------------------
39
40    /// Instantiates a new [Segment] by evaluating polynomials from the provided [ColMatrix]
41    /// starting at the specified offset.
42    ///
43    /// The offset is assumed to be an offset into the view of the matrix where extension field
44    /// elements are decomposed into base field elements. This offset must be compatible with the
45    /// values supplied into [Matrix::get_base_element()] method.
46    ///
47    /// Evaluation is performed over the domain specified by the provided twiddles and offsets.
48    ///
49    /// # Panics
50    /// Panics if:
51    /// - `poly_offset` greater than or equal to the number of base field columns in `polys`.
52    /// - Number of offsets is not a power of two.
53    /// - Number of offsets is smaller than or equal to the polynomial size.
54    /// - The number of twiddles is not half the size of the polynomial size.
55    pub fn new<E>(polys: &ColMatrix<E>, poly_offset: usize, offsets: &[B], twiddles: &[B]) -> Self
56    where
57        E: FieldElement<BaseField = B>,
58    {
59        let poly_size = polys.num_rows();
60        let domain_size = offsets.len();
61        assert!(domain_size.is_power_of_two());
62        assert!(domain_size > poly_size);
63        assert_eq!(poly_size, twiddles.len() * 2);
64        assert!(poly_offset < polys.num_base_cols());
65
66        // allocate memory for the segment
67        let data = if polys.num_base_cols() - poly_offset >= N {
68            // if we will fill the entire segment, we allocate uninitialized memory
69            unsafe { uninit_vector::<[B; N]>(domain_size) }
70        } else {
71            // but if some columns in the segment will remain unfilled, we allocate memory
72            // initialized to zeros to make sure we don't end up with memory with
73            // undefined values
74            vec![[B::ZERO; N]; domain_size]
75        };
76
77        Self::new_with_buffer(data, polys, poly_offset, offsets, twiddles)
78    }
79
80    /// Instantiates a new [Segment] using the provided data buffer by evaluating polynomials in
81    /// the [ColMatrix] starting at the specified offset.
82    ///
83    /// The offset is assumed to be an offset into the view of the matrix where extension field
84    /// elements are decomposed into base field elements. This offset must be compatible with the
85    /// values supplied into [Matrix::get_base_element()] method.
86    ///
87    /// Evaluation is performed over the domain specified by the provided twiddles and offsets.
88    ///
89    /// # Panics
90    /// Panics if:
91    /// - `poly_offset` greater than or equal to the number of base field columns in `polys`.
92    /// - Number of offsets is not a power of two.
93    /// - Number of offsets is smaller than or equal to the polynomial size.
94    /// - The number of twiddles is not half the size of the polynomial size.
95    /// - Number of offsets is smaller than the length of the data buffer
96    pub fn new_with_buffer<E>(
97        data_buffer: Vec<[B; N]>,
98        polys: &ColMatrix<E>,
99        poly_offset: usize,
100        offsets: &[B],
101        twiddles: &[B],
102    ) -> Self
103    where
104        E: FieldElement<BaseField = B>,
105    {
106        let poly_size = polys.num_rows();
107        let domain_size = offsets.len();
108        let mut data = data_buffer;
109
110        assert!(domain_size.is_power_of_two());
111        assert!(domain_size > poly_size);
112        assert_eq!(poly_size, twiddles.len() * 2);
113        assert!(poly_offset < polys.num_base_cols());
114        assert_eq!(data.len(), domain_size);
115
116        // determine the number of polynomials to add to this segment; this number can be either N,
117        // or smaller than N when there are fewer than N polynomials remaining to be processed
118        let num_polys_remaining = polys.num_base_cols() - poly_offset;
119        let num_polys = if num_polys_remaining < N {
120            num_polys_remaining
121        } else {
122            N
123        };
124
125        // evaluate the polynomials either in a single thread or multiple threads, depending
126        // on whether `concurrent` feature is enabled and domain size is greater than 1024;
127
128        if cfg!(feature = "concurrent") && domain_size >= MIN_CONCURRENT_SIZE {
129            #[cfg(feature = "concurrent")]
130            data.par_chunks_mut(poly_size).zip(offsets.par_chunks(poly_size)).for_each(
131                |(d_chunk, o_chunk)| {
132                    // TODO: investigate multi-threaded copy
133                    if num_polys == N {
134                        Self::copy_polys(d_chunk, polys, poly_offset, o_chunk);
135                    } else {
136                        Self::copy_polys_partial(d_chunk, polys, poly_offset, num_polys, o_chunk);
137                    }
138                    concurrent::split_radix_fft(d_chunk, twiddles);
139                },
140            );
141            #[cfg(feature = "concurrent")]
142            concurrent::permute(&mut data);
143        } else {
144            data.chunks_mut(poly_size).zip(offsets.chunks(poly_size)).for_each(
145                |(d_chunk, o_chunk)| {
146                    if num_polys == N {
147                        Self::copy_polys(d_chunk, polys, poly_offset, o_chunk);
148                    } else {
149                        Self::copy_polys_partial(d_chunk, polys, poly_offset, num_polys, o_chunk);
150                    }
151                    d_chunk.fft_in_place(twiddles);
152                },
153            );
154            data.permute();
155        }
156
157        Segment { data }
158    }
159
160    // PUBLIC ACCESSORS
161    // --------------------------------------------------------------------------------------------
162
163    /// Returns the number of rows in this segment.
164    pub fn num_rows(&self) -> usize {
165        self.data.len()
166    }
167
168    /// Returns the underlying vector of arrays for this segment.
169    pub fn into_data(self) -> Vec<[B; N]> {
170        self.data
171    }
172
173    // HELPER METHODS
174    // --------------------------------------------------------------------------------------------
175
176    /// Copies N polynomials starting at the specified base column offset (`poly_offset`) into the
177    /// specified destination. Each polynomial coefficient is offset by the specified offset.
178    fn copy_polys<E: FieldElement<BaseField = B>>(
179        dest: &mut [[B; N]],
180        polys: &ColMatrix<E>,
181        poly_offset: usize,
182        offsets: &[B],
183    ) {
184        for row_idx in 0..dest.len() {
185            for i in 0..N {
186                let coeff = polys.get_base_element(poly_offset + i, row_idx);
187                dest[row_idx][i] = coeff * offsets[row_idx];
188            }
189        }
190    }
191
192    /// Similar to `clone_and_shift` method above, but copies `num_polys` polynomials instead of
193    /// `N` polynomials.
194    ///
195    /// Assumes that `num_polys` is smaller than `N`.
196    fn copy_polys_partial<E: FieldElement<BaseField = B>>(
197        dest: &mut [[B; N]],
198        polys: &ColMatrix<E>,
199        poly_offset: usize,
200        num_polys: usize,
201        offsets: &[B],
202    ) {
203        debug_assert!(num_polys < N);
204        for row_idx in 0..dest.len() {
205            for i in 0..num_polys {
206                let coeff = polys.get_base_element(poly_offset + i, row_idx);
207                dest[row_idx][i] = coeff * offsets[row_idx];
208            }
209        }
210    }
211}
212
213impl<B: StarkField, const N: usize> Deref for Segment<B, N> {
214    type Target = Vec<[B; N]>;
215
216    fn deref(&self) -> &Self::Target {
217        &self.data
218    }
219}
220
221// CONCURRENT FFT IMPLEMENTATION
222// ================================================================================================
223
224/// Multi-threaded implementations of FFT and permutation algorithms. These are very similar to
225/// the functions implemented in `winter-math::fft::concurrent` module, but are adapted to work
226/// with slices of element arrays.
227#[cfg(feature = "concurrent")]
228mod concurrent {
229    use math::fft::permute_index;
230    use utils::{iterators::*, rayon};
231
232    use super::{FftInputs, StarkField};
233
234    /// In-place recursive FFT with permuted output.
235    /// Adapted from: https://github.com/0xProject/OpenZKP/tree/master/algebra/primefield/src/fft
236    #[allow(clippy::needless_range_loop)]
237    pub fn split_radix_fft<B: StarkField, const N: usize>(data: &mut [[B; N]], twiddles: &[B]) {
238        // generator of the domain should be in the middle of twiddles
239        let n = data.len();
240        let g = twiddles[twiddles.len() / 2];
241        debug_assert_eq!(g.exp((n as u32).into()), B::ONE);
242
243        let inner_len = 1_usize << (n.ilog2() / 2);
244        let outer_len = n / inner_len;
245        let stretch = outer_len / inner_len;
246        debug_assert!(outer_len == inner_len || outer_len == 2 * inner_len);
247        debug_assert_eq!(outer_len * inner_len, n);
248
249        // transpose inner x inner x stretch square matrix
250        transpose_square_stretch(data, inner_len, stretch);
251
252        // apply inner FFTs
253        data.par_chunks_mut(outer_len)
254            .for_each(|row| row.fft_in_place_raw(twiddles, stretch, stretch, 0));
255
256        // transpose inner x inner x stretch square matrix
257        transpose_square_stretch(data, inner_len, stretch);
258
259        // apply outer FFTs
260        data.par_chunks_mut(outer_len).enumerate().for_each(|(i, row)| {
261            if i > 0 {
262                let i = permute_index(inner_len, i);
263                let inner_twiddle = g.exp_vartime((i as u32).into());
264                let mut outer_twiddle = inner_twiddle;
265                for element in row.iter_mut().skip(1) {
266                    for col_idx in 0..N {
267                        element[col_idx] *= outer_twiddle;
268                    }
269                    outer_twiddle *= inner_twiddle;
270                }
271            }
272            row.fft_in_place(twiddles)
273        });
274    }
275
276    // PERMUTATIONS
277    // --------------------------------------------------------------------------------------------
278
279    pub fn permute<T: Send>(v: &mut [T]) {
280        let n = v.len();
281        let num_batches = rayon::current_num_threads().next_power_of_two() * 2;
282        let batch_size = n / num_batches;
283        rayon::scope(|s| {
284            for batch_idx in 0..num_batches {
285                // create another mutable reference to the slice of values to use in a new thread;
286                // this is OK because we never write the same positions in the slice from different
287                // threads
288                let values = unsafe { &mut *(&mut v[..] as *mut [T]) };
289                s.spawn(move |_| {
290                    let batch_start = batch_idx * batch_size;
291                    let batch_end = batch_start + batch_size;
292                    for i in batch_start..batch_end {
293                        let j = permute_index(n, i);
294                        if j > i {
295                            values.swap(i, j);
296                        }
297                    }
298                });
299            }
300        });
301    }
302
303    // TRANSPOSING
304    // --------------------------------------------------------------------------------------------
305
306    fn transpose_square_stretch<T>(data: &mut [T], size: usize, stretch: usize) {
307        assert_eq!(data.len(), size * size * stretch);
308        match stretch {
309            1 => transpose_square_1(data, size),
310            2 => transpose_square_2(data, size),
311            _ => unimplemented!("only stretch sizes 1 and 2 are supported"),
312        }
313    }
314
315    fn transpose_square_1<T>(data: &mut [T], size: usize) {
316        debug_assert_eq!(data.len(), size * size);
317        debug_assert_eq!(size % 2, 0, "odd sizes are not supported");
318
319        // iterate over upper-left triangle, working in 2x2 blocks
320        // TODO: investigate concurrent implementation
321        for row in (0..size).step_by(2) {
322            let i = row * size + row;
323            data.swap(i + 1, i + size);
324            for col in (row..size).step_by(2).skip(1) {
325                let i = row * size + col;
326                let j = col * size + row;
327                data.swap(i, j);
328                data.swap(i + 1, j + size);
329                data.swap(i + size, j + 1);
330                data.swap(i + size + 1, j + size + 1);
331            }
332        }
333    }
334
335    fn transpose_square_2<T>(data: &mut [T], size: usize) {
336        debug_assert_eq!(data.len(), 2 * size * size);
337
338        // iterate over upper-left triangle, working in 1x2 blocks
339        // TODO: investigate concurrent implementation
340        for row in 0..size {
341            for col in (row..size).skip(1) {
342                let i = (row * size + col) * 2;
343                let j = (col * size + row) * 2;
344                data.swap(i, j);
345                data.swap(i + 1, j + 1);
346            }
347        }
348    }
349}