tract_hir/ops/array/
constant_like.rs

1use crate::internal::*;
2use tract_ndarray::*;
3use tract_num_traits::{AsPrimitive, One, Zero};
4
5#[derive(Debug, Clone, new, Default)]
6pub struct ConstantLike {
7    value: f32,
8}
9
10impl Op for ConstantLike {
11    fn name(&self) -> StaticName {
12        "ConstantLike".into()
13    }
14
15    op_as_typed_op!();
16}
17
18impl EvalOp for ConstantLike {
19    fn is_stateless(&self) -> bool {
20        true
21    }
22
23    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
24        let input = args_1!(inputs);
25        Ok(tvec!(tensor0(self.value).broadcast_scalar_to_shape(input.shape())?.into_tvalue()))
26    }
27}
28
29impl InferenceRulesOp for ConstantLike {
30    fn rules<'r, 'p: 'r, 's: 'r>(
31        &'s self,
32        s: &mut Solver<'r>,
33        inputs: &'p [TensorProxy],
34        outputs: &'p [TensorProxy],
35    ) -> InferenceResult {
36        check_input_arity(inputs, 1)?;
37        check_output_arity(outputs, 1)?;
38        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
39        s.equals(&inputs[0].rank, &outputs[0].rank)?;
40        s.equals(&inputs[0].shape, &outputs[0].shape)?;
41        s.given_2(&inputs[0].shape, &inputs[0].datum_type, move |s, shape, dt| {
42            if shape.iter().all(|d| d.to_usize().is_ok()) {
43                let shape: Vec<usize> = shape.iter().map(|d| d.to_usize().unwrap()).collect();
44                let value = tensor0(self.value)
45                    .cast_to_dt(dt)?
46                    .broadcast_scalar_to_shape(&shape)?
47                    .into_arc_tensor();
48                s.equals(&outputs[0].value, value)?;
49            }
50            Ok(())
51        })
52    }
53
54    as_op!();
55    to_typed!();
56}
57
58impl TypedOp for ConstantLike {
59    as_op!();
60
61    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
62        Ok(tvec!(inputs[0].clone()))
63    }
64}
65
66#[derive(Debug, Clone, new, Default, Hash)]
67pub struct EyeLike {
68    dt: Option<DatumType>,
69    k: isize,
70}
71
72impl EyeLike {
73    pub fn make<T>(&self, (r, c): (usize, usize)) -> TractResult<TValue>
74    where
75        T: Copy + Datum + One + Zero,
76        f32: AsPrimitive<T>,
77    {
78        let mut array = Array2::<T>::zeros((r, c));
79        for y in 0..r {
80            let x = y as isize + self.k;
81            if x >= 0 && x < c as isize {
82                array[(y, x as usize)] = T::one()
83            }
84        }
85        Ok(array.into_dyn().into_tvalue())
86    }
87}
88
89impl Op for EyeLike {
90    fn name(&self) -> StaticName {
91        "EyeLike".into()
92    }
93
94    op_as_typed_op!();
95}
96
97impl EvalOp for EyeLike {
98    fn is_stateless(&self) -> bool {
99        true
100    }
101
102    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
103        let input = args_1!(inputs);
104        let dt = self.dt.unwrap_or_else(|| input.datum_type());
105        Ok(tvec!(dispatch_numbers!(Self::make(dt)(self, (input.shape()[0], input.shape()[1])))?))
106    }
107}
108
109impl InferenceRulesOp for EyeLike {
110    fn rules<'r, 'p: 'r, 's: 'r>(
111        &'s self,
112        s: &mut Solver<'r>,
113        inputs: &'p [TensorProxy],
114        outputs: &'p [TensorProxy],
115    ) -> InferenceResult {
116        check_input_arity(inputs, 1)?;
117        check_output_arity(outputs, 1)?;
118        if let Some(dt) = self.dt {
119            s.equals(&outputs[0].datum_type, dt)?;
120        } else {
121            s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
122        }
123        s.equals(&inputs[0].rank, 2)?;
124        s.equals(&inputs[0].shape, &outputs[0].shape)?;
125        s.given(&inputs[0].shape, move |s, shape| {
126            if let (Ok(r), Ok(c)) = (shape[0].to_usize(), shape[1].to_usize()) {
127                let shape = (r, c);
128                if let Some(dt) = self.dt {
129                    let value = dispatch_numbers!(Self::make(dt)(self, shape))?;
130                    s.equals(&outputs[0].value, value.into_arc_tensor())?;
131                } else {
132                    s.given(&inputs[0].datum_type, move |s, dt| {
133                        let value = dispatch_numbers!(Self::make(dt)(self, shape))?;
134                        s.equals(&outputs[0].value, value.into_arc_tensor())
135                    })?;
136                }
137            }
138            Ok(())
139        })
140    }
141
142    as_op!();
143    to_typed!();
144}
145
146impl TypedOp for EyeLike {
147    as_op!();
148
149    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
150        Ok(tvec!(self.dt.unwrap_or(inputs[0].datum_type).fact(inputs[0].shape.iter())))
151    }
152}