tract_hir/ops/array/
tile.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone, new, Default, Hash)]
4pub struct Tile;
5
6impl Expansion for Tile {
7    fn name(&self) -> StaticName {
8        "Tile".into()
9    }
10
11    fn rules<'r, 'p: 'r, 's: 'r>(
12        &'s self,
13        s: &mut Solver<'r>,
14        inputs: &'p [TensorProxy],
15        outputs: &'p [TensorProxy],
16    ) -> InferenceResult {
17        check_input_arity(inputs, 2)?;
18        check_output_arity(outputs, 1)?;
19        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
20        s.equals(&inputs[0].rank, &outputs[0].rank)?;
21        s.equals(&inputs[1].rank, 1)?;
22        s.equals(&inputs[1].shape[0], inputs[0].rank.bex().to_dim())?;
23        s.given(&inputs[1].value, move |s, mult| {
24            for (ix, m) in mult.cast_to::<TDim>()?.as_slice::<TDim>()?.iter().enumerate() {
25                if let Some(m) = m.as_i64() {
26                    s.equals(m * inputs[0].shape[ix].bex(), &outputs[0].shape[ix])?;
27                } else {
28                    let m = m.clone();
29                    s.given(&inputs[0].shape[ix], move |s, input| {
30                        s.equals(input * &m, &outputs[0].shape[ix])
31                    })?;
32                }
33            }
34            Ok(())
35        })?;
36        Ok(())
37    }
38
39    fn wire(
40        &self,
41        prefix: &str,
42        target: &mut TypedModel,
43        inputs: &[OutletId],
44    ) -> TractResult<TVec<OutletId>> {
45        if let Some(ref mult) = target.outlet_fact(inputs[1])?.konst {
46            let mult: TVec<TDim> = mult.cast_to::<TDim>()?.as_slice::<TDim>()?.into();
47            target.wire_node(prefix, tract_core::ops::array::Tile::new(mult), &inputs[0..1])
48        } else {
49            bail!("shape input is variable")
50        }
51    }
52}