tract_core/ops/array/
tile.rs1use crate::internal::*;
2use ndarray::*;
3
4use super::MultiBroadcastTo;
5
6#[derive(Debug, Clone, new, Default, Hash)]
7pub struct Tile {
8 pub multipliers: TVec<TDim>,
9}
10
11impl Op for Tile {
12 fn name(&self) -> Cow<str> {
13 "Tile".into()
14 }
15
16 fn info(&self) -> TractResult<Vec<String>> {
17 Ok(vec![format!("multipliers: {:?}", self.multipliers)])
18 }
19
20 op_as_typed_op!();
21}
22
23impl EvalOp for Tile {
24 fn is_stateless(&self) -> bool {
25 true
26 }
27
28 fn eval_with_session(
29 &self,
30 session: &SessionState,
31 inputs: TVec<TValue>,
32 ) -> TractResult<TVec<TValue>> {
33 let multipliers: TVec<usize> = self
34 .multipliers
35 .iter()
36 .map(|m| m.eval(&session.resolved_symbols).to_usize())
37 .collect::<TractResult<_>>()?;
38 let result =
39 dispatch_datum_by_size!(eval_t(inputs[0].datum_type())(&inputs[0], &multipliers))?;
40 Ok(tvec!(result))
41 }
42}
43
44impl TypedOp for Tile {
45 as_op!();
46
47 fn concretize_dims(
48 &self,
49 _source: &TypedModel,
50 node: &TypedNode,
51 target: &mut TypedModel,
52 mapping: &HashMap<OutletId, OutletId>,
53 values: &SymbolValues,
54 ) -> TractResult<TVec<OutletId>> {
55 let multipliers = self.multipliers.iter().map(|m| m.eval(values)).collect();
56 target.wire_node(&node.name, Self { multipliers }, &[mapping[&node.inputs[0]]])
57 }
58
59 fn declutter(
60 &self,
61 model: &TypedModel,
62 node: &TypedNode,
63 ) -> TractResult<Option<TypedModelPatch>> {
64 let input_fact = model.outlet_fact(node.inputs[0])?;
65 if input_fact
66 .shape
67 .iter()
68 .zip(self.multipliers.iter())
69 .all(|(i, m)| i.is_one() || m.is_one())
70 {
71 let output_fact = self.output_facts(&[input_fact])?.remove(0);
72 TypedModelPatch::replace_single_op(
73 model,
74 node,
75 &node.inputs[0..1],
76 MultiBroadcastTo { shape: output_fact.shape },
77 )
78 .map(Some)
79 } else {
80 Ok(None)
81 }
82 }
83
84 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
85 let shape = inputs[0]
86 .shape
87 .iter()
88 .zip(self.multipliers.iter())
89 .map(|(a, b)| a.clone() * b)
90 .collect::<TVec<_>>();
91 Ok(tvec!(inputs[0].datum_type.fact(shape)))
92 }
93}
94
95#[derive(Debug, Clone, Hash)]
96pub struct DynTile {
97 pub multiplier_placeholders: TVec<TDim>,
98}
99
100impl DynTile {
101 pub fn new(scope: &SymbolScope, rank: usize) -> DynTile {
102 let multiplier_placeholders =
103 (0..rank).map(|_| scope.new_with_prefix("_tile_mult_").to_dim()).collect();
104 DynTile { multiplier_placeholders }
105 }
106}
107
108impl Op for DynTile {
109 fn name(&self) -> Cow<str> {
110 "DynTile".into()
111 }
112
113 op_as_typed_op!();
114}
115
116impl EvalOp for DynTile {
117 fn is_stateless(&self) -> bool {
118 true
119 }
120
121 fn eval_with_session(
122 &self,
123 session: &SessionState,
124 inputs: TVec<TValue>,
125 ) -> TractResult<TVec<TValue>> {
126 let multipliers = inputs[1].cast_to::<TDim>()?;
127 let multipliers: TVec<usize> = multipliers
128 .as_slice::<TDim>()?
129 .iter()
130 .map(|m| Ok(m.eval_to_i64(&session.resolved_symbols)? as usize))
131 .collect::<TractResult<_>>()?;
132 let result =
133 dispatch_datum_by_size!(eval_t(inputs[0].datum_type())(&inputs[0], &multipliers))?;
134 Ok(tvec!(result))
135 }
136}
137
138impl TypedOp for DynTile {
139 as_op!();
140
141 fn declutter(
142 &self,
143 model: &TypedModel,
144 node: &TypedNode,
145 ) -> TractResult<Option<TypedModelPatch>> {
146 if let Some(mult) = &model.outlet_fact(node.inputs[1])?.konst {
147 let multipliers = mult.cast_to::<TDim>()?.as_slice::<TDim>()?.iter().cloned().collect();
148 return TypedModelPatch::replace_single_op(
149 model,
150 node,
151 &node.inputs,
152 Tile { multipliers },
153 )
154 .map(Some);
155 }
156 Ok(None)
157 }
158
159 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
160 let multipliers = if let Some(k) = &inputs[1].konst {
161 k.cast_to::<TDim>()?.as_slice::<TDim>()?.iter().cloned().collect()
162 } else {
163 self.multiplier_placeholders.clone()
164 };
165 let shape =
166 inputs[0].shape.iter().zip(multipliers).map(|(a, b)| b * a).collect::<TVec<_>>();
167 Ok(tvec!(inputs[0].datum_type.fact(shape)))
168 }
169}
170
171fn eval_t<T: Datum>(data: &TValue, multipliers: &[usize]) -> TractResult<TValue> {
172 let view = unsafe { data.to_array_view_unchecked::<T>() };
173 let output_shape: TVec<usize> =
174 view.shape().iter().zip(multipliers.iter()).map(|(&d, &m)| d * m).collect();
175 let output = ndarray::ArrayD::from_shape_fn(&*output_shape, |coords| {
176 let coords: TVec<usize> =
177 coords.slice().iter().zip(data.shape().iter()).map(|(&x, &d)| x % d).collect();
178 view[&*coords].clone()
179 });
180 let mut output = output.into_tensor();
181 unsafe {
182 output.set_datum_type(data.datum_type());
183 }
184
185 Ok(output.into_tvalue())
186}