../../.cargo/katex-header.html

winter_math/fft/
fft_inputs.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6use super::{permute_index, FieldElement};
7
8// CONSTANTS
9// ================================================================================================
10const MAX_LOOP: usize = 256;
11
12// FFT INPUTS TRAIT
13// ================================================================================================
14
15/// Defines the interface that must be implemented by the input to fft_in_place method.
16#[allow(clippy::len_without_is_empty)]
17pub trait FftInputs<E: FieldElement> {
18    /// Returns the number of elements in this input.
19    fn len(&self) -> usize;
20
21    /// Combines the result of smaller number theoretic transform into a larger NTT.
22    fn butterfly(&mut self, offset: usize, stride: usize);
23
24    /// Combines the result of smaller number theoretic transform multiplied with a
25    /// twiddle factor into a larger NTT.
26    fn butterfly_twiddle(&mut self, twiddle: E::BaseField, offset: usize, stride: usize);
27
28    /// Swaps the element at index i with the element at index j. Specifically:
29    ///
30    /// elem_i <-> elem_j
31    ///
32    /// # Panics
33    /// Panics if i or j are out of bounds.
34    fn swap(&mut self, i: usize, j: usize);
35
36    /// Multiplies every element in this input by a series of increment. Specifically:
37    ///
38    /// elem_i = elem_i * offset * increment^i
39    fn shift_by_series(&mut self, offset: E::BaseField, increment: E::BaseField);
40
41    /// Multiplies every element in this input by `offset`. Specifically:
42    ///
43    /// elem_i = elem_i * offset
44    fn shift_by(&mut self, offset: E::BaseField);
45
46    /// Permutes the elements in this input using the permutation defined by the given
47    /// permutation index.
48    ///
49    /// The permutation index is a number between 0 and `self.len() - 1` that specifies the
50    /// permutation to apply to the input. The permutation is applied in place, so the input
51    /// is replaced with the result of the permutation. The permutation is applied by swapping
52    /// elements in the input.
53    ///
54    /// # Panics
55    /// Panics if the permutation index is out of bounds.
56    fn permute(&mut self) {
57        let n = self.len();
58        for i in 0..n {
59            let j = permute_index(n, i);
60            if j > i {
61                self.swap(i, j);
62            }
63        }
64    }
65
66    /// Applies the FFT to this input.
67    ///
68    /// The FFT is applied in place, so the input is replaced with the result of the FFT. The
69    /// `twiddles` parameter specifies the twiddle factors to use for the FFT.
70    ///
71    /// This is a convenience method equivalent to calling fft_in_place_raw(twiddles, 1, 1, 0).
72    ///
73    /// # Panics
74    /// Panics if length of the `twiddles` parameter is not self.len() / 2.
75    fn fft_in_place(&mut self, twiddles: &[E::BaseField]) {
76        fft_in_place(self, twiddles, 1, 1, 0);
77    }
78
79    /// Applies the FFT to this input.
80    ///
81    /// The FFT is applied in place, so the input is replaced with the result of the FFT. The
82    /// `twiddles` parameter specifies the twiddle factors to use for the FFT.
83    ///
84    /// # Panics
85    /// Panics if length of the `twiddles` parameter is not self.len() / 2.
86    fn fft_in_place_raw(
87        &mut self,
88        twiddles: &[E::BaseField],
89        count: usize,
90        stride: usize,
91        offset: usize,
92    ) {
93        fft_in_place(self, twiddles, count, stride, offset)
94    }
95}
96
97// SLICE IMPLEMENTATION
98// ================================================================================================
99
100/// Implements FftInputs for a slice of field elements.
101impl<E: FieldElement> FftInputs<E> for [E] {
102    fn len(&self) -> usize {
103        self.len()
104    }
105
106    #[inline(always)]
107    fn butterfly(&mut self, offset: usize, stride: usize) {
108        let i = offset;
109        let j = offset + stride;
110        let temp = self[i];
111        self[i] = temp + self[j];
112        self[j] = temp - self[j];
113    }
114
115    #[inline(always)]
116    fn butterfly_twiddle(&mut self, twiddle: E::BaseField, offset: usize, stride: usize) {
117        let i = offset;
118        let j = offset + stride;
119        let temp = self[i];
120        self[j] = self[j].mul_base(twiddle);
121        self[i] = temp + self[j];
122        self[j] = temp - self[j];
123    }
124
125    fn swap(&mut self, i: usize, j: usize) {
126        self.swap(i, j)
127    }
128
129    fn shift_by_series(&mut self, offset: E::BaseField, increment: E::BaseField) {
130        let mut offset = E::from(offset);
131        let increment = E::from(increment);
132        for d in self.iter_mut() {
133            *d *= offset;
134            offset *= increment;
135        }
136    }
137
138    fn shift_by(&mut self, offset: E::BaseField) {
139        let offset = E::from(offset);
140        for d in self.iter_mut() {
141            *d *= offset;
142        }
143    }
144}
145
146// SLICE OF ARRAYS IMPLEMENTATION
147// ================================================================================================
148
149/// Implements [FftInputs] for a slice of field element arrays.
150#[allow(clippy::needless_range_loop)]
151impl<E: FieldElement, const N: usize> FftInputs<E> for [[E; N]] {
152    fn len(&self) -> usize {
153        self.len()
154    }
155
156    #[inline(always)]
157    fn butterfly(&mut self, offset: usize, stride: usize) {
158        let i = offset;
159        let j = offset + stride;
160
161        let temp = self[i];
162        for col_idx in 0..N {
163            self[i][col_idx] = temp[col_idx] + self[j][col_idx];
164            self[j][col_idx] = temp[col_idx] - self[j][col_idx];
165        }
166    }
167
168    #[inline(always)]
169    fn butterfly_twiddle(&mut self, twiddle: E::BaseField, offset: usize, stride: usize) {
170        let i = offset;
171        let j = offset + stride;
172
173        let twiddle = E::from(twiddle);
174        let temp = self[i];
175
176        for col_idx in 0..N {
177            self[j][col_idx] *= twiddle;
178            self[i][col_idx] = temp[col_idx] + self[j][col_idx];
179            self[j][col_idx] = temp[col_idx] - self[j][col_idx];
180        }
181    }
182
183    fn swap(&mut self, i: usize, j: usize) {
184        self.swap(i, j)
185    }
186
187    fn shift_by(&mut self, offset: E::BaseField) {
188        let offset = E::from(offset);
189        for row_idx in 0..self.len() {
190            for col_idx in 0..N {
191                self[row_idx][col_idx] *= offset;
192            }
193        }
194    }
195
196    fn shift_by_series(&mut self, offset: E::BaseField, increment: E::BaseField) {
197        let increment = E::from(increment);
198        let mut offset = E::from(offset);
199
200        for row_idx in 0..self.len() {
201            for col_idx in 0..N {
202                self[row_idx][col_idx] *= offset;
203            }
204            offset *= increment;
205        }
206    }
207}
208
209// CORE FFT ALGORITHM
210// ================================================================================================
211
212/// In-place recursive FFT with permuted output.
213///
214/// Adapted from: https://github.com/0xProject/OpenZKP/tree/master/algebra/primefield/src/fft
215fn fft_in_place<E, I>(
216    values: &mut I,
217    twiddles: &[E::BaseField],
218    count: usize,
219    stride: usize,
220    offset: usize,
221) where
222    E: FieldElement,
223    I: FftInputs<E> + ?Sized,
224{
225    let size = values.len() / stride;
226    debug_assert!(size.is_power_of_two());
227    debug_assert!(offset < stride);
228    debug_assert_eq!(values.len() % size, 0);
229
230    // Keep recursing until size is 2
231    if size > 2 {
232        if stride == count && count < MAX_LOOP {
233            fft_in_place(values, twiddles, 2 * count, 2 * stride, offset);
234        } else {
235            fft_in_place(values, twiddles, count, 2 * stride, offset);
236            fft_in_place(values, twiddles, count, 2 * stride, offset + stride);
237        }
238    }
239
240    // Apply butterfly operations.
241    for offset in offset..(offset + count) {
242        I::butterfly(values, offset, stride);
243    }
244
245    // Apply butterfly operations with twiddle factors.
246    let last_offset = offset + size * stride;
247    for (i, offset) in (offset..last_offset).step_by(2 * stride).enumerate().skip(1) {
248        for j in offset..(offset + count) {
249            I::butterfly_twiddle(values, twiddles[i], j, stride);
250        }
251    }
252}