Skip to main content

tract_core/ops/array/
broadcast.rs

1use tract_data::itertools::izip;
2
3use crate::broadcast::multi_broadcast;
4use crate::internal::*;
5use crate::ops::binary::TypedBinOp;
6
7#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
8pub struct MultiBroadcastTo {
9    pub shape: ShapeFact,
10}
11
12impl Op for MultiBroadcastTo {
13    fn name(&self) -> StaticName {
14        "MultiBroadcastTo".into()
15    }
16
17    op_as_typed_op!();
18}
19
20impl EvalOp for MultiBroadcastTo {
21    fn is_stateless(&self) -> bool {
22        true
23    }
24
25    fn eval_with_session(
26        &self,
27        _node_id: usize,
28        session: &TurnState,
29        inputs: TVec<TValue>,
30    ) -> TractResult<TVec<TValue>> {
31        let shape = self.shape.eval_to_usize(&session.resolved_symbols)?;
32        Ok(tvec!(inputs[0].broadcast_to_shape(&shape)?.into_tvalue()))
33    }
34}
35
36impl TypedOp for MultiBroadcastTo {
37    fn axes_mapping(
38        &self,
39        inputs: &[&TypedFact],
40        outputs: &[&TypedFact],
41    ) -> TractResult<AxesMapping> {
42        // ONNX-style broadcasting right-aligns input over output, so when
43        // output_rank > input_rank the leading output axes are pure
44        // broadcast axes with no input correspondence. natural_for_rank's
45        // square shape would skip them and trip the optimizer's axes-mapping
46        // check (caught under paranoid_assertions).
47        let in_rank = inputs[0].rank();
48        let out_rank = outputs[0].rank();
49        let leading = out_rank.saturating_sub(in_rank);
50        let mut axes = tvec!();
51        let mut alphabet = 'a'..;
52        for o in 0..leading {
53            axes.push(
54                Axis::new(alphabet.next().unwrap(), inputs.len(), outputs.len()).output(0, o),
55            );
56        }
57        for i in 0..in_rank.min(out_rank) {
58            axes.push(
59                Axis::new(alphabet.next().unwrap(), inputs.len(), outputs.len())
60                    .input(0, i)
61                    .output(0, leading + i),
62            );
63        }
64        AxesMapping::new(inputs.len(), outputs.len(), axes)
65    }
66
67    fn change_axes(
68        &self,
69        model: &TypedModel,
70        node: &TypedNode,
71        _io: InOut,
72        change: &AxisOp,
73    ) -> TractResult<Option<AxisChangeConsequence>> {
74        // Only propagate axis changes that touch passthrough axes — those
75        // where the input and output shapes agree. Touching a broadcast
76        // axis (input=1, output=N) would make the input and output rank
77        // diverge through the change and break the broadcast relationship,
78        // and propagating Rm of a non-trivial axis into a Source produces
79        // the "Removing non-trivial axis" hard error from change_shape.
80        let input_shape = &model.outlet_fact(node.inputs[0])?.shape;
81        let canonical = change.canonical();
82        let touched: TVec<usize> = match canonical.as_ref() {
83            AxisOp::Add(ix) | AxisOp::Rm(ix) => tvec![*ix],
84            AxisOp::Move(from, to) => {
85                rule_if!(input_shape.rank() == self.shape.rank());
86                tvec![*from, *to]
87            }
88            _ => return Ok(None),
89        };
90        for &ix in &touched {
91            if ix < self.shape.rank()
92                && ix < input_shape.rank()
93                && input_shape[ix] != self.shape[ix]
94            {
95                return Ok(None);
96            }
97        }
98
99        let mut shape = self.shape.clone();
100        if change.change_shape(&mut shape, false).is_ok() {
101            return Ok(Some(AxisChangeConsequence::new(
102                model,
103                node,
104                Some(Box::new(MultiBroadcastTo { shape })),
105                change,
106            )));
107        }
108        Ok(None)
109    }
110
111    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
112        ensure!(inputs.len() == 1);
113        let mut fact = inputs[0].datum_type.fact(self.shape.clone());
114        fact.uniform.clone_from(&inputs[0].uniform);
115        fact.uniform_tdim = inputs[0].uniform_tdim.clone();
116        Ok(tvec!(fact))
117    }
118
119    fn input_roi(
120        &self,
121        model: &TypedModel,
122        node: &TypedNode,
123    ) -> TractResult<Option<TVec<Option<TDim>>>> {
124        crate::optim::propagate_roi::bubble_roi(model, node)
125    }
126
127    fn set_symbols(
128        &self,
129        _source: &TypedModel,
130        node: &TypedNode,
131        target: &mut TypedModel,
132        mapping: &HashMap<OutletId, OutletId>,
133        subs: &HashMap<Symbol, TDim>,
134    ) -> TractResult<TVec<OutletId>> {
135        let input = mapping[&node.inputs[0]];
136        let shape: TVec<_> =
137            self.shape.iter().map(|d| d.substitute_all(subs)).collect::<TractResult<_>>()?;
138        let op = Self { shape: shape.into() };
139        target.wire_node(&node.name, op, &[input])
140    }
141
142    fn declutter(
143        &self,
144        model: &TypedModel,
145        node: &TypedNode,
146    ) -> TractResult<Option<TypedModelPatch>> {
147        let input_fact = model.outlet_fact(node.inputs[0])?;
148        if input_fact.shape == self.shape {
149            return TypedModelPatch::shunt_one_op(model, node);
150        }
151        // Swap with an AxisOp successor: `Broadcast(x, S) → AxisOp` becomes
152        // `AxisOp(x) → Broadcast(σ(S))` whenever the AxisOp transforms every
153        // axis the broadcast actually expanded.  Fires per-successor, so this
154        // works under fan-out (the original broadcast stays in place for
155        // siblings; only the matched AxisOp branch is rerouted).
156        for succ in &*node.outputs[0].successors {
157            let succ = model.node(succ.node);
158            let Some(op) = succ.op_as::<AxisOp>() else { continue };
159            let mut shape = self.shape.clone();
160            if izip!(0.., &*input_fact.shape, &*self.shape)
161                .filter(|(_, l, r)| l != r)
162                .all(|(axis, _, _)| op.transform_axis(axis).is_some())
163                && op.change_shape(&mut shape, false).is_ok()
164            {
165                let mut patch = TypedModelPatch::default();
166                let mut wire = patch.tap_model(model, node.inputs[0])?;
167                wire = patch.wire_node(&succ.name, op.clone(), &[wire])?[0];
168                wire = patch.wire_node(&node.name, MultiBroadcastTo { shape }, &[wire])?[0];
169                patch.shunt_outside(model, succ.id.into(), wire)?;
170                return Ok(Some(patch));
171            }
172        }
173        if let [succ] = &*node.outputs[0].successors {
174            let succ = model.node(succ.node);
175            if succ.op_is::<TypedBinOp>() {
176                let our_slot = node.outputs[0].successors[0].slot;
177                let other_slot = 1 - our_slot;
178                let other_operand = succ.inputs[other_slot];
179                let other_fact = model.outlet_fact(other_operand)?;
180                let output_fact = model.outlet_fact(succ.id.into())?;
181                if input_fact.rank() == other_fact.rank()
182                    && multi_broadcast(&[&input_fact.shape, &other_fact.shape])
183                        .is_ok_and(|s| &*s == &*output_fact.shape)
184                {
185                    let mut operands = tvec!(node.inputs[0], other_operand);
186                    if our_slot == 1 {
187                        operands.swap(0, 1);
188                    }
189                    return TypedModelPatch::rewire(
190                        &model,
191                        &operands,
192                        &[succ.id.into()],
193                        &|p, inputs| p.wire_node(&succ.name, succ.op.clone(), &inputs),
194                    )
195                    .map(Some);
196                }
197            }
198        }
199        Ok(None)
200    }
201
202    as_op!();
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use crate::ops::change_axes::AxisOp;
209    use crate::ops::logic::And;
210
211    /// `Broadcast → Move` with the broadcast feeding a SINGLE successor.
212    /// Pre-existing path: the swap rewrite kicks in.
213    #[test]
214    fn broadcast_move_single_successor_swaps() -> TractResult<()> {
215        let mut model = TypedModel::default();
216        let t = model.symbols.sym("T");
217        let pad = model.add_source("pad", bool::fact(&[t.to_dim()]))?;
218        let unsq = model.wire_node("unsq", AxisOp::Add(0), &[pad])?[0];
219        let bcast = model.wire_node(
220            "bcast",
221            MultiBroadcastTo { shape: ShapeFact::from_dims([t.to_dim(), t.to_dim()]) },
222            &[unsq],
223        )?[0];
224        let mv = model.wire_node("move", AxisOp::Move(0, 1), &[bcast])?[0];
225        model.select_output_outlets(&[mv])?;
226
227        let model = model.into_decluttered()?;
228
229        let move_count = model
230            .nodes()
231            .iter()
232            .filter(|n| matches!(n.op_as::<AxisOp>(), Some(AxisOp::Move(0, 1))))
233            .count();
234        assert_eq!(move_count, 0, "Move should have been pushed through Broadcast and absorbed");
235        Ok(())
236    }
237
238    /// `Broadcast → {Move, And-direct}` — the encoder-style pad-mask outer-AND
239    /// pattern.  Pre-fix: declutter bailed because broadcast had > 1 successor;
240    /// the Move stayed.  Post-fix: the Move-branch gets its own swapped
241    /// chain, the direct-AND branch still consumes the original broadcast.
242    #[test]
243    fn broadcast_move_fanout_pushes_through_one_branch() -> TractResult<()> {
244        let mut model = TypedModel::default();
245        let t = model.symbols.sym("T");
246        let pad = model.add_source("pad", bool::fact(&[t.to_dim()]))?;
247        let unsq = model.wire_node("unsq", AxisOp::Add(0), &[pad])?[0];
248        let bcast = model.wire_node(
249            "bcast",
250            MultiBroadcastTo { shape: ShapeFact::from_dims([t.to_dim(), t.to_dim()]) },
251            &[unsq],
252        )?[0];
253        let mv = model.wire_node("move", AxisOp::Move(0, 1), &[bcast])?[0];
254        let and = model.wire_node("and", TypedBinOp(Box::new(And), None), &[bcast, mv])?[0];
255        model.select_output_outlets(&[and])?;
256
257        let model = model.into_decluttered()?;
258
259        // Expected: fan-out swap-through fires on the Move branch, then the
260        // existing Broadcast→TypedBinOp rule fires on each (now single-
261        // successor) broadcast, eliminating both — the AND ends up
262        // broadcasting [1, T] and [T, 1] implicitly.
263        let bcast_count = model.nodes().iter().filter(|n| n.op_is::<MultiBroadcastTo>()).count();
264        assert_eq!(
265            bcast_count, 0,
266            "Both broadcasts should be subsumed into AND's implicit broadcasting"
267        );
268
269        let and_node =
270            model.nodes().iter().find(|n| n.op_is::<TypedBinOp>()).expect("AND should survive");
271        assert_eq!(and_node.inputs.len(), 2);
272        let and_input_shapes: Vec<_> = and_node
273            .inputs
274            .iter()
275            .map(|i| model.outlet_fact(*i).unwrap().shape.to_tvec())
276            .collect();
277        let expected_a = tvec![1.to_dim(), t.to_dim()];
278        let expected_b = tvec![t.to_dim(), 1.to_dim()];
279        let (a, b) = (&and_input_shapes[0], &and_input_shapes[1]);
280        assert!(
281            (a == &expected_a && b == &expected_b) || (a == &expected_b && b == &expected_a),
282            "AND should receive [1, T] and [T, 1]; got {a:?} and {b:?}"
283        );
284        Ok(())
285    }
286}