rustfft/algorithm/
raders_algorithm.rs

1use std::sync::Arc;
2
3use num_complex::Complex;
4use num_integer::Integer;
5use num_traits::Zero;
6use primal_check::miller_rabin;
7use strength_reduce::StrengthReducedU64;
8
9use crate::math_utils;
10use crate::{common::FftNum, twiddles, FftDirection};
11use crate::{Direction, Fft, Length};
12
13/// Implementation of Rader's Algorithm
14///
15/// This algorithm computes a prime-sized FFT in O(nlogn) time. It does this by converting this size-N FFT into a
16/// size-(N - 1) FFT, which is guaranteed to be composite.
17///
18/// The worst case for this algorithm is when (N - 1) is 2 * prime, resulting in a
19/// [Cunningham Chain](https://en.wikipedia.org/wiki/Cunningham_chain)
20///
21/// ~~~
22/// // Computes a forward FFT of size 1201 (prime number), using Rader's Algorithm
23/// use rustfft::algorithm::RadersAlgorithm;
24/// use rustfft::{Fft, FftPlanner};
25/// use rustfft::num_complex::Complex;
26///
27/// let mut buffer = vec![Complex{ re: 0.0f32, im: 0.0f32 }; 1201];
28///
29/// // plan a FFT of size n - 1 = 1200
30/// let mut planner = FftPlanner::new();
31/// let inner_fft = planner.plan_fft_forward(1200);
32///
33/// let fft = RadersAlgorithm::new(inner_fft);
34/// fft.process(&mut buffer);
35/// ~~~
36///
37/// Rader's Algorithm is relatively expensive compared to other FFT algorithms. Benchmarking shows that it is up to
38/// an order of magnitude slower than similar composite sizes. In the example size above of 1201, benchmarking shows
39/// that it takes 2.5x more time to compute than a FFT of size 1200.
40
41pub struct RadersAlgorithm<T> {
42    inner_fft: Arc<dyn Fft<T>>,
43    inner_fft_data: Box<[Complex<T>]>,
44
45    primitive_root: u64,
46    primitive_root_inverse: u64,
47
48    len: StrengthReducedU64,
49    inplace_scratch_len: usize,
50    outofplace_scratch_len: usize,
51    immut_scratch_len: usize,
52
53    direction: FftDirection,
54}
55
56impl<T: FftNum> RadersAlgorithm<T> {
57    /// Creates a FFT instance which will process inputs/outputs of size `inner_fft.len() + 1`.
58    ///
59    /// Note that this constructor is quite expensive to run; This algorithm must compute a FFT using `inner_fft` within the
60    /// constructor. This further underlines the fact that Rader's Algorithm is more expensive to run than other
61    /// FFT algorithms
62    ///
63    /// # Panics
64    /// Panics if `inner_fft.len() + 1` is not a prime number.
65    pub fn new(inner_fft: Arc<dyn Fft<T>>) -> Self {
66        let inner_fft_len = inner_fft.len();
67        let len = inner_fft_len + 1;
68        assert!(miller_rabin(len as u64), "For raders algorithm, inner_fft.len() + 1 must be prime. Expected prime number, got {} + 1 = {}", inner_fft_len, len);
69
70        let direction = inner_fft.fft_direction();
71        let reduced_len = StrengthReducedU64::new(len as u64);
72
73        // compute the primitive root and its inverse for this size
74        let primitive_root = math_utils::primitive_root(len as u64).unwrap();
75
76        // compute the multiplicative inverse of primative_root mod len and vice versa.
77        // i64::extended_gcd will compute both the inverse of left mod right, and the inverse of right mod left, but we're only goingto use one of them
78        // the primtive root inverse might be negative, if o make it positive by wrapping
79        let gcd_data = i64::extended_gcd(&(primitive_root as i64), &(len as i64));
80        let primitive_root_inverse = if gcd_data.x >= 0 {
81            gcd_data.x
82        } else {
83            gcd_data.x + len as i64
84        } as u64;
85
86        // precompute the coefficients to use inside the process method
87        let inner_fft_scale = T::one() / T::from_usize(inner_fft_len).unwrap();
88        let mut inner_fft_input = vec![Complex::zero(); inner_fft_len];
89        let mut twiddle_input = 1;
90        for input_cell in &mut inner_fft_input {
91            let twiddle = twiddles::compute_twiddle(twiddle_input, len, direction);
92            *input_cell = twiddle * inner_fft_scale;
93
94            twiddle_input =
95                ((twiddle_input as u64 * primitive_root_inverse) % reduced_len) as usize;
96        }
97
98        let required_inner_scratch = inner_fft.get_inplace_scratch_len();
99        let extra_inner_scratch = if required_inner_scratch <= inner_fft_len {
100            0
101        } else {
102            required_inner_scratch
103        };
104        let inplace_scratch_len = inner_fft_len + extra_inner_scratch;
105        let immut_scratch_len = inner_fft_len + required_inner_scratch;
106
107        //precompute a FFT of our reordered twiddle factors
108        let mut inner_fft_scratch = vec![Zero::zero(); required_inner_scratch];
109        inner_fft.process_with_scratch(&mut inner_fft_input, &mut inner_fft_scratch);
110
111        Self {
112            inner_fft,
113            inner_fft_data: inner_fft_input.into_boxed_slice(),
114
115            primitive_root,
116            primitive_root_inverse,
117
118            len: reduced_len,
119            inplace_scratch_len,
120            outofplace_scratch_len: extra_inner_scratch,
121            immut_scratch_len,
122            direction,
123        }
124    }
125
126    fn perform_fft_immut(
127        &self,
128        input: &[Complex<T>],
129        output: &mut [Complex<T>],
130        scratch: &mut [Complex<T>],
131    ) {
132        // The first output element is just the sum of all the input elements, and we need to store off the first input value
133        let (output_first, output) = output.split_first_mut().unwrap();
134        let (input_first, input) = input.split_first().unwrap();
135        let (scratch, extra_scratch) = scratch.split_at_mut(self.len() - 1);
136
137        // copy the input into the scratch space, reordering as we go
138        let mut input_index = 1;
139        for output_element in scratch.iter_mut() {
140            input_index = ((input_index as u64 * self.primitive_root) % self.len) as usize;
141
142            let input_element = input[input_index - 1];
143            *output_element = input_element;
144        }
145
146        self.inner_fft.process_with_scratch(scratch, extra_scratch);
147
148        // output[0] now contains the sum of elements 1..len. We need the sum of all elements, so all we have to do is add the first input
149        *output_first = *input_first + scratch[0];
150
151        // multiply the inner result with our cached setup data
152        // also conjugate every entry. this sets us up to do an inverse FFT
153        // (because an inverse FFT is equivalent to a normal FFT where you conjugate both the inputs and outputs)
154        for (scratch_cell, &twiddle) in scratch.iter_mut().zip(self.inner_fft_data.iter()) {
155            *scratch_cell = (*scratch_cell * twiddle).conj();
156        }
157
158        // We need to add the first input value to all output values. We can accomplish this by adding it to the DC input of our inner ifft.
159        // Of course, we have to conjugate it, just like we conjugated the complex multiplied above
160        scratch[0] = scratch[0] + input_first.conj();
161
162        // execute the second FFT
163        self.inner_fft.process_with_scratch(scratch, extra_scratch);
164
165        // copy the final values into the output, reordering as we go
166        let mut output_index = 1;
167        for scratch_element in scratch {
168            output_index =
169                ((output_index as u64 * self.primitive_root_inverse) % self.len) as usize;
170            output[output_index - 1] = scratch_element.conj();
171        }
172    }
173
174    fn perform_fft_out_of_place(
175        &self,
176        input: &mut [Complex<T>],
177        output: &mut [Complex<T>],
178        scratch: &mut [Complex<T>],
179    ) {
180        // The first output element is just the sum of all the input elements, and we need to store off the first input value
181        let (output_first, output) = output.split_first_mut().unwrap();
182        let (input_first, input) = input.split_first_mut().unwrap();
183
184        // copy the input into the output, reordering as we go. also compute a sum of all elements
185        let mut input_index = 1;
186        for output_element in output.iter_mut() {
187            input_index = ((input_index as u64 * self.primitive_root) % self.len) as usize;
188
189            let input_element = input[input_index - 1];
190            *output_element = input_element;
191        }
192
193        // perform the first of two inner FFTs
194        let inner_scratch = if scratch.len() > 0 {
195            &mut scratch[..]
196        } else {
197            &mut input[..]
198        };
199        self.inner_fft.process_with_scratch(output, inner_scratch);
200
201        // output[0] now contains the sum of elements 1..len. We need the sum of all elements, so all we have to do is add the first input
202        *output_first = *input_first + output[0];
203
204        // multiply the inner result with our cached setup data
205        // also conjugate every entry. this sets us up to do an inverse FFT
206        // (because an inverse FFT is equivalent to a normal FFT where you conjugate both the inputs and outputs)
207        for ((output_cell, input_cell), &multiple) in output
208            .iter()
209            .zip(input.iter_mut())
210            .zip(self.inner_fft_data.iter())
211        {
212            *input_cell = (*output_cell * multiple).conj();
213        }
214
215        // We need to add the first input value to all output values. We can accomplish this by adding it to the DC input of our inner ifft.
216        // Of course, we have to conjugate it, just like we conjugated the complex multiplied above
217        input[0] = input[0] + input_first.conj();
218
219        // execute the second FFT
220        let inner_scratch = if scratch.len() > 0 {
221            scratch
222        } else {
223            &mut output[..]
224        };
225        self.inner_fft.process_with_scratch(input, inner_scratch);
226
227        // copy the final values into the output, reordering as we go
228        let mut output_index = 1;
229        for input_element in input {
230            output_index =
231                ((output_index as u64 * self.primitive_root_inverse) % self.len) as usize;
232            output[output_index - 1] = input_element.conj();
233        }
234    }
235    fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
236        // The first output element is just the sum of all the input elements, and we need to store off the first input value
237        let (buffer_first, buffer) = buffer.split_first_mut().unwrap();
238        let buffer_first_val = *buffer_first;
239
240        let (scratch, extra_scratch) = scratch.split_at_mut(self.len() - 1);
241
242        // copy the buffer into the scratch, reordering as we go. also compute a sum of all elements
243        let mut input_index = 1;
244        for scratch_element in scratch.iter_mut() {
245            input_index = ((input_index as u64 * self.primitive_root) % self.len) as usize;
246
247            let buffer_element = buffer[input_index - 1];
248            *scratch_element = buffer_element;
249        }
250
251        // perform the first of two inner FFTs
252        let inner_scratch = if extra_scratch.len() > 0 {
253            extra_scratch
254        } else {
255            &mut buffer[..]
256        };
257        self.inner_fft.process_with_scratch(scratch, inner_scratch);
258
259        // scratch[0] now contains the sum of elements 1..len. We need the sum of all elements, so all we have to do is add the first input
260        *buffer_first = *buffer_first + scratch[0];
261
262        // multiply the inner result with our cached setup data
263        // also conjugate every entry. this sets us up to do an inverse FFT
264        // (because an inverse FFT is equivalent to a normal FFT where you conjugate both the inputs and outputs)
265        for (scratch_cell, &twiddle) in scratch.iter_mut().zip(self.inner_fft_data.iter()) {
266            *scratch_cell = (*scratch_cell * twiddle).conj();
267        }
268
269        // We need to add the first input value to all output values. We can accomplish this by adding it to the DC input of our inner ifft.
270        // Of course, we have to conjugate it, just like we conjugated the complex multiplied above
271        scratch[0] = scratch[0] + buffer_first_val.conj();
272
273        // execute the second FFT
274        self.inner_fft.process_with_scratch(scratch, inner_scratch);
275
276        // copy the final values into the output, reordering as we go
277        let mut output_index = 1;
278        for scratch_element in scratch {
279            output_index =
280                ((output_index as u64 * self.primitive_root_inverse) % self.len) as usize;
281            buffer[output_index - 1] = scratch_element.conj();
282        }
283    }
284}
285boilerplate_fft!(
286    RadersAlgorithm,
287    |this: &RadersAlgorithm<_>| this.len.get() as usize,
288    |this: &RadersAlgorithm<_>| this.inplace_scratch_len,
289    |this: &RadersAlgorithm<_>| this.outofplace_scratch_len,
290    |this: &RadersAlgorithm<_>| this.immut_scratch_len
291);
292
293#[cfg(test)]
294mod unit_tests {
295    use super::*;
296    use crate::algorithm::Dft;
297    use crate::test_utils::check_fft_algorithm;
298    use crate::FftPlanner;
299    use std::sync::Arc;
300
301    #[test]
302    fn test_raders() {
303        for len in 3..100 {
304            if miller_rabin(len as u64) {
305                test_raders_with_length(len, FftDirection::Forward);
306                test_raders_with_length(len, FftDirection::Inverse);
307            }
308        }
309    }
310
311    #[test]
312    fn test_raders_32bit_overflow() {
313        // Construct and use Raders instances for a few large primes
314        // that could panic due to overflow errors on 32-bit builds.
315        let mut planner = FftPlanner::<f32>::new();
316        for len in [112501, 216569, 417623] {
317            let inner_fft = planner.plan_fft_forward(len - 1);
318            let fft: RadersAlgorithm<f32> = RadersAlgorithm::new(inner_fft);
319            let mut data = vec![Complex::new(0.0, 0.0); len];
320            fft.process(&mut data);
321        }
322    }
323
324    fn test_raders_with_length(len: usize, direction: FftDirection) {
325        let inner_fft = Arc::new(Dft::new(len - 1, direction));
326        let fft = RadersAlgorithm::new(inner_fft);
327
328        check_fft_algorithm::<f32>(&fft, len, direction);
329    }
330}