rustfft/algorithm/
raders_algorithm.rs1use std::sync::Arc;
2
3use num_complex::Complex;
4use num_integer::Integer;
5use num_traits::Zero;
6use primal_check::miller_rabin;
7use strength_reduce::StrengthReducedU64;
8
9use crate::math_utils;
10use crate::{common::FftNum, twiddles, FftDirection};
11use crate::{Direction, Fft, Length};
12
13pub struct RadersAlgorithm<T> {
42 inner_fft: Arc<dyn Fft<T>>,
43 inner_fft_data: Box<[Complex<T>]>,
44
45 primitive_root: u64,
46 primitive_root_inverse: u64,
47
48 len: StrengthReducedU64,
49 inplace_scratch_len: usize,
50 outofplace_scratch_len: usize,
51 immut_scratch_len: usize,
52
53 direction: FftDirection,
54}
55
56impl<T: FftNum> RadersAlgorithm<T> {
57 pub fn new(inner_fft: Arc<dyn Fft<T>>) -> Self {
66 let inner_fft_len = inner_fft.len();
67 let len = inner_fft_len + 1;
68 assert!(miller_rabin(len as u64), "For raders algorithm, inner_fft.len() + 1 must be prime. Expected prime number, got {} + 1 = {}", inner_fft_len, len);
69
70 let direction = inner_fft.fft_direction();
71 let reduced_len = StrengthReducedU64::new(len as u64);
72
73 let primitive_root = math_utils::primitive_root(len as u64).unwrap();
75
76 let gcd_data = i64::extended_gcd(&(primitive_root as i64), &(len as i64));
80 let primitive_root_inverse = if gcd_data.x >= 0 {
81 gcd_data.x
82 } else {
83 gcd_data.x + len as i64
84 } as u64;
85
86 let inner_fft_scale = T::one() / T::from_usize(inner_fft_len).unwrap();
88 let mut inner_fft_input = vec![Complex::zero(); inner_fft_len];
89 let mut twiddle_input = 1;
90 for input_cell in &mut inner_fft_input {
91 let twiddle = twiddles::compute_twiddle(twiddle_input, len, direction);
92 *input_cell = twiddle * inner_fft_scale;
93
94 twiddle_input =
95 ((twiddle_input as u64 * primitive_root_inverse) % reduced_len) as usize;
96 }
97
98 let required_inner_scratch = inner_fft.get_inplace_scratch_len();
99 let extra_inner_scratch = if required_inner_scratch <= inner_fft_len {
100 0
101 } else {
102 required_inner_scratch
103 };
104 let inplace_scratch_len = inner_fft_len + extra_inner_scratch;
105 let immut_scratch_len = inner_fft_len + required_inner_scratch;
106
107 let mut inner_fft_scratch = vec![Zero::zero(); required_inner_scratch];
109 inner_fft.process_with_scratch(&mut inner_fft_input, &mut inner_fft_scratch);
110
111 Self {
112 inner_fft,
113 inner_fft_data: inner_fft_input.into_boxed_slice(),
114
115 primitive_root,
116 primitive_root_inverse,
117
118 len: reduced_len,
119 inplace_scratch_len,
120 outofplace_scratch_len: extra_inner_scratch,
121 immut_scratch_len,
122 direction,
123 }
124 }
125
126 fn perform_fft_immut(
127 &self,
128 input: &[Complex<T>],
129 output: &mut [Complex<T>],
130 scratch: &mut [Complex<T>],
131 ) {
132 let (output_first, output) = output.split_first_mut().unwrap();
134 let (input_first, input) = input.split_first().unwrap();
135 let (scratch, extra_scratch) = scratch.split_at_mut(self.len() - 1);
136
137 let mut input_index = 1;
139 for output_element in scratch.iter_mut() {
140 input_index = ((input_index as u64 * self.primitive_root) % self.len) as usize;
141
142 let input_element = input[input_index - 1];
143 *output_element = input_element;
144 }
145
146 self.inner_fft.process_with_scratch(scratch, extra_scratch);
147
148 *output_first = *input_first + scratch[0];
150
151 for (scratch_cell, &twiddle) in scratch.iter_mut().zip(self.inner_fft_data.iter()) {
155 *scratch_cell = (*scratch_cell * twiddle).conj();
156 }
157
158 scratch[0] = scratch[0] + input_first.conj();
161
162 self.inner_fft.process_with_scratch(scratch, extra_scratch);
164
165 let mut output_index = 1;
167 for scratch_element in scratch {
168 output_index =
169 ((output_index as u64 * self.primitive_root_inverse) % self.len) as usize;
170 output[output_index - 1] = scratch_element.conj();
171 }
172 }
173
174 fn perform_fft_out_of_place(
175 &self,
176 input: &mut [Complex<T>],
177 output: &mut [Complex<T>],
178 scratch: &mut [Complex<T>],
179 ) {
180 let (output_first, output) = output.split_first_mut().unwrap();
182 let (input_first, input) = input.split_first_mut().unwrap();
183
184 let mut input_index = 1;
186 for output_element in output.iter_mut() {
187 input_index = ((input_index as u64 * self.primitive_root) % self.len) as usize;
188
189 let input_element = input[input_index - 1];
190 *output_element = input_element;
191 }
192
193 let inner_scratch = if scratch.len() > 0 {
195 &mut scratch[..]
196 } else {
197 &mut input[..]
198 };
199 self.inner_fft.process_with_scratch(output, inner_scratch);
200
201 *output_first = *input_first + output[0];
203
204 for ((output_cell, input_cell), &multiple) in output
208 .iter()
209 .zip(input.iter_mut())
210 .zip(self.inner_fft_data.iter())
211 {
212 *input_cell = (*output_cell * multiple).conj();
213 }
214
215 input[0] = input[0] + input_first.conj();
218
219 let inner_scratch = if scratch.len() > 0 {
221 scratch
222 } else {
223 &mut output[..]
224 };
225 self.inner_fft.process_with_scratch(input, inner_scratch);
226
227 let mut output_index = 1;
229 for input_element in input {
230 output_index =
231 ((output_index as u64 * self.primitive_root_inverse) % self.len) as usize;
232 output[output_index - 1] = input_element.conj();
233 }
234 }
235 fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
236 let (buffer_first, buffer) = buffer.split_first_mut().unwrap();
238 let buffer_first_val = *buffer_first;
239
240 let (scratch, extra_scratch) = scratch.split_at_mut(self.len() - 1);
241
242 let mut input_index = 1;
244 for scratch_element in scratch.iter_mut() {
245 input_index = ((input_index as u64 * self.primitive_root) % self.len) as usize;
246
247 let buffer_element = buffer[input_index - 1];
248 *scratch_element = buffer_element;
249 }
250
251 let inner_scratch = if extra_scratch.len() > 0 {
253 extra_scratch
254 } else {
255 &mut buffer[..]
256 };
257 self.inner_fft.process_with_scratch(scratch, inner_scratch);
258
259 *buffer_first = *buffer_first + scratch[0];
261
262 for (scratch_cell, &twiddle) in scratch.iter_mut().zip(self.inner_fft_data.iter()) {
266 *scratch_cell = (*scratch_cell * twiddle).conj();
267 }
268
269 scratch[0] = scratch[0] + buffer_first_val.conj();
272
273 self.inner_fft.process_with_scratch(scratch, inner_scratch);
275
276 let mut output_index = 1;
278 for scratch_element in scratch {
279 output_index =
280 ((output_index as u64 * self.primitive_root_inverse) % self.len) as usize;
281 buffer[output_index - 1] = scratch_element.conj();
282 }
283 }
284}
285boilerplate_fft!(
286 RadersAlgorithm,
287 |this: &RadersAlgorithm<_>| this.len.get() as usize,
288 |this: &RadersAlgorithm<_>| this.inplace_scratch_len,
289 |this: &RadersAlgorithm<_>| this.outofplace_scratch_len,
290 |this: &RadersAlgorithm<_>| this.immut_scratch_len
291);
292
293#[cfg(test)]
294mod unit_tests {
295 use super::*;
296 use crate::algorithm::Dft;
297 use crate::test_utils::check_fft_algorithm;
298 use crate::FftPlanner;
299 use std::sync::Arc;
300
301 #[test]
302 fn test_raders() {
303 for len in 3..100 {
304 if miller_rabin(len as u64) {
305 test_raders_with_length(len, FftDirection::Forward);
306 test_raders_with_length(len, FftDirection::Inverse);
307 }
308 }
309 }
310
311 #[test]
312 fn test_raders_32bit_overflow() {
313 let mut planner = FftPlanner::<f32>::new();
316 for len in [112501, 216569, 417623] {
317 let inner_fft = planner.plan_fft_forward(len - 1);
318 let fft: RadersAlgorithm<f32> = RadersAlgorithm::new(inner_fft);
319 let mut data = vec![Complex::new(0.0, 0.0); len];
320 fft.process(&mut data);
321 }
322 }
323
324 fn test_raders_with_length(len: usize, direction: FftDirection) {
325 let inner_fft = Arc::new(Dft::new(len - 1, direction));
326 let fft = RadersAlgorithm::new(inner_fft);
327
328 check_fft_algorithm::<f32>(&fft, len, direction);
329 }
330}