Skip to main content

tract_gpu/ops/
gather.rs

1use crate::tensor::{DeviceTensor, DeviceTensorExt};
2use derive_new::new;
3use tract_core::internal::*;
4
5/// `output = data.gather(axis, indices)`, i.e.
6/// `output[..., i, ...] = data[..., indices[i], ...]` along `axis`.
7/// Negative indices wrap (matches the CPU op).
8///
9/// First implementation supports the plain-tensor path only (no block-quant,
10/// no packed-matrix storage); the translator's `rule_if` guards the rest out.
11pub type DispatchGatherFn = fn(
12    data: &DeviceTensor,
13    indices: &DeviceTensor,
14    axis: usize,
15    output: &DeviceTensor,
16) -> TractResult<()>;
17
18#[derive(Clone, new)]
19pub struct GpuGather {
20    pub axis: usize,
21    pub backend_name: &'static str,
22    pub dispatch: DispatchGatherFn,
23}
24
25impl std::fmt::Debug for GpuGather {
26    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
27        write!(f, "{}Gather", self.backend_name)
28    }
29}
30
31impl PartialEq for GpuGather {
32    fn eq(&self, other: &Self) -> bool {
33        self.backend_name == other.backend_name && self.axis == other.axis
34    }
35}
36impl Eq for GpuGather {}
37
38impl std::hash::Hash for GpuGather {
39    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
40        self.backend_name.hash(state);
41        self.axis.hash(state);
42    }
43}
44
45impl Op for GpuGather {
46    fn name(&self) -> StaticName {
47        format!("{}Gather", self.backend_name).into()
48    }
49    fn info(&self) -> TractResult<Vec<String>> {
50        Ok(vec![format!("axis={}", self.axis)])
51    }
52    op_as_typed_op!();
53}
54
55impl EvalOp for GpuGather {
56    fn is_stateless(&self) -> bool {
57        true
58    }
59
60    fn eval_with_session(
61        &self,
62        node_id: usize,
63        session: &TurnState,
64        inputs: TVec<TValue>,
65    ) -> TractResult<TVec<TValue>> {
66        let (data_val, indices_val) = args_2!(inputs);
67        let data = data_val.to_device_tensor()?;
68        let indices = indices_val.to_device_tensor()?;
69        let out_shape = compute_output_shape(self.axis, data.shape(), indices.shape())?;
70        let output = crate::session_handler::make_tensor_for_node(
71            session,
72            node_id,
73            data.datum_type(),
74            &out_shape,
75        )?;
76        (self.dispatch)(data, indices, self.axis, &output)?;
77        Ok(tvec!(output.into_tensor().into_tvalue()))
78    }
79}
80
81impl TypedOp for GpuGather {
82    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
83        crate::utils::facts_to_device_facts(inputs, |facts| {
84            ensure!(facts.len() == 2);
85            ensure!(facts[1].datum_type == i64::datum_type());
86            ensure!(facts[0].rank() > self.axis);
87            let dt = facts[0].datum_type;
88            let mut shape: TVec<TDim> = facts[0].shape.iter().take(self.axis).cloned().collect();
89            shape.extend(facts[1].shape.iter().cloned());
90            shape.extend(facts[0].shape.iter().skip(self.axis + 1).cloned());
91            Ok(tvec!(dt.fact(&shape)))
92        })
93        .with_context(|| format!("Error while computing facts for {:?}", self.name()))
94    }
95    as_op!();
96}
97
98fn compute_output_shape(
99    axis: usize,
100    data: &[usize],
101    indices: &[usize],
102) -> TractResult<TVec<usize>> {
103    ensure!(data.len() > axis);
104    let mut out: TVec<usize> = data[..axis].into();
105    out.extend(indices.iter().copied());
106    out.extend(data[axis + 1..].iter().copied());
107    Ok(out)
108}