1use crate::tensor::{DeviceTensor, DeviceTensorExt, IntoDevice};
2use tract_core::internal::*;
3
4pub const SUPPORTED_FRAMES: [usize; 4] = [256, 512, 1024, 2048];
7
8pub fn is_supported_frame(frame: usize) -> bool {
10 SUPPORTED_FRAMES.contains(&frame)
11}
12
13pub type DispatchStftFn = fn(usize, &DeviceTensor, &DeviceTensor, &DeviceTensor) -> TractResult<()>;
18
19#[derive(Clone)]
26pub struct GpuStft {
27 pub axis: usize,
28 pub frame: usize,
29 pub stride: usize,
30 pub window: Arc<Tensor>,
31 pub backend_name: &'static str,
32 pub dispatch: DispatchStftFn,
33}
34
35impl GpuStft {
36 fn output_shape<D: DimLike>(&self, input: &[D]) -> TVec<D> {
37 let mut shape: TVec<D> = input.into();
38 let frames = (input[self.axis].clone() - self.frame) / self.stride + 1;
39 shape[self.axis] = frames;
40 shape.insert(self.axis + 1, self.frame.into());
41 shape
42 }
43}
44
45impl std::fmt::Debug for GpuStft {
46 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
47 write!(f, "{}Stft(frame={}, stride={})", self.backend_name, self.frame, self.stride)
48 }
49}
50
51impl PartialEq for GpuStft {
52 fn eq(&self, other: &Self) -> bool {
53 self.backend_name == other.backend_name
54 && self.axis == other.axis
55 && self.frame == other.frame
56 && self.stride == other.stride
57 && self.window == other.window
58 }
59}
60
61impl Eq for GpuStft {}
62
63impl std::hash::Hash for GpuStft {
64 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
65 self.backend_name.hash(state);
66 self.axis.hash(state);
67 self.frame.hash(state);
68 self.stride.hash(state);
69 self.window.hash(state);
70 }
71}
72
73impl Op for GpuStft {
74 fn name(&self) -> StaticName {
75 format!("{}Stft", self.backend_name).into()
76 }
77
78 op_as_typed_op!();
79}
80
81impl EvalOp for GpuStft {
82 fn is_stateless(&self) -> bool {
83 true
84 }
85
86 fn eval_with_session(
87 &self,
88 node_id: usize,
89 session: &TurnState,
90 inputs: TVec<TValue>,
91 ) -> TractResult<TVec<TValue>> {
92 let input = inputs[0].to_device_tensor()?;
93 let window = (*self.window).clone().into_device()?;
94 let output = crate::session_handler::make_tensor_for_node(
95 session,
96 node_id,
97 input.datum_type(),
98 &self.output_shape(input.shape()),
99 )?;
100 (self.dispatch)(self.stride, input, &window, &output)
101 .with_context(|| format!("Error while dispatching eval for {}", self.name()))?;
102 Ok(tvec!(output.into_tensor().into_tvalue()))
103 }
104}
105
106impl TypedOp for GpuStft {
107 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
108 crate::utils::facts_to_device_facts(inputs, |facts| {
109 let input = facts[0];
110 ensure!(
111 input.rank() >= 2 && input.shape[input.rank() - 1] == 2.to_dim(),
112 "{} expects a complex input [.., T, 2]",
113 self.name()
114 );
115 Ok(tvec!(input.datum_type.fact(self.output_shape(&input.shape.to_tvec()))))
116 })
117 .with_context(|| format!("Error while computing facts for {:?}", self.name()))
118 }
119
120 as_op!();
121}
122
123pub type DispatchFftFn = fn(bool, &DeviceTensor, &DeviceTensor) -> TractResult<()>;
128
129#[derive(Clone)]
133pub struct GpuFft {
134 pub axis: usize,
135 pub inverse: bool,
136 pub backend_name: &'static str,
137 pub dispatch: DispatchFftFn,
138}
139
140impl std::fmt::Debug for GpuFft {
141 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
142 write!(f, "{}Fft({})", self.backend_name, if self.inverse { "inverse" } else { "forward" })
143 }
144}
145
146impl PartialEq for GpuFft {
147 fn eq(&self, other: &Self) -> bool {
148 self.backend_name == other.backend_name
149 && self.axis == other.axis
150 && self.inverse == other.inverse
151 }
152}
153
154impl Eq for GpuFft {}
155
156impl std::hash::Hash for GpuFft {
157 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
158 self.backend_name.hash(state);
159 self.axis.hash(state);
160 self.inverse.hash(state);
161 }
162}
163
164impl Op for GpuFft {
165 fn name(&self) -> StaticName {
166 format!("{}Fft", self.backend_name).into()
167 }
168
169 op_as_typed_op!();
170}
171
172impl EvalOp for GpuFft {
173 fn is_stateless(&self) -> bool {
174 true
175 }
176
177 fn eval_with_session(
178 &self,
179 node_id: usize,
180 session: &TurnState,
181 inputs: TVec<TValue>,
182 ) -> TractResult<TVec<TValue>> {
183 let input = inputs[0].to_device_tensor()?;
184 let output = crate::session_handler::make_tensor_for_node(
185 session,
186 node_id,
187 input.datum_type(),
188 input.shape(),
189 )?;
190 (self.dispatch)(self.inverse, input, &output)
191 .with_context(|| format!("Error while dispatching eval for {}", self.name()))?;
192 Ok(tvec!(output.into_tensor().into_tvalue()))
193 }
194}
195
196impl TypedOp for GpuFft {
197 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
198 crate::utils::facts_to_device_facts(inputs, |facts| {
199 let input = facts[0];
200 ensure!(
201 input.rank() >= 2 && input.shape[input.rank() - 1] == 2.to_dim(),
202 "{} expects a complex input [.., N, 2]",
203 self.name()
204 );
205 Ok(tvec!(input.datum_type.fact(input.shape.clone())))
206 })
207 .with_context(|| format!("Error while computing facts for {:?}", self.name()))
208 }
209
210 as_op!();
211}
212
213pub fn padded_window(window: Option<&Arc<Tensor>>, frame: usize) -> TractResult<Arc<Tensor>> {
216 let mut win = vec![0f32; frame];
217 match window {
218 Some(w) => {
219 let w = w.cast_to::<f32>()?;
220 let w = w.try_as_plain()?;
221 let w = w.as_slice::<f32>()?;
222 ensure!(w.len() <= frame, "STFT window longer than frame");
223 let pad_left = (frame - w.len()) / 2;
224 win[pad_left..pad_left + w.len()].copy_from_slice(w);
225 }
226 None => win.fill(1.0),
227 }
228 Ok(Arc::new(tensor1(&win)))
229}