1use crate::tensor::{DeviceTensor, DeviceTensorExt};
2use derive_new::new;
3use tract_core::internal::*;
4
5pub type DispatchGatherFn = fn(
12 data: &DeviceTensor,
13 indices: &DeviceTensor,
14 axis: usize,
15 output: &DeviceTensor,
16) -> TractResult<()>;
17
18#[derive(Clone, new)]
19pub struct GpuGather {
20 pub axis: usize,
21 pub backend_name: &'static str,
22 pub dispatch: DispatchGatherFn,
23}
24
25impl std::fmt::Debug for GpuGather {
26 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
27 write!(f, "{}Gather", self.backend_name)
28 }
29}
30
31impl PartialEq for GpuGather {
32 fn eq(&self, other: &Self) -> bool {
33 self.backend_name == other.backend_name && self.axis == other.axis
34 }
35}
36impl Eq for GpuGather {}
37
38impl std::hash::Hash for GpuGather {
39 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
40 self.backend_name.hash(state);
41 self.axis.hash(state);
42 }
43}
44
45impl Op for GpuGather {
46 fn name(&self) -> StaticName {
47 format!("{}Gather", self.backend_name).into()
48 }
49 fn info(&self) -> TractResult<Vec<String>> {
50 Ok(vec![format!("axis={}", self.axis)])
51 }
52 op_as_typed_op!();
53}
54
55impl EvalOp for GpuGather {
56 fn is_stateless(&self) -> bool {
57 true
58 }
59
60 fn eval_with_session(
61 &self,
62 node_id: usize,
63 session: &TurnState,
64 inputs: TVec<TValue>,
65 ) -> TractResult<TVec<TValue>> {
66 let (data_val, indices_val) = args_2!(inputs);
67 let data = data_val.to_device_tensor()?;
68 let indices = indices_val.to_device_tensor()?;
69 let out_shape = compute_output_shape(self.axis, data.shape(), indices.shape())?;
70 let output = crate::session_handler::make_tensor_for_node(
71 session,
72 node_id,
73 data.datum_type(),
74 &out_shape,
75 )?;
76 (self.dispatch)(data, indices, self.axis, &output)?;
77 Ok(tvec!(output.into_tensor().into_tvalue()))
78 }
79}
80
81impl TypedOp for GpuGather {
82 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
83 crate::utils::facts_to_device_facts(inputs, |facts| {
84 ensure!(facts.len() == 2);
85 ensure!(facts[1].datum_type == i64::datum_type());
86 ensure!(facts[0].rank() > self.axis);
87 let dt = facts[0].datum_type;
88 let mut shape: TVec<TDim> = facts[0].shape.iter().take(self.axis).cloned().collect();
89 shape.extend(facts[1].shape.iter().cloned());
90 shape.extend(facts[0].shape.iter().skip(self.axis + 1).cloned());
91 Ok(tvec!(dt.fact(&shape)))
92 })
93 .with_context(|| format!("Error while computing facts for {:?}", self.name()))
94 }
95 as_op!();
96}
97
98fn compute_output_shape(
99 axis: usize,
100 data: &[usize],
101 indices: &[usize],
102) -> TractResult<TVec<usize>> {
103 ensure!(data.len() > axis);
104 let mut out: TVec<usize> = data[..axis].into();
105 out.extend(indices.iter().copied());
106 out.extend(data[axis + 1..].iter().copied());
107 Ok(out)
108}