tract_hir/ops/array/
tile.rs1use crate::internal::*;
2
3#[derive(Debug, Clone, new, Default, Hash)]
4pub struct Tile;
5
6impl Expansion for Tile {
7 fn name(&self) -> StaticName {
8 "Tile".into()
9 }
10
11 fn rules<'r, 'p: 'r, 's: 'r>(
12 &'s self,
13 s: &mut Solver<'r>,
14 inputs: &'p [TensorProxy],
15 outputs: &'p [TensorProxy],
16 ) -> InferenceResult {
17 check_input_arity(inputs, 2)?;
18 check_output_arity(outputs, 1)?;
19 s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
20 s.equals(&inputs[0].rank, &outputs[0].rank)?;
21 s.equals(&inputs[1].rank, 1)?;
22 s.equals(&inputs[1].shape[0], inputs[0].rank.bex().to_dim())?;
23 s.given(&inputs[1].value, move |s, mult| {
24 for (ix, m) in mult.cast_to::<TDim>()?.as_slice::<TDim>()?.iter().enumerate() {
25 if let Some(m) = m.as_i64() {
26 s.equals(m * inputs[0].shape[ix].bex(), &outputs[0].shape[ix])?;
27 } else {
28 let m = m.clone();
29 s.given(&inputs[0].shape[ix], move |s, input| {
30 s.equals(input * &m, &outputs[0].shape[ix])
31 })?;
32 }
33 }
34 Ok(())
35 })?;
36 Ok(())
37 }
38
39 fn wire(
40 &self,
41 prefix: &str,
42 target: &mut TypedModel,
43 inputs: &[OutletId],
44 ) -> TractResult<TVec<OutletId>> {
45 if let Some(ref mult) = target.outlet_fact(inputs[1])?.konst {
46 let mult: TVec<TDim> = mult.cast_to::<TDim>()?.as_slice::<TDim>()?.into();
47 target.wire_node(prefix, tract_core::ops::array::Tile::new(mult), &inputs[0..1])
48 } else {
49 bail!("shape input is variable")
50 }
51 }
52}