use crate::internal::*;
use num_complex::Complex;
use rustfft::num_traits::{Float, FromPrimitive};
use rustfft::{FftDirection, FftNum};
use tract_data::itertools::Itertools;
use tract_ndarray::Axis;
#[derive(Clone, Debug, Hash)]
pub struct Fft {
pub axis: usize,
pub inverse: bool,
}
impl Fft {
fn eval_t<T: Datum + FftNum + FromPrimitive + Float>(
&self,
tensor: &mut Tensor,
) -> TractResult<()> {
let mut iterator_shape: TVec<usize> = tensor.shape().into();
iterator_shape.pop(); iterator_shape[self.axis] = 1;
let len = tensor.shape()[self.axis];
let direction = if self.inverse { FftDirection::Inverse } else { FftDirection::Forward };
let fft = rustfft::FftPlanner::new().plan_fft(len, direction);
let mut array = tensor.to_array_view_mut::<T>()?;
let mut v = Vec::with_capacity(len);
for coords in tract_ndarray::indices(&*iterator_shape) {
v.clear();
let mut slice = array.slice_each_axis_mut(|ax| {
if ax.axis.index() == self.axis || ax.stride == 1 {
(..).into()
} else {
let c = coords[ax.axis.index()] as isize;
(c..=c).into()
}
});
v.extend(slice.iter().tuples().map(|(r, i)| Complex::new(*r, *i)));
fft.process(&mut v);
slice
.iter_mut()
.zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter()))
.for_each(|(s, v)| *s = v);
}
Ok(())
}
}
impl Op for Fft {
fn name(&self) -> Cow<str> {
"Fft".into()
}
fn info(&self) -> TractResult<Vec<String>> {
Ok(vec![if self.inverse { "inverse" } else { "forward" }.into()])
}
op_as_typed_op!();
}
impl EvalOp for Fft {
fn is_stateless(&self) -> bool {
true
}
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let mut tensor = args_1!(inputs).into_tensor();
match tensor.datum_type() {
DatumType::F16 => {
let mut temp = tensor.cast_to::<f32>()?.into_owned();
self.eval_t::<f32>(&mut temp)?;
tensor = temp.cast_to::<f16>()?.into_owned();
}
DatumType::F32 => self.eval_t::<f32>(&mut tensor)?,
DatumType::F64 => self.eval_t::<f64>(&mut tensor)?,
_ => bail!("FFT not implemented for type {:?}", tensor.datum_type()),
}
Ok(tvec!(tensor.into_tvalue()))
}
}
impl TypedOp for Fft {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
anyhow::ensure!(
inputs[0].rank() >= 2,
"Expect rank 2 (one for fft dimension, one for complex dimension"
);
anyhow::ensure!(
inputs[0].shape.last().unwrap() == &2.to_dim(),
"Fft operators expect inner (last) dimension to be 2 for real and imaginary part"
);
Ok(tvec!(inputs[0].without_value()))
}
as_op!();
}
#[derive(Clone, Debug, Hash)]
pub struct Stft {
pub axis: usize,
pub frame: usize,
pub stride: usize,
pub window: Option<Arc<Tensor>>,
}
impl Stft {
fn eval_t<T: Datum + FftNum + FromPrimitive + Float>(
&self,
input: &Tensor,
) -> TractResult<Tensor> {
let mut iterator_shape: TVec<usize> = input.shape().into();
iterator_shape.pop(); iterator_shape[self.axis] = 1;
let mut output_shape: TVec<usize> = input.shape().into();
let frames = (input.shape()[self.axis] - self.frame) / self.stride + 1;
output_shape.insert(self.axis, frames);
output_shape[self.axis + 1] = self.frame;
let mut output = unsafe { Tensor::uninitialized::<T>(&output_shape)? };
let fft = rustfft::FftPlanner::new().plan_fft_forward(self.frame);
let input = input.to_array_view::<T>()?;
let mut oview = output.to_array_view_mut::<T>()?;
let mut v = Vec::with_capacity(self.frame);
for coords in tract_ndarray::indices(&*iterator_shape) {
let islice = input.slice_each_axis(|ax| {
if ax.axis.index() == self.axis || ax.stride == 1 {
(..).into()
} else {
let c = coords[ax.axis.index()] as isize;
(c..=c).into()
}
});
let mut oslice = oview.slice_each_axis_mut(|ax| {
if ax.stride == 1 {
(..).into()
} else if ax.axis.index() < self.axis {
let c = coords[ax.axis.index()] as isize;
(c..=c).into()
} else if ax.axis.index() == self.axis || ax.axis.index() == self.axis + 1 {
(..).into()
} else {
let c = coords[ax.axis.index() - 1] as isize;
(c..=c).into()
}
});
for f in 0..frames {
v.clear();
v.extend(
islice
.iter()
.tuples()
.skip(self.stride * f)
.take(self.frame)
.map(|(re, im)| Complex::new(*re, *im)),
);
if let Some(win) = &self.window {
let win = win.as_slice::<T>()?;
v.iter_mut()
.zip(win.iter())
.for_each(|(v, w)| *v = *v * Complex::new(*w, T::zero()));
}
fft.process(&mut v);
oslice
.index_axis_mut(Axis(self.axis), f)
.iter_mut()
.zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter()))
.for_each(|(s, v)| *s = v);
}
}
Ok(output)
}
}
impl Op for Stft {
fn name(&self) -> Cow<str> {
"STFT".into()
}
op_as_typed_op!();
}
impl EvalOp for Stft {
fn is_stateless(&self) -> bool {
true
}
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let input = args_1!(inputs);
let output = match input.datum_type() {
DatumType::F16 => {
let temp = input.cast_to::<f32>()?;
self.eval_t::<f32>(&temp)?.cast_to::<f16>()?.into_owned()
}
DatumType::F32 => self.eval_t::<f32>(&input)?,
DatumType::F64 => self.eval_t::<f64>(&input)?,
_ => bail!("FFT not implemented for type {:?}", input.datum_type()),
};
Ok(tvec!(output.into_tvalue()))
}
}
impl TypedOp for Stft {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
anyhow::ensure!(
inputs[0].rank() >= 2,
"Expect rank 2 (one for fft dimension, one for complex dimension"
);
anyhow::ensure!(
inputs[0].shape.last().unwrap() == &2.to_dim(),
"Fft operators expect inner (last) dimension to be 2 for real and imaginary part"
);
let mut shape = inputs[0].shape.to_tvec();
let frames = (inputs[0].shape[self.axis].clone() - self.frame) / self.stride + 1;
shape[self.axis] = frames;
shape.insert(self.axis + 1, self.frame.to_dim());
Ok(tvec!(inputs[0].datum_type.fact(shape)))
}
as_op!();
}