tract_tensorflow/ops/array/
pack.rs1use tract_hir::internal::*;
2use tract_hir::ops::array::TypedConcat;
3use tract_hir::ops::binary::wire_cast;
4
5use crate::model::ParsingContext;
6use crate::tfpb::tensorflow::NodeDef;
7
8pub fn pack(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
9 let n = pb.input.len();
10 let axis = pb.get_attr_int("axis")?;
11
12 Ok(expand(Pack::new(n, axis)))
13}
14
15#[derive(Debug, Clone, new, Hash)]
16pub struct Pack {
17 n: usize, axis: usize,
19}
20
21
22
23impl Expansion for Pack {
24 fn name(&self) -> StaticName {
25 "Pack".into()
26 }
27
28 fn rules<'r, 'p: 'r, 's: 'r>(
29 &'s self,
30 s: &mut Solver<'r>,
31 inputs: &'p [TensorProxy],
32 outputs: &'p [TensorProxy],
33 ) -> InferenceResult {
34 let axis = self.axis;
35 check_input_arity(inputs, self.n)?;
36 check_output_arity(outputs, 1)?;
37 s.equals(&outputs[0].rank, inputs[0].rank.bex() + 1)?;
38 s.equals_all((0..self.n).map(|i| inputs[i].rank.bex()).collect())?;
39 s.given_all((0..self.n).map(move |i| &inputs[i].datum_type), move |s, dts| {
40 if let Some(dt) = DatumType::super_type_for(dts) {
41 s.equals(&outputs[0].datum_type, dt)?;
42 }
43 Ok(())
44 })?;
45 s.given(&inputs[0].rank, move |s, r| {
46 for d in 0..r as usize {
47 s.equals_all((0..self.n).map(|i| inputs[i].shape[d].bex()).collect())?;
48 }
49 Ok(())
50 })?;
51 s.given(&inputs[0].rank, move |s, r| {
52 for d in 0..axis {
53 s.equals(&outputs[0].shape[d], &inputs[0].shape[d])?;
54 }
55 if r > 0 {
56 for d in axis..r as usize {
57 s.equals(&outputs[0].shape[d + 1], &inputs[0].shape[d])?
58 }
59 }
60 Ok(())
61 })?;
62 s.equals(&outputs[0].shape[axis], self.n.to_dim())
63 }
64
65 fn wire(
66 &self,
67 prefix: &str,
68 model: &mut TypedModel,
69 inputs: &[OutletId],
70 ) -> TractResult<TVec<OutletId>> {
71 let dt = inputs
72 .iter()
73 .map(|&i| Ok(model.outlet_fact(i)?.datum_type))
74 .collect::<TractResult<TVec<DatumType>>>()?;
75 let dt = DatumType::super_type_for(dt.iter()).context("No supertype")?;
76 let wires = wire_cast(prefix, model, inputs, dt)?;
77 let inputs: TVec<OutletId> = wires
78 .iter()
79 .enumerate()
80 .map(|(ix, &o)| {
81 Ok(model.wire_node(
82 format!("{prefix}.add_dims-{ix}"),
83 AxisOp::Add(self.axis),
84 &[o],
85 )?[0])
86 })
87 .collect::<TractResult<TVec<OutletId>>>()?;
88 model.wire_node(prefix, TypedConcat::new(self.axis), &inputs)
89 }
90}