1use crate::internal::*;
2use num_complex::Complex;
3use rustfft::num_traits::{Float, FromPrimitive};
4use rustfft::{FftDirection, FftNum};
5use tract_data::itertools::Itertools;
6use tract_ndarray::Axis as NdAxis;
7
8#[derive(Clone, Debug, Hash, PartialEq, Eq)]
9pub struct Fft {
10 pub axis: usize,
11 pub inverse: bool,
12}
13
14impl Fft {
15 fn eval_t<T: Datum + FftNum + FromPrimitive + Float>(
16 &self,
17 tensor: &mut Tensor,
18 ) -> TractResult<()> {
19 let mut iterator_shape: TVec<usize> = tensor.shape().into();
20 iterator_shape.pop(); iterator_shape[self.axis] = 1;
22 let len = tensor.shape()[self.axis];
23 let direction = if self.inverse { FftDirection::Inverse } else { FftDirection::Forward };
24 let fft = rustfft::FftPlanner::new().plan_fft(len, direction);
25 let mut tensor_plain = tensor.try_as_plain_mut()?;
26 let mut array = tensor_plain.to_array_view_mut::<T>()?;
27 let mut v = Vec::with_capacity(len);
28 for coords in tract_ndarray::indices(&*iterator_shape) {
29 v.clear();
30 let mut slice = array.slice_each_axis_mut(|ax| {
31 if ax.axis.index() == self.axis || ax.stride == 1 {
32 (..).into()
34 } else {
35 let c = coords[ax.axis.index()] as isize;
36 (c..=c).into()
37 }
38 });
39 v.extend(slice.iter().tuples().map(|(r, i)| Complex::new(*r, *i)));
40 fft.process(&mut v);
41 slice
42 .iter_mut()
43 .zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter()))
44 .for_each(|(s, v)| *s = v);
45 }
46 Ok(())
47 }
48}
49
50impl Op for Fft {
51 fn name(&self) -> StaticName {
52 "Fft".into()
53 }
54
55 fn info(&self) -> TractResult<Vec<String>> {
56 Ok(vec![if self.inverse { "inverse" } else { "forward" }.into()])
57 }
58
59 op_as_typed_op!();
60}
61
62impl EvalOp for Fft {
63 fn is_stateless(&self) -> bool {
64 true
65 }
66
67 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
68 let mut tensor = args_1!(inputs).into_tensor();
69 match tensor.datum_type() {
70 DatumType::F16 => {
71 let mut temp = tensor.cast_to::<f32>()?.into_owned();
72 self.eval_t::<f32>(&mut temp)?;
73 tensor = temp.cast_to::<f16>()?.into_owned();
74 }
75 DatumType::F32 => self.eval_t::<f32>(&mut tensor)?,
76 DatumType::F64 => self.eval_t::<f64>(&mut tensor)?,
77 _ => bail!("FFT not implemented for type {:?}", tensor.datum_type()),
78 }
79 Ok(tvec!(tensor.into_tvalue()))
80 }
81}
82
83impl TypedOp for Fft {
84 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
85 anyhow::ensure!(
86 inputs[0].rank() >= 2,
87 "Expect rank 2 (one for fft dimension, one for complex dimension"
88 );
89 anyhow::ensure!(
90 inputs[0].shape.last().unwrap() == &2.to_dim(),
91 "Fft operators expect inner (last) dimension to be 2 for real and imaginary part"
92 );
93 Ok(tvec!(inputs[0].without_value()))
94 }
95
96 fn axes_mapping(
97 &self,
98 inputs: &[&TypedFact],
99 _outputs: &[&TypedFact],
100 ) -> TractResult<AxesMapping> {
101 let rank = inputs[0].rank();
117 let complex_axis = rank - 1;
118 let mut axes = tvec!();
119 let mut alphabet = 'a'..;
120 for i in 0..rank {
121 if i == self.axis || i == complex_axis {
122 axes.push(crate::axes::Axis::new(alphabet.next().unwrap(), 1, 1).input(0, i));
123 axes.push(crate::axes::Axis::new(alphabet.next().unwrap(), 1, 1).output(0, i));
124 } else {
125 axes.push(
126 crate::axes::Axis::new(alphabet.next().unwrap(), 1, 1).input(0, i).output(0, i),
127 );
128 }
129 }
130 AxesMapping::new(1, 1, axes)
131 }
132
133 as_op!();
134}
135
136#[derive(Clone, Debug, Hash, PartialEq, Eq)]
137pub struct Stft {
138 pub axis: usize,
139 pub frame: usize,
140 pub stride: usize,
141 pub window: Option<Arc<Tensor>>,
142}
143
144impl Stft {
145 fn eval_t<T: Datum + FftNum + FromPrimitive + Float>(
146 &self,
147 input: &Tensor,
148 ) -> TractResult<Tensor> {
149 let mut iterator_shape: TVec<usize> = input.shape().into();
150 iterator_shape.pop(); iterator_shape[self.axis] = 1;
152 let mut output_shape: TVec<usize> = input.shape().into();
153 let frames = (input.shape()[self.axis] - self.frame) / self.stride + 1;
154 output_shape.insert(self.axis, frames);
155 output_shape[self.axis + 1] = self.frame;
156 let mut output = unsafe { Tensor::uninitialized::<T>(&output_shape)? };
157 let fft = rustfft::FftPlanner::new().plan_fft_forward(self.frame);
158 let input = input.to_plain_array_view::<T>()?;
159 let mut output_plain = output.try_as_plain_mut()?;
160 let mut oview = output_plain.to_array_view_mut::<T>()?;
161 let mut v = Vec::with_capacity(self.frame);
162 for coords in tract_ndarray::indices(&*iterator_shape) {
163 let islice = input.slice_each_axis(|ax| {
164 if ax.axis.index() == self.axis || ax.stride == 1 {
165 (..).into()
166 } else {
167 let c = coords[ax.axis.index()] as isize;
168 (c..=c).into()
169 }
170 });
171 let mut oslice = oview.slice_each_axis_mut(|ax| {
172 if ax.stride == 1 {
173 (..).into()
174 } else if ax.axis.index() < self.axis {
175 let c = coords[ax.axis.index()] as isize;
176 (c..=c).into()
177 } else if ax.axis.index() == self.axis || ax.axis.index() == self.axis + 1 {
178 (..).into()
179 } else {
180 let c = coords[ax.axis.index() - 1] as isize;
181 (c..=c).into()
182 }
183 });
184 for f in 0..frames {
185 v.clear();
186 v.extend(
187 islice
188 .iter()
189 .tuples()
190 .skip(self.stride * f)
191 .take(self.frame)
192 .map(|(re, im)| Complex::new(*re, *im)),
193 );
194 if let Some(win) = &self.window {
195 let win = win.try_as_plain()?.as_slice::<T>()?;
196 let pad_left = (self.frame - win.len()) / 2;
198 v.iter_mut().enumerate().for_each(|(ix, v)| {
199 *v = if ix < pad_left || ix >= pad_left + win.len() {
200 Complex::new(T::zero(), T::zero())
201 } else {
202 *v * Complex::new(win[ix - pad_left], T::zero())
203 }
204 });
205 }
206 fft.process(&mut v);
207 oslice
208 .index_axis_mut(NdAxis(self.axis), f)
209 .iter_mut()
210 .zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter()))
211 .for_each(|(s, v)| *s = v);
212 }
213 }
214 Ok(output)
215 }
216}
217
218impl Op for Stft {
219 fn name(&self) -> StaticName {
220 "STFT".into()
221 }
222
223 op_as_typed_op!();
224}
225
226impl EvalOp for Stft {
227 fn is_stateless(&self) -> bool {
228 true
229 }
230
231 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
232 let input = args_1!(inputs);
233 let output = match input.datum_type() {
234 DatumType::F16 => {
235 let temp = input.cast_to::<f32>()?;
236 self.eval_t::<f32>(&temp)?.cast_to::<f16>()?.into_owned()
237 }
238 DatumType::F32 => self.eval_t::<f32>(&input)?,
239 DatumType::F64 => self.eval_t::<f64>(&input)?,
240 _ => bail!("FFT not implemented for type {:?}", input.datum_type()),
241 };
242 Ok(tvec!(output.into_tvalue()))
243 }
244}
245
246impl TypedOp for Stft {
247 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
248 anyhow::ensure!(
249 inputs[0].rank() >= 2,
250 "Expect rank 2 (one for fft dimension, one for complex dimension"
251 );
252 anyhow::ensure!(
253 inputs[0].shape.last().unwrap() == &2.to_dim(),
254 "Fft operators expect inner (last) dimension to be 2 for real and imaginary part"
255 );
256 let mut shape = inputs[0].shape.to_tvec();
257 let frames = (inputs[0].shape[self.axis].clone() - self.frame) / self.stride + 1;
258 shape[self.axis] = frames;
259 shape.insert(self.axis + 1, self.frame.to_dim());
260 Ok(tvec!(inputs[0].datum_type.fact(shape)))
261 }
262
263 fn axes_mapping(
264 &self,
265 inputs: &[&TypedFact],
266 _outputs: &[&TypedFact],
267 ) -> TractResult<crate::axes::AxesMapping> {
268 let in_rank = inputs[0].rank();
285 let mut axes = tvec!();
286 let mut alphabet = 'a'..;
287 for i in 0..in_rank {
288 let out_axis = if i <= self.axis { i } else { i + 1 };
289 axes.push(
290 crate::axes::Axis::new(alphabet.next().unwrap(), 1, 1)
291 .input(0, i)
292 .output(0, out_axis),
293 );
294 }
295 axes.push(crate::axes::Axis::new(alphabet.next().unwrap(), 1, 1).output(0, self.axis + 1));
297 crate::axes::AxesMapping::new(1, 1, axes)
298 }
299
300 as_op!();
301}