Skip to main content

tract_gpu/ops/
element_wise.rs

1use crate::tensor::{DeviceTensor, DeviceTensorExt};
2use derive_new::new;
3use tract_core::internal::*;
4use tract_core::ops::element_wise::ElementWiseMiniOp;
5
6pub type DispatchElementWiseFn =
7    fn(&dyn ElementWiseMiniOp, &DeviceTensor, &DeviceTensor) -> TractResult<()>;
8
9#[derive(Clone, new)]
10pub struct GpuElementWise {
11    pub mini_op: Box<dyn ElementWiseMiniOp>,
12    pub backend_name: &'static str,
13    pub dispatch: DispatchElementWiseFn,
14}
15
16impl std::fmt::Debug for GpuElementWise {
17    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
18        write!(f, "GpuElementWise({}{:?})", self.backend_name, self.mini_op)
19    }
20}
21
22impl PartialEq for GpuElementWise {
23    fn eq(&self, other: &Self) -> bool {
24        self.backend_name == other.backend_name && self.mini_op == other.mini_op
25    }
26}
27
28impl Eq for GpuElementWise {}
29
30impl std::hash::Hash for GpuElementWise {
31    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
32        self.backend_name.hash(state);
33        self.mini_op.name().hash(state);
34    }
35}
36
37impl Op for GpuElementWise {
38    fn name(&self) -> StaticName {
39        format!("{}{}", self.backend_name, self.mini_op.name()).into()
40    }
41
42    op_as_typed_op!();
43}
44
45impl EvalOp for GpuElementWise {
46    fn is_stateless(&self) -> bool {
47        true
48    }
49
50    fn eval_with_session(
51        &self,
52        node_id: usize,
53        session: &TurnState,
54        inputs: TVec<TValue>,
55    ) -> TractResult<TVec<TValue>> {
56        let input_value = args_1!(inputs);
57        let input = input_value.to_device_tensor()?;
58        let output = crate::session_handler::make_tensor_for_node(
59            session,
60            node_id,
61            input.datum_type(),
62            input.shape(),
63        )?;
64        (self.dispatch)(&*self.mini_op, input, &output)
65            .with_context(|| format!("Error while dispatching eval for {}", self.name()))?;
66        Ok(tvec!(output.into_tensor().into_tvalue()))
67    }
68}
69
70impl TypedOp for GpuElementWise {
71    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
72        crate::utils::facts_to_device_facts(inputs, |facts| {
73            let dt = facts[0].datum_type;
74            let fact = dt.fact(facts[0].shape.clone());
75            Ok(tvec!(fact))
76        })
77        .with_context(|| format!("Error while computing facts for {:?}", self.name()))
78    }
79
80    as_op!();
81}