Skip to main content

tract_gpu/ops/
pulse.rs

1#![allow(unpredictable_function_pointer_comparisons)]
2use crate::device::{DeviceContext, get_context};
3use crate::session_handler::make_tensor_for_node;
4use crate::tensor::{DeviceTensor, DeviceTensorExt, IntoDevice};
5use std::ops::Range;
6use tract_core::internal::*;
7use tract_core::ops::array::PadMode;
8use tract_core::trivial_op_state_freeze;
9use tract_pulse_opl::ops::{Delay, PulsePad};
10
11// ─── GpuDelay ────────────────────────────────────────────────────────────────
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub struct GpuDelay {
15    pub inner: Delay,
16}
17
18impl GpuDelay {
19    pub fn new(inner: &Delay) -> Self {
20        Self { inner: inner.clone() }
21    }
22}
23
24impl Op for GpuDelay {
25    fn name(&self) -> StaticName {
26        "GpuDelay".into()
27    }
28
29    fn info(&self) -> TractResult<Vec<String>> {
30        self.inner.info()
31    }
32
33    op_as_typed_op!();
34}
35
36impl EvalOp for GpuDelay {
37    fn is_stateless(&self) -> bool {
38        false
39    }
40
41    fn state(&self, _session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
42        Ok(Some(Box::new(GpuDelayState { node_id, buffer: None, shift_scratch: None })))
43    }
44}
45
46impl TypedOp for GpuDelay {
47    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
48        crate::utils::facts_to_device_facts(inputs, |facts| self.inner.output_facts(facts))
49            .with_context(|| format!("Error while computing output facts for {}", self.name()))
50    }
51
52    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
53        crate::utils::get_device_facts(inputs, |facts| self.inner.cost(facts))
54    }
55
56    as_op!();
57}
58
59#[derive(Debug, Clone)]
60pub struct GpuDelayState {
61    pub node_id: usize,
62    pub buffer: Option<DeviceTensor>,
63    pub shift_scratch: Option<DeviceTensor>,
64}
65
66impl GpuDelayState {
67    unsafe fn apply_delay_unchecked(
68        &mut self,
69        ctx: &dyn DeviceContext,
70        op: &Delay,
71        input: &DeviceTensor,
72        output: &mut DeviceTensor,
73    ) -> TractResult<()> {
74        let buffered = op.delay + op.overlap;
75        let input_pulse = input.shape()[op.axis];
76        let output_pulse = input_pulse + op.overlap;
77        let buffer = self.buffer.as_mut().unwrap();
78
79        let from_input = input_pulse.saturating_sub(op.delay);
80        let from_buffer = output_pulse.saturating_sub(from_input);
81
82        // Copy from buffer to output
83        ctx.assign_slice(output, 0..from_buffer, buffer, 0..from_buffer, op.axis)?;
84        // Copy from input to output
85        ctx.assign_slice(output, from_buffer..output_pulse, input, 0..from_input, op.axis)?;
86
87        // Maintain buffer
88        if buffered < input_pulse {
89            ctx.assign_slice(
90                buffer,
91                0..buffered,
92                input,
93                (input_pulse - buffered)..input_pulse,
94                op.axis,
95            )?;
96        } else {
97            // Shift buffer left by input_pulse elements.
98            // CUDA memcpy is undefined for overlapping regions in the same
99            // buffer (parallel threads), so copy via a scratch buffer.
100            let keep = buffered - input_pulse;
101            let scratch = self.shift_scratch.get_or_insert_with(|| {
102                DeviceTensor::uninitialized_dt(input.datum_type(), buffer.shape()).unwrap()
103            });
104            ctx.assign_slice(scratch, 0..keep, buffer, input_pulse..buffered, op.axis)?;
105            ctx.assign_slice(buffer, 0..keep, scratch, 0..keep, op.axis)?;
106            // Copy input to end of buffer
107            ctx.assign_slice(
108                buffer,
109                (buffered - input_pulse)..buffered,
110                input,
111                0..input_pulse,
112                op.axis,
113            )?;
114        }
115        Ok(())
116    }
117}
118
119impl OpState for GpuDelayState {
120    fn eval(
121        &mut self,
122        state: &mut TurnState,
123        op: &dyn Op,
124        inputs: TVec<TValue>,
125    ) -> TractResult<TVec<TValue>> {
126        let input = args_1!(inputs);
127        let op = &op.downcast_ref::<GpuDelay>().ok_or_else(|| format_err!("Wrong Op type"))?.inner;
128        let buffered = op.delay + op.overlap;
129        let device_input = input.as_device_tensor().context("Expected a GPU tensor")?;
130        let input_pulse = device_input.shape()[op.axis];
131        let output_pulse = input_pulse + op.overlap;
132        let mut output_shape: TVec<usize> = device_input.shape().into();
133        output_shape[op.axis] = output_pulse;
134        let dt = device_input.datum_type();
135        let ctx = get_context()?;
136        unsafe {
137            if self.buffer.is_none() {
138                let mut shape = device_input.shape().to_owned();
139                shape[op.axis] = buffered;
140                self.buffer = Some(Tensor::zero_dt(dt, &shape)?.into_device()?);
141            };
142            let mut output = make_tensor_for_node(state, self.node_id, dt, &output_shape)?;
143            self.apply_delay_unchecked(&*ctx, op, device_input, &mut output)?;
144            Ok(tvec!(output.into_tensor().into()))
145        }
146    }
147}
148
149trivial_op_state_freeze!(GpuDelayState);
150
151// ─── GpuPulsePad ─────────────────────────────────────────────────────────────
152
153#[derive(Debug, Clone, PartialEq, Eq)]
154pub struct GpuPulsePad {
155    pub op: PulsePad,
156    pub device_cst: Option<DeviceTensor>,
157}
158
159impl GpuPulsePad {
160    pub fn new(op: &PulsePad) -> TractResult<Self> {
161        let device_cst =
162            if let PadMode::Constant(c) = &op.mode { Some(c.clone().into_device()?) } else { None };
163        Ok(Self { op: op.clone(), device_cst })
164    }
165}
166
167impl std::hash::Hash for GpuPulsePad {
168    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
169        self.op.hash(state);
170    }
171}
172
173impl Op for GpuPulsePad {
174    fn name(&self) -> StaticName {
175        "GpuPulsePad".into()
176    }
177
178    fn info(&self) -> TractResult<Vec<String>> {
179        self.op.info()
180    }
181
182    op_as_typed_op!();
183}
184
185impl EvalOp for GpuPulsePad {
186    fn is_stateless(&self) -> bool {
187        false
188    }
189
190    fn state(&self, _session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
191        Ok(Some(Box::new(GpuPulsePadState { node_id, current_pos: 0, last_valid_frame: None })))
192    }
193}
194
195impl TypedOp for GpuPulsePad {
196    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
197        crate::utils::facts_to_device_facts(inputs, |facts| self.op.output_facts(facts))
198            .with_context(|| format!("Error while computing output facts for {}", self.name()))
199    }
200
201    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
202        crate::utils::get_device_facts(inputs, |facts| self.op.cost(facts))
203    }
204
205    as_op!();
206}
207
208#[derive(Debug, Clone, Hash, PartialEq, Eq)]
209struct GpuPulsePadState {
210    node_id: usize,
211    current_pos: usize,
212    last_valid_frame: Option<DeviceTensor>,
213}
214
215fn fill_slice_constant(
216    ctx: &dyn DeviceContext,
217    dst: &mut DeviceTensor,
218    cst: &DeviceTensor,
219    axis: usize,
220    range: Range<usize>,
221) -> TractResult<()> {
222    let mut zone_shape: TVec<usize> = dst.shape().into();
223    zone_shape[axis] = range.len();
224    let mut dst_origin = tvec!(0; dst.rank());
225    dst_origin[axis] = range.start;
226    ctx.copy_with_origins(
227        &zone_shape,
228        dst,
229        &dst_origin,
230        dst.strides(),
231        cst,
232        &tvec!(0; dst.rank()),
233        &tvec!(0; dst.rank()),
234    )
235}
236
237fn fill_slice_repeating_one_frame(
238    ctx: &dyn DeviceContext,
239    dst: &mut DeviceTensor,
240    src: &DeviceTensor,
241    axis: usize,
242    dst_range: Range<usize>,
243    src_frame: usize,
244) -> TractResult<()> {
245    let mut zone_shape: TVec<usize> = dst.shape().into();
246    zone_shape[axis] = dst_range.len();
247    let mut dst_origin = tvec!(0; dst.rank());
248    dst_origin[axis] = dst_range.start;
249    let mut src_origin = tvec!(0; src.rank());
250    src_origin[axis] = src_frame;
251    let mut src_strides: TVec<isize> = src.strides().into();
252    src_strides[axis] = 0;
253    ctx.copy_with_origins(
254        &zone_shape,
255        dst,
256        &dst_origin,
257        dst.strides(),
258        src,
259        &src_origin,
260        &src_strides,
261    )
262}
263
264impl GpuPulsePadState {
265    fn save_frame(
266        &mut self,
267        ctx: &dyn DeviceContext,
268        op: &PulsePad,
269        input: &DeviceTensor,
270        frame: usize,
271    ) -> TractResult<()> {
272        let mut frame_shape: TVec<usize> = input.shape().into();
273        frame_shape[op.axis] = 1;
274        let last_valid_frame = DeviceTensor::uninitialized_dt(input.datum_type(), &frame_shape)?;
275        ctx.assign_slice(&last_valid_frame, 0..1, input, frame..frame + 1, op.axis)?;
276        self.last_valid_frame = Some(last_valid_frame);
277        Ok(())
278    }
279
280    fn pad(
281        &mut self,
282        session: &TurnState,
283        gpu_op: &GpuPulsePad,
284        input: &DeviceTensor,
285    ) -> TractResult<DeviceTensor> {
286        let ctx = get_context()?;
287        let op = &gpu_op.op;
288        let pulse = input.shape()[op.axis];
289        let pulse_begin = self.current_pos;
290        let pulse_end = self.current_pos + pulse;
291        self.current_pos += pulse - op.overlap;
292        let end_input =
293            op.end_input.eval(&session.resolved_symbols).to_usize().unwrap_or(usize::MAX);
294        let after = op.after.eval(&session.resolved_symbols).to_usize().unwrap_or(usize::MAX);
295
296        if let PadMode::Edge = op.mode
297            && after != 0
298            && pulse_begin < end_input
299        {
300            let latest_valid_frame = (end_input - pulse_begin).min(pulse) - 1;
301            self.save_frame(&*ctx, op, input, latest_valid_frame)?;
302        }
303
304        // Start with a copy of input.  The fused-axis-op chain may have
305        // installed a non-contiguous view (Move only permutes strides,
306        // never materialises), so a flat memcpy would read the buffer in
307        // pre-Move order; copy_nd honours `input.strides()` instead.
308        let mut output =
309            make_tensor_for_node(session, self.node_id, input.datum_type(), input.shape())?;
310        ctx.copy_nd(input, 0, input.strides(), &output, 0, input.shape(), output.strides())?;
311
312        // Quick return if entirely in valid or invalid range
313        if (pulse_begin >= op.begin_input && pulse_end <= end_input)
314            || (pulse_end <= op.begin_input - op.before
315                || pulse_begin >= end_input.saturating_add(after))
316        {
317            return Ok(output);
318        }
319
320        if pulse_begin < op.begin_input {
321            let fill_up_to = (op.begin_input - pulse_begin).min(pulse);
322            match &op.mode {
323                PadMode::Constant(_) => fill_slice_constant(
324                    &*ctx,
325                    &mut output,
326                    gpu_op.device_cst.as_ref().unwrap(),
327                    op.axis,
328                    0..fill_up_to,
329                )?,
330                PadMode::Edge => fill_slice_repeating_one_frame(
331                    &*ctx,
332                    &mut output,
333                    input,
334                    op.axis,
335                    0..fill_up_to,
336                    fill_up_to,
337                )?,
338                _ => unimplemented!(),
339            }
340        }
341
342        if pulse_end > end_input {
343            let fill_from = pulse - (pulse_end - end_input).min(pulse);
344            match &op.mode {
345                PadMode::Constant(_) => fill_slice_constant(
346                    &*ctx,
347                    &mut output,
348                    gpu_op.device_cst.as_ref().unwrap(),
349                    op.axis,
350                    fill_from..pulse,
351                )?,
352                PadMode::Edge => fill_slice_repeating_one_frame(
353                    &*ctx,
354                    &mut output,
355                    self.last_valid_frame.as_ref().unwrap(),
356                    op.axis,
357                    fill_from..pulse,
358                    0,
359                )?,
360                _ => unimplemented!(),
361            }
362        }
363        Ok(output)
364    }
365}
366
367impl OpState for GpuPulsePadState {
368    fn eval(
369        &mut self,
370        session: &mut TurnState,
371        op: &dyn Op,
372        inputs: TVec<TValue>,
373    ) -> TractResult<TVec<TValue>> {
374        let input = args_1!(inputs);
375        let gpu_op =
376            op.downcast_ref::<GpuPulsePad>().ok_or_else(|| format_err!("Wrong Op type"))?;
377        let device_input = input.as_device_tensor().context("Expected a GPU tensor")?;
378        let output = self.pad(session, gpu_op, device_input)?;
379        Ok(tvec!(output.into_tensor().into_tvalue()))
380    }
381}
382
383trivial_op_state_freeze!(GpuPulsePadState);