1use crate::internal::*;
2use num_complex::Complex;
3use rustfft::num_traits::{Float, FromPrimitive};
4use rustfft::{FftDirection, FftNum};
5use tract_data::itertools::Itertools;
6use tract_ndarray::Axis;
7
8#[derive(Clone, Debug, Hash)]
9pub struct Fft {
10 pub axis: usize,
11 pub inverse: bool,
12}
13
14impl Fft {
15 fn eval_t<T: Datum + FftNum + FromPrimitive + Float>(
16 &self,
17 tensor: &mut Tensor,
18 ) -> TractResult<()> {
19 let mut iterator_shape: TVec<usize> = tensor.shape().into();
20 iterator_shape.pop(); iterator_shape[self.axis] = 1;
22 let len = tensor.shape()[self.axis];
23 let direction = if self.inverse { FftDirection::Inverse } else { FftDirection::Forward };
24 let fft = rustfft::FftPlanner::new().plan_fft(len, direction);
25 let mut array = tensor.to_array_view_mut::<T>()?;
26 let mut v = Vec::with_capacity(len);
27 for coords in tract_ndarray::indices(&*iterator_shape) {
28 v.clear();
29 let mut slice = array.slice_each_axis_mut(|ax| {
30 if ax.axis.index() == self.axis || ax.stride == 1 {
31 (..).into()
33 } else {
34 let c = coords[ax.axis.index()] as isize;
35 (c..=c).into()
36 }
37 });
38 v.extend(slice.iter().tuples().map(|(r, i)| Complex::new(*r, *i)));
39 fft.process(&mut v);
40 slice
41 .iter_mut()
42 .zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter()))
43 .for_each(|(s, v)| *s = v);
44 }
45 Ok(())
46 }
47}
48
49impl Op for Fft {
50 fn name(&self) -> Cow<str> {
51 "Fft".into()
52 }
53
54 fn info(&self) -> TractResult<Vec<String>> {
55 Ok(vec![if self.inverse { "inverse" } else { "forward" }.into()])
56 }
57
58 op_as_typed_op!();
59}
60
61impl EvalOp for Fft {
62 fn is_stateless(&self) -> bool {
63 true
64 }
65
66 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
67 let mut tensor = args_1!(inputs).into_tensor();
68 match tensor.datum_type() {
69 DatumType::F16 => {
70 let mut temp = tensor.cast_to::<f32>()?.into_owned();
71 self.eval_t::<f32>(&mut temp)?;
72 tensor = temp.cast_to::<f16>()?.into_owned();
73 }
74 DatumType::F32 => self.eval_t::<f32>(&mut tensor)?,
75 DatumType::F64 => self.eval_t::<f64>(&mut tensor)?,
76 _ => bail!("FFT not implemented for type {:?}", tensor.datum_type()),
77 }
78 Ok(tvec!(tensor.into_tvalue()))
79 }
80}
81
82impl TypedOp for Fft {
83 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
84 anyhow::ensure!(
85 inputs[0].rank() >= 2,
86 "Expect rank 2 (one for fft dimension, one for complex dimension"
87 );
88 anyhow::ensure!(
89 inputs[0].shape.last().unwrap() == &2.to_dim(),
90 "Fft operators expect inner (last) dimension to be 2 for real and imaginary part"
91 );
92 Ok(tvec!(inputs[0].without_value()))
93 }
94
95 as_op!();
96}
97
98#[derive(Clone, Debug, Hash)]
99pub struct Stft {
100 pub axis: usize,
101 pub frame: usize,
102 pub stride: usize,
103 pub window: Option<Arc<Tensor>>,
104}
105
106impl Stft {
107 fn eval_t<T: Datum + FftNum + FromPrimitive + Float>(
108 &self,
109 input: &Tensor,
110 ) -> TractResult<Tensor> {
111 let mut iterator_shape: TVec<usize> = input.shape().into();
112 iterator_shape.pop(); iterator_shape[self.axis] = 1;
114 let mut output_shape: TVec<usize> = input.shape().into();
115 let frames = (input.shape()[self.axis] - self.frame) / self.stride + 1;
116 output_shape.insert(self.axis, frames);
117 output_shape[self.axis + 1] = self.frame;
118 let mut output = unsafe { Tensor::uninitialized::<T>(&output_shape)? };
119 let fft = rustfft::FftPlanner::new().plan_fft_forward(self.frame);
120 let input = input.to_array_view::<T>()?;
121 let mut oview = output.to_array_view_mut::<T>()?;
122 let mut v = Vec::with_capacity(self.frame);
123 for coords in tract_ndarray::indices(&*iterator_shape) {
124 let islice = input.slice_each_axis(|ax| {
125 if ax.axis.index() == self.axis || ax.stride == 1 {
126 (..).into()
127 } else {
128 let c = coords[ax.axis.index()] as isize;
129 (c..=c).into()
130 }
131 });
132 let mut oslice = oview.slice_each_axis_mut(|ax| {
133 if ax.stride == 1 {
134 (..).into()
135 } else if ax.axis.index() < self.axis {
136 let c = coords[ax.axis.index()] as isize;
137 (c..=c).into()
138 } else if ax.axis.index() == self.axis || ax.axis.index() == self.axis + 1 {
139 (..).into()
140 } else {
141 let c = coords[ax.axis.index() - 1] as isize;
142 (c..=c).into()
143 }
144 });
145 for f in 0..frames {
146 v.clear();
147 v.extend(
148 islice
149 .iter()
150 .tuples()
151 .skip(self.stride * f)
152 .take(self.frame)
153 .map(|(re, im)| Complex::new(*re, *im)),
154 );
155 if let Some(win) = &self.window {
156 let win = win.as_slice::<T>()?;
157 v.iter_mut()
158 .zip(win.iter())
159 .for_each(|(v, w)| *v = *v * Complex::new(*w, T::zero()));
160 }
161 fft.process(&mut v);
162 oslice
163 .index_axis_mut(Axis(self.axis), f)
164 .iter_mut()
165 .zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter()))
166 .for_each(|(s, v)| *s = v);
167 }
168 }
169 Ok(output)
170 }
171}
172
173impl Op for Stft {
174 fn name(&self) -> Cow<str> {
175 "STFT".into()
176 }
177
178 op_as_typed_op!();
179}
180
181impl EvalOp for Stft {
182 fn is_stateless(&self) -> bool {
183 true
184 }
185
186 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
187 let input = args_1!(inputs);
188 let output = match input.datum_type() {
189 DatumType::F16 => {
190 let temp = input.cast_to::<f32>()?;
191 self.eval_t::<f32>(&temp)?.cast_to::<f16>()?.into_owned()
192 }
193 DatumType::F32 => self.eval_t::<f32>(&input)?,
194 DatumType::F64 => self.eval_t::<f64>(&input)?,
195 _ => bail!("FFT not implemented for type {:?}", input.datum_type()),
196 };
197 Ok(tvec!(output.into_tvalue()))
198 }
199}
200
201impl TypedOp for Stft {
202 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
203 anyhow::ensure!(
204 inputs[0].rank() >= 2,
205 "Expect rank 2 (one for fft dimension, one for complex dimension"
206 );
207 anyhow::ensure!(
208 inputs[0].shape.last().unwrap() == &2.to_dim(),
209 "Fft operators expect inner (last) dimension to be 2 for real and imaginary part"
210 );
211 let mut shape = inputs[0].shape.to_tvec();
212 let frames = (inputs[0].shape[self.axis].clone() - self.frame) / self.stride + 1;
213 shape[self.axis] = frames;
214 shape.insert(self.axis + 1, self.frame.to_dim());
215 Ok(tvec!(inputs[0].datum_type.fact(shape)))
216 }
217
218 as_op!();
219}