Skip to main content

tract_gpu/ops/
pulse.rs

1#![allow(unpredictable_function_pointer_comparisons)]
2use crate::device::{DeviceContext, get_context};
3use crate::session_handler::make_tensor_for_node;
4use crate::tensor::{DeviceTensor, DeviceTensorExt, IntoDevice};
5use std::ops::Range;
6use tract_core::internal::*;
7use tract_core::ops::array::PadMode;
8use tract_core::trivial_op_state_freeze;
9use tract_pulse_opl::ops::{Delay, PulsePad};
10
11// ─── GpuDelay ────────────────────────────────────────────────────────────────
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub struct GpuDelay {
15    pub inner: Delay,
16}
17
18impl GpuDelay {
19    pub fn new(inner: &Delay) -> Self {
20        Self { inner: inner.clone() }
21    }
22}
23
24impl Op for GpuDelay {
25    fn name(&self) -> StaticName {
26        "GpuDelay".into()
27    }
28
29    fn info(&self) -> TractResult<Vec<String>> {
30        self.inner.info()
31    }
32
33    op_as_typed_op!();
34}
35
36impl EvalOp for GpuDelay {
37    fn is_stateless(&self) -> bool {
38        false
39    }
40
41    fn state(&self, _session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
42        Ok(Some(Box::new(GpuDelayState { node_id, buffer: None, shift_scratch: None })))
43    }
44}
45
46impl TypedOp for GpuDelay {
47    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
48        crate::utils::facts_to_device_facts(inputs, |facts| self.inner.output_facts(facts))
49            .with_context(|| format!("Error while computing output facts for {}", self.name()))
50    }
51
52    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
53        crate::utils::get_device_facts(inputs, |facts| self.inner.cost(facts))
54    }
55
56    as_op!();
57}
58
59#[derive(Debug, Clone)]
60pub struct GpuDelayState {
61    pub node_id: usize,
62    pub buffer: Option<DeviceTensor>,
63    pub shift_scratch: Option<DeviceTensor>,
64}
65
66impl GpuDelayState {
67    unsafe fn apply_delay_unchecked(
68        &mut self,
69        ctx: &dyn DeviceContext,
70        op: &Delay,
71        input: &DeviceTensor,
72        output: &mut DeviceTensor,
73    ) -> TractResult<()> {
74        let buffered = op.delay + op.overlap;
75        let input_pulse = input.shape()[op.axis];
76        let output_pulse = input_pulse + op.overlap;
77        let buffer = self.buffer.as_mut().unwrap();
78
79        let from_input = input_pulse.saturating_sub(op.delay);
80        let from_buffer = output_pulse.saturating_sub(from_input);
81
82        // Copy from buffer to output
83        ctx.assign_slice(output, 0..from_buffer, buffer, 0..from_buffer, op.axis)?;
84        // Copy from input to output
85        ctx.assign_slice(output, from_buffer..output_pulse, input, 0..from_input, op.axis)?;
86
87        // Maintain buffer
88        if buffered < input_pulse {
89            ctx.assign_slice(
90                buffer,
91                0..buffered,
92                input,
93                (input_pulse - buffered)..input_pulse,
94                op.axis,
95            )?;
96        } else {
97            // Shift buffer left by input_pulse elements.
98            // CUDA memcpy is undefined for overlapping regions in the same
99            // buffer (parallel threads), so copy via a scratch buffer.
100            let keep = buffered - input_pulse;
101            let scratch = self.shift_scratch.get_or_insert_with(|| {
102                DeviceTensor::uninitialized_dt(input.datum_type(), buffer.shape()).unwrap()
103            });
104            ctx.assign_slice(scratch, 0..keep, buffer, input_pulse..buffered, op.axis)?;
105            ctx.assign_slice(buffer, 0..keep, scratch, 0..keep, op.axis)?;
106            // Copy input to end of buffer
107            ctx.assign_slice(
108                buffer,
109                (buffered - input_pulse)..buffered,
110                input,
111                0..input_pulse,
112                op.axis,
113            )?;
114        }
115        Ok(())
116    }
117}
118
119impl OpState for GpuDelayState {
120    fn eval(
121        &mut self,
122        state: &mut TurnState,
123        op: &dyn Op,
124        inputs: TVec<TValue>,
125    ) -> TractResult<TVec<TValue>> {
126        let input = args_1!(inputs);
127        let op = &op.downcast_ref::<GpuDelay>().ok_or_else(|| format_err!("Wrong Op type"))?.inner;
128        let buffered = op.delay + op.overlap;
129        let device_input = input.as_device_tensor().context("Expected a GPU tensor")?;
130        let input_pulse = device_input.shape()[op.axis];
131        let output_pulse = input_pulse + op.overlap;
132        let mut output_shape: TVec<usize> = device_input.shape().into();
133        output_shape[op.axis] = output_pulse;
134        let dt = device_input.datum_type();
135        let ctx = get_context()?;
136        unsafe {
137            if self.buffer.is_none() {
138                let mut shape = device_input.shape().to_owned();
139                shape[op.axis] = buffered;
140                self.buffer = Some(Tensor::zero_dt(dt, &shape)?.into_device()?);
141            };
142            let mut output = make_tensor_for_node(state, self.node_id, dt, &output_shape)?;
143            self.apply_delay_unchecked(&*ctx, op, device_input, &mut output)?;
144            Ok(tvec!(output.into_tensor().into()))
145        }
146    }
147}
148
149trivial_op_state_freeze!(GpuDelayState);
150
151// ─── GpuPulsePad ─────────────────────────────────────────────────────────────
152
153#[derive(Debug, Clone, PartialEq, Eq)]
154pub struct GpuPulsePad {
155    pub op: PulsePad,
156    pub device_cst: Option<DeviceTensor>,
157}
158
159impl GpuPulsePad {
160    pub fn new(op: &PulsePad) -> TractResult<Self> {
161        let device_cst =
162            if let PadMode::Constant(c) = &op.mode { Some(c.clone().into_device()?) } else { None };
163        Ok(Self { op: op.clone(), device_cst })
164    }
165}
166
167impl std::hash::Hash for GpuPulsePad {
168    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
169        self.op.hash(state);
170    }
171}
172
173impl Op for GpuPulsePad {
174    fn name(&self) -> StaticName {
175        "GpuPulsePad".into()
176    }
177
178    fn info(&self) -> TractResult<Vec<String>> {
179        self.op.info()
180    }
181
182    op_as_typed_op!();
183}
184
185impl EvalOp for GpuPulsePad {
186    fn is_stateless(&self) -> bool {
187        false
188    }
189
190    fn state(&self, _session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
191        Ok(Some(Box::new(GpuPulsePadState { node_id, current_pos: 0, last_valid_frame: None })))
192    }
193}
194
195impl TypedOp for GpuPulsePad {
196    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
197        crate::utils::facts_to_device_facts(inputs, |facts| self.op.output_facts(facts))
198            .with_context(|| format!("Error while computing output facts for {}", self.name()))
199    }
200
201    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
202        crate::utils::get_device_facts(inputs, |facts| self.op.cost(facts))
203    }
204
205    as_op!();
206}
207
208#[derive(Debug, Clone, Hash, PartialEq, Eq)]
209struct GpuPulsePadState {
210    node_id: usize,
211    current_pos: usize,
212    last_valid_frame: Option<DeviceTensor>,
213}
214
215fn fill_slice_constant(
216    ctx: &dyn DeviceContext,
217    dst: &mut DeviceTensor,
218    cst: &DeviceTensor,
219    axis: usize,
220    range: Range<usize>,
221) -> TractResult<()> {
222    let mut zone_shape: TVec<usize> = dst.shape().into();
223    zone_shape[axis] = range.len();
224    let mut dst_origin = tvec!(0; dst.rank());
225    dst_origin[axis] = range.start;
226    ctx.copy_with_origins(
227        &zone_shape,
228        dst,
229        &dst_origin,
230        dst.strides(),
231        cst,
232        &tvec!(0; dst.rank()),
233        &tvec!(0; dst.rank()),
234    )
235}
236
237fn fill_slice_repeating_one_frame(
238    ctx: &dyn DeviceContext,
239    dst: &mut DeviceTensor,
240    src: &DeviceTensor,
241    axis: usize,
242    dst_range: Range<usize>,
243    src_frame: usize,
244) -> TractResult<()> {
245    let mut zone_shape: TVec<usize> = dst.shape().into();
246    zone_shape[axis] = dst_range.len();
247    let mut dst_origin = tvec!(0; dst.rank());
248    dst_origin[axis] = dst_range.start;
249    let mut src_origin = tvec!(0; src.rank());
250    src_origin[axis] = src_frame;
251    let mut src_strides: TVec<isize> = src.strides().into();
252    src_strides[axis] = 0;
253    ctx.copy_with_origins(
254        &zone_shape,
255        dst,
256        &dst_origin,
257        dst.strides(),
258        src,
259        &src_origin,
260        &src_strides,
261    )
262}
263
264impl GpuPulsePadState {
265    fn save_frame(
266        &mut self,
267        ctx: &dyn DeviceContext,
268        op: &PulsePad,
269        input: &DeviceTensor,
270        frame: usize,
271    ) -> TractResult<()> {
272        let mut frame_shape: TVec<usize> = input.shape().into();
273        frame_shape[op.axis] = 1;
274        let last_valid_frame = DeviceTensor::uninitialized_dt(input.datum_type(), &frame_shape)?;
275        ctx.assign_slice(&last_valid_frame, 0..1, input, frame..frame + 1, op.axis)?;
276        self.last_valid_frame = Some(last_valid_frame);
277        Ok(())
278    }
279
280    fn pad(
281        &mut self,
282        session: &TurnState,
283        gpu_op: &GpuPulsePad,
284        input: &DeviceTensor,
285    ) -> TractResult<DeviceTensor> {
286        let ctx = get_context()?;
287        let op = &gpu_op.op;
288        let pulse = input.shape()[op.axis];
289        let pulse_begin = self.current_pos;
290        let pulse_end = self.current_pos + pulse;
291        self.current_pos += pulse - op.overlap;
292        let end_input =
293            op.end_input.eval(&session.resolved_symbols).to_usize().unwrap_or(usize::MAX);
294        let after = op.after.eval(&session.resolved_symbols).to_usize().unwrap_or(usize::MAX);
295
296        if let PadMode::Edge = op.mode
297            && after != 0
298            && pulse_begin < end_input
299        {
300            let latest_valid_frame = (end_input - pulse_begin).min(pulse) - 1;
301            self.save_frame(&*ctx, op, input, latest_valid_frame)?;
302        }
303
304        // Start with a copy of input
305        let mut output =
306            make_tensor_for_node(session, self.node_id, input.datum_type(), input.shape())?;
307        let flat_len = input.len() * input.datum_type().size_of();
308        ctx.flat_copy(input, 0, &output, 0, flat_len)?;
309
310        // Quick return if entirely in valid or invalid range
311        if (pulse_begin >= op.begin_input && pulse_end <= end_input)
312            || (pulse_end <= op.begin_input - op.before
313                || pulse_begin >= end_input.saturating_add(after))
314        {
315            return Ok(output);
316        }
317
318        if pulse_begin < op.begin_input {
319            let fill_up_to = (op.begin_input - pulse_begin).min(pulse);
320            match &op.mode {
321                PadMode::Constant(_) => fill_slice_constant(
322                    &*ctx,
323                    &mut output,
324                    gpu_op.device_cst.as_ref().unwrap(),
325                    op.axis,
326                    0..fill_up_to,
327                )?,
328                PadMode::Edge => fill_slice_repeating_one_frame(
329                    &*ctx,
330                    &mut output,
331                    input,
332                    op.axis,
333                    0..fill_up_to,
334                    fill_up_to,
335                )?,
336                _ => unimplemented!(),
337            }
338        }
339
340        if pulse_end > end_input {
341            let fill_from = pulse - (pulse_end - end_input).min(pulse);
342            match &op.mode {
343                PadMode::Constant(_) => fill_slice_constant(
344                    &*ctx,
345                    &mut output,
346                    gpu_op.device_cst.as_ref().unwrap(),
347                    op.axis,
348                    fill_from..pulse,
349                )?,
350                PadMode::Edge => fill_slice_repeating_one_frame(
351                    &*ctx,
352                    &mut output,
353                    self.last_valid_frame.as_ref().unwrap(),
354                    op.axis,
355                    fill_from..pulse,
356                    0,
357                )?,
358                _ => unimplemented!(),
359            }
360        }
361        Ok(output)
362    }
363}
364
365impl OpState for GpuPulsePadState {
366    fn eval(
367        &mut self,
368        session: &mut TurnState,
369        op: &dyn Op,
370        inputs: TVec<TValue>,
371    ) -> TractResult<TVec<TValue>> {
372        let input = args_1!(inputs);
373        let gpu_op =
374            op.downcast_ref::<GpuPulsePad>().ok_or_else(|| format_err!("Wrong Op type"))?;
375        let device_input = input.as_device_tensor().context("Expected a GPU tensor")?;
376        let output = self.pad(session, gpu_op, device_input)?;
377        Ok(tvec!(output.into_tensor().into_tvalue()))
378    }
379}
380
381trivial_op_state_freeze!(GpuPulsePadState);