Skip to main content

tract_gpu/ops/
binary.rs

1use crate::tensor::{DeviceTensor, DeviceTensorExt};
2use derive_new::new;
3use tract_core::internal::*;
4use tract_core::ops::binary::BinMiniOp;
5
6pub type DispatchBinOpFn =
7    fn(&dyn BinMiniOp, &DeviceTensor, &DeviceTensor, &DeviceTensor) -> TractResult<()>;
8
9#[derive(Clone, new)]
10pub struct GpuBinOp {
11    pub mini_op: Box<dyn BinMiniOp>,
12    pub backend_name: &'static str,
13    pub dispatch: DispatchBinOpFn,
14}
15
16impl std::fmt::Debug for GpuBinOp {
17    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
18        write!(f, "GpuBinOp({}{:?})", self.backend_name, self.mini_op)
19    }
20}
21
22impl PartialEq for GpuBinOp {
23    fn eq(&self, other: &Self) -> bool {
24        self.backend_name == other.backend_name && self.mini_op == other.mini_op
25    }
26}
27
28impl Eq for GpuBinOp {}
29
30impl std::hash::Hash for GpuBinOp {
31    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
32        self.backend_name.hash(state);
33        self.mini_op.name().hash(state);
34    }
35}
36
37impl GpuBinOp {
38    fn resolve_output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
39        let (a, b) = (inputs[0], inputs[1]);
40        if a.rank() != b.rank() {
41            bail!(
42                "Typed ops require rank match. Invalid inputs for {}: {{a: {:?}, b: {:?}}}",
43                self.name(),
44                a.shape,
45                b.shape
46            );
47        }
48        let out_shape = tract_core::broadcast::multi_broadcast(&[&a.shape, &b.shape])
49            .with_context(|| format!("Error while broadcasting {:?} {:?}", a.shape, b.shape))?;
50        let out_dt = self.mini_op.result_datum_type(a.datum_type, b.datum_type)?;
51        Ok(tvec!(out_dt.fact(out_shape)))
52    }
53}
54
55impl Op for GpuBinOp {
56    fn name(&self) -> StaticName {
57        format!("{}{}", self.backend_name, self.mini_op.name()).into()
58    }
59
60    op_as_typed_op!();
61}
62
63impl EvalOp for GpuBinOp {
64    fn is_stateless(&self) -> bool {
65        true
66    }
67
68    fn eval_with_session(
69        &self,
70        node_id: usize,
71        session: &TurnState,
72        inputs: TVec<TValue>,
73    ) -> TractResult<TVec<TValue>> {
74        let (a_val, b_val) = args_2!(inputs);
75        let a = a_val.to_device_tensor()?;
76        let b = b_val.to_device_tensor()?;
77        let out_shape = tract_core::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
78        let out_dt = self.mini_op.result_datum_type(a.datum_type(), b.datum_type())?;
79        let output =
80            crate::session_handler::make_tensor_for_node(session, node_id, out_dt, &out_shape)?;
81        if a.len() > 0 && b.len() > 0 {
82            (self.dispatch)(&*self.mini_op, a, b, &output)
83                .with_context(|| format!("Error while dispatching eval for {}", self.name()))?;
84        }
85        Ok(tvec!(output.into_tensor().into_tvalue()))
86    }
87}
88
89impl TypedOp for GpuBinOp {
90    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
91        crate::utils::facts_to_device_facts(inputs, |facts| self.resolve_output_facts(facts))
92            .with_context(|| format!("Error while computing facts for {:?}", self.name()))
93    }
94
95    as_op!();
96}