tract_tensorflow/ops/array/
pack.rs

1use tract_hir::internal::*;
2use tract_hir::ops::array::TypedConcat;
3use tract_hir::ops::binary::wire_cast;
4
5use crate::model::ParsingContext;
6use crate::tfpb::tensorflow::NodeDef;
7
8pub fn pack(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
9    let n = pb.input.len();
10    let axis = pb.get_attr_int("axis")?;
11
12    Ok(expand(Pack::new(n, axis)))
13}
14
15#[derive(Debug, Clone, new, Hash)]
16pub struct Pack {
17    n: usize, // The number of inputs
18    axis: usize,
19}
20
21
22
23impl Expansion for Pack {
24    fn name(&self) -> StaticName {
25        "Pack".into()
26    }
27
28    fn rules<'r, 'p: 'r, 's: 'r>(
29        &'s self,
30        s: &mut Solver<'r>,
31        inputs: &'p [TensorProxy],
32        outputs: &'p [TensorProxy],
33    ) -> InferenceResult {
34        let axis = self.axis;
35        check_input_arity(inputs, self.n)?;
36        check_output_arity(outputs, 1)?;
37        s.equals(&outputs[0].rank, inputs[0].rank.bex() + 1)?;
38        s.equals_all((0..self.n).map(|i| inputs[i].rank.bex()).collect())?;
39        s.given_all((0..self.n).map(move |i| &inputs[i].datum_type), move |s, dts| {
40            if let Some(dt) = DatumType::super_type_for(dts) {
41                s.equals(&outputs[0].datum_type, dt)?;
42            }
43            Ok(())
44        })?;
45        s.given(&inputs[0].rank, move |s, r| {
46            for d in 0..r as usize {
47                s.equals_all((0..self.n).map(|i| inputs[i].shape[d].bex()).collect())?;
48            }
49            Ok(())
50        })?;
51        s.given(&inputs[0].rank, move |s, r| {
52            for d in 0..axis {
53                s.equals(&outputs[0].shape[d], &inputs[0].shape[d])?;
54            }
55            if r > 0 {
56                for d in axis..r as usize {
57                    s.equals(&outputs[0].shape[d + 1], &inputs[0].shape[d])?
58                }
59            }
60            Ok(())
61        })?;
62        s.equals(&outputs[0].shape[axis], self.n.to_dim())
63    }
64
65    fn wire(
66        &self,
67        prefix: &str,
68        model: &mut TypedModel,
69        inputs: &[OutletId],
70    ) -> TractResult<TVec<OutletId>> {
71        let dt = inputs
72            .iter()
73            .map(|&i| Ok(model.outlet_fact(i)?.datum_type))
74            .collect::<TractResult<TVec<DatumType>>>()?;
75        let dt = DatumType::super_type_for(dt.iter()).context("No supertype")?;
76        let wires = wire_cast(prefix, model, inputs, dt)?;
77        let inputs: TVec<OutletId> = wires
78            .iter()
79            .enumerate()
80            .map(|(ix, &o)| {
81                Ok(model.wire_node(
82                    format!("{prefix}.add_dims-{ix}"),
83                    AxisOp::Add(self.axis),
84                    &[o],
85                )?[0])
86            })
87            .collect::<TractResult<TVec<OutletId>>>()?;
88        model.wire_node(prefix, TypedConcat::new(self.axis), &inputs)
89    }
90}