1#![allow(unpredictable_function_pointer_comparisons)]
2use crate::device::{DeviceContext, get_context};
3use crate::session_handler::make_tensor_for_node;
4use crate::tensor::{DeviceTensor, DeviceTensorExt, IntoDevice};
5use std::ops::Range;
6use tract_core::internal::*;
7use tract_core::ops::array::PadMode;
8use tract_core::trivial_op_state_freeze;
9use tract_pulse_opl::ops::{Delay, PulsePad};
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub struct GpuDelay {
15 pub inner: Delay,
16}
17
18impl GpuDelay {
19 pub fn new(inner: &Delay) -> Self {
20 Self { inner: inner.clone() }
21 }
22}
23
24impl Op for GpuDelay {
25 fn name(&self) -> StaticName {
26 "GpuDelay".into()
27 }
28
29 fn info(&self) -> TractResult<Vec<String>> {
30 self.inner.info()
31 }
32
33 op_as_typed_op!();
34}
35
36impl EvalOp for GpuDelay {
37 fn is_stateless(&self) -> bool {
38 false
39 }
40
41 fn state(&self, _session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
42 Ok(Some(Box::new(GpuDelayState { node_id, buffer: None, shift_scratch: None })))
43 }
44}
45
46impl TypedOp for GpuDelay {
47 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
48 crate::utils::facts_to_device_facts(inputs, |facts| self.inner.output_facts(facts))
49 .with_context(|| format!("Error while computing output facts for {}", self.name()))
50 }
51
52 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
53 crate::utils::get_device_facts(inputs, |facts| self.inner.cost(facts))
54 }
55
56 as_op!();
57}
58
59#[derive(Debug, Clone)]
60pub struct GpuDelayState {
61 pub node_id: usize,
62 pub buffer: Option<DeviceTensor>,
63 pub shift_scratch: Option<DeviceTensor>,
64}
65
66impl GpuDelayState {
67 unsafe fn apply_delay_unchecked(
68 &mut self,
69 ctx: &dyn DeviceContext,
70 op: &Delay,
71 input: &DeviceTensor,
72 output: &mut DeviceTensor,
73 ) -> TractResult<()> {
74 let buffered = op.delay + op.overlap;
75 let input_pulse = input.shape()[op.axis];
76 let output_pulse = input_pulse + op.overlap;
77 let buffer = self.buffer.as_mut().unwrap();
78
79 let from_input = input_pulse.saturating_sub(op.delay);
80 let from_buffer = output_pulse.saturating_sub(from_input);
81
82 ctx.assign_slice(output, 0..from_buffer, buffer, 0..from_buffer, op.axis)?;
84 ctx.assign_slice(output, from_buffer..output_pulse, input, 0..from_input, op.axis)?;
86
87 if buffered < input_pulse {
89 ctx.assign_slice(
90 buffer,
91 0..buffered,
92 input,
93 (input_pulse - buffered)..input_pulse,
94 op.axis,
95 )?;
96 } else {
97 let keep = buffered - input_pulse;
101 let scratch = self.shift_scratch.get_or_insert_with(|| {
102 DeviceTensor::uninitialized_dt(input.datum_type(), buffer.shape()).unwrap()
103 });
104 ctx.assign_slice(scratch, 0..keep, buffer, input_pulse..buffered, op.axis)?;
105 ctx.assign_slice(buffer, 0..keep, scratch, 0..keep, op.axis)?;
106 ctx.assign_slice(
108 buffer,
109 (buffered - input_pulse)..buffered,
110 input,
111 0..input_pulse,
112 op.axis,
113 )?;
114 }
115 Ok(())
116 }
117}
118
119impl OpState for GpuDelayState {
120 fn eval(
121 &mut self,
122 state: &mut TurnState,
123 op: &dyn Op,
124 inputs: TVec<TValue>,
125 ) -> TractResult<TVec<TValue>> {
126 let input = args_1!(inputs);
127 let op = &op.downcast_ref::<GpuDelay>().ok_or_else(|| format_err!("Wrong Op type"))?.inner;
128 let buffered = op.delay + op.overlap;
129 let device_input = input.as_device_tensor().context("Expected a GPU tensor")?;
130 let input_pulse = device_input.shape()[op.axis];
131 let output_pulse = input_pulse + op.overlap;
132 let mut output_shape: TVec<usize> = device_input.shape().into();
133 output_shape[op.axis] = output_pulse;
134 let dt = device_input.datum_type();
135 let ctx = get_context()?;
136 unsafe {
137 if self.buffer.is_none() {
138 let mut shape = device_input.shape().to_owned();
139 shape[op.axis] = buffered;
140 self.buffer = Some(Tensor::zero_dt(dt, &shape)?.into_device()?);
141 };
142 let mut output = make_tensor_for_node(state, self.node_id, dt, &output_shape)?;
143 self.apply_delay_unchecked(&*ctx, op, device_input, &mut output)?;
144 Ok(tvec!(output.into_tensor().into()))
145 }
146 }
147}
148
149trivial_op_state_freeze!(GpuDelayState);
150
151#[derive(Debug, Clone, PartialEq, Eq)]
154pub struct GpuPulsePad {
155 pub op: PulsePad,
156 pub device_cst: Option<DeviceTensor>,
157}
158
159impl GpuPulsePad {
160 pub fn new(op: &PulsePad) -> TractResult<Self> {
161 let device_cst =
162 if let PadMode::Constant(c) = &op.mode { Some(c.clone().into_device()?) } else { None };
163 Ok(Self { op: op.clone(), device_cst })
164 }
165}
166
167impl std::hash::Hash for GpuPulsePad {
168 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
169 self.op.hash(state);
170 }
171}
172
173impl Op for GpuPulsePad {
174 fn name(&self) -> StaticName {
175 "GpuPulsePad".into()
176 }
177
178 fn info(&self) -> TractResult<Vec<String>> {
179 self.op.info()
180 }
181
182 op_as_typed_op!();
183}
184
185impl EvalOp for GpuPulsePad {
186 fn is_stateless(&self) -> bool {
187 false
188 }
189
190 fn state(&self, _session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
191 Ok(Some(Box::new(GpuPulsePadState { node_id, current_pos: 0, last_valid_frame: None })))
192 }
193}
194
195impl TypedOp for GpuPulsePad {
196 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
197 crate::utils::facts_to_device_facts(inputs, |facts| self.op.output_facts(facts))
198 .with_context(|| format!("Error while computing output facts for {}", self.name()))
199 }
200
201 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
202 crate::utils::get_device_facts(inputs, |facts| self.op.cost(facts))
203 }
204
205 as_op!();
206}
207
208#[derive(Debug, Clone, Hash, PartialEq, Eq)]
209struct GpuPulsePadState {
210 node_id: usize,
211 current_pos: usize,
212 last_valid_frame: Option<DeviceTensor>,
213}
214
215fn fill_slice_constant(
216 ctx: &dyn DeviceContext,
217 dst: &mut DeviceTensor,
218 cst: &DeviceTensor,
219 axis: usize,
220 range: Range<usize>,
221) -> TractResult<()> {
222 let mut zone_shape: TVec<usize> = dst.shape().into();
223 zone_shape[axis] = range.len();
224 let mut dst_origin = tvec!(0; dst.rank());
225 dst_origin[axis] = range.start;
226 ctx.copy_with_origins(
227 &zone_shape,
228 dst,
229 &dst_origin,
230 dst.strides(),
231 cst,
232 &tvec!(0; dst.rank()),
233 &tvec!(0; dst.rank()),
234 )
235}
236
237fn fill_slice_repeating_one_frame(
238 ctx: &dyn DeviceContext,
239 dst: &mut DeviceTensor,
240 src: &DeviceTensor,
241 axis: usize,
242 dst_range: Range<usize>,
243 src_frame: usize,
244) -> TractResult<()> {
245 let mut zone_shape: TVec<usize> = dst.shape().into();
246 zone_shape[axis] = dst_range.len();
247 let mut dst_origin = tvec!(0; dst.rank());
248 dst_origin[axis] = dst_range.start;
249 let mut src_origin = tvec!(0; src.rank());
250 src_origin[axis] = src_frame;
251 let mut src_strides: TVec<isize> = src.strides().into();
252 src_strides[axis] = 0;
253 ctx.copy_with_origins(
254 &zone_shape,
255 dst,
256 &dst_origin,
257 dst.strides(),
258 src,
259 &src_origin,
260 &src_strides,
261 )
262}
263
264impl GpuPulsePadState {
265 fn save_frame(
266 &mut self,
267 ctx: &dyn DeviceContext,
268 op: &PulsePad,
269 input: &DeviceTensor,
270 frame: usize,
271 ) -> TractResult<()> {
272 let mut frame_shape: TVec<usize> = input.shape().into();
273 frame_shape[op.axis] = 1;
274 let last_valid_frame = DeviceTensor::uninitialized_dt(input.datum_type(), &frame_shape)?;
275 ctx.assign_slice(&last_valid_frame, 0..1, input, frame..frame + 1, op.axis)?;
276 self.last_valid_frame = Some(last_valid_frame);
277 Ok(())
278 }
279
280 fn pad(
281 &mut self,
282 session: &TurnState,
283 gpu_op: &GpuPulsePad,
284 input: &DeviceTensor,
285 ) -> TractResult<DeviceTensor> {
286 let ctx = get_context()?;
287 let op = &gpu_op.op;
288 let pulse = input.shape()[op.axis];
289 let pulse_begin = self.current_pos;
290 let pulse_end = self.current_pos + pulse;
291 self.current_pos += pulse - op.overlap;
292 let end_input =
293 op.end_input.eval(&session.resolved_symbols).to_usize().unwrap_or(usize::MAX);
294 let after = op.after.eval(&session.resolved_symbols).to_usize().unwrap_or(usize::MAX);
295
296 if let PadMode::Edge = op.mode
297 && after != 0
298 && pulse_begin < end_input
299 {
300 let latest_valid_frame = (end_input - pulse_begin).min(pulse) - 1;
301 self.save_frame(&*ctx, op, input, latest_valid_frame)?;
302 }
303
304 let mut output =
306 make_tensor_for_node(session, self.node_id, input.datum_type(), input.shape())?;
307 let flat_len = input.len() * input.datum_type().size_of();
308 ctx.flat_copy(input, 0, &output, 0, flat_len)?;
309
310 if (pulse_begin >= op.begin_input && pulse_end <= end_input)
312 || (pulse_end <= op.begin_input - op.before
313 || pulse_begin >= end_input.saturating_add(after))
314 {
315 return Ok(output);
316 }
317
318 if pulse_begin < op.begin_input {
319 let fill_up_to = (op.begin_input - pulse_begin).min(pulse);
320 match &op.mode {
321 PadMode::Constant(_) => fill_slice_constant(
322 &*ctx,
323 &mut output,
324 gpu_op.device_cst.as_ref().unwrap(),
325 op.axis,
326 0..fill_up_to,
327 )?,
328 PadMode::Edge => fill_slice_repeating_one_frame(
329 &*ctx,
330 &mut output,
331 input,
332 op.axis,
333 0..fill_up_to,
334 fill_up_to,
335 )?,
336 _ => unimplemented!(),
337 }
338 }
339
340 if pulse_end > end_input {
341 let fill_from = pulse - (pulse_end - end_input).min(pulse);
342 match &op.mode {
343 PadMode::Constant(_) => fill_slice_constant(
344 &*ctx,
345 &mut output,
346 gpu_op.device_cst.as_ref().unwrap(),
347 op.axis,
348 fill_from..pulse,
349 )?,
350 PadMode::Edge => fill_slice_repeating_one_frame(
351 &*ctx,
352 &mut output,
353 self.last_valid_frame.as_ref().unwrap(),
354 op.axis,
355 fill_from..pulse,
356 0,
357 )?,
358 _ => unimplemented!(),
359 }
360 }
361 Ok(output)
362 }
363}
364
365impl OpState for GpuPulsePadState {
366 fn eval(
367 &mut self,
368 session: &mut TurnState,
369 op: &dyn Op,
370 inputs: TVec<TValue>,
371 ) -> TractResult<TVec<TValue>> {
372 let input = args_1!(inputs);
373 let gpu_op =
374 op.downcast_ref::<GpuPulsePad>().ok_or_else(|| format_err!("Wrong Op type"))?;
375 let device_input = input.as_device_tensor().context("Expected a GPU tensor")?;
376 let output = self.pad(session, gpu_op, device_input)?;
377 Ok(tvec!(output.into_tensor().into_tvalue()))
378 }
379}
380
381trivial_op_state_freeze!(GpuPulsePadState);