tract_hir/ops/array/
constant_like.rs1use crate::internal::*;
2use tract_ndarray::*;
3use tract_num_traits::{AsPrimitive, One, Zero};
4
5#[derive(Debug, Clone, new, Default)]
6pub struct ConstantLike {
7 value: f32,
8}
9
10impl Op for ConstantLike {
11 fn name(&self) -> StaticName {
12 "ConstantLike".into()
13 }
14
15 op_as_typed_op!();
16}
17
18impl EvalOp for ConstantLike {
19 fn is_stateless(&self) -> bool {
20 true
21 }
22
23 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
24 let input = args_1!(inputs);
25 Ok(tvec!(tensor0(self.value).broadcast_scalar_to_shape(input.shape())?.into_tvalue()))
26 }
27}
28
29impl InferenceRulesOp for ConstantLike {
30 fn rules<'r, 'p: 'r, 's: 'r>(
31 &'s self,
32 s: &mut Solver<'r>,
33 inputs: &'p [TensorProxy],
34 outputs: &'p [TensorProxy],
35 ) -> InferenceResult {
36 check_input_arity(inputs, 1)?;
37 check_output_arity(outputs, 1)?;
38 s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
39 s.equals(&inputs[0].rank, &outputs[0].rank)?;
40 s.equals(&inputs[0].shape, &outputs[0].shape)?;
41 s.given_2(&inputs[0].shape, &inputs[0].datum_type, move |s, shape, dt| {
42 if shape.iter().all(|d| d.to_usize().is_ok()) {
43 let shape: Vec<usize> = shape.iter().map(|d| d.to_usize().unwrap()).collect();
44 let value = tensor0(self.value)
45 .cast_to_dt(dt)?
46 .broadcast_scalar_to_shape(&shape)?
47 .into_arc_tensor();
48 s.equals(&outputs[0].value, value)?;
49 }
50 Ok(())
51 })
52 }
53
54 as_op!();
55 to_typed!();
56}
57
58impl TypedOp for ConstantLike {
59 as_op!();
60
61 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
62 Ok(tvec!(inputs[0].clone()))
63 }
64}
65
66#[derive(Debug, Clone, new, Default, Hash)]
67pub struct EyeLike {
68 dt: Option<DatumType>,
69 k: isize,
70}
71
72impl EyeLike {
73 pub fn make<T>(&self, (r, c): (usize, usize)) -> TractResult<TValue>
74 where
75 T: Copy + Datum + One + Zero,
76 f32: AsPrimitive<T>,
77 {
78 let mut array = Array2::<T>::zeros((r, c));
79 for y in 0..r {
80 let x = y as isize + self.k;
81 if x >= 0 && x < c as isize {
82 array[(y, x as usize)] = T::one()
83 }
84 }
85 Ok(array.into_dyn().into_tvalue())
86 }
87}
88
89impl Op for EyeLike {
90 fn name(&self) -> StaticName {
91 "EyeLike".into()
92 }
93
94 op_as_typed_op!();
95}
96
97impl EvalOp for EyeLike {
98 fn is_stateless(&self) -> bool {
99 true
100 }
101
102 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
103 let input = args_1!(inputs);
104 let dt = self.dt.unwrap_or_else(|| input.datum_type());
105 Ok(tvec!(dispatch_numbers!(Self::make(dt)(self, (input.shape()[0], input.shape()[1])))?))
106 }
107}
108
109impl InferenceRulesOp for EyeLike {
110 fn rules<'r, 'p: 'r, 's: 'r>(
111 &'s self,
112 s: &mut Solver<'r>,
113 inputs: &'p [TensorProxy],
114 outputs: &'p [TensorProxy],
115 ) -> InferenceResult {
116 check_input_arity(inputs, 1)?;
117 check_output_arity(outputs, 1)?;
118 if let Some(dt) = self.dt {
119 s.equals(&outputs[0].datum_type, dt)?;
120 } else {
121 s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
122 }
123 s.equals(&inputs[0].rank, 2)?;
124 s.equals(&inputs[0].shape, &outputs[0].shape)?;
125 s.given(&inputs[0].shape, move |s, shape| {
126 if let (Ok(r), Ok(c)) = (shape[0].to_usize(), shape[1].to_usize()) {
127 let shape = (r, c);
128 if let Some(dt) = self.dt {
129 let value = dispatch_numbers!(Self::make(dt)(self, shape))?;
130 s.equals(&outputs[0].value, value.into_arc_tensor())?;
131 } else {
132 s.given(&inputs[0].datum_type, move |s, dt| {
133 let value = dispatch_numbers!(Self::make(dt)(self, shape))?;
134 s.equals(&outputs[0].value, value.into_arc_tensor())
135 })?;
136 }
137 }
138 Ok(())
139 })
140 }
141
142 as_op!();
143 to_typed!();
144}
145
146impl TypedOp for EyeLike {
147 as_op!();
148
149 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
150 Ok(tvec!(self.dt.unwrap_or(inputs[0].datum_type).fact(inputs[0].shape.iter())))
151 }
152}