use num_traits::{FromPrimitive, Signed};
use std::fmt::Debug;
pub trait FftNum: Copy + FromPrimitive + Signed + Sync + Send + Debug + 'static {}
impl<T> FftNum for T where T: Copy + FromPrimitive + Signed + Sync + Send + Debug + 'static {}
#[cold]
#[inline(never)]
pub fn fft_error_inplace(
expected_len: usize,
actual_len: usize,
expected_scratch: usize,
actual_scratch: usize,
) {
assert!(
actual_len >= expected_len,
"Provided FFT buffer was too small. Expected len = {}, got len = {}",
expected_len,
actual_len
);
assert_eq!(
actual_len % expected_len,
0,
"Input FFT buffer must be a multiple of FFT length. Expected multiple of {}, got len = {}",
expected_len,
actual_len
);
assert!(
actual_scratch >= expected_scratch,
"Not enough scratch space was provided. Expected scratch len >= {}, got scratch len = {}",
expected_scratch,
actual_scratch
);
}
#[cold]
#[inline(never)]
pub fn fft_error_outofplace(
expected_len: usize,
actual_input: usize,
actual_output: usize,
expected_scratch: usize,
actual_scratch: usize,
) {
assert_eq!(actual_input, actual_output, "Provided FFT input buffer and output buffer must have the same length. Got input.len() = {}, output.len() = {}", actual_input, actual_output);
assert!(
actual_input >= expected_len,
"Provided FFT buffer was too small. Expected len = {}, got len = {}",
expected_len,
actual_input
);
assert_eq!(
actual_input % expected_len,
0,
"Input FFT buffer must be a multiple of FFT length. Expected multiple of {}, got len = {}",
expected_len,
actual_input
);
assert!(
actual_scratch >= expected_scratch,
"Not enough scratch space was provided. Expected scratch len >= {}, got scratch len = {}",
expected_scratch,
actual_scratch
);
}
macro_rules! boilerplate_fft_oop {
($struct_name:ident, $len_fn:expr) => {
impl<T: FftNum> Fft<T> for $struct_name<T> {
fn process_outofplace_with_scratch(
&self,
input: &mut [Complex<T>],
output: &mut [Complex<T>],
_scratch: &mut [Complex<T>],
) {
if self.len() == 0 {
return;
}
if input.len() < self.len() || output.len() != input.len() {
fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0);
return;
}
let result = array_utils::iter_chunks_zipped(
input,
output,
self.len(),
|in_chunk, out_chunk| {
self.perform_fft_out_of_place(in_chunk, out_chunk, &mut [])
},
);
if result.is_err() {
fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0);
}
}
fn process_with_scratch(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
if self.len() == 0 {
return;
}
let required_scratch = self.get_inplace_scratch_len();
if scratch.len() < required_scratch || buffer.len() < self.len() {
fft_error_inplace(
self.len(),
buffer.len(),
self.get_inplace_scratch_len(),
scratch.len(),
);
return;
}
let scratch = &mut scratch[..required_scratch];
let result = array_utils::iter_chunks(buffer, self.len(), |chunk| {
self.perform_fft_out_of_place(chunk, scratch, &mut []);
chunk.copy_from_slice(scratch);
});
if result.is_err() {
fft_error_inplace(
self.len(),
buffer.len(),
self.get_inplace_scratch_len(),
scratch.len(),
);
}
}
#[inline(always)]
fn get_inplace_scratch_len(&self) -> usize {
self.len()
}
#[inline(always)]
fn get_outofplace_scratch_len(&self) -> usize {
0
}
}
impl<T> Length for $struct_name<T> {
#[inline(always)]
fn len(&self) -> usize {
$len_fn(self)
}
}
impl<T> Direction for $struct_name<T> {
#[inline(always)]
fn fft_direction(&self) -> FftDirection {
self.direction
}
}
};
}
macro_rules! boilerplate_fft {
($struct_name:ident, $len_fn:expr, $inplace_scratch_len_fn:expr, $out_of_place_scratch_len_fn:expr) => {
impl<T: FftNum> Fft<T> for $struct_name<T> {
fn process_outofplace_with_scratch(
&self,
input: &mut [Complex<T>],
output: &mut [Complex<T>],
scratch: &mut [Complex<T>],
) {
if self.len() == 0 {
return;
}
let required_scratch = self.get_outofplace_scratch_len();
if scratch.len() < required_scratch
|| input.len() < self.len()
|| output.len() != input.len()
{
fft_error_outofplace(
self.len(),
input.len(),
output.len(),
self.get_outofplace_scratch_len(),
scratch.len(),
);
return;
}
let scratch = &mut scratch[..required_scratch];
let result = array_utils::iter_chunks_zipped(
input,
output,
self.len(),
|in_chunk, out_chunk| {
self.perform_fft_out_of_place(in_chunk, out_chunk, scratch)
},
);
if result.is_err() {
fft_error_outofplace(
self.len(),
input.len(),
output.len(),
self.get_outofplace_scratch_len(),
scratch.len(),
);
}
}
fn process_with_scratch(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
if self.len() == 0 {
return;
}
let required_scratch = self.get_inplace_scratch_len();
if scratch.len() < required_scratch || buffer.len() < self.len() {
fft_error_inplace(
self.len(),
buffer.len(),
self.get_inplace_scratch_len(),
scratch.len(),
);
return;
}
let scratch = &mut scratch[..required_scratch];
let result = array_utils::iter_chunks(buffer, self.len(), |chunk| {
self.perform_fft_inplace(chunk, scratch)
});
if result.is_err() {
fft_error_inplace(
self.len(),
buffer.len(),
self.get_inplace_scratch_len(),
scratch.len(),
);
}
}
#[inline(always)]
fn get_inplace_scratch_len(&self) -> usize {
$inplace_scratch_len_fn(self)
}
#[inline(always)]
fn get_outofplace_scratch_len(&self) -> usize {
$out_of_place_scratch_len_fn(self)
}
}
impl<T: FftNum> Length for $struct_name<T> {
#[inline(always)]
fn len(&self) -> usize {
$len_fn(self)
}
}
impl<T: FftNum> Direction for $struct_name<T> {
#[inline(always)]
fn fft_direction(&self) -> FftDirection {
self.direction
}
}
};
}