tract_core/ops/array/
tile.rs

1use crate::internal::*;
2use ndarray::*;
3
4use super::MultiBroadcastTo;
5
6#[derive(Debug, Clone, new, Default, Hash)]
7pub struct Tile {
8    pub multipliers: TVec<TDim>,
9}
10
11impl Op for Tile {
12    fn name(&self) -> Cow<str> {
13        "Tile".into()
14    }
15
16    fn info(&self) -> TractResult<Vec<String>> {
17        Ok(vec![format!("multipliers: {:?}", self.multipliers)])
18    }
19
20    op_as_typed_op!();
21}
22
23impl EvalOp for Tile {
24    fn is_stateless(&self) -> bool {
25        true
26    }
27
28    fn eval_with_session(
29        &self,
30        session: &SessionState,
31        inputs: TVec<TValue>,
32    ) -> TractResult<TVec<TValue>> {
33        let multipliers: TVec<usize> = self
34            .multipliers
35            .iter()
36            .map(|m| m.eval(&session.resolved_symbols).to_usize())
37            .collect::<TractResult<_>>()?;
38        let result =
39            dispatch_datum_by_size!(eval_t(inputs[0].datum_type())(&inputs[0], &multipliers))?;
40        Ok(tvec!(result))
41    }
42}
43
44impl TypedOp for Tile {
45    as_op!();
46
47    fn concretize_dims(
48        &self,
49        _source: &TypedModel,
50        node: &TypedNode,
51        target: &mut TypedModel,
52        mapping: &HashMap<OutletId, OutletId>,
53        values: &SymbolValues,
54    ) -> TractResult<TVec<OutletId>> {
55        let multipliers = self.multipliers.iter().map(|m| m.eval(values)).collect();
56        target.wire_node(&node.name, Self { multipliers }, &[mapping[&node.inputs[0]]])
57    }
58
59    fn declutter(
60        &self,
61        model: &TypedModel,
62        node: &TypedNode,
63    ) -> TractResult<Option<TypedModelPatch>> {
64        let input_fact = model.outlet_fact(node.inputs[0])?;
65        if input_fact
66            .shape
67            .iter()
68            .zip(self.multipliers.iter())
69            .all(|(i, m)| i.is_one() || m.is_one())
70        {
71            let output_fact = self.output_facts(&[input_fact])?.remove(0);
72            TypedModelPatch::replace_single_op(
73                model,
74                node,
75                &node.inputs[0..1],
76                MultiBroadcastTo { shape: output_fact.shape },
77            )
78            .map(Some)
79        } else {
80            Ok(None)
81        }
82    }
83
84    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
85        let shape = inputs[0]
86            .shape
87            .iter()
88            .zip(self.multipliers.iter())
89            .map(|(a, b)| a.clone() * b)
90            .collect::<TVec<_>>();
91        Ok(tvec!(inputs[0].datum_type.fact(shape)))
92    }
93}
94
95#[derive(Debug, Clone, Hash)]
96pub struct DynTile {
97    pub multiplier_placeholders: TVec<TDim>,
98}
99
100impl DynTile {
101    pub fn new(scope: &SymbolScope, rank: usize) -> DynTile {
102        let multiplier_placeholders =
103            (0..rank).map(|_| scope.new_with_prefix("_tile_mult_").to_dim()).collect();
104        DynTile { multiplier_placeholders }
105    }
106}
107
108impl Op for DynTile {
109    fn name(&self) -> Cow<str> {
110        "DynTile".into()
111    }
112
113    op_as_typed_op!();
114}
115
116impl EvalOp for DynTile {
117    fn is_stateless(&self) -> bool {
118        true
119    }
120
121    fn eval_with_session(
122        &self,
123        session: &SessionState,
124        inputs: TVec<TValue>,
125    ) -> TractResult<TVec<TValue>> {
126        let multipliers = inputs[1].cast_to::<TDim>()?;
127        let multipliers: TVec<usize> = multipliers
128            .as_slice::<TDim>()?
129            .iter()
130            .map(|m| Ok(m.eval_to_i64(&session.resolved_symbols)? as usize))
131            .collect::<TractResult<_>>()?;
132        let result =
133            dispatch_datum_by_size!(eval_t(inputs[0].datum_type())(&inputs[0], &multipliers))?;
134        Ok(tvec!(result))
135    }
136}
137
138impl TypedOp for DynTile {
139    as_op!();
140
141    fn declutter(
142        &self,
143        model: &TypedModel,
144        node: &TypedNode,
145    ) -> TractResult<Option<TypedModelPatch>> {
146        if let Some(mult) = &model.outlet_fact(node.inputs[1])?.konst {
147            let multipliers = mult.cast_to::<TDim>()?.as_slice::<TDim>()?.iter().cloned().collect();
148            return TypedModelPatch::replace_single_op(
149                model,
150                node,
151                &node.inputs,
152                Tile { multipliers },
153            )
154            .map(Some);
155        }
156        Ok(None)
157    }
158
159    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
160        let multipliers = if let Some(k) = &inputs[1].konst {
161            k.cast_to::<TDim>()?.as_slice::<TDim>()?.iter().cloned().collect()
162        } else {
163            self.multiplier_placeholders.clone()
164        };
165        let shape =
166            inputs[0].shape.iter().zip(multipliers).map(|(a, b)| b * a).collect::<TVec<_>>();
167        Ok(tvec!(inputs[0].datum_type.fact(shape)))
168    }
169}
170
171fn eval_t<T: Datum>(data: &TValue, multipliers: &[usize]) -> TractResult<TValue> {
172    let view = unsafe { data.to_array_view_unchecked::<T>() };
173    let output_shape: TVec<usize> =
174        view.shape().iter().zip(multipliers.iter()).map(|(&d, &m)| d * m).collect();
175    let output = ndarray::ArrayD::from_shape_fn(&*output_shape, |coords| {
176        let coords: TVec<usize> =
177            coords.slice().iter().zip(data.shape().iter()).map(|(&x, &d)| x % d).collect();
178        view[&*coords].clone()
179    });
180    let mut output = output.into_tensor();
181    unsafe {
182        output.set_datum_type(data.datum_type());
183    }
184
185    Ok(output.into_tvalue())
186}