1use tract_data::itertools::izip;
2
3use crate::broadcast::multi_broadcast;
4use crate::internal::*;
5use crate::ops::binary::TypedBinOp;
6
7#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
8pub struct MultiBroadcastTo {
9 pub shape: ShapeFact,
10}
11
12impl Op for MultiBroadcastTo {
13 fn name(&self) -> StaticName {
14 "MultiBroadcastTo".into()
15 }
16
17 op_as_typed_op!();
18}
19
20impl EvalOp for MultiBroadcastTo {
21 fn is_stateless(&self) -> bool {
22 true
23 }
24
25 fn eval_with_session(
26 &self,
27 _node_id: usize,
28 session: &TurnState,
29 inputs: TVec<TValue>,
30 ) -> TractResult<TVec<TValue>> {
31 let shape = self.shape.eval_to_usize(&session.resolved_symbols)?;
32 Ok(tvec!(inputs[0].broadcast_to_shape(&shape)?.into_tvalue()))
33 }
34}
35
36impl TypedOp for MultiBroadcastTo {
37 fn axes_mapping(
38 &self,
39 inputs: &[&TypedFact],
40 outputs: &[&TypedFact],
41 ) -> TractResult<AxesMapping> {
42 let in_rank = inputs[0].rank();
48 let out_rank = outputs[0].rank();
49 let leading = out_rank.saturating_sub(in_rank);
50 let mut axes = tvec!();
51 let mut alphabet = 'a'..;
52 for o in 0..leading {
53 axes.push(
54 Axis::new(alphabet.next().unwrap(), inputs.len(), outputs.len()).output(0, o),
55 );
56 }
57 for i in 0..in_rank.min(out_rank) {
58 axes.push(
59 Axis::new(alphabet.next().unwrap(), inputs.len(), outputs.len())
60 .input(0, i)
61 .output(0, leading + i),
62 );
63 }
64 AxesMapping::new(inputs.len(), outputs.len(), axes)
65 }
66
67 fn change_axes(
68 &self,
69 model: &TypedModel,
70 node: &TypedNode,
71 _io: InOut,
72 change: &AxisOp,
73 ) -> TractResult<Option<AxisChangeConsequence>> {
74 let input_shape = &model.outlet_fact(node.inputs[0])?.shape;
81 let canonical = change.canonical();
82 let touched: TVec<usize> = match canonical.as_ref() {
83 AxisOp::Add(ix) | AxisOp::Rm(ix) => tvec![*ix],
84 AxisOp::Move(from, to) => {
85 rule_if!(input_shape.rank() == self.shape.rank());
86 tvec![*from, *to]
87 }
88 _ => return Ok(None),
89 };
90 for &ix in &touched {
91 if ix < self.shape.rank()
92 && ix < input_shape.rank()
93 && input_shape[ix] != self.shape[ix]
94 {
95 return Ok(None);
96 }
97 }
98
99 let mut shape = self.shape.clone();
100 if change.change_shape(&mut shape, false).is_ok() {
101 return Ok(Some(AxisChangeConsequence::new(
102 model,
103 node,
104 Some(Box::new(MultiBroadcastTo { shape })),
105 change,
106 )));
107 }
108 Ok(None)
109 }
110
111 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
112 ensure!(inputs.len() == 1);
113 let mut fact = inputs[0].datum_type.fact(self.shape.clone());
114 fact.uniform.clone_from(&inputs[0].uniform);
115 fact.uniform_tdim = inputs[0].uniform_tdim.clone();
116 Ok(tvec!(fact))
117 }
118
119 fn input_roi(
120 &self,
121 model: &TypedModel,
122 node: &TypedNode,
123 ) -> TractResult<Option<TVec<Option<TDim>>>> {
124 crate::optim::propagate_roi::bubble_roi(model, node)
125 }
126
127 fn set_symbols(
128 &self,
129 _source: &TypedModel,
130 node: &TypedNode,
131 target: &mut TypedModel,
132 mapping: &HashMap<OutletId, OutletId>,
133 subs: &HashMap<Symbol, TDim>,
134 ) -> TractResult<TVec<OutletId>> {
135 let input = mapping[&node.inputs[0]];
136 let shape: TVec<_> =
137 self.shape.iter().map(|d| d.substitute_all(subs)).collect::<TractResult<_>>()?;
138 let op = Self { shape: shape.into() };
139 target.wire_node(&node.name, op, &[input])
140 }
141
142 fn declutter(
143 &self,
144 model: &TypedModel,
145 node: &TypedNode,
146 ) -> TractResult<Option<TypedModelPatch>> {
147 let input_fact = model.outlet_fact(node.inputs[0])?;
148 if input_fact.shape == self.shape {
149 return TypedModelPatch::shunt_one_op(model, node);
150 }
151 for succ in &*node.outputs[0].successors {
157 let succ = model.node(succ.node);
158 let Some(op) = succ.op_as::<AxisOp>() else { continue };
159 let mut shape = self.shape.clone();
160 if izip!(0.., &*input_fact.shape, &*self.shape)
161 .filter(|(_, l, r)| l != r)
162 .all(|(axis, _, _)| op.transform_axis(axis).is_some())
163 && op.change_shape(&mut shape, false).is_ok()
164 {
165 let mut patch = TypedModelPatch::default();
166 let mut wire = patch.tap_model(model, node.inputs[0])?;
167 wire = patch.wire_node(&succ.name, op.clone(), &[wire])?[0];
168 wire = patch.wire_node(&node.name, MultiBroadcastTo { shape }, &[wire])?[0];
169 patch.shunt_outside(model, succ.id.into(), wire)?;
170 return Ok(Some(patch));
171 }
172 }
173 if let [succ] = &*node.outputs[0].successors {
174 let succ = model.node(succ.node);
175 if succ.op_is::<TypedBinOp>() {
176 let our_slot = node.outputs[0].successors[0].slot;
177 let other_slot = 1 - our_slot;
178 let other_operand = succ.inputs[other_slot];
179 let other_fact = model.outlet_fact(other_operand)?;
180 let output_fact = model.outlet_fact(succ.id.into())?;
181 if input_fact.rank() == other_fact.rank()
182 && multi_broadcast(&[&input_fact.shape, &other_fact.shape])
183 .is_ok_and(|s| &*s == &*output_fact.shape)
184 {
185 let mut operands = tvec!(node.inputs[0], other_operand);
186 if our_slot == 1 {
187 operands.swap(0, 1);
188 }
189 return TypedModelPatch::rewire(
190 &model,
191 &operands,
192 &[succ.id.into()],
193 &|p, inputs| p.wire_node(&succ.name, succ.op.clone(), &inputs),
194 )
195 .map(Some);
196 }
197 }
198 }
199 Ok(None)
200 }
201
202 as_op!();
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use crate::ops::change_axes::AxisOp;
209 use crate::ops::logic::And;
210
211 #[test]
214 fn broadcast_move_single_successor_swaps() -> TractResult<()> {
215 let mut model = TypedModel::default();
216 let t = model.symbols.sym("T");
217 let pad = model.add_source("pad", bool::fact(&[t.to_dim()]))?;
218 let unsq = model.wire_node("unsq", AxisOp::Add(0), &[pad])?[0];
219 let bcast = model.wire_node(
220 "bcast",
221 MultiBroadcastTo { shape: ShapeFact::from_dims([t.to_dim(), t.to_dim()]) },
222 &[unsq],
223 )?[0];
224 let mv = model.wire_node("move", AxisOp::Move(0, 1), &[bcast])?[0];
225 model.select_output_outlets(&[mv])?;
226
227 let model = model.into_decluttered()?;
228
229 let move_count = model
230 .nodes()
231 .iter()
232 .filter(|n| matches!(n.op_as::<AxisOp>(), Some(AxisOp::Move(0, 1))))
233 .count();
234 assert_eq!(move_count, 0, "Move should have been pushed through Broadcast and absorbed");
235 Ok(())
236 }
237
238 #[test]
243 fn broadcast_move_fanout_pushes_through_one_branch() -> TractResult<()> {
244 let mut model = TypedModel::default();
245 let t = model.symbols.sym("T");
246 let pad = model.add_source("pad", bool::fact(&[t.to_dim()]))?;
247 let unsq = model.wire_node("unsq", AxisOp::Add(0), &[pad])?[0];
248 let bcast = model.wire_node(
249 "bcast",
250 MultiBroadcastTo { shape: ShapeFact::from_dims([t.to_dim(), t.to_dim()]) },
251 &[unsq],
252 )?[0];
253 let mv = model.wire_node("move", AxisOp::Move(0, 1), &[bcast])?[0];
254 let and = model.wire_node("and", TypedBinOp(Box::new(And), None), &[bcast, mv])?[0];
255 model.select_output_outlets(&[and])?;
256
257 let model = model.into_decluttered()?;
258
259 let bcast_count = model.nodes().iter().filter(|n| n.op_is::<MultiBroadcastTo>()).count();
264 assert_eq!(
265 bcast_count, 0,
266 "Both broadcasts should be subsumed into AND's implicit broadcasting"
267 );
268
269 let and_node =
270 model.nodes().iter().find(|n| n.op_is::<TypedBinOp>()).expect("AND should survive");
271 assert_eq!(and_node.inputs.len(), 2);
272 let and_input_shapes: Vec<_> = and_node
273 .inputs
274 .iter()
275 .map(|i| model.outlet_fact(*i).unwrap().shape.to_tvec())
276 .collect();
277 let expected_a = tvec![1.to_dim(), t.to_dim()];
278 let expected_b = tvec![t.to_dim(), 1.to_dim()];
279 let (a, b) = (&and_input_shapes[0], &and_input_shapes[1]);
280 assert!(
281 (a == &expected_a && b == &expected_b) || (a == &expected_b && b == &expected_a),
282 "AND should receive [1, T] and [T, 1]; got {a:?} and {b:?}"
283 );
284 Ok(())
285 }
286}