Skip to main content

tract_gpu/ops/
slice.rs

1use crate::tensor::DeviceTensorExt;
2use crate::utils::compute_broadcast_strides;
3use tract_core::internal::*;
4use tract_core::ops::array::Slice;
5
6#[derive(Clone, Debug, PartialEq, Eq, Hash)]
7pub struct GpuSlice {
8    pub inner: Slice,
9}
10
11impl GpuSlice {
12    pub fn new(inner: Slice) -> Self {
13        Self { inner }
14    }
15}
16
17impl Op for GpuSlice {
18    fn name(&self) -> StaticName {
19        "GpuSlice".into()
20    }
21
22    fn info(&self) -> TractResult<Vec<String>> {
23        self.inner.info()
24    }
25
26    op_as_typed_op!();
27}
28
29impl EvalOp for GpuSlice {
30    fn is_stateless(&self) -> bool {
31        true
32    }
33
34    fn eval_with_session(
35        &self,
36        node_id: usize,
37        session: &TurnState,
38        inputs: TVec<TValue>,
39    ) -> TractResult<TVec<TValue>> {
40        let input_value = args_1!(inputs);
41        let input = input_value.to_device_tensor()?;
42
43        let start = self.inner.start.eval(&session.resolved_symbols).to_usize()?;
44        let end = self.inner.end.eval(&session.resolved_symbols).to_usize()?;
45        let axis = self.inner.axis;
46
47        let input_shape = input.shape();
48        let input_strides = input.strides();
49        let input_dt = input.datum_type();
50
51        ensure!(
52            end <= input_shape[axis] && start <= end,
53            "Invalid range {}..{} for slicing {:?} on axis {}",
54            start,
55            end,
56            input,
57            axis
58        );
59
60        let mut o_shape: TVec<usize> = input_shape.into();
61        o_shape[axis] = end - start;
62
63        let offset = (start * input_strides[axis] as usize) * input_dt.size_of();
64
65        let output = crate::session_handler::make_tensor_for_node(
66            session,
67            node_id,
68            input.datum_type(),
69            &o_shape,
70        )?;
71
72        if o_shape[axis] != 0 {
73            // Slice uses same strides as input (broadcast strides with matching shapes)
74            let broadcast_strides: TVec<isize> =
75                compute_broadcast_strides(&o_shape, input_strides)?;
76            let ctx = crate::device::get_context()?;
77            ctx.copy_nd(
78                input,
79                offset,
80                &broadcast_strides,
81                &output,
82                0,
83                output.shape(),
84                output.strides(),
85            )?;
86        }
87        Ok(tvec![output.into_tensor().into_tvalue()])
88    }
89}
90
91impl TypedOp for GpuSlice {
92    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
93        crate::utils::facts_to_device_facts(inputs, |facts| self.inner.output_facts(facts))
94            .with_context(|| format!("Error while computing facts for {:?}", self.name()))
95    }
96
97    fn concretize_dims(
98        &self,
99        _source: &TypedModel,
100        node: &TypedNode,
101        target: &mut TypedModel,
102        mapping: &HashMap<OutletId, OutletId>,
103        values: &SymbolValues,
104    ) -> TractResult<TVec<OutletId>> {
105        let op = GpuSlice {
106            inner: Slice {
107                axis: self.inner.axis,
108                start: self.inner.start.eval(values),
109                end: self.inner.end.eval(values),
110            },
111        };
112        let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();
113        target.wire_node(&node.name, op, &inputs)
114    }
115
116    as_op!();
117}