tract_pulse/ops/array/
concat.rs1use crate::internal::*;
2use crate::model::NonPulsingWrappingOp;
3use tract_core::ops::array::TypedConcat;
4use tract_pulse_opl::concat::PulsedSameAxisConcat;
5use tract_pulse_opl::ops::Delay;
6use tract_pulse_opl::tract_core::ops::array::MultiBroadcastTo;
7use tract_pulse_opl::tract_core::tract_data::itertools::Itertools;
8
9register_all!(TypedConcat: pulsify);
10
11fn pulsify(
12 op: &TypedConcat,
13 source: &TypedModel,
14 node: &TypedNode,
15 target: &mut PulsedModel,
16 mapping: &HashMap<OutletId, OutletId>,
17 symbol: &Symbol,
18 _pulse: &TDim,
19) -> TractResult<Option<TVec<OutletId>>> {
20 let pulse_facts: TVec<PulsedFact> =
21 node.inputs.iter().map(|i| target.outlet_fact(mapping[i]).unwrap().clone()).collect();
22 let (_stream_input_ix, pulse_fact) =
23 pulse_facts.iter().enumerate().find(|(_ix, pf)| pf.stream.is_some()).unwrap();
24
25 if pulse_fact.stream.as_ref().unwrap().axis == op.axis {
26 pulsify_along_concat_axis(op, source, node, target, mapping, symbol)
27 } else {
28 Ok(None)
29 }
30}
31
32fn pulsify_along_concat_axis(
33 op: &TypedConcat,
34 source: &TypedModel,
35 node: &TypedNode,
36 target: &mut PulsedModel,
37 mapping: &HashMap<OutletId, OutletId>,
38 symbol: &Symbol,
39) -> TractResult<Option<TVec<OutletId>>> {
40 let name = &node.name;
41 let axis = op.axis;
42 let source_facts: TVec<TypedFact> =
43 node.inputs.iter().map(|i| source.outlet_fact(*i).unwrap().clone()).collect();
44 ensure!(source_facts.iter().filter(|fact| fact.shape[axis].symbols().contains(symbol)).count() == 1,
45 "Concat over pulse axis (#{axis}, {symbol:?}) expcts one single streaming input. Got: {source_facts:?}"
46 );
47 let pulsed_inputs: TVec<OutletId> = node.inputs.iter().map(|i| mapping[i]).collect();
48 let pulse_facts: TVec<PulsedFact> = pulsed_inputs
49 .iter()
50 .map(|i| target.outlet_fact(*i).cloned())
51 .collect::<TractResult<_>>()?;
52 let (stream_input_ix, pulse_fact) =
53 pulse_facts.iter().enumerate().find(|(_ix, pf)| pf.stream.is_some()).unwrap();
54 let stream = pulse_fact.stream.as_ref().unwrap();
55
56 let zero = target
57 .add_const(format!("{name}.zero"), Tensor::zero_scalar_dt(source_facts[0].datum_type)?)?;
58 let mut shape = pulse_fact.shape.clone();
59 shape.set(axis, 0.to_dim());
60 let empty = target.wire_node(
61 format!("{name}.pre"),
62 NonPulsingWrappingOp(Box::new(MultiBroadcastTo { shape })),
63 &[zero],
64 )?[0];
65
66 let pre = if stream_input_ix > 0 {
67 target.wire_node(
68 format!("{name}.pre"),
69 NonPulsingWrappingOp(Box::new(TypedConcat::new(axis))),
70 &pulsed_inputs.iter().take(stream_input_ix).cloned().collect_vec(),
71 )?[0]
72 } else {
73 empty
74 };
75 let post = if stream_input_ix + 1 < pulsed_inputs.len() {
76 target.wire_node(
77 format!("{name}.post"),
78 NonPulsingWrappingOp(Box::new(TypedConcat::new(axis))),
79 &pulsed_inputs.iter().skip(stream_input_ix + 1).cloned().collect_vec(),
80 )?[0]
81 } else {
82 empty
83 };
84
85 let mut input = pulsed_inputs[stream_input_ix];
86 let pre_fact = target.outlet_fact(pre)?;
87 let before = pre_fact.shape[op.axis].to_usize()?;
88 if stream.delay < before {
89 input = target.wire_node(
90 format!("{}.Delay", node.name),
91 Delay::new_typed(
92 source.outlet_fact(node.inputs[stream_input_ix])?,
93 stream.axis,
94 before - stream.delay,
95 0,
96 ),
97 &[input],
98 )?[0];
99 }
100 let main_op = PulsedSameAxisConcat {
101 axis: op.axis,
102 input_delay: stream.delay.saturating_sub(before),
103 input_len: stream.dim.clone(),
104 };
105 Ok(Some(target.wire_node(&*node.name, main_op, &[pre, input, post])?))
106}
107
108impl PulsedOp for PulsedSameAxisConcat {
109 fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
110 let &[pre, fact, post] = inputs else { bail!("Expect 3 inputs") };
111 let mut fact: PulsedFact = fact.clone();
112 let stream = fact.stream.as_mut().unwrap();
113 let before = pre.shape[self.axis].to_usize()?;
114 let after = post.shape[self.axis].to_usize()?;
115 stream.dim += (before + after).to_dim();
116 stream.delay -= before.to_usize()?;
117 Ok(tvec!(fact))
118 }
119
120 as_op!();
121 pulsed_op_to_typed_op!();
122}