use std::cmp::max;
use std::sync::Arc;
use num_complex::Complex;
use num_traits::Zero;
use transpose;
use crate::array_utils;
use crate::common::{fft_error_inplace, fft_error_outofplace};
use crate::{common::FftNum, twiddles, FftDirection};
use crate::{Direction, Fft, Length};
pub struct MixedRadix<T> {
twiddles: Box<[Complex<T>]>,
width_size_fft: Arc<dyn Fft<T>>,
width: usize,
height_size_fft: Arc<dyn Fft<T>>,
height: usize,
inplace_scratch_len: usize,
outofplace_scratch_len: usize,
direction: FftDirection,
}
impl<T: FftNum> MixedRadix<T> {
pub fn new(width_fft: Arc<dyn Fft<T>>, height_fft: Arc<dyn Fft<T>>) -> Self {
assert_eq!(
width_fft.fft_direction(), height_fft.fft_direction(),
"width_fft and height_fft must have the same direction. got width direction={}, height direction={}",
width_fft.fft_direction(), height_fft.fft_direction());
let direction = width_fft.fft_direction();
let width = width_fft.len();
let height = height_fft.len();
let len = width * height;
let mut twiddles = vec![Complex::zero(); len];
for (x, twiddle_chunk) in twiddles.chunks_exact_mut(height).enumerate() {
for (y, twiddle_element) in twiddle_chunk.iter_mut().enumerate() {
*twiddle_element = twiddles::compute_twiddle(x * y, len, direction);
}
}
let height_inplace_scratch = height_fft.get_inplace_scratch_len();
let width_inplace_scratch = width_fft.get_inplace_scratch_len();
let width_outofplace_scratch = width_fft.get_outofplace_scratch_len();
let max_inner_inplace_scratch = max(height_inplace_scratch, width_inplace_scratch);
let outofplace_scratch_len = if max_inner_inplace_scratch > len {
max_inner_inplace_scratch
} else {
0
};
let inplace_scratch_len = len
+ max(
if height_inplace_scratch > len {
height_inplace_scratch
} else {
0
},
width_outofplace_scratch,
);
Self {
twiddles: twiddles.into_boxed_slice(),
width_size_fft: width_fft,
width: width,
height_size_fft: height_fft,
height: height,
inplace_scratch_len,
outofplace_scratch_len,
direction,
}
}
fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
let (scratch, inner_scratch) = scratch.split_at_mut(self.len());
transpose::transpose(buffer, scratch, self.width, self.height);
let height_scratch = if inner_scratch.len() > buffer.len() {
&mut inner_scratch[..]
} else {
&mut buffer[..]
};
self.height_size_fft
.process_with_scratch(scratch, height_scratch);
for (element, twiddle) in scratch.iter_mut().zip(self.twiddles.iter()) {
*element = *element * twiddle;
}
transpose::transpose(scratch, buffer, self.height, self.width);
self.width_size_fft
.process_outofplace_with_scratch(buffer, scratch, inner_scratch);
transpose::transpose(scratch, buffer, self.width, self.height);
}
fn perform_fft_out_of_place(
&self,
input: &mut [Complex<T>],
output: &mut [Complex<T>],
scratch: &mut [Complex<T>],
) {
transpose::transpose(input, output, self.width, self.height);
let height_scratch = if scratch.len() > input.len() {
&mut scratch[..]
} else {
&mut input[..]
};
self.height_size_fft
.process_with_scratch(output, height_scratch);
for (element, twiddle) in output.iter_mut().zip(self.twiddles.iter()) {
*element = *element * twiddle;
}
transpose::transpose(output, input, self.height, self.width);
let width_scratch = if scratch.len() > output.len() {
&mut scratch[..]
} else {
&mut output[..]
};
self.width_size_fft
.process_with_scratch(input, width_scratch);
transpose::transpose(input, output, self.width, self.height);
}
}
boilerplate_fft!(
MixedRadix,
|this: &MixedRadix<_>| this.twiddles.len(),
|this: &MixedRadix<_>| this.inplace_scratch_len,
|this: &MixedRadix<_>| this.outofplace_scratch_len
);
pub struct MixedRadixSmall<T> {
twiddles: Box<[Complex<T>]>,
width_size_fft: Arc<dyn Fft<T>>,
width: usize,
height_size_fft: Arc<dyn Fft<T>>,
height: usize,
direction: FftDirection,
}
impl<T: FftNum> MixedRadixSmall<T> {
pub fn new(width_fft: Arc<dyn Fft<T>>, height_fft: Arc<dyn Fft<T>>) -> Self {
assert_eq!(
width_fft.fft_direction(), height_fft.fft_direction(),
"width_fft and height_fft must have the same direction. got width direction={}, height direction={}",
width_fft.fft_direction(), height_fft.fft_direction());
let width = width_fft.len();
let height = height_fft.len();
let len = width * height;
assert_eq!(width_fft.get_outofplace_scratch_len(), 0, "MixedRadixSmall should only be used with algorithms that require 0 out-of-place scratch. Width FFT (len={}) requires {}, should require 0", width, width_fft.get_outofplace_scratch_len());
assert_eq!(height_fft.get_outofplace_scratch_len(), 0, "MixedRadixSmall should only be used with algorithms that require 0 out-of-place scratch. Height FFT (len={}) requires {}, should require 0", height, height_fft.get_outofplace_scratch_len());
assert!(width_fft.get_inplace_scratch_len() <= width, "MixedRadixSmall should only be used with algorithms that require little inplace scratch. Width FFT (len={}) requires {}, should require {} or less", width, width_fft.get_inplace_scratch_len(), width);
assert!(height_fft.get_inplace_scratch_len() <= height, "MixedRadixSmall should only be used with algorithms that require little inplace scratch. Height FFT (len={}) requires {}, should require {} or less", height, height_fft.get_inplace_scratch_len(), height);
let direction = width_fft.fft_direction();
let mut twiddles = vec![Complex::zero(); len];
for (x, twiddle_chunk) in twiddles.chunks_exact_mut(height).enumerate() {
for (y, twiddle_element) in twiddle_chunk.iter_mut().enumerate() {
*twiddle_element = twiddles::compute_twiddle(x * y, len, direction);
}
}
Self {
twiddles: twiddles.into_boxed_slice(),
width_size_fft: width_fft,
width: width,
height_size_fft: height_fft,
height: height,
direction,
}
}
fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
unsafe { array_utils::transpose_small(self.width, self.height, buffer, scratch) };
self.height_size_fft.process_with_scratch(scratch, buffer);
for (element, twiddle) in scratch.iter_mut().zip(self.twiddles.iter()) {
*element = *element * twiddle;
}
unsafe { array_utils::transpose_small(self.height, self.width, scratch, buffer) };
self.width_size_fft
.process_outofplace_with_scratch(buffer, scratch, &mut []);
unsafe { array_utils::transpose_small(self.width, self.height, scratch, buffer) };
}
fn perform_fft_out_of_place(
&self,
input: &mut [Complex<T>],
output: &mut [Complex<T>],
_scratch: &mut [Complex<T>],
) {
unsafe { array_utils::transpose_small(self.width, self.height, input, output) };
self.height_size_fft.process_with_scratch(output, input);
for (element, twiddle) in output.iter_mut().zip(self.twiddles.iter()) {
*element = *element * twiddle;
}
unsafe { array_utils::transpose_small(self.height, self.width, output, input) };
self.width_size_fft.process_with_scratch(input, output);
unsafe { array_utils::transpose_small(self.width, self.height, input, output) };
}
}
boilerplate_fft!(
MixedRadixSmall,
|this: &MixedRadixSmall<_>| this.twiddles.len(),
|this: &MixedRadixSmall<_>| this.len(),
|_| 0
);
#[cfg(test)]
mod unit_tests {
use super::*;
use crate::test_utils::check_fft_algorithm;
use crate::{algorithm::Dft, test_utils::BigScratchAlgorithm};
use num_traits::Zero;
use std::sync::Arc;
#[test]
fn test_mixed_radix() {
for width in 1..7 {
for height in 1..7 {
test_mixed_radix_with_lengths(width, height, FftDirection::Forward);
test_mixed_radix_with_lengths(width, height, FftDirection::Inverse);
}
}
}
#[test]
fn test_mixed_radix_small() {
for width in 2..7 {
for height in 2..7 {
test_mixed_radix_small_with_lengths(width, height, FftDirection::Forward);
test_mixed_radix_small_with_lengths(width, height, FftDirection::Inverse);
}
}
}
fn test_mixed_radix_with_lengths(width: usize, height: usize, direction: FftDirection) {
let width_fft = Arc::new(Dft::new(width, direction)) as Arc<dyn Fft<f32>>;
let height_fft = Arc::new(Dft::new(height, direction)) as Arc<dyn Fft<f32>>;
let fft = MixedRadix::new(width_fft, height_fft);
check_fft_algorithm(&fft, width * height, direction);
}
fn test_mixed_radix_small_with_lengths(width: usize, height: usize, direction: FftDirection) {
let width_fft = Arc::new(Dft::new(width, direction)) as Arc<dyn Fft<f32>>;
let height_fft = Arc::new(Dft::new(height, direction)) as Arc<dyn Fft<f32>>;
let fft = MixedRadixSmall::new(width_fft, height_fft);
check_fft_algorithm(&fft, width * height, direction);
}
#[test]
fn test_mixed_radix_inner_scratch() {
let scratch_lengths = [1, 5, 25];
let mut inner_ffts = Vec::new();
for &len in &scratch_lengths {
for &inplace_scratch in &scratch_lengths {
for &outofplace_scratch in &scratch_lengths {
inner_ffts.push(Arc::new(BigScratchAlgorithm {
len,
inplace_scratch,
outofplace_scratch,
direction: FftDirection::Forward,
}) as Arc<dyn Fft<f32>>);
}
}
}
for width_fft in inner_ffts.iter() {
for height_fft in inner_ffts.iter() {
let fft = MixedRadix::new(Arc::clone(width_fft), Arc::clone(height_fft));
let mut inplace_buffer = vec![Complex::zero(); fft.len()];
let mut inplace_scratch = vec![Complex::zero(); fft.get_inplace_scratch_len()];
fft.process_with_scratch(&mut inplace_buffer, &mut inplace_scratch);
let mut outofplace_input = vec![Complex::zero(); fft.len()];
let mut outofplace_output = vec![Complex::zero(); fft.len()];
let mut outofplace_scratch =
vec![Complex::zero(); fft.get_outofplace_scratch_len()];
fft.process_outofplace_with_scratch(
&mut outofplace_input,
&mut outofplace_output,
&mut outofplace_scratch,
);
}
}
}
}