1use crate::tensor::{DeviceTensor, DeviceTensorExt};
2use derive_new::new;
3use tract_core::broadcast::multi_broadcast;
4use tract_core::internal::*;
5
6static IFF_MAX_RANK: usize = 5;
7
8pub type DispatchIffFn = fn(
12 cond: &DeviceTensor,
13 then_value: &DeviceTensor,
14 else_value: &DeviceTensor,
15 cond_strides: &[isize],
16 then_strides: &[isize],
17 else_strides: &[isize],
18 output: &DeviceTensor,
19 output_shape: &[usize],
20 output_strides: &[isize],
21) -> TractResult<()>;
22
23#[derive(Clone, new)]
24pub struct GpuIff {
25 pub backend_name: &'static str,
26 pub dispatch: DispatchIffFn,
27}
28
29impl std::fmt::Debug for GpuIff {
30 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
31 write!(f, "{}Iff", self.backend_name)
32 }
33}
34
35impl PartialEq for GpuIff {
36 fn eq(&self, other: &Self) -> bool {
37 self.backend_name == other.backend_name
38 }
39}
40
41impl Eq for GpuIff {}
42
43impl std::hash::Hash for GpuIff {
44 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
45 self.backend_name.hash(state);
46 }
47}
48
49impl Op for GpuIff {
50 fn name(&self) -> StaticName {
51 format!("{}Iff", self.backend_name).into()
52 }
53
54 op_as_typed_op!();
55}
56
57impl EvalOp for GpuIff {
58 fn is_stateless(&self) -> bool {
59 true
60 }
61
62 fn eval_with_session(
63 &self,
64 node_id: usize,
65 session: &TurnState,
66 inputs: TVec<TValue>,
67 ) -> TractResult<TVec<TValue>> {
68 let (cond_val, then_val, else_val) = args_3!(inputs);
69
70 let cond = cond_val.to_device_tensor()?;
71 let then_t = then_val.to_device_tensor()?;
72 let else_t = else_val.to_device_tensor()?;
73 ensure!(cond.rank() == then_t.rank());
74 ensure!(cond.rank() == else_t.rank());
75 ensure!(then_t.datum_type() == else_t.datum_type());
76
77 let out_shape = multi_broadcast(&[cond.shape(), then_t.shape(), else_t.shape()])
78 .context("No broadcasting solution found")?;
79 let out_dt = then_t.datum_type();
80 let output =
81 crate::session_handler::make_tensor_for_node(session, node_id, out_dt, &out_shape)?;
82
83 if output.len() > 0 {
84 let rank = cond.rank();
85 ensure!(rank <= IFF_MAX_RANK);
86 let rank_pad = IFF_MAX_RANK - rank;
87
88 let mut padded_cond_strides = [0isize; IFF_MAX_RANK];
89 let mut padded_then_strides = [0isize; IFF_MAX_RANK];
90 let mut padded_else_strides = [0isize; IFF_MAX_RANK];
91 let mut padded_out_shape = [1usize; IFF_MAX_RANK];
92 let mut padded_out_strides = [0isize; IFF_MAX_RANK];
93
94 for axis in 0..rank {
95 padded_out_shape[rank_pad + axis] = output.shape()[axis];
96 padded_out_strides[rank_pad + axis] = output.strides()[axis];
97 padded_cond_strides[rank_pad + axis] = if cond.shape()[axis] < output.shape()[axis]
98 {
99 0
100 } else {
101 cond.strides()[axis]
102 };
103 padded_then_strides[rank_pad + axis] =
104 if then_t.shape()[axis] < output.shape()[axis] {
105 0
106 } else {
107 then_t.strides()[axis]
108 };
109 padded_else_strides[rank_pad + axis] =
110 if else_t.shape()[axis] < output.shape()[axis] {
111 0
112 } else {
113 else_t.strides()[axis]
114 };
115 }
116
117 (self.dispatch)(
118 cond,
119 then_t,
120 else_t,
121 &padded_cond_strides,
122 &padded_then_strides,
123 &padded_else_strides,
124 &output,
125 &padded_out_shape,
126 &padded_out_strides,
127 )
128 .with_context(|| "Error while dispatching eval for Iff")?;
129 }
130 Ok(tvec!(output.into_tensor().into_tvalue()))
131 }
132}
133
134impl TypedOp for GpuIff {
135 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
136 crate::utils::facts_to_device_facts(inputs, |inputs| {
137 let out_shape =
138 multi_broadcast(&[&*inputs[0].shape, &*inputs[1].shape, &*inputs[2].shape])
139 .context("No broadcasting solution found")?;
140 let out_dt = inputs[1].datum_type;
141 Ok(tvec!(out_dt.fact(out_shape)))
142 })
143 }
144
145 as_op!();
146}