tract_core/ops/
fft.rs

1use crate::internal::*;
2use num_complex::Complex;
3use rustfft::num_traits::{Float, FromPrimitive};
4use rustfft::{FftDirection, FftNum};
5use tract_data::itertools::Itertools;
6use tract_ndarray::Axis;
7
8#[derive(Clone, Debug, Hash)]
9pub struct Fft {
10    pub axis: usize,
11    pub inverse: bool,
12}
13
14impl Fft {
15    fn eval_t<T: Datum + FftNum + FromPrimitive + Float>(
16        &self,
17        tensor: &mut Tensor,
18    ) -> TractResult<()> {
19        let mut iterator_shape: TVec<usize> = tensor.shape().into();
20        iterator_shape.pop(); // last dim is [re, im]
21        iterator_shape[self.axis] = 1;
22        let len = tensor.shape()[self.axis];
23        let direction = if self.inverse { FftDirection::Inverse } else { FftDirection::Forward };
24        let fft = rustfft::FftPlanner::new().plan_fft(len, direction);
25        let mut array = tensor.to_array_view_mut::<T>()?;
26        let mut v = Vec::with_capacity(len);
27        for coords in tract_ndarray::indices(&*iterator_shape) {
28            v.clear();
29            let mut slice = array.slice_each_axis_mut(|ax| {
30                if ax.axis.index() == self.axis || ax.stride == 1 {
31                    // ax.stride == 1 => last dim
32                    (..).into()
33                } else {
34                    let c = coords[ax.axis.index()] as isize;
35                    (c..=c).into()
36                }
37            });
38            v.extend(slice.iter().tuples().map(|(r, i)| Complex::new(*r, *i)));
39            fft.process(&mut v);
40            slice
41                .iter_mut()
42                .zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter()))
43                .for_each(|(s, v)| *s = v);
44        }
45        Ok(())
46    }
47}
48
49impl Op for Fft {
50    fn name(&self) -> Cow<str> {
51        "Fft".into()
52    }
53
54    fn info(&self) -> TractResult<Vec<String>> {
55        Ok(vec![if self.inverse { "inverse" } else { "forward" }.into()])
56    }
57
58    op_as_typed_op!();
59}
60
61impl EvalOp for Fft {
62    fn is_stateless(&self) -> bool {
63        true
64    }
65
66    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
67        let mut tensor = args_1!(inputs).into_tensor();
68        match tensor.datum_type() {
69            DatumType::F16 => {
70                let mut temp = tensor.cast_to::<f32>()?.into_owned();
71                self.eval_t::<f32>(&mut temp)?;
72                tensor = temp.cast_to::<f16>()?.into_owned();
73            }
74            DatumType::F32 => self.eval_t::<f32>(&mut tensor)?,
75            DatumType::F64 => self.eval_t::<f64>(&mut tensor)?,
76            _ => bail!("FFT not implemented for type {:?}", tensor.datum_type()),
77        }
78        Ok(tvec!(tensor.into_tvalue()))
79    }
80}
81
82impl TypedOp for Fft {
83    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
84        anyhow::ensure!(
85            inputs[0].rank() >= 2,
86            "Expect rank 2 (one for fft dimension, one for complex dimension"
87        );
88        anyhow::ensure!(
89            inputs[0].shape.last().unwrap() == &2.to_dim(),
90            "Fft operators expect inner (last) dimension to be 2 for real and imaginary part"
91        );
92        Ok(tvec!(inputs[0].without_value()))
93    }
94
95    as_op!();
96}
97
98#[derive(Clone, Debug, Hash)]
99pub struct Stft {
100    pub axis: usize,
101    pub frame: usize,
102    pub stride: usize,
103    pub window: Option<Arc<Tensor>>,
104}
105
106impl Stft {
107    fn eval_t<T: Datum + FftNum + FromPrimitive + Float>(
108        &self,
109        input: &Tensor,
110    ) -> TractResult<Tensor> {
111        let mut iterator_shape: TVec<usize> = input.shape().into();
112        iterator_shape.pop(); // [re,im]
113        iterator_shape[self.axis] = 1;
114        let mut output_shape: TVec<usize> = input.shape().into();
115        let frames = (input.shape()[self.axis] - self.frame) / self.stride + 1;
116        output_shape.insert(self.axis, frames);
117        output_shape[self.axis + 1] = self.frame;
118        let mut output = unsafe { Tensor::uninitialized::<T>(&output_shape)? };
119        let fft = rustfft::FftPlanner::new().plan_fft_forward(self.frame);
120        let input = input.to_array_view::<T>()?;
121        let mut oview = output.to_array_view_mut::<T>()?;
122        let mut v = Vec::with_capacity(self.frame);
123        for coords in tract_ndarray::indices(&*iterator_shape) {
124            let islice = input.slice_each_axis(|ax| {
125                if ax.axis.index() == self.axis || ax.stride == 1 {
126                    (..).into()
127                } else {
128                    let c = coords[ax.axis.index()] as isize;
129                    (c..=c).into()
130                }
131            });
132            let mut oslice = oview.slice_each_axis_mut(|ax| {
133                if ax.stride == 1 {
134                    (..).into()
135                } else if ax.axis.index() < self.axis {
136                    let c = coords[ax.axis.index()] as isize;
137                    (c..=c).into()
138                } else if ax.axis.index() == self.axis || ax.axis.index() == self.axis + 1 {
139                    (..).into()
140                } else {
141                    let c = coords[ax.axis.index() - 1] as isize;
142                    (c..=c).into()
143                }
144            });
145            for f in 0..frames {
146                v.clear();
147                v.extend(
148                    islice
149                        .iter()
150                        .tuples()
151                        .skip(self.stride * f)
152                        .take(self.frame)
153                        .map(|(re, im)| Complex::new(*re, *im)),
154                );
155                if let Some(win) = &self.window {
156                    let win = win.as_slice::<T>()?;
157                    v.iter_mut()
158                        .zip(win.iter())
159                        .for_each(|(v, w)| *v = *v * Complex::new(*w, T::zero()));
160                }
161                fft.process(&mut v);
162                oslice
163                    .index_axis_mut(Axis(self.axis), f)
164                    .iter_mut()
165                    .zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter()))
166                    .for_each(|(s, v)| *s = v);
167            }
168        }
169        Ok(output)
170    }
171}
172
173impl Op for Stft {
174    fn name(&self) -> Cow<str> {
175        "STFT".into()
176    }
177
178    op_as_typed_op!();
179}
180
181impl EvalOp for Stft {
182    fn is_stateless(&self) -> bool {
183        true
184    }
185
186    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
187        let input = args_1!(inputs);
188        let output = match input.datum_type() {
189            DatumType::F16 => {
190                let temp = input.cast_to::<f32>()?;
191                self.eval_t::<f32>(&temp)?.cast_to::<f16>()?.into_owned()
192            }
193            DatumType::F32 => self.eval_t::<f32>(&input)?,
194            DatumType::F64 => self.eval_t::<f64>(&input)?,
195            _ => bail!("FFT not implemented for type {:?}", input.datum_type()),
196        };
197        Ok(tvec!(output.into_tvalue()))
198    }
199}
200
201impl TypedOp for Stft {
202    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
203        anyhow::ensure!(
204            inputs[0].rank() >= 2,
205            "Expect rank 2 (one for fft dimension, one for complex dimension"
206        );
207        anyhow::ensure!(
208            inputs[0].shape.last().unwrap() == &2.to_dim(),
209            "Fft operators expect inner (last) dimension to be 2 for real and imaginary part"
210        );
211        let mut shape = inputs[0].shape.to_tvec();
212        let frames = (inputs[0].shape[self.axis].clone() - self.frame) / self.stride + 1;
213        shape[self.axis] = frames;
214        shape.insert(self.axis + 1, self.frame.to_dim());
215        Ok(tvec!(inputs[0].datum_type.fact(shape)))
216    }
217
218    as_op!();
219}