Skip to main content

tract_gpu/ops/
iff.rs

1use crate::tensor::{DeviceTensor, DeviceTensorExt};
2use derive_new::new;
3use tract_core::broadcast::multi_broadcast;
4use tract_core::internal::*;
5
6static IFF_MAX_RANK: usize = 5;
7
8/// Dispatch function for the iff (select) kernel.
9/// Args: cond, then, else tensors with pre-computed broadcast strides,
10/// output tensor, output shape and strides. All strides are padded to IFF_MAX_RANK.
11pub type DispatchIffFn = fn(
12    cond: &DeviceTensor,
13    then_value: &DeviceTensor,
14    else_value: &DeviceTensor,
15    cond_strides: &[isize],
16    then_strides: &[isize],
17    else_strides: &[isize],
18    output: &DeviceTensor,
19    output_shape: &[usize],
20    output_strides: &[isize],
21) -> TractResult<()>;
22
23#[derive(Clone, new)]
24pub struct GpuIff {
25    pub backend_name: &'static str,
26    pub dispatch: DispatchIffFn,
27}
28
29impl std::fmt::Debug for GpuIff {
30    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
31        write!(f, "{}Iff", self.backend_name)
32    }
33}
34
35impl PartialEq for GpuIff {
36    fn eq(&self, other: &Self) -> bool {
37        self.backend_name == other.backend_name
38    }
39}
40
41impl Eq for GpuIff {}
42
43impl std::hash::Hash for GpuIff {
44    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
45        self.backend_name.hash(state);
46    }
47}
48
49impl Op for GpuIff {
50    fn name(&self) -> StaticName {
51        format!("{}Iff", self.backend_name).into()
52    }
53
54    op_as_typed_op!();
55}
56
57impl EvalOp for GpuIff {
58    fn is_stateless(&self) -> bool {
59        true
60    }
61
62    fn eval_with_session(
63        &self,
64        node_id: usize,
65        session: &TurnState,
66        inputs: TVec<TValue>,
67    ) -> TractResult<TVec<TValue>> {
68        let (cond_val, then_val, else_val) = args_3!(inputs);
69
70        let cond = cond_val.to_device_tensor()?;
71        let then_t = then_val.to_device_tensor()?;
72        let else_t = else_val.to_device_tensor()?;
73        ensure!(cond.rank() == then_t.rank());
74        ensure!(cond.rank() == else_t.rank());
75        ensure!(then_t.datum_type() == else_t.datum_type());
76
77        let out_shape = multi_broadcast(&[cond.shape(), then_t.shape(), else_t.shape()])
78            .context("No broadcasting solution found")?;
79        let out_dt = then_t.datum_type();
80        let output =
81            crate::session_handler::make_tensor_for_node(session, node_id, out_dt, &out_shape)?;
82
83        if output.len() > 0 {
84            let rank = cond.rank();
85            ensure!(rank <= IFF_MAX_RANK);
86            let rank_pad = IFF_MAX_RANK - rank;
87
88            let mut padded_cond_strides = [0isize; IFF_MAX_RANK];
89            let mut padded_then_strides = [0isize; IFF_MAX_RANK];
90            let mut padded_else_strides = [0isize; IFF_MAX_RANK];
91            let mut padded_out_shape = [1usize; IFF_MAX_RANK];
92            let mut padded_out_strides = [0isize; IFF_MAX_RANK];
93
94            for axis in 0..rank {
95                padded_out_shape[rank_pad + axis] = output.shape()[axis];
96                padded_out_strides[rank_pad + axis] = output.strides()[axis];
97                padded_cond_strides[rank_pad + axis] = if cond.shape()[axis] < output.shape()[axis]
98                {
99                    0
100                } else {
101                    cond.strides()[axis]
102                };
103                padded_then_strides[rank_pad + axis] =
104                    if then_t.shape()[axis] < output.shape()[axis] {
105                        0
106                    } else {
107                        then_t.strides()[axis]
108                    };
109                padded_else_strides[rank_pad + axis] =
110                    if else_t.shape()[axis] < output.shape()[axis] {
111                        0
112                    } else {
113                        else_t.strides()[axis]
114                    };
115            }
116
117            (self.dispatch)(
118                cond,
119                then_t,
120                else_t,
121                &padded_cond_strides,
122                &padded_then_strides,
123                &padded_else_strides,
124                &output,
125                &padded_out_shape,
126                &padded_out_strides,
127            )
128            .with_context(|| "Error while dispatching eval for Iff")?;
129        }
130        Ok(tvec!(output.into_tensor().into_tvalue()))
131    }
132}
133
134impl TypedOp for GpuIff {
135    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
136        crate::utils::facts_to_device_facts(inputs, |inputs| {
137            let out_shape =
138                multi_broadcast(&[&*inputs[0].shape, &*inputs[1].shape, &*inputs[2].shape])
139                    .context("No broadcasting solution found")?;
140            let out_dt = inputs[1].datum_type;
141            Ok(tvec!(out_dt.fact(out_shape)))
142        })
143    }
144
145    as_op!();
146}