tract_hir/ops/array/
add_dims.rs

1use crate::infer::*;
2use crate::internal::*;
3use tract_itertools::Itertools;
4
5#[derive(Debug, Clone, new, Hash)]
6pub struct AddDims {
7    pub axes: Vec<isize>,
8}
9
10
11
12impl AddDims {
13    pub fn output_shape<D: DimLike>(&self, input: &[D]) -> TVec<D> {
14        let rank = input.len() as isize;
15        let mut shape: TVec<D> = input.iter().cloned().collect();
16        let output_rank = rank + self.axes.len() as isize;
17        let axes = self
18            .axes
19            .iter()
20            .map(|&axis| if axis < 0 { axis + output_rank } else { axis } as usize)
21            .sorted();
22        for axis in axes {
23            shape.insert(axis, D::one())
24        }
25        shape
26    }
27}
28
29impl Expansion for AddDims {
30    fn name(&self) -> StaticName {
31        "AddDims".into()
32    }
33
34    fn info(&self) -> TractResult<Vec<String>> {
35        Ok(vec![format!("Axes: {:?}", self.axes)])
36    }
37
38
39    fn rules<'r, 'p: 'r, 's: 'r>(
40        &'s self,
41        s: &mut Solver<'r>,
42        inputs: &'p [TensorProxy],
43        outputs: &'p [TensorProxy],
44    ) -> InferenceResult {
45        check_output_arity(outputs, 1)?;
46        s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
47        s.equals(&outputs[0].rank, (&inputs[0].rank).bex() + self.axes.len() as i64)?;
48        s.given(&inputs[0].shape, move |s, shape| {
49            let output_shape = self.output_shape(&shape);
50            s.equals(&outputs[0].shape, output_shape)
51        })
52    }
53
54    fn wire(
55        &self,
56        prefix: &str,
57        model: &mut TypedModel,
58        inputs: &[OutletId],
59    ) -> TractResult<TVec<OutletId>> {
60        let rank = model.outlet_fact(inputs[0])?.rank() as isize;
61        let mut wire: TVec<OutletId> = inputs.into();
62        let output_rank = rank + self.axes.len() as isize;
63        let axes = self
64            .axes
65            .iter()
66            .map(|&axis| if axis < 0 { axis + output_rank } else { axis } as usize)
67            .sorted();
68        for axis in axes {
69            wire =
70                model.wire_node(format!("{prefix}.axis-{axis}"), AxisOp::Add(axis), &wire)?;
71        }
72        Ok(wire)
73    }
74}