Skip to main content

tract_core/ops/
fft.rs

1use crate::internal::*;
2use num_complex::Complex;
3use rustfft::num_traits::{Float, FromPrimitive};
4use rustfft::{FftDirection, FftNum};
5use tract_data::itertools::Itertools;
6use tract_ndarray::Axis;
7
8#[derive(Clone, Debug, Hash)]
9pub struct Fft {
10    pub axis: usize,
11    pub inverse: bool,
12}
13
14impl Fft {
15    fn eval_t<T: Datum + FftNum + FromPrimitive + Float>(
16        &self,
17        tensor: &mut Tensor,
18    ) -> TractResult<()> {
19        let mut iterator_shape: TVec<usize> = tensor.shape().into();
20        iterator_shape.pop(); // last dim is [re, im]
21        iterator_shape[self.axis] = 1;
22        let len = tensor.shape()[self.axis];
23        let direction = if self.inverse { FftDirection::Inverse } else { FftDirection::Forward };
24        let fft = rustfft::FftPlanner::new().plan_fft(len, direction);
25        let mut tensor_dense = tensor.try_as_dense_mut()?;
26        let mut array = tensor_dense.to_array_view_mut::<T>()?;
27        let mut v = Vec::with_capacity(len);
28        for coords in tract_ndarray::indices(&*iterator_shape) {
29            v.clear();
30            let mut slice = array.slice_each_axis_mut(|ax| {
31                if ax.axis.index() == self.axis || ax.stride == 1 {
32                    // ax.stride == 1 => last dim
33                    (..).into()
34                } else {
35                    let c = coords[ax.axis.index()] as isize;
36                    (c..=c).into()
37                }
38            });
39            v.extend(slice.iter().tuples().map(|(r, i)| Complex::new(*r, *i)));
40            fft.process(&mut v);
41            slice
42                .iter_mut()
43                .zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter()))
44                .for_each(|(s, v)| *s = v);
45        }
46        Ok(())
47    }
48}
49
50impl Op for Fft {
51    fn name(&self) -> StaticName {
52        "Fft".into()
53    }
54
55    fn info(&self) -> TractResult<Vec<String>> {
56        Ok(vec![if self.inverse { "inverse" } else { "forward" }.into()])
57    }
58
59    op_as_typed_op!();
60}
61
62impl EvalOp for Fft {
63    fn is_stateless(&self) -> bool {
64        true
65    }
66
67    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
68        let mut tensor = args_1!(inputs).into_tensor();
69        match tensor.datum_type() {
70            DatumType::F16 => {
71                let mut temp = tensor.cast_to::<f32>()?.into_owned();
72                self.eval_t::<f32>(&mut temp)?;
73                tensor = temp.cast_to::<f16>()?.into_owned();
74            }
75            DatumType::F32 => self.eval_t::<f32>(&mut tensor)?,
76            DatumType::F64 => self.eval_t::<f64>(&mut tensor)?,
77            _ => bail!("FFT not implemented for type {:?}", tensor.datum_type()),
78        }
79        Ok(tvec!(tensor.into_tvalue()))
80    }
81}
82
83impl TypedOp for Fft {
84    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
85        anyhow::ensure!(
86            inputs[0].rank() >= 2,
87            "Expect rank 2 (one for fft dimension, one for complex dimension"
88        );
89        anyhow::ensure!(
90            inputs[0].shape.last().unwrap() == &2.to_dim(),
91            "Fft operators expect inner (last) dimension to be 2 for real and imaginary part"
92        );
93        Ok(tvec!(inputs[0].without_value()))
94    }
95
96    as_op!();
97}
98
99#[derive(Clone, Debug, Hash)]
100pub struct Stft {
101    pub axis: usize,
102    pub frame: usize,
103    pub stride: usize,
104    pub window: Option<Arc<Tensor>>,
105}
106
107impl Stft {
108    fn eval_t<T: Datum + FftNum + FromPrimitive + Float>(
109        &self,
110        input: &Tensor,
111    ) -> TractResult<Tensor> {
112        let mut iterator_shape: TVec<usize> = input.shape().into();
113        iterator_shape.pop(); // [re,im]
114        iterator_shape[self.axis] = 1;
115        let mut output_shape: TVec<usize> = input.shape().into();
116        let frames = (input.shape()[self.axis] - self.frame) / self.stride + 1;
117        output_shape.insert(self.axis, frames);
118        output_shape[self.axis + 1] = self.frame;
119        let mut output = unsafe { Tensor::uninitialized::<T>(&output_shape)? };
120        let fft = rustfft::FftPlanner::new().plan_fft_forward(self.frame);
121        let input = input.to_dense_array_view::<T>()?;
122        let mut output_dense = output.try_as_dense_mut()?;
123        let mut oview = output_dense.to_array_view_mut::<T>()?;
124        let mut v = Vec::with_capacity(self.frame);
125        for coords in tract_ndarray::indices(&*iterator_shape) {
126            let islice = input.slice_each_axis(|ax| {
127                if ax.axis.index() == self.axis || ax.stride == 1 {
128                    (..).into()
129                } else {
130                    let c = coords[ax.axis.index()] as isize;
131                    (c..=c).into()
132                }
133            });
134            let mut oslice = oview.slice_each_axis_mut(|ax| {
135                if ax.stride == 1 {
136                    (..).into()
137                } else if ax.axis.index() < self.axis {
138                    let c = coords[ax.axis.index()] as isize;
139                    (c..=c).into()
140                } else if ax.axis.index() == self.axis || ax.axis.index() == self.axis + 1 {
141                    (..).into()
142                } else {
143                    let c = coords[ax.axis.index() - 1] as isize;
144                    (c..=c).into()
145                }
146            });
147            for f in 0..frames {
148                v.clear();
149                v.extend(
150                    islice
151                        .iter()
152                        .tuples()
153                        .skip(self.stride * f)
154                        .take(self.frame)
155                        .map(|(re, im)| Complex::new(*re, *im)),
156                );
157                if let Some(win) = &self.window {
158                    let win = win.try_as_dense()?.as_slice::<T>()?;
159                    // symmetric padding in case window is smaller than frames (aka n fft)
160                    let pad_left = (self.frame - win.len()) / 2;
161                    v.iter_mut().enumerate().for_each(|(ix, v)| {
162                        *v = if ix < pad_left || ix >= pad_left + win.len() {
163                            Complex::new(T::zero(), T::zero())
164                        } else {
165                            *v * Complex::new(win[ix - pad_left], T::zero())
166                        }
167                    });
168                }
169                fft.process(&mut v);
170                oslice
171                    .index_axis_mut(Axis(self.axis), f)
172                    .iter_mut()
173                    .zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter()))
174                    .for_each(|(s, v)| *s = v);
175            }
176        }
177        Ok(output)
178    }
179}
180
181impl Op for Stft {
182    fn name(&self) -> StaticName {
183        "STFT".into()
184    }
185
186    op_as_typed_op!();
187}
188
189impl EvalOp for Stft {
190    fn is_stateless(&self) -> bool {
191        true
192    }
193
194    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
195        let input = args_1!(inputs);
196        let output = match input.datum_type() {
197            DatumType::F16 => {
198                let temp = input.cast_to::<f32>()?;
199                self.eval_t::<f32>(&temp)?.cast_to::<f16>()?.into_owned()
200            }
201            DatumType::F32 => self.eval_t::<f32>(&input)?,
202            DatumType::F64 => self.eval_t::<f64>(&input)?,
203            _ => bail!("FFT not implemented for type {:?}", input.datum_type()),
204        };
205        Ok(tvec!(output.into_tvalue()))
206    }
207}
208
209impl TypedOp for Stft {
210    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
211        anyhow::ensure!(
212            inputs[0].rank() >= 2,
213            "Expect rank 2 (one for fft dimension, one for complex dimension"
214        );
215        anyhow::ensure!(
216            inputs[0].shape.last().unwrap() == &2.to_dim(),
217            "Fft operators expect inner (last) dimension to be 2 for real and imaginary part"
218        );
219        let mut shape = inputs[0].shape.to_tvec();
220        let frames = (inputs[0].shape[self.axis].clone() - self.frame) / self.stride + 1;
221        shape[self.axis] = frames;
222        shape.insert(self.axis + 1, self.frame.to_dim());
223        Ok(tvec!(inputs[0].datum_type.fact(shape)))
224    }
225
226    as_op!();
227}