1use std::cmp::max;
2use std::sync::Arc;
3
4use num_complex::Complex;
5use num_integer::Integer;
6use strength_reduce::StrengthReducedUsize;
7use transpose;
8
9use crate::array_utils;
10use crate::{common::FftNum, FftDirection};
11use crate::{Direction, Fft, Length};
12
13pub struct GoodThomasAlgorithm<T> {
41 width: usize,
42 width_size_fft: Arc<dyn Fft<T>>,
43
44 height: usize,
45 height_size_fft: Arc<dyn Fft<T>>,
46
47 reduced_width: StrengthReducedUsize,
48 reduced_width_plus_one: StrengthReducedUsize,
49
50 inplace_scratch_len: usize,
51 outofplace_scratch_len: usize,
52 immut_scratch_len: usize,
53
54 len: usize,
55 direction: FftDirection,
56}
57
58impl<T: FftNum> GoodThomasAlgorithm<T> {
59 pub fn new(mut width_fft: Arc<dyn Fft<T>>, mut height_fft: Arc<dyn Fft<T>>) -> Self {
63 assert_eq!(
64 width_fft.fft_direction(), height_fft.fft_direction(),
65 "width_fft and height_fft must have the same direction. got width direction={}, height direction={}",
66 width_fft.fft_direction(), height_fft.fft_direction());
67
68 let mut width = width_fft.len();
69 let mut height = height_fft.len();
70 let direction = width_fft.fft_direction();
71
72 let gcd = num_integer::gcd(width as i64, height as i64);
74 assert!(gcd == 1,
75 "Invalid width and height for Good-Thomas Algorithm (width={}, height={}): Inputs must be coprime",
76 width,
77 height);
78
79 if width > height {
81 std::mem::swap(&mut width, &mut height);
82 std::mem::swap(&mut width_fft, &mut height_fft);
83 }
84
85 let len = width * height;
86
87 let width_inplace_scratch = width_fft.get_inplace_scratch_len();
89 let height_inplace_scratch = height_fft.get_inplace_scratch_len();
90 let height_outofplace_scratch = height_fft.get_outofplace_scratch_len();
91
92 let max_inner_inplace_scratch = max(height_inplace_scratch, width_inplace_scratch);
100 let outofplace_scratch_len = if max_inner_inplace_scratch > len {
101 max_inner_inplace_scratch
102 } else {
103 0
104 };
105
106 let inplace_scratch_len = len
111 + max(
112 if width_inplace_scratch > len {
113 width_inplace_scratch
114 } else {
115 0
116 },
117 height_outofplace_scratch,
118 );
119
120 let immut_scratch_len = max(
121 width_fft.get_inplace_scratch_len(),
122 len + height_fft.get_inplace_scratch_len(),
123 );
124
125 Self {
126 width,
127 width_size_fft: width_fft,
128
129 height,
130 height_size_fft: height_fft,
131
132 reduced_width: StrengthReducedUsize::new(width),
133 reduced_width_plus_one: StrengthReducedUsize::new(width + 1),
134
135 inplace_scratch_len,
136 outofplace_scratch_len,
137 immut_scratch_len,
138
139 len,
140 direction,
141 }
142 }
143
144 fn reindex_input(&self, source: &[Complex<T>], destination: &mut [Complex<T>]) {
145 let mut destination_index = 0;
160 for mut source_row in source.chunks_exact(self.width) {
161 let increments_until_cycle =
162 1 + (self.len() - destination_index) / self.reduced_width_plus_one;
163
164 if increments_until_cycle < self.width {
166 let (pre_cycle_row, post_cycle_row) = source_row.split_at(increments_until_cycle);
167
168 for input_element in pre_cycle_row {
169 destination[destination_index] = *input_element;
170 destination_index += self.reduced_width_plus_one.get();
171 }
172
173 source_row = post_cycle_row;
175 destination_index -= self.len();
176 }
177
178 for input_element in source_row {
180 destination[destination_index] = *input_element;
181 destination_index += self.reduced_width_plus_one.get();
182 }
183
184 destination_index -= self.width;
187 }
188 }
189
190 fn reindex_output(&self, source: &[Complex<T>], destination: &mut [Complex<T>]) {
191 for (y, source_chunk) in source.chunks_exact(self.height).enumerate() {
203 let (quotient, remainder) =
204 StrengthReducedUsize::div_rem(y * self.height, self.reduced_width);
205
206 let mut destination_index = remainder;
208 let start_x = self.height - quotient;
209
210 for x in start_x..self.height {
212 destination[destination_index] = source_chunk[x];
213 destination_index += self.width;
214 }
215
216 for x in 0..start_x {
218 destination[destination_index] = source_chunk[x];
219 destination_index += self.width;
220 }
221 }
222 }
223
224 fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
225 let (scratch, inner_scratch) = scratch.split_at_mut(self.len());
226
227 self.reindex_input(buffer, scratch);
229
230 let width_scratch = if inner_scratch.len() > buffer.len() {
232 &mut inner_scratch[..]
233 } else {
234 &mut buffer[..]
235 };
236 self.width_size_fft
237 .process_with_scratch(scratch, width_scratch);
238
239 transpose::transpose(scratch, buffer, self.width, self.height);
241
242 self.height_size_fft
244 .process_outofplace_with_scratch(buffer, scratch, inner_scratch);
245
246 self.reindex_output(scratch, buffer);
248 }
249
250 fn perform_fft_immut(
251 &self,
252 input: &[Complex<T>],
253 output: &mut [Complex<T>],
254 scratch: &mut [Complex<T>],
255 ) {
256 self.reindex_input(input, output);
258
259 self.width_size_fft.process_with_scratch(output, scratch);
261
262 let (scratch, inner_scratch) = scratch.split_at_mut(self.len());
263
264 transpose::transpose(output, scratch, self.width, self.height);
266
267 self.height_size_fft
269 .process_with_scratch(scratch, inner_scratch);
270
271 self.reindex_output(scratch, output);
273 }
274
275 fn perform_fft_out_of_place(
276 &self,
277 input: &mut [Complex<T>],
278 output: &mut [Complex<T>],
279 scratch: &mut [Complex<T>],
280 ) {
281 self.reindex_input(input, output);
283
284 let width_scratch = if scratch.len() > input.len() {
286 &mut scratch[..]
287 } else {
288 &mut input[..]
289 };
290 self.width_size_fft
291 .process_with_scratch(output, width_scratch);
292
293 transpose::transpose(output, input, self.width, self.height);
295
296 let height_scratch = if scratch.len() > output.len() {
298 &mut scratch[..]
299 } else {
300 &mut output[..]
301 };
302 self.height_size_fft
303 .process_with_scratch(input, height_scratch);
304
305 self.reindex_output(input, output);
307 }
308}
309boilerplate_fft!(
310 GoodThomasAlgorithm,
311 |this: &GoodThomasAlgorithm<_>| this.len,
312 |this: &GoodThomasAlgorithm<_>| this.inplace_scratch_len,
313 |this: &GoodThomasAlgorithm<_>| this.outofplace_scratch_len,
314 |this: &GoodThomasAlgorithm<_>| this.immut_scratch_len
315);
316
317pub struct GoodThomasAlgorithmSmall<T> {
345 width: usize,
346 width_size_fft: Arc<dyn Fft<T>>,
347
348 height: usize,
349 height_size_fft: Arc<dyn Fft<T>>,
350
351 input_output_map: Box<[usize]>,
352
353 direction: FftDirection,
354}
355
356impl<T: FftNum> GoodThomasAlgorithmSmall<T> {
357 pub fn new(width_fft: Arc<dyn Fft<T>>, height_fft: Arc<dyn Fft<T>>) -> Self {
361 assert_eq!(
362 width_fft.fft_direction(), height_fft.fft_direction(),
363 "n1_fft and height_fft must have the same direction. got width direction={}, height direction={}",
364 width_fft.fft_direction(), height_fft.fft_direction());
365
366 let width = width_fft.len();
367 let height = height_fft.len();
368 let len = width * height;
369
370 assert_eq!(width_fft.get_outofplace_scratch_len(), 0, "GoodThomasAlgorithmSmall should only be used with algorithms that require 0 out-of-place scratch. Width FFT (len={}) requires {}, should require 0", width, width_fft.get_outofplace_scratch_len());
371 assert_eq!(height_fft.get_outofplace_scratch_len(), 0, "GoodThomasAlgorithmSmall should only be used with algorithms that require 0 out-of-place scratch. Height FFT (len={}) requires {}, should require 0", height, height_fft.get_outofplace_scratch_len());
372
373 assert!(width_fft.get_inplace_scratch_len() <= width, "GoodThomasAlgorithmSmall should only be used with algorithms that require little inplace scratch. Width FFT (len={}) requires {}, should require {} or less", width, width_fft.get_inplace_scratch_len(), width);
374 assert!(height_fft.get_inplace_scratch_len() <= height, "GoodThomasAlgorithmSmall should only be used with algorithms that require little inplace scratch. Height FFT (len={}) requires {}, should require {} or less", height, height_fft.get_inplace_scratch_len(), height);
375
376 let gcd_data = i64::extended_gcd(&(width as i64), &(height as i64));
378 assert!(gcd_data.gcd == 1,
379 "Invalid input width and height to Good-Thomas Algorithm: ({},{}): Inputs must be coprime",
380 width,
381 height);
382
383 let width_inverse = if gcd_data.x >= 0 {
385 gcd_data.x
386 } else {
387 gcd_data.x + height as i64
388 } as usize;
389 let height_inverse = if gcd_data.y >= 0 {
390 gcd_data.y
391 } else {
392 gcd_data.y + width as i64
393 } as usize;
394
395 let input_iter = (0..len)
398 .map(|i| (i % width, i / width))
399 .map(|(x, y)| (x * height + y * width) % len);
400 let output_iter = (0..len).map(|i| (i % height, i / height)).map(|(y, x)| {
401 (x * height * height_inverse as usize + y * width * width_inverse as usize) % len
402 });
403
404 let input_output_map: Vec<usize> = input_iter.chain(output_iter).collect();
405
406 Self {
407 direction: width_fft.fft_direction(),
408
409 width,
410 width_size_fft: width_fft,
411
412 height,
413 height_size_fft: height_fft,
414
415 input_output_map: input_output_map.into_boxed_slice(),
416 }
417 }
418
419 fn perform_fft_immut(
420 &self,
421 input: &[Complex<T>],
422 output: &mut [Complex<T>],
423 scratch: &mut [Complex<T>],
424 ) {
425 assert_eq!(self.len(), input.len());
427 assert_eq!(self.len(), output.len());
428
429 let (input_map, output_map) = self.input_output_map.split_at(self.len());
430
431 for (output_element, &input_index) in output.iter_mut().zip(input_map.iter()) {
433 *output_element = input[input_index];
434 }
435
436 self.width_size_fft.process_with_scratch(output, scratch);
438
439 unsafe { array_utils::transpose_small(self.width, self.height, output, scratch) };
441
442 self.height_size_fft.process_with_scratch(scratch, output);
444
445 for (input_element, &output_index) in scratch.iter().zip(output_map.iter()) {
447 output[output_index] = *input_element;
448 }
449 }
450
451 fn perform_fft_out_of_place(
452 &self,
453 input: &mut [Complex<T>],
454 output: &mut [Complex<T>],
455 _scratch: &mut [Complex<T>],
456 ) {
457 assert_eq!(self.len(), input.len());
459 assert_eq!(self.len(), output.len());
460
461 let (input_map, output_map) = self.input_output_map.split_at(self.len());
462
463 for (output_element, &input_index) in output.iter_mut().zip(input_map.iter()) {
465 *output_element = input[input_index];
466 }
467
468 self.width_size_fft.process_with_scratch(output, input);
470
471 unsafe { array_utils::transpose_small(self.width, self.height, output, input) };
473
474 self.height_size_fft.process_with_scratch(input, output);
476
477 for (input_element, &output_index) in input.iter().zip(output_map.iter()) {
479 output[output_index] = *input_element;
480 }
481 }
482
483 fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
484 assert_eq!(self.len(), buffer.len());
486 assert_eq!(self.len(), scratch.len());
487
488 let (input_map, output_map) = self.input_output_map.split_at(self.len());
489
490 for (output_element, &input_index) in scratch.iter_mut().zip(input_map.iter()) {
492 *output_element = buffer[input_index];
493 }
494
495 self.width_size_fft.process_with_scratch(scratch, buffer);
497
498 unsafe { array_utils::transpose_small(self.width, self.height, scratch, buffer) };
500
501 self.height_size_fft
503 .process_outofplace_with_scratch(buffer, scratch, &mut []);
504
505 for (input_element, &output_index) in scratch.iter().zip(output_map.iter()) {
507 buffer[output_index] = *input_element;
508 }
509 }
510}
511boilerplate_fft!(
512 GoodThomasAlgorithmSmall,
513 |this: &GoodThomasAlgorithmSmall<_>| this.width * this.height,
514 |this: &GoodThomasAlgorithmSmall<_>| this.len(),
515 |_| 0,
516 |this: &GoodThomasAlgorithmSmall<_>| this.len()
517);
518
519#[cfg(test)]
520mod unit_tests {
521 use super::*;
522 use crate::test_utils::check_fft_algorithm;
523 use crate::{algorithm::Dft, test_utils::BigScratchAlgorithm};
524 use num_integer::gcd;
525 use num_traits::Zero;
526 use std::sync::Arc;
527
528 #[test]
529 fn test_good_thomas() {
530 for width in 1..12 {
531 for height in 1..12 {
532 if gcd(width, height) == 1 {
533 test_good_thomas_with_lengths(width, height, FftDirection::Forward);
534 test_good_thomas_with_lengths(width, height, FftDirection::Inverse);
535 }
536 }
537 }
538 }
539
540 #[test]
541 fn test_good_thomas_small() {
542 let butterfly_sizes = [2, 3, 4, 5, 6, 7, 8, 16];
543 for width in &butterfly_sizes {
544 for height in &butterfly_sizes {
545 if gcd(*width, *height) == 1 {
546 test_good_thomas_small_with_lengths(*width, *height, FftDirection::Forward);
547 test_good_thomas_small_with_lengths(*width, *height, FftDirection::Inverse);
548 }
549 }
550 }
551 }
552
553 fn test_good_thomas_with_lengths(width: usize, height: usize, direction: FftDirection) {
554 let width_fft = Arc::new(Dft::new(width, direction)) as Arc<dyn Fft<f32>>;
555 let height_fft = Arc::new(Dft::new(height, direction)) as Arc<dyn Fft<f32>>;
556
557 let fft = GoodThomasAlgorithm::new(width_fft, height_fft);
558
559 check_fft_algorithm(&fft, width * height, direction);
560 }
561
562 fn test_good_thomas_small_with_lengths(width: usize, height: usize, direction: FftDirection) {
563 let width_fft = Arc::new(Dft::new(width, direction)) as Arc<dyn Fft<f32>>;
564 let height_fft = Arc::new(Dft::new(height, direction)) as Arc<dyn Fft<f32>>;
565
566 let fft = GoodThomasAlgorithmSmall::new(width_fft, height_fft);
567
568 check_fft_algorithm(&fft, width * height, direction);
569 }
570
571 #[test]
572 fn test_output_mapping() {
573 let width = 15;
574 for height in 3..width {
575 if gcd(width, height) == 1 {
576 let width_fft =
577 Arc::new(Dft::new(width, FftDirection::Forward)) as Arc<dyn Fft<f32>>;
578 let height_fft =
579 Arc::new(Dft::new(height, FftDirection::Forward)) as Arc<dyn Fft<f32>>;
580
581 let fft = GoodThomasAlgorithm::new(width_fft, height_fft);
582
583 let mut buffer = vec![Complex { re: 0.0, im: 0.0 }; fft.len()];
584
585 fft.process(&mut buffer);
586 }
587 }
588 }
589
590 #[test]
592 fn test_good_thomas_inner_scratch() {
593 let scratch_lengths = [1, 5, 24];
594
595 let mut inner_ffts = Vec::new();
596
597 for &len in &scratch_lengths {
598 for &inplace_scratch in &scratch_lengths {
599 for &outofplace_scratch in &scratch_lengths {
600 for &immut_scratch in &scratch_lengths {
601 inner_ffts.push(Arc::new(BigScratchAlgorithm {
602 len,
603 inplace_scratch,
604 outofplace_scratch,
605 immut_scratch,
606 direction: FftDirection::Forward,
607 }) as Arc<dyn Fft<f32>>);
608 }
609 }
610 }
611 }
612
613 for width_fft in inner_ffts.iter() {
614 for height_fft in inner_ffts.iter() {
615 if width_fft.len() == height_fft.len() {
616 continue;
617 }
618
619 let fft = GoodThomasAlgorithm::new(Arc::clone(width_fft), Arc::clone(height_fft));
620
621 let mut inplace_buffer = vec![Complex::zero(); fft.len()];
622 let mut inplace_scratch = vec![Complex::zero(); fft.get_inplace_scratch_len()];
623
624 fft.process_with_scratch(&mut inplace_buffer, &mut inplace_scratch);
625
626 let mut outofplace_input = vec![Complex::zero(); fft.len()];
627 let mut outofplace_output = vec![Complex::zero(); fft.len()];
628 let mut outofplace_scratch =
629 vec![Complex::zero(); fft.get_outofplace_scratch_len()];
630
631 fft.process_outofplace_with_scratch(
632 &mut outofplace_input,
633 &mut outofplace_output,
634 &mut outofplace_scratch,
635 );
636
637 let immut_input = vec![Complex::zero(); fft.len()];
638 let mut immut_output = vec![Complex::zero(); fft.len()];
639 let mut immut_scratch = vec![Complex::zero(); fft.get_immutable_scratch_len()];
640
641 fft.process_immutable_with_scratch(
642 &immut_input,
643 &mut immut_output,
644 &mut immut_scratch,
645 );
646 }
647 }
648 }
649}