Skip to main content

tract_core/ops/
fft.rs

1use crate::internal::*;
2use num_complex::Complex;
3use rustfft::num_traits::{Float, FromPrimitive};
4use rustfft::{FftDirection, FftNum};
5use tract_data::itertools::Itertools;
6use tract_ndarray::Axis as NdAxis;
7
8#[derive(Clone, Debug, Hash, PartialEq, Eq)]
9pub struct Fft {
10    pub axis: usize,
11    pub inverse: bool,
12}
13
14impl Fft {
15    fn eval_t<T: Datum + FftNum + FromPrimitive + Float>(
16        &self,
17        tensor: &mut Tensor,
18    ) -> TractResult<()> {
19        let mut iterator_shape: TVec<usize> = tensor.shape().into();
20        iterator_shape.pop(); // last dim is [re, im]
21        iterator_shape[self.axis] = 1;
22        let len = tensor.shape()[self.axis];
23        let direction = if self.inverse { FftDirection::Inverse } else { FftDirection::Forward };
24        let fft = rustfft::FftPlanner::new().plan_fft(len, direction);
25        let mut tensor_plain = tensor.try_as_plain_mut()?;
26        let mut array = tensor_plain.to_array_view_mut::<T>()?;
27        let mut v = Vec::with_capacity(len);
28        for coords in tract_ndarray::indices(&*iterator_shape) {
29            v.clear();
30            let mut slice = array.slice_each_axis_mut(|ax| {
31                if ax.axis.index() == self.axis || ax.stride == 1 {
32                    // ax.stride == 1 => last dim
33                    (..).into()
34                } else {
35                    let c = coords[ax.axis.index()] as isize;
36                    (c..=c).into()
37                }
38            });
39            v.extend(slice.iter().tuples().map(|(r, i)| Complex::new(*r, *i)));
40            fft.process(&mut v);
41            slice
42                .iter_mut()
43                .zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter()))
44                .for_each(|(s, v)| *s = v);
45        }
46        Ok(())
47    }
48}
49
50impl Op for Fft {
51    fn name(&self) -> StaticName {
52        "Fft".into()
53    }
54
55    fn info(&self) -> TractResult<Vec<String>> {
56        Ok(vec![if self.inverse { "inverse" } else { "forward" }.into()])
57    }
58
59    op_as_typed_op!();
60}
61
62impl EvalOp for Fft {
63    fn is_stateless(&self) -> bool {
64        true
65    }
66
67    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
68        let mut tensor = args_1!(inputs).into_tensor();
69        match tensor.datum_type() {
70            DatumType::F16 => {
71                let mut temp = tensor.cast_to::<f32>()?.into_owned();
72                self.eval_t::<f32>(&mut temp)?;
73                tensor = temp.cast_to::<f16>()?.into_owned();
74            }
75            DatumType::F32 => self.eval_t::<f32>(&mut tensor)?,
76            DatumType::F64 => self.eval_t::<f64>(&mut tensor)?,
77            _ => bail!("FFT not implemented for type {:?}", tensor.datum_type()),
78        }
79        Ok(tvec!(tensor.into_tvalue()))
80    }
81}
82
83impl TypedOp for Fft {
84    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
85        anyhow::ensure!(
86            inputs[0].rank() >= 2,
87            "Expect rank 2 (one for fft dimension, one for complex dimension"
88        );
89        anyhow::ensure!(
90            inputs[0].shape.last().unwrap() == &2.to_dim(),
91            "Fft operators expect inner (last) dimension to be 2 for real and imaginary part"
92        );
93        Ok(tvec!(inputs[0].without_value()))
94    }
95
96    fn axes_mapping(
97        &self,
98        inputs: &[&TypedFact],
99        _outputs: &[&TypedFact],
100    ) -> TractResult<AxesMapping> {
101        // Fft is rank-preserving but it is NOT axes-natural: two axes do
102        // not map 1-to-1 from input to output and must be declared as a
103        // separate input-only and output-only axis.
104        //
105        //   - the FFT axis (`self.axis`): every output sample along it
106        //     depends on every input sample, so the axis cannot be
107        //     sliced or streamed.
108        //   - the trailing complex axis (`rank - 1`): the FFT mixes the
109        //     real and imaginary parts, so re/im do not map 1-to-1.
110        //
111        // Splitting them is exactly what makes the generic pulse fallback
112        // bail when asked to track a streaming axis through the FFT or
113        // complex axis, while every genuine batch axis stays 1-to-1 and
114        // is handled by the per-pulse `PulseWrappingOp`. No dedicated
115        // `Fft` pulsifier is needed.
116        let rank = inputs[0].rank();
117        let complex_axis = rank - 1;
118        let mut axes = tvec!();
119        let mut alphabet = 'a'..;
120        for i in 0..rank {
121            if i == self.axis || i == complex_axis {
122                axes.push(crate::axes::Axis::new(alphabet.next().unwrap(), 1, 1).input(0, i));
123                axes.push(crate::axes::Axis::new(alphabet.next().unwrap(), 1, 1).output(0, i));
124            } else {
125                axes.push(
126                    crate::axes::Axis::new(alphabet.next().unwrap(), 1, 1).input(0, i).output(0, i),
127                );
128            }
129        }
130        AxesMapping::new(1, 1, axes)
131    }
132
133    as_op!();
134}
135
136#[derive(Clone, Debug, Hash, PartialEq, Eq)]
137pub struct Stft {
138    pub axis: usize,
139    pub frame: usize,
140    pub stride: usize,
141    pub window: Option<Arc<Tensor>>,
142}
143
144impl Stft {
145    fn eval_t<T: Datum + FftNum + FromPrimitive + Float>(
146        &self,
147        input: &Tensor,
148    ) -> TractResult<Tensor> {
149        let mut iterator_shape: TVec<usize> = input.shape().into();
150        iterator_shape.pop(); // [re,im]
151        iterator_shape[self.axis] = 1;
152        let mut output_shape: TVec<usize> = input.shape().into();
153        let frames = (input.shape()[self.axis] - self.frame) / self.stride + 1;
154        output_shape.insert(self.axis, frames);
155        output_shape[self.axis + 1] = self.frame;
156        let mut output = unsafe { Tensor::uninitialized::<T>(&output_shape)? };
157        let fft = rustfft::FftPlanner::new().plan_fft_forward(self.frame);
158        let input = input.to_plain_array_view::<T>()?;
159        let mut output_plain = output.try_as_plain_mut()?;
160        let mut oview = output_plain.to_array_view_mut::<T>()?;
161        let mut v = Vec::with_capacity(self.frame);
162        for coords in tract_ndarray::indices(&*iterator_shape) {
163            let islice = input.slice_each_axis(|ax| {
164                if ax.axis.index() == self.axis || ax.stride == 1 {
165                    (..).into()
166                } else {
167                    let c = coords[ax.axis.index()] as isize;
168                    (c..=c).into()
169                }
170            });
171            let mut oslice = oview.slice_each_axis_mut(|ax| {
172                if ax.stride == 1 {
173                    (..).into()
174                } else if ax.axis.index() < self.axis {
175                    let c = coords[ax.axis.index()] as isize;
176                    (c..=c).into()
177                } else if ax.axis.index() == self.axis || ax.axis.index() == self.axis + 1 {
178                    (..).into()
179                } else {
180                    let c = coords[ax.axis.index() - 1] as isize;
181                    (c..=c).into()
182                }
183            });
184            for f in 0..frames {
185                v.clear();
186                v.extend(
187                    islice
188                        .iter()
189                        .tuples()
190                        .skip(self.stride * f)
191                        .take(self.frame)
192                        .map(|(re, im)| Complex::new(*re, *im)),
193                );
194                if let Some(win) = &self.window {
195                    let win = win.try_as_plain()?.as_slice::<T>()?;
196                    // symmetric padding in case window is smaller than frames (aka n fft)
197                    let pad_left = (self.frame - win.len()) / 2;
198                    v.iter_mut().enumerate().for_each(|(ix, v)| {
199                        *v = if ix < pad_left || ix >= pad_left + win.len() {
200                            Complex::new(T::zero(), T::zero())
201                        } else {
202                            *v * Complex::new(win[ix - pad_left], T::zero())
203                        }
204                    });
205                }
206                fft.process(&mut v);
207                oslice
208                    .index_axis_mut(NdAxis(self.axis), f)
209                    .iter_mut()
210                    .zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter()))
211                    .for_each(|(s, v)| *s = v);
212            }
213        }
214        Ok(output)
215    }
216}
217
218impl Op for Stft {
219    fn name(&self) -> StaticName {
220        "STFT".into()
221    }
222
223    op_as_typed_op!();
224}
225
226impl EvalOp for Stft {
227    fn is_stateless(&self) -> bool {
228        true
229    }
230
231    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
232        let input = args_1!(inputs);
233        let output = match input.datum_type() {
234            DatumType::F16 => {
235                let temp = input.cast_to::<f32>()?;
236                self.eval_t::<f32>(&temp)?.cast_to::<f16>()?.into_owned()
237            }
238            DatumType::F32 => self.eval_t::<f32>(&input)?,
239            DatumType::F64 => self.eval_t::<f64>(&input)?,
240            _ => bail!("FFT not implemented for type {:?}", input.datum_type()),
241        };
242        Ok(tvec!(output.into_tvalue()))
243    }
244}
245
246impl TypedOp for Stft {
247    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
248        anyhow::ensure!(
249            inputs[0].rank() >= 2,
250            "Expect rank 2 (one for fft dimension, one for complex dimension"
251        );
252        anyhow::ensure!(
253            inputs[0].shape.last().unwrap() == &2.to_dim(),
254            "Fft operators expect inner (last) dimension to be 2 for real and imaginary part"
255        );
256        let mut shape = inputs[0].shape.to_tvec();
257        let frames = (inputs[0].shape[self.axis].clone() - self.frame) / self.stride + 1;
258        shape[self.axis] = frames;
259        shape.insert(self.axis + 1, self.frame.to_dim());
260        Ok(tvec!(inputs[0].datum_type.fact(shape)))
261    }
262
263    fn axes_mapping(
264        &self,
265        inputs: &[&TypedFact],
266        _outputs: &[&TypedFact],
267    ) -> TractResult<crate::axes::AxesMapping> {
268        // Stft is NOT rank-preserving: it inserts a frame axis at
269        // `axis + 1`. The mapping is:
270        //   - axes 0..self.axis (leading dims): 1-to-1 input <-> output.
271        //   - input axis `self.axis` (the time axis) <-> output axis
272        //     `self.axis` (now the n_frames axis -- same position, the
273        //     dim shrinks from `T` to `(T - frame) / stride + 1`).
274        //   - output axis `self.axis + 1` (the inserted frame axis):
275        //     output-only, no input correspondence.
276        //   - input axes `self.axis + 1..rank` (trailing dims incl.
277        //     the complex pair) <-> output axes `self.axis + 2..rank+1`
278        //     (shifted right by 1 to make room for the frame axis).
279        //
280        // Without this mapping the generic `PulseWrappingOp` fallback
281        // bails with "could not track pulsing axis" the moment a user
282        // streams a non-time axis through STFT (typical pattern: a
283        // batched STFT pipeline that streams the batch axis).
284        let in_rank = inputs[0].rank();
285        let mut axes = tvec!();
286        let mut alphabet = 'a'..;
287        for i in 0..in_rank {
288            let out_axis = if i <= self.axis { i } else { i + 1 };
289            axes.push(
290                crate::axes::Axis::new(alphabet.next().unwrap(), 1, 1)
291                    .input(0, i)
292                    .output(0, out_axis),
293            );
294        }
295        // Inserted frame axis (output-only).
296        axes.push(crate::axes::Axis::new(alphabet.next().unwrap(), 1, 1).output(0, self.axis + 1));
297        crate::axes::AxesMapping::new(1, 1, axes)
298    }
299
300    as_op!();
301}