1use crate::internal::*;
2use num_complex::Complex;
3use rustfft::num_traits::{Float, FromPrimitive};
4use rustfft::{FftDirection, FftNum};
5use tract_data::itertools::Itertools;
6use tract_ndarray::Axis;
7
8#[derive(Clone, Debug, Hash)]
9pub struct Fft {
10 pub axis: usize,
11 pub inverse: bool,
12}
13
14impl Fft {
15 fn eval_t<T: Datum + FftNum + FromPrimitive + Float>(
16 &self,
17 tensor: &mut Tensor,
18 ) -> TractResult<()> {
19 let mut iterator_shape: TVec<usize> = tensor.shape().into();
20 iterator_shape.pop(); iterator_shape[self.axis] = 1;
22 let len = tensor.shape()[self.axis];
23 let direction = if self.inverse { FftDirection::Inverse } else { FftDirection::Forward };
24 let fft = rustfft::FftPlanner::new().plan_fft(len, direction);
25 let mut tensor_dense = tensor.try_as_dense_mut()?;
26 let mut array = tensor_dense.to_array_view_mut::<T>()?;
27 let mut v = Vec::with_capacity(len);
28 for coords in tract_ndarray::indices(&*iterator_shape) {
29 v.clear();
30 let mut slice = array.slice_each_axis_mut(|ax| {
31 if ax.axis.index() == self.axis || ax.stride == 1 {
32 (..).into()
34 } else {
35 let c = coords[ax.axis.index()] as isize;
36 (c..=c).into()
37 }
38 });
39 v.extend(slice.iter().tuples().map(|(r, i)| Complex::new(*r, *i)));
40 fft.process(&mut v);
41 slice
42 .iter_mut()
43 .zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter()))
44 .for_each(|(s, v)| *s = v);
45 }
46 Ok(())
47 }
48}
49
50impl Op for Fft {
51 fn name(&self) -> StaticName {
52 "Fft".into()
53 }
54
55 fn info(&self) -> TractResult<Vec<String>> {
56 Ok(vec![if self.inverse { "inverse" } else { "forward" }.into()])
57 }
58
59 op_as_typed_op!();
60}
61
62impl EvalOp for Fft {
63 fn is_stateless(&self) -> bool {
64 true
65 }
66
67 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
68 let mut tensor = args_1!(inputs).into_tensor();
69 match tensor.datum_type() {
70 DatumType::F16 => {
71 let mut temp = tensor.cast_to::<f32>()?.into_owned();
72 self.eval_t::<f32>(&mut temp)?;
73 tensor = temp.cast_to::<f16>()?.into_owned();
74 }
75 DatumType::F32 => self.eval_t::<f32>(&mut tensor)?,
76 DatumType::F64 => self.eval_t::<f64>(&mut tensor)?,
77 _ => bail!("FFT not implemented for type {:?}", tensor.datum_type()),
78 }
79 Ok(tvec!(tensor.into_tvalue()))
80 }
81}
82
83impl TypedOp for Fft {
84 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
85 anyhow::ensure!(
86 inputs[0].rank() >= 2,
87 "Expect rank 2 (one for fft dimension, one for complex dimension"
88 );
89 anyhow::ensure!(
90 inputs[0].shape.last().unwrap() == &2.to_dim(),
91 "Fft operators expect inner (last) dimension to be 2 for real and imaginary part"
92 );
93 Ok(tvec!(inputs[0].without_value()))
94 }
95
96 as_op!();
97}
98
99#[derive(Clone, Debug, Hash)]
100pub struct Stft {
101 pub axis: usize,
102 pub frame: usize,
103 pub stride: usize,
104 pub window: Option<Arc<Tensor>>,
105}
106
107impl Stft {
108 fn eval_t<T: Datum + FftNum + FromPrimitive + Float>(
109 &self,
110 input: &Tensor,
111 ) -> TractResult<Tensor> {
112 let mut iterator_shape: TVec<usize> = input.shape().into();
113 iterator_shape.pop(); iterator_shape[self.axis] = 1;
115 let mut output_shape: TVec<usize> = input.shape().into();
116 let frames = (input.shape()[self.axis] - self.frame) / self.stride + 1;
117 output_shape.insert(self.axis, frames);
118 output_shape[self.axis + 1] = self.frame;
119 let mut output = unsafe { Tensor::uninitialized::<T>(&output_shape)? };
120 let fft = rustfft::FftPlanner::new().plan_fft_forward(self.frame);
121 let input = input.to_dense_array_view::<T>()?;
122 let mut output_dense = output.try_as_dense_mut()?;
123 let mut oview = output_dense.to_array_view_mut::<T>()?;
124 let mut v = Vec::with_capacity(self.frame);
125 for coords in tract_ndarray::indices(&*iterator_shape) {
126 let islice = input.slice_each_axis(|ax| {
127 if ax.axis.index() == self.axis || ax.stride == 1 {
128 (..).into()
129 } else {
130 let c = coords[ax.axis.index()] as isize;
131 (c..=c).into()
132 }
133 });
134 let mut oslice = oview.slice_each_axis_mut(|ax| {
135 if ax.stride == 1 {
136 (..).into()
137 } else if ax.axis.index() < self.axis {
138 let c = coords[ax.axis.index()] as isize;
139 (c..=c).into()
140 } else if ax.axis.index() == self.axis || ax.axis.index() == self.axis + 1 {
141 (..).into()
142 } else {
143 let c = coords[ax.axis.index() - 1] as isize;
144 (c..=c).into()
145 }
146 });
147 for f in 0..frames {
148 v.clear();
149 v.extend(
150 islice
151 .iter()
152 .tuples()
153 .skip(self.stride * f)
154 .take(self.frame)
155 .map(|(re, im)| Complex::new(*re, *im)),
156 );
157 if let Some(win) = &self.window {
158 let win = win.try_as_dense()?.as_slice::<T>()?;
159 let pad_left = (self.frame - win.len()) / 2;
161 v.iter_mut().enumerate().for_each(|(ix, v)| {
162 *v = if ix < pad_left || ix >= pad_left + win.len() {
163 Complex::new(T::zero(), T::zero())
164 } else {
165 *v * Complex::new(win[ix - pad_left], T::zero())
166 }
167 });
168 }
169 fft.process(&mut v);
170 oslice
171 .index_axis_mut(Axis(self.axis), f)
172 .iter_mut()
173 .zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter()))
174 .for_each(|(s, v)| *s = v);
175 }
176 }
177 Ok(output)
178 }
179}
180
181impl Op for Stft {
182 fn name(&self) -> StaticName {
183 "STFT".into()
184 }
185
186 op_as_typed_op!();
187}
188
189impl EvalOp for Stft {
190 fn is_stateless(&self) -> bool {
191 true
192 }
193
194 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
195 let input = args_1!(inputs);
196 let output = match input.datum_type() {
197 DatumType::F16 => {
198 let temp = input.cast_to::<f32>()?;
199 self.eval_t::<f32>(&temp)?.cast_to::<f16>()?.into_owned()
200 }
201 DatumType::F32 => self.eval_t::<f32>(&input)?,
202 DatumType::F64 => self.eval_t::<f64>(&input)?,
203 _ => bail!("FFT not implemented for type {:?}", input.datum_type()),
204 };
205 Ok(tvec!(output.into_tvalue()))
206 }
207}
208
209impl TypedOp for Stft {
210 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
211 anyhow::ensure!(
212 inputs[0].rank() >= 2,
213 "Expect rank 2 (one for fft dimension, one for complex dimension"
214 );
215 anyhow::ensure!(
216 inputs[0].shape.last().unwrap() == &2.to_dim(),
217 "Fft operators expect inner (last) dimension to be 2 for real and imaginary part"
218 );
219 let mut shape = inputs[0].shape.to_tvec();
220 let frames = (inputs[0].shape[self.axis].clone() - self.frame) / self.stride + 1;
221 shape[self.axis] = frames;
222 shape.insert(self.axis + 1, self.frame.to_dim());
223 Ok(tvec!(inputs[0].datum_type.fact(shape)))
224 }
225
226 as_op!();
227}