winter_math/fft/
fft_inputs.rs1use super::{permute_index, FieldElement};
7
8const MAX_LOOP: usize = 256;
11
12#[allow(clippy::len_without_is_empty)]
17pub trait FftInputs<E: FieldElement> {
18 fn len(&self) -> usize;
20
21 fn butterfly(&mut self, offset: usize, stride: usize);
23
24 fn butterfly_twiddle(&mut self, twiddle: E::BaseField, offset: usize, stride: usize);
27
28 fn swap(&mut self, i: usize, j: usize);
35
36 fn shift_by_series(&mut self, offset: E::BaseField, increment: E::BaseField);
40
41 fn shift_by(&mut self, offset: E::BaseField);
45
46 fn permute(&mut self) {
57 let n = self.len();
58 for i in 0..n {
59 let j = permute_index(n, i);
60 if j > i {
61 self.swap(i, j);
62 }
63 }
64 }
65
66 fn fft_in_place(&mut self, twiddles: &[E::BaseField]) {
76 fft_in_place(self, twiddles, 1, 1, 0);
77 }
78
79 fn fft_in_place_raw(
87 &mut self,
88 twiddles: &[E::BaseField],
89 count: usize,
90 stride: usize,
91 offset: usize,
92 ) {
93 fft_in_place(self, twiddles, count, stride, offset)
94 }
95}
96
97impl<E: FieldElement> FftInputs<E> for [E] {
102 fn len(&self) -> usize {
103 self.len()
104 }
105
106 #[inline(always)]
107 fn butterfly(&mut self, offset: usize, stride: usize) {
108 let i = offset;
109 let j = offset + stride;
110 let temp = self[i];
111 self[i] = temp + self[j];
112 self[j] = temp - self[j];
113 }
114
115 #[inline(always)]
116 fn butterfly_twiddle(&mut self, twiddle: E::BaseField, offset: usize, stride: usize) {
117 let i = offset;
118 let j = offset + stride;
119 let temp = self[i];
120 self[j] = self[j].mul_base(twiddle);
121 self[i] = temp + self[j];
122 self[j] = temp - self[j];
123 }
124
125 fn swap(&mut self, i: usize, j: usize) {
126 self.swap(i, j)
127 }
128
129 fn shift_by_series(&mut self, offset: E::BaseField, increment: E::BaseField) {
130 let mut offset = E::from(offset);
131 let increment = E::from(increment);
132 for d in self.iter_mut() {
133 *d *= offset;
134 offset *= increment;
135 }
136 }
137
138 fn shift_by(&mut self, offset: E::BaseField) {
139 let offset = E::from(offset);
140 for d in self.iter_mut() {
141 *d *= offset;
142 }
143 }
144}
145
146#[allow(clippy::needless_range_loop)]
151impl<E: FieldElement, const N: usize> FftInputs<E> for [[E; N]] {
152 fn len(&self) -> usize {
153 self.len()
154 }
155
156 #[inline(always)]
157 fn butterfly(&mut self, offset: usize, stride: usize) {
158 let i = offset;
159 let j = offset + stride;
160
161 let temp = self[i];
162 for col_idx in 0..N {
163 self[i][col_idx] = temp[col_idx] + self[j][col_idx];
164 self[j][col_idx] = temp[col_idx] - self[j][col_idx];
165 }
166 }
167
168 #[inline(always)]
169 fn butterfly_twiddle(&mut self, twiddle: E::BaseField, offset: usize, stride: usize) {
170 let i = offset;
171 let j = offset + stride;
172
173 let twiddle = E::from(twiddle);
174 let temp = self[i];
175
176 for col_idx in 0..N {
177 self[j][col_idx] *= twiddle;
178 self[i][col_idx] = temp[col_idx] + self[j][col_idx];
179 self[j][col_idx] = temp[col_idx] - self[j][col_idx];
180 }
181 }
182
183 fn swap(&mut self, i: usize, j: usize) {
184 self.swap(i, j)
185 }
186
187 fn shift_by(&mut self, offset: E::BaseField) {
188 let offset = E::from(offset);
189 for row_idx in 0..self.len() {
190 for col_idx in 0..N {
191 self[row_idx][col_idx] *= offset;
192 }
193 }
194 }
195
196 fn shift_by_series(&mut self, offset: E::BaseField, increment: E::BaseField) {
197 let increment = E::from(increment);
198 let mut offset = E::from(offset);
199
200 for row_idx in 0..self.len() {
201 for col_idx in 0..N {
202 self[row_idx][col_idx] *= offset;
203 }
204 offset *= increment;
205 }
206 }
207}
208
209fn fft_in_place<E, I>(
216 values: &mut I,
217 twiddles: &[E::BaseField],
218 count: usize,
219 stride: usize,
220 offset: usize,
221) where
222 E: FieldElement,
223 I: FftInputs<E> + ?Sized,
224{
225 let size = values.len() / stride;
226 debug_assert!(size.is_power_of_two());
227 debug_assert!(offset < stride);
228 debug_assert_eq!(values.len() % size, 0);
229
230 if size > 2 {
232 if stride == count && count < MAX_LOOP {
233 fft_in_place(values, twiddles, 2 * count, 2 * stride, offset);
234 } else {
235 fft_in_place(values, twiddles, count, 2 * stride, offset);
236 fft_in_place(values, twiddles, count, 2 * stride, offset + stride);
237 }
238 }
239
240 for offset in offset..(offset + count) {
242 I::butterfly(values, offset, stride);
243 }
244
245 let last_offset = offset + size * stride;
247 for (i, offset) in (offset..last_offset).step_by(2 * stride).enumerate().skip(1) {
248 for j in offset..(offset + count) {
249 I::butterfly_twiddle(values, twiddles[i], j, stride);
250 }
251 }
252}