1use crate::tensor::DeviceTensorExt;
2use crate::utils::compute_broadcast_strides;
3use tract_core::internal::*;
4use tract_core::ops::array::Slice;
5
6#[derive(Clone, Debug, PartialEq, Eq, Hash)]
7pub struct GpuSlice {
8 pub inner: Slice,
9}
10
11impl GpuSlice {
12 pub fn new(inner: Slice) -> Self {
13 Self { inner }
14 }
15}
16
17impl Op for GpuSlice {
18 fn name(&self) -> StaticName {
19 "GpuSlice".into()
20 }
21
22 fn info(&self) -> TractResult<Vec<String>> {
23 self.inner.info()
24 }
25
26 op_as_typed_op!();
27}
28
29impl EvalOp for GpuSlice {
30 fn is_stateless(&self) -> bool {
31 true
32 }
33
34 fn eval_with_session(
35 &self,
36 node_id: usize,
37 session: &TurnState,
38 inputs: TVec<TValue>,
39 ) -> TractResult<TVec<TValue>> {
40 let input_value = args_1!(inputs);
41 let input = input_value.to_device_tensor()?;
42
43 let start = self.inner.start.eval(&session.resolved_symbols).to_usize()?;
44 let end = self.inner.end.eval(&session.resolved_symbols).to_usize()?;
45 let axis = self.inner.axis;
46
47 let input_shape = input.shape();
48 let input_strides = input.strides();
49 let input_dt = input.datum_type();
50
51 ensure!(
52 end <= input_shape[axis] && start <= end,
53 "Invalid range {}..{} for slicing {:?} on axis {}",
54 start,
55 end,
56 input,
57 axis
58 );
59
60 let mut o_shape: TVec<usize> = input_shape.into();
61 o_shape[axis] = end - start;
62
63 let offset = (start * input_strides[axis] as usize) * input_dt.size_of();
64
65 let output = crate::session_handler::make_tensor_for_node(
66 session,
67 node_id,
68 input.datum_type(),
69 &o_shape,
70 )?;
71
72 if o_shape[axis] != 0 {
73 let broadcast_strides: TVec<isize> =
75 compute_broadcast_strides(&o_shape, input_strides)?;
76 let ctx = crate::device::get_context()?;
77 ctx.copy_nd(
78 input,
79 offset,
80 &broadcast_strides,
81 &output,
82 0,
83 output.shape(),
84 output.strides(),
85 )?;
86 }
87 Ok(tvec![output.into_tensor().into_tvalue()])
88 }
89}
90
91impl TypedOp for GpuSlice {
92 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
93 crate::utils::facts_to_device_facts(inputs, |facts| self.inner.output_facts(facts))
94 .with_context(|| format!("Error while computing facts for {:?}", self.name()))
95 }
96
97 fn concretize_dims(
98 &self,
99 _source: &TypedModel,
100 node: &TypedNode,
101 target: &mut TypedModel,
102 mapping: &HashMap<OutletId, OutletId>,
103 values: &SymbolValues,
104 ) -> TractResult<TVec<OutletId>> {
105 let op = GpuSlice {
106 inner: Slice {
107 axis: self.inner.axis,
108 start: self.inner.start.eval(values),
109 end: self.inner.end.eval(values),
110 },
111 };
112 let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();
113 target.wire_node(&node.name, op, &inputs)
114 }
115
116 as_op!();
117}