tract_hir/ops/array/
add_dims.rs1use crate::infer::*;
2use crate::internal::*;
3use tract_itertools::Itertools;
4
5#[derive(Debug, Clone, new, Hash)]
6pub struct AddDims {
7 pub axes: Vec<isize>,
8}
9
10
11
12impl AddDims {
13 pub fn output_shape<D: DimLike>(&self, input: &[D]) -> TVec<D> {
14 let rank = input.len() as isize;
15 let mut shape: TVec<D> = input.iter().cloned().collect();
16 let output_rank = rank + self.axes.len() as isize;
17 let axes = self
18 .axes
19 .iter()
20 .map(|&axis| if axis < 0 { axis + output_rank } else { axis } as usize)
21 .sorted();
22 for axis in axes {
23 shape.insert(axis, D::one())
24 }
25 shape
26 }
27}
28
29impl Expansion for AddDims {
30 fn name(&self) -> StaticName {
31 "AddDims".into()
32 }
33
34 fn info(&self) -> TractResult<Vec<String>> {
35 Ok(vec![format!("Axes: {:?}", self.axes)])
36 }
37
38
39 fn rules<'r, 'p: 'r, 's: 'r>(
40 &'s self,
41 s: &mut Solver<'r>,
42 inputs: &'p [TensorProxy],
43 outputs: &'p [TensorProxy],
44 ) -> InferenceResult {
45 check_output_arity(outputs, 1)?;
46 s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
47 s.equals(&outputs[0].rank, (&inputs[0].rank).bex() + self.axes.len() as i64)?;
48 s.given(&inputs[0].shape, move |s, shape| {
49 let output_shape = self.output_shape(&shape);
50 s.equals(&outputs[0].shape, output_shape)
51 })
52 }
53
54 fn wire(
55 &self,
56 prefix: &str,
57 model: &mut TypedModel,
58 inputs: &[OutletId],
59 ) -> TractResult<TVec<OutletId>> {
60 let rank = model.outlet_fact(inputs[0])?.rank() as isize;
61 let mut wire: TVec<OutletId> = inputs.into();
62 let output_rank = rank + self.axes.len() as isize;
63 let axes = self
64 .axes
65 .iter()
66 .map(|&axis| if axis < 0 { axis + output_rank } else { axis } as usize)
67 .sorted();
68 for axis in axes {
69 wire =
70 model.wire_node(format!("{prefix}.axis-{axis}"), AxisOp::Add(axis), &wire)?;
71 }
72 Ok(wire)
73 }
74}