Skip to main content

tract_core/ops/einsum/
mod.rs

1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::ops::array::MultiBroadcastTo;
6use crate::tract_data::itertools::Itertools;
7
8mod eval;
9
10pub mod einsum_matmul;
11pub mod kernel_selection;
12pub mod prefix_matmul;
13
14#[cfg(test)]
15mod proptest;
16
17use num_traits::One;
18use tract_linalg::block_quant::{BlockQuantFact, PackedBlockQuantFact};
19use tract_linalg::mmm::PackedExoticFact;
20
21pub fn block_quant_aware_input_shape(fact: &TypedFact) -> TractResult<Cow<'_, [TDim]>> {
22    if fact.is_plain() {
23        return Ok(Cow::Borrowed(&*fact.shape));
24    }
25    let Some(exotic_fact) = fact.exotic_fact() else {
26        bail!("Datum fact is exotic, but no exotic fact was found.")
27    };
28    if let Some(_bqf) = exotic_fact.downcast_ref::<BlockQuantFact>() {
29        Ok(Cow::Borrowed(&*fact.shape))
30    } else if let Some(pof) = exotic_fact.downcast_ref::<PackedBlockQuantFact>() {
31        Ok(Cow::Owned(
32            fact.shape.iter().cloned().chain(pof.shape.iter().map(|i| i.to_dim())).collect_vec(),
33        ))
34    } else if let Some(pof) = exotic_fact.downcast_ref::<PackedExoticFact>() {
35        Ok(Cow::Owned(
36            fact.shape.iter().cloned().chain([pof.mn.clone(), pof.k.to_dim()]).collect_vec(),
37        ))
38    } else {
39        bail!("Unsupported exotic fact {exotic_fact:?}")
40    }
41}
42
43#[derive(Clone, Hash, PartialEq, Eq)]
44pub struct EinSum {
45    pub axes: AxesMapping,
46    pub operating_dt: DatumType,
47    // if present, assume we're a binary op.
48    // 9 inputs are: A,B,bias, A0,Ascale, B0,BScale, C0,Cscale
49    pub q_params: Option<DatumType>,
50}
51
52impl EinSum {
53    pub fn new(axes: AxesMapping, operating_dt: DatumType) -> EinSum {
54        EinSum { axes, operating_dt, q_params: None }
55    }
56
57    pub fn newq(axes: AxesMapping, operating_dt: DatumType, output_type: DatumType) -> EinSum {
58        EinSum { axes, operating_dt, q_params: Some(output_type) }
59    }
60
61    pub fn actual_input_shapes_from_facts<'m>(
62        &self,
63        inputs: &'m [impl Borrow<TypedFact>],
64    ) -> TractResult<TVec<Cow<'m, [TDim]>>> {
65        ensure!(inputs.len() == self.axes.input_count());
66        let shapes: TVec<Cow<[TDim]>> = inputs
67            .iter()
68            .map(|t| block_quant_aware_input_shape(t.borrow()))
69            .collect::<TractResult<_>>()?;
70        ensure!(
71            shapes.iter().enumerate().all(|(ix, fact)| fact.len() == self.axes.rank(InOut::In(ix)))
72        );
73        Ok(shapes)
74    }
75
76    #[allow(unused_variables)]
77    pub(crate) fn propagate_axis(
78        &self,
79        model: &TypedModel,
80        node: &TypedNode,
81        io: InOut,
82        axis: usize,
83    ) -> TractResult<Option<TypedModelPatch>> {
84        let mut new_axis = self.axes.axis((io, axis))?.clone();
85        let repr = new_axis.repr;
86        let mut patch = TypedModelPatch::new(format!("Propagate axis {}", new_axis.repr));
87        let mut taps = tvec!();
88        for (ix, input) in node.inputs.iter().enumerate() {
89            let mut tap = patch.tap_model(model, *input)?;
90            rule_if!(new_axis.inputs[ix].len() <= 1); // FIXME maybe
91            if new_axis.inputs[ix].is_empty() {
92                let insert_at = self.axes.rank(InOut::In(ix));
93                tap = patch.wire_node(
94                    format!("{}.prop_axis.{}.input_{}", node.name, new_axis.repr, ix),
95                    AxisOp::Add(insert_at),
96                    &[tap],
97                )?[0];
98                new_axis.inputs[ix].push(insert_at);
99            }
100            taps.push(tap);
101        }
102        let must_rm_axis: Option<usize> = if new_axis.outputs[0].len() == 0 {
103            let insert_at = self.axes.rank(InOut::Out(0));
104            new_axis.outputs[0].push(insert_at);
105            Some(insert_at)
106        } else {
107            None
108        };
109        let new_expr = self
110            .axes
111            .iter_all_axes()
112            .map(|it| if it.repr == new_axis.repr { new_axis.clone() } else { it.clone() })
113            .collect_vec();
114        let axes = AxesMapping::new(node.inputs.len(), 1, new_expr)?;
115        let mut wire = patch.wire_node(&node.name, Self { axes, ..self.clone() }, &taps)?;
116        if let Some(position) = must_rm_axis {
117            wire = patch.wire_node(
118                format!("{}.prop_axis.{}.output", node.name, repr),
119                AxisOp::Rm(position),
120                &wire,
121            )?;
122        }
123        patch.shunt_outside(model, node.id.into(), wire[0])?;
124        Ok(Some(patch))
125    }
126
127    pub fn acceptable_accumulators(&self) -> TVec<DatumType> {
128        if self.operating_dt.is_integer() {
129            tvec!(i32::datum_type())
130        } else if self.operating_dt == f16::datum_type() {
131            tvec!(f16::datum_type(), f32::datum_type())
132        } else {
133            tvec!(self.operating_dt)
134        }
135    }
136}
137
138impl Debug for EinSum {
139    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140        write!(f, "EinSum {} ({:?})", self.axes, self.operating_dt)
141    }
142}
143
144impl Op for EinSum {
145    fn name(&self) -> StaticName {
146        "EinSum".into()
147    }
148
149    fn info(&self) -> TractResult<Vec<String>> {
150        let mut info = vec![format!("{} ({:?})", self.axes, self.operating_dt)];
151        if let Some(qp) = self.q_params {
152            info.push(format!("Quantized output: {qp:?}"));
153        }
154        Ok(info)
155    }
156
157    op_as_typed_op!();
158}
159
160impl EvalOp for EinSum {
161    fn is_stateless(&self) -> bool {
162        true
163    }
164
165    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
166        if inputs.iter().all(|i| i.datum_type().is_number() && i.is_plain()) {
167            let mut adhoc_model = TypedModel::default();
168            let mut wires = tvec!();
169            for (ix, input) in inputs.iter().enumerate() {
170                let fact = TypedFact::shape_and_dt_of(input);
171                let wire = adhoc_model.add_source(format!("input.{ix}"), fact)?;
172                wires.push(wire);
173            }
174            let output = adhoc_model.wire_node("einsum", self.clone(), &wires)?;
175            adhoc_model.select_output_outlets(&output)?;
176            let opti = adhoc_model.into_optimized()?;
177            if opti.nodes.iter().all(|node| !node.op_is::<Self>()) {
178                return opti.into_runnable()?.run(inputs);
179            }
180        }
181
182        let output = if let Some(qp) = self.q_params {
183            eval::eval_q(&self.axes, qp, inputs)
184        } else {
185            dispatch_numbers!(eval::eval_t(self.operating_dt)(&self.axes, inputs))
186        }?;
187        Ok(tvec!(output.into_tvalue()))
188    }
189}
190
191impl TypedOp for EinSum {
192    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
193        let shapes = self.actual_input_shapes_from_facts(inputs)?;
194        for i in 0..inputs.len() {
195            ensure!(shapes[i].len() == self.axes.rank(InOut::In(i)));
196        }
197        for axis in self.axes.iter_all_axes() {
198            assert!(
199                shapes
200                    .iter()
201                    .enumerate()
202                    .flat_map(|(slot, shape)| axis.inputs[slot].iter().map(|a| &shape[*a]))
203                    .try_fold(TDim::one(), |a, b| TDim::broadcast(a, b.clone()))
204                    .is_ok()
205            );
206        }
207        if let Some(qp) = self.q_params {
208            ensure!(inputs.len() == 9);
209            Ok(tvec!(qp.fact(eval::output_shape(&self.axes, &shapes[0..2])?)))
210        } else {
211            Ok(tvec!(TypedFact::dt_shape(
212                self.operating_dt,
213                eval::output_shape(&self.axes, &shapes)?
214            )))
215        }
216    }
217
218    fn input_roi(
219        &self,
220        model: &TypedModel,
221        node: &TypedNode,
222    ) -> TractResult<Option<TVec<Option<TDim>>>> {
223        // First try bubble_roi: works for inputs that cover all ROI coord
224        // axes mentioned in the output ROI.  For inputs that DON'T cover
225        // every coord axis (= contracted/projected-out axes from this
226        // input's perspective), try the closed-form chunked-band recogniser
227        // which yields a constant band on the input's kept axis after
228        // existentially quantifying the projected axes.
229        let output_fact = model.outlet_fact(OutletId::new(node.id, 0))?;
230        let Some(roi) = &output_fact.region_of_interest else { return Ok(None) };
231        let input_facts: TVec<&TypedFact> =
232            node.inputs.iter().map(|i| model.outlet_fact(*i)).collect::<TractResult<_>>()?;
233        let output_facts = tvec![output_fact];
234        let inputs_ref: Vec<&TypedFact> = input_facts.iter().copied().collect();
235        let outputs_ref: Vec<&TypedFact> = output_facts.iter().copied().collect();
236        let mapping = self.axes_mapping(&inputs_ref, &outputs_ref)?;
237        let roi_coord_axes: Vec<(usize, Symbol)> = roi
238            .symbols()
239            .into_iter()
240            .filter_map(|s| crate::ops::logic::sym_to_coord_axis(&s).map(|k| (k, s)))
241            .collect();
242
243        let project_for_input = |input_ix: usize| -> Option<TDim> {
244            // Classify each output ROI coord axis: projected (no input axis)
245            // or preserved (maps to input).
246            let mut projected: Vec<Symbol> = vec![];
247            let mut preserved: Vec<(Symbol, usize)> = vec![];
248            for (out_pos, sym) in &roi_coord_axes {
249                let logical = mapping
250                    .iter_all_axes()
251                    .find(|a| a.outputs.first().is_some_and(|o| o.contains(out_pos)))?;
252                match logical.inputs[input_ix].first() {
253                    None => projected.push(sym.clone()),
254                    Some(&in_pos) => {
255                        if input_facts[input_ix].shape[in_pos] != output_fact.shape[*out_pos] {
256                            return None;
257                        }
258                        preserved.push((sym.clone(), in_pos));
259                    }
260                }
261            }
262            if projected.is_empty() {
263                // All axes preserved — fall through to standard remap.
264                let mut sub_map: HashMap<Symbol, TDim> = HashMap::new();
265                for (sym, in_pos) in &preserved {
266                    if crate::ops::logic::sym_to_coord_axis(sym) != Some(*in_pos) {
267                        let scope = sym.scope()?;
268                        sub_map.insert(sym.clone(), TDim::Sym(scope.coord_sym(*in_pos)));
269                    }
270                }
271                return if sub_map.is_empty() {
272                    Some(roi.clone())
273                } else {
274                    roi.substitute_all(&sub_map).ok()
275                };
276            }
277            // Try the chunked-band recogniser: one projected axis × one
278            // preserved axis at a time.
279            for p_sym in &projected {
280                for (k_sym, k_in_pos) in &preserved {
281                    if let Some(band) = crate::optim::propagate_roi::recognise_chunked_band_project(
282                        roi, p_sym, k_sym,
283                    ) {
284                        // Result mentions k_sym (output frame).  Remap to
285                        // input axis position.
286                        if crate::ops::logic::sym_to_coord_axis(k_sym) != Some(*k_in_pos) {
287                            let scope = k_sym.scope()?;
288                            let mut m: HashMap<Symbol, TDim> = HashMap::new();
289                            m.insert(k_sym.clone(), TDim::Sym(scope.coord_sym(*k_in_pos)));
290                            return band.substitute_all(&m).ok();
291                        }
292                        return Some(band);
293                    }
294                }
295            }
296            None
297        };
298        let result: TVec<Option<TDim>> = (0..node.inputs.len()).map(project_for_input).collect();
299        Ok(Some(result))
300    }
301
302    fn axes_mapping(
303        &self,
304        inputs: &[&TypedFact],
305        _outputs: &[&TypedFact],
306    ) -> TractResult<AxesMapping> {
307        let mut axes = self.axes.clone();
308        for (slot, i) in inputs.iter().enumerate() {
309            if i.is_exotic()
310                && (i.exotic_fact().is_some_and(|of| {
311                    of.is::<PackedExoticFact>() || of.is::<PackedBlockQuantFact>()
312                }))
313            {
314                axes = axes
315                    .remove_axis_occurency(InOut::In(slot), i.rank())?
316                    .remove_axis_occurency(InOut::In(slot), i.rank())?;
317            }
318        }
319        Ok(axes)
320    }
321
322    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
323        let shapes = self.actual_input_shapes_from_facts(inputs)?;
324        let oshape = eval::output_shape(&self.axes, &shapes)?;
325        let ks = self
326            .axes
327            .iter_all_axes()
328            .filter(|axis| axis.outputs[0].len() == 0)
329            .map(|axis| {
330                axis.inputs
331                    .iter()
332                    .enumerate()
333                    .flat_map(|(ix, axes)| {
334                        axes.iter()
335                            .map(|axis| shapes[ix][*axis].clone())
336                            .collect::<TVec<_>>()
337                            .into_iter()
338                    })
339                    .find(|d| !d.is_one())
340                    .unwrap_or_else(|| 1.to_dim())
341            })
342            .product::<TDim>();
343        Ok(tvec!((Cost::FMA(self.operating_dt), oshape.iter().product::<TDim>() * ks)))
344    }
345
346    fn slice(
347        &self,
348        patch: &mut TypedModelPatch,
349        model: &TypedModel,
350        node: &TypedNode,
351        prefix: &str,
352        inputs: &[OutletId],
353        output_axis: usize,
354        _start: &TDim,
355        _end: &TDim,
356    ) -> TractResult<Option<TVec<OutletId>>> {
357        let facts = model.node_input_facts(node.id)?;
358        let axis = self.axes.axis((InOut::Out(0), output_axis))?;
359        if facts
360            .iter()
361            .enumerate()
362            .any(|(slot, fact)| axis.inputs[slot].len() > 0 && fact.is_exotic())
363        {
364            Ok(None)
365        } else {
366            patch.wire_node(prefix, self.clone(), inputs).map(Some)
367        }
368    }
369
370    #[allow(unused_variables)]
371    fn change_axes(
372        &self,
373        model: &TypedModel,
374        node: &TypedNode,
375        io: InOut,
376        change: &AxisOp,
377    ) -> TractResult<Option<AxisChangeConsequence>> {
378        let (mut inputs, mut outputs) = self.axes.to_strs();
379        let interface: &mut String = match io {
380            InOut::In(i) => &mut inputs[i],
381            InOut::Out(o) => &mut outputs[o],
382        };
383        let mut axes: Vec<char> = interface.chars().collect();
384        match change {
385            AxisOp::Rm(rm) => {
386                axes.remove(*rm);
387            }
388            AxisOp::Add(add) => axes.insert(*add, self.axes.available_label()),
389            AxisOp::Move(from, to) => {
390                let c = axes.remove(*from);
391                axes.insert(*to, c);
392            }
393            _ => {
394                return Ok(None);
395            }
396        };
397        *interface = axes.into_iter().collect();
398        let axes = AxesMapping::from_strs(&inputs, &outputs)?;
399        Ok(Some(AxisChangeConsequence {
400            substitute_op: Some(Box::new(EinSum { axes, ..self.clone() })),
401            wire_changes: tvec!((io, change.clone())),
402        }))
403    }
404
405    fn declutter_with_session(
406        &self,
407        session: &mut crate::optim::OptimizerSession,
408        model: &TypedModel,
409        node: &TypedNode,
410    ) -> TractResult<Option<TypedModelPatch>> {
411        if let Some(patch) = declutter_reshape_folding_input_axis(self, session, model, node)? {
412            return Ok(Some(patch));
413        }
414        if let Some(patch) = declutter_broadcast(self, session, model, node)? {
415            return Ok(Some(patch));
416        }
417        if let Some(patch) = unit_k_to_broadcast_mul(self, model, node)? {
418            return Ok(Some(patch));
419        }
420        Ok(None)
421    }
422
423    fn codegen(
424        &self,
425        model: &TypedModel,
426        node: &TypedNode,
427    ) -> TractResult<Option<TypedModelPatch>> {
428        rule_if!(
429            (self.q_params.is_none() && node.inputs.len() == 2)
430                || (self.q_params.is_some() && node.inputs.len() == 9)
431        );
432        // Some EinSums are introduced during codegen itself (e.g. ConvTranspose lowering
433        // emits an EinSum + DeconvSum pair). Those don't get a chance to go through declutter
434        // before being lowered, so we re-check the unit-K → broadcast-Mul rule here as a
435        // fast path. For EinSums that already existed at declutter time, this is a no-op
436        // (the declutter pass would already have rewritten them).
437        if let Some(patch) = unit_k_to_broadcast_mul(self, model, node)? {
438            return Ok(Some(patch));
439        }
440        einsum_matmul::detect_rule(&(), model, node, &node.name, self)
441    }
442
443    as_op!();
444}
445
446fn declutter_reshape_folding_input_axis(
447    op: &EinSum,
448    _session: &mut crate::optim::OptimizerSession,
449    model: &TypedModel,
450    node: &TypedNode,
451) -> TractResult<Option<TypedModelPatch>> {
452    for (slot, prec) in node.inputs.iter().map(|n| model.node(n.node)).enumerate() {
453        let Some(&AxisOp::Reshape(at, ref from, ref to)) = prec.op_as() else { continue };
454        if to.len() > 1 {
455            continue;
456        }
457        let mut axes = op.axes.clone();
458        let extra_labels = axes.available_labels().take(from.len() - 1).collect_vec();
459        // add a temporary input to axes to hold the extra axes
460        let extra_input = node.inputs.len();
461        axes = axes.with_extra_input(extra_input)?;
462        for label in &extra_labels {
463            axes = axes.with_extra_axis(*label, InOut::In(extra_input), 0)?;
464        }
465        let folded_axis = op.axes.axis((InOut::In(slot), at))?;
466        rule_if!(folded_axis.outputs[0].len() <= 1);
467        let mut patch = TypedModelPatch::default();
468        let mut taps = patch.taps(model, &node.inputs)?;
469        for (input, tap) in taps.iter_mut().enumerate() {
470            if folded_axis.inputs[input].len() == 0 {
471                continue;
472            };
473            rule_if!(folded_axis.inputs[input].len() <= 1);
474            let pos = folded_axis.inputs[input][0];
475            for label in &extra_labels {
476                axes = axes.with_extra_axis_occurency(*label, InOut::In(input), pos)?;
477            }
478            *tap = patch.wire_node(
479                format!("{}.reshape_folded_input_{}", node.name, input),
480                AxisOp::Reshape(pos, to.clone(), from.clone()),
481                &[*tap],
482            )?[0];
483        }
484        if folded_axis.outputs[0].len() == 1 {
485            let pos = folded_axis.outputs[0][0];
486            for label in &extra_labels {
487                axes = axes.with_extra_axis_occurency(*label, InOut::Out(0), pos)?;
488            }
489        }
490        axes = axes.remove_slot(InOut::In(extra_input))?;
491        let mut wire = patch.wire_node(&node.name, EinSum { axes, ..op.clone() }, &taps)?;
492        if folded_axis.outputs[0].len() == 1 {
493            let pos = folded_axis.outputs[0][0];
494            wire = patch.wire_node(
495                format!("{}.reshape_folded_output", node.name),
496                AxisOp::Reshape(pos, from.clone(), to.clone()),
497                &wire,
498            )?;
499        }
500        patch.shunt_outside(model, node.id.into(), wire[0])?;
501        return Ok(Some(patch));
502    }
503    Ok(None)
504}
505
506fn declutter_broadcast(
507    op: &EinSum,
508    _session: &mut crate::optim::OptimizerSession,
509    model: &TypedModel,
510    node: &TypedNode,
511) -> TractResult<Option<TypedModelPatch>> {
512    for (ix, outlet) in node.inputs.iter().enumerate() {
513        let prec = model.node(outlet.node);
514        if prec.op_is::<MultiBroadcastTo>() && prec.outputs[0].successors.len() == 1 {
515            let mut patch = TypedModelPatch::default();
516            let mut wires = patch.taps(model, &node.inputs)?;
517            wires[ix] = patch.tap_model(model, prec.inputs[0])?;
518            let wire = patch.wire_node(&node.name, op.clone(), &wires)?[0];
519            patch.shunt_outside(model, node.id.into(), wire)?;
520            return Ok(Some(patch));
521        }
522    }
523    Ok(None)
524}
525
526/// Rewrite an EinSum whose contraction product is statically 1 as a broadcast Mul.
527///
528/// Triggers when:
529/// - All "k-like" axes (present in both inputs, absent from output) have shape 1 in both inputs, OR
530/// - There are no k-like axes at all (Hadamard products like `mn,mn->mn`, outer products like
531///   `m,n->mn`, or any pure broadcast pattern).
532///
533/// In both cases the einsum has no real contraction work — it's a broadcast multiplication
534/// dressed up as an einsum. Lowering it as a matmul leaves the GEMM kernel running per-tile
535/// setup (clear, panel-load, store) for at most one FMA, so a direct broadcast Mul is much
536/// faster on Native (and a net semantic simplification regardless of perf).
537///
538/// Quantized einsums are left untouched: the existing `dequant` path in `EinSumMatMul::codegen`
539/// produces a non-q einsum that this rule then catches naturally on the next declutter pass.
540fn unit_k_to_broadcast_mul(
541    op: &EinSum,
542    model: &TypedModel,
543    node: &TypedNode,
544) -> TractResult<Option<TypedModelPatch>> {
545    if op.q_params.is_some() || node.inputs.len() != 2 {
546        return Ok(None);
547    }
548    let input_facts = model.node_input_facts(node.id)?;
549    let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
550    let k_axes: TVec<&Axis> = op
551        .axes
552        .iter_all_axes()
553        .filter(|a| a.inputs[0].len() == 1 && a.inputs[1].len() == 1 && a.outputs[0].is_empty())
554        .collect();
555    // Bail if any k-axis is non-trivial — that's a real contraction, leave it to matmul lowering.
556    let any_nontrivial_k = k_axes.iter().any(|a| {
557        !input_shapes[0][a.inputs[0][0]].is_one() || !input_shapes[1][a.inputs[1][0]].is_one()
558    });
559    if any_nontrivial_k {
560        return Ok(None);
561    }
562    // Scope: only fire when this einsum's output is consumed by a DeconvSum (i.e. it was
563    // emitted by the ConvTranspose lowering pipeline in `Deconv::wire_with_deconv_sum`).
564    // That's the original target case (DFN3 / GTCRN depthwise ConvTranspose with 1×N kernel
565    // collapsing to K=1 — see PR #2183). Other K=1 einsums (e.g. degenerate Q@K^T inside
566    // SDPA when head_dim=1, random-shape proptests with K=1) are intentionally left alone:
567    // backend-specific pipelines (Metal SDPA fusion, MetalMul rank-4 broadcast-segment limit,
568    // …) pattern-match on the matmul shape and break when we substitute a Mul.
569    let has_deconv_sum_consumer = node.outputs.first().is_some_and(|o| {
570        o.successors.iter().any(|inlet| model.node(inlet.node).op.name() == "DeconvSum")
571    });
572    if !has_deconv_sum_consumer {
573        return Ok(None);
574    }
575
576    let one = TDim::one();
577    // Reject "non-trivial single-side disappearing" axes — those need a real reduction.
578    for axis in op.axes.iter_all_axes() {
579        let in_left =
580            axis.inputs[0].first().map(|pos| &input_shapes[0][*pos]).unwrap_or(&one) != &one;
581        let in_right =
582            axis.inputs[1].first().map(|pos| &input_shapes[1][*pos]).unwrap_or(&one) != &one;
583        let in_out = !axis.outputs[0].is_empty();
584        if (in_left ^ in_right) && !in_out {
585            return Ok(None);
586        }
587    }
588
589    let c_axes: Vec<char> = op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
590    if c_axes.is_empty() {
591        return Ok(None);
592    }
593
594    let k_reprs: TVec<char> = k_axes.iter().map(|a| a.repr).collect();
595    let mut patch = TypedModelPatch::new("EinSum unit-K → broadcast Mul");
596    let mut wires: TVec<OutletId> = patch.taps(model, &node.inputs)?;
597    let name = &node.name;
598
599    for (slot, wire) in wires.iter_mut().enumerate() {
600        // Promote inputs to operating_dt so the result type matches EinSum::output_facts
601        // (e.g. i8 inputs with i32 operating_dt for an integer matmul that has been dequantized).
602        let cur_dt = patch.outlet_fact(*wire)?.datum_type;
603        if cur_dt != op.operating_dt {
604            *wire = patch.wire_node(
605                format!("{name}.cast_in{slot}"),
606                crate::ops::cast::cast(op.operating_dt),
607                &[*wire],
608            )?[0];
609        }
610
611        // Drop k axes (sorted descending so positions stay valid).
612        let mut k_positions: Vec<usize> = k_axes.iter().map(|a| a.inputs[slot][0]).collect();
613        k_positions.sort_by(|a, b| b.cmp(a));
614        for (i, pos) in k_positions.into_iter().enumerate() {
615            *wire =
616                patch.wire_node(format!("{name}.rm_k_in{slot}.{i}"), AxisOp::Rm(pos), &[*wire])?[0];
617        }
618
619        let mut current: Vec<char> = op
620            .axes
621            .axes(InOut::In(slot))
622            .map(|a| a.repr)
623            .filter(|c| !k_reprs.contains(c))
624            .collect();
625
626        // Drop any remaining axes not in output (must be size 1 by precondition above).
627        let mut to_drop: Vec<(usize, char)> = current
628            .iter()
629            .enumerate()
630            .filter(|(_, c)| !c_axes.contains(c))
631            .map(|(i, c)| (i, *c))
632            .collect();
633        to_drop.sort_by_key(|a| std::cmp::Reverse(a.0));
634        for (pos, c) in to_drop {
635            *wire = patch.wire_node(
636                format!("{name}.rm_extra_in{slot}_{c}"),
637                AxisOp::Rm(pos),
638                &[*wire],
639            )?[0];
640            current.remove(pos);
641        }
642
643        // Insert unit axes for output axes missing from this input.
644        for (target_pos, &t) in c_axes.iter().enumerate() {
645            if !current.contains(&t) {
646                *wire = patch.wire_node(
647                    format!("{name}.add_in{slot}_{t}"),
648                    AxisOp::Add(target_pos),
649                    &[*wire],
650                )?[0];
651                current.insert(target_pos, t);
652            }
653        }
654
655        // Permute to match output axis order.
656        for (target_pos, &t) in c_axes.iter().enumerate() {
657            let cur_pos = current.iter().position(|&c| c == t).unwrap();
658            if cur_pos != target_pos {
659                *wire = patch.wire_node(
660                    format!("{name}.move_in{slot}_{t}"),
661                    AxisOp::Move(cur_pos, target_pos),
662                    &[*wire],
663                )?[0];
664                let removed = current.remove(cur_pos);
665                current.insert(target_pos, removed);
666            }
667        }
668    }
669
670    let result = patch.wire_node(name, crate::ops::math::mul(), &wires)?;
671    patch.shunt_outside(model, node.id.into(), result[0])?;
672    Ok(Some(patch))
673}