Skip to main content

tract_core/ops/einsum/
mod.rs

1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::ops::array::MultiBroadcastTo;
6use crate::tract_data::itertools::Itertools;
7
8mod eval;
9
10pub mod einsum_matmul;
11pub mod kernel_selection;
12pub mod prefix_matmul;
13
14#[cfg(test)]
15mod proptest;
16
17use num_traits::One;
18use tract_linalg::block_quant::{BlockQuantFact, PackedBlockQuantFact};
19use tract_linalg::mmm::PackedExoticFact;
20
21pub fn block_quant_aware_input_shape(fact: &TypedFact) -> TractResult<Cow<'_, [TDim]>> {
22    if fact.is_plain() {
23        return Ok(Cow::Borrowed(&*fact.shape));
24    }
25    let Some(exotic_fact) = fact.exotic_fact() else {
26        bail!("Datum fact is exotic, but no exotic fact was found.")
27    };
28    if let Some(_bqf) = exotic_fact.downcast_ref::<BlockQuantFact>() {
29        Ok(Cow::Borrowed(&*fact.shape))
30    } else if let Some(pof) = exotic_fact.downcast_ref::<PackedBlockQuantFact>() {
31        Ok(Cow::Owned(
32            fact.shape.iter().cloned().chain(pof.shape.iter().map(|i| i.to_dim())).collect_vec(),
33        ))
34    } else if let Some(pof) = exotic_fact.downcast_ref::<PackedExoticFact>() {
35        Ok(Cow::Owned(
36            fact.shape.iter().cloned().chain([pof.mn.clone(), pof.k.to_dim()]).collect_vec(),
37        ))
38    } else {
39        bail!("Unsupported exotic fact {exotic_fact:?}")
40    }
41}
42
43#[derive(Clone, Hash, PartialEq, Eq)]
44pub struct EinSum {
45    pub axes: AxesMapping,
46    pub operating_dt: DatumType,
47    // if present, assume we're a binary op.
48    // 9 inputs are: A,B,bias, A0,Ascale, B0,BScale, C0,Cscale
49    pub q_params: Option<DatumType>,
50}
51
52impl EinSum {
53    pub fn new(axes: AxesMapping, operating_dt: DatumType) -> EinSum {
54        EinSum { axes, operating_dt, q_params: None }
55    }
56
57    pub fn newq(axes: AxesMapping, operating_dt: DatumType, output_type: DatumType) -> EinSum {
58        EinSum { axes, operating_dt, q_params: Some(output_type) }
59    }
60
61    pub fn actual_input_shapes_from_facts<'m>(
62        &self,
63        inputs: &'m [impl Borrow<TypedFact>],
64    ) -> TractResult<TVec<Cow<'m, [TDim]>>> {
65        ensure!(inputs.len() == self.axes.input_count());
66        let shapes: TVec<Cow<[TDim]>> = inputs
67            .iter()
68            .map(|t| block_quant_aware_input_shape(t.borrow()))
69            .collect::<TractResult<_>>()?;
70        ensure!(
71            shapes.iter().enumerate().all(|(ix, fact)| fact.len() == self.axes.rank(InOut::In(ix)))
72        );
73        Ok(shapes)
74    }
75
76    #[allow(unused_variables)]
77    pub(crate) fn propagate_axis(
78        &self,
79        model: &TypedModel,
80        node: &TypedNode,
81        io: InOut,
82        axis: usize,
83    ) -> TractResult<Option<TypedModelPatch>> {
84        let mut new_axis = self.axes.axis((io, axis))?.clone();
85        let repr = new_axis.repr;
86        let mut patch = TypedModelPatch::new(format!("Propagate axis {}", new_axis.repr));
87        let mut taps = tvec!();
88        for (ix, input) in node.inputs.iter().enumerate() {
89            let mut tap = patch.tap_model(model, *input)?;
90            rule_if!(new_axis.inputs[ix].len() <= 1); // FIXME maybe
91            if new_axis.inputs[ix].is_empty() {
92                let insert_at = self.axes.rank(InOut::In(ix));
93                tap = patch.wire_node(
94                    format!("{}.prop_axis.{}.input_{}", &node.name, new_axis.repr, ix),
95                    AxisOp::Add(insert_at),
96                    &[tap],
97                )?[0];
98                new_axis.inputs[ix].push(insert_at);
99            }
100            taps.push(tap);
101        }
102        let must_rm_axis: Option<usize> = if new_axis.outputs[0].len() == 0 {
103            let insert_at = self.axes.rank(InOut::Out(0));
104            new_axis.outputs[0].push(insert_at);
105            Some(insert_at)
106        } else {
107            None
108        };
109        let new_expr = self
110            .axes
111            .iter_all_axes()
112            .map(|it| if it.repr == new_axis.repr { new_axis.clone() } else { it.clone() })
113            .collect_vec();
114        let axes = AxesMapping::new(node.inputs.len(), 1, new_expr)?;
115        let mut wire = patch.wire_node(&node.name, Self { axes, ..self.clone() }, &taps)?;
116        if let Some(position) = must_rm_axis {
117            wire = patch.wire_node(
118                format!("{}.prop_axis.{}.output", &node.name, repr),
119                AxisOp::Rm(position),
120                &wire,
121            )?;
122        }
123        patch.shunt_outside(model, node.id.into(), wire[0])?;
124        Ok(Some(patch))
125    }
126
127    pub fn acceptable_accumulators(&self) -> TVec<DatumType> {
128        if self.operating_dt.is_integer() {
129            tvec!(i32::datum_type())
130        } else if self.operating_dt == f16::datum_type() {
131            tvec!(f16::datum_type(), f32::datum_type())
132        } else {
133            tvec!(self.operating_dt)
134        }
135    }
136}
137
138impl Debug for EinSum {
139    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140        write!(f, "EinSum {} ({:?})", self.axes, self.operating_dt)
141    }
142}
143
144impl Op for EinSum {
145    fn name(&self) -> StaticName {
146        "EinSum".into()
147    }
148
149    fn info(&self) -> TractResult<Vec<String>> {
150        let mut info = vec![format!("{} ({:?})", self.axes, self.operating_dt)];
151        if let Some(qp) = self.q_params {
152            info.push(format!("Quantized output: {qp:?}"));
153        }
154        Ok(info)
155    }
156
157    op_as_typed_op!();
158}
159
160impl EvalOp for EinSum {
161    fn is_stateless(&self) -> bool {
162        true
163    }
164
165    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
166        if inputs.iter().all(|i| i.datum_type().is_number() && i.is_plain()) {
167            let mut adhoc_model = TypedModel::default();
168            let mut wires = tvec!();
169            for (ix, input) in inputs.iter().enumerate() {
170                let fact = TypedFact::shape_and_dt_of(input);
171                let wire = adhoc_model.add_source(format!("input.{ix}"), fact)?;
172                wires.push(wire);
173            }
174            let output = adhoc_model.wire_node("einsum", self.clone(), &wires)?;
175            adhoc_model.select_output_outlets(&output)?;
176            let opti = adhoc_model.into_optimized()?;
177            if opti.nodes.iter().all(|node| !node.op_is::<Self>()) {
178                return opti.into_runnable()?.run(inputs);
179            }
180        }
181
182        let output = if let Some(qp) = self.q_params {
183            eval::eval_q(&self.axes, qp, inputs)
184        } else {
185            dispatch_numbers!(eval::eval_t(self.operating_dt)(&self.axes, inputs))
186        }?;
187        Ok(tvec!(output.into_tvalue()))
188    }
189}
190
191impl TypedOp for EinSum {
192    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
193        let shapes = self.actual_input_shapes_from_facts(inputs)?;
194        for i in 0..inputs.len() {
195            ensure!(shapes[i].len() == self.axes.rank(InOut::In(i)));
196        }
197        for axis in self.axes.iter_all_axes() {
198            assert!(
199                shapes
200                    .iter()
201                    .enumerate()
202                    .flat_map(|(slot, shape)| axis.inputs[slot].iter().map(|a| &shape[*a]))
203                    .try_fold(TDim::one(), |a, b| TDim::broadcast(a, b.clone()))
204                    .is_ok()
205            );
206        }
207        if let Some(qp) = self.q_params {
208            ensure!(inputs.len() == 9);
209            Ok(tvec!(qp.fact(eval::output_shape(&self.axes, &shapes[0..2])?)))
210        } else {
211            Ok(tvec!(TypedFact::dt_shape(
212                self.operating_dt,
213                eval::output_shape(&self.axes, &shapes)?
214            )))
215        }
216    }
217
218    fn input_roi(
219        &self,
220        model: &TypedModel,
221        node: &TypedNode,
222    ) -> TractResult<Option<TVec<Option<TDim>>>> {
223        // First try bubble_roi: works for inputs that cover all ROI coord
224        // axes mentioned in the output ROI.  For inputs that DON'T cover
225        // every coord axis (= contracted/projected-out axes from this
226        // input's perspective), try the closed-form chunked-band recogniser
227        // which yields a constant band on the input's kept axis after
228        // existentially quantifying the projected axes.
229        let output_fact = model.outlet_fact(OutletId::new(node.id, 0))?;
230        let Some(roi) = &output_fact.region_of_interest else { return Ok(None) };
231        let input_facts: TVec<&TypedFact> =
232            node.inputs.iter().map(|i| model.outlet_fact(*i)).collect::<TractResult<_>>()?;
233        let output_facts = tvec![output_fact];
234        let inputs_ref: Vec<&TypedFact> = input_facts.iter().copied().collect();
235        let outputs_ref: Vec<&TypedFact> = output_facts.iter().copied().collect();
236        let mapping = self.axes_mapping(&inputs_ref, &outputs_ref)?;
237        let roi_coord_axes: Vec<(usize, Symbol)> = roi
238            .symbols()
239            .into_iter()
240            .filter_map(|s| crate::ops::logic::sym_to_coord_axis(&s).map(|k| (k, s)))
241            .collect();
242
243        let project_for_input = |input_ix: usize| -> Option<TDim> {
244            // Classify each output ROI coord axis: projected (no input axis)
245            // or preserved (maps to input).
246            let mut projected: Vec<Symbol> = vec![];
247            let mut preserved: Vec<(Symbol, usize)> = vec![];
248            for (out_pos, sym) in &roi_coord_axes {
249                let logical = mapping
250                    .iter_all_axes()
251                    .find(|a| a.outputs.first().is_some_and(|o| o.contains(out_pos)))?;
252                match logical.inputs[input_ix].first() {
253                    None => projected.push(sym.clone()),
254                    Some(&in_pos) => {
255                        if input_facts[input_ix].shape[in_pos] != output_fact.shape[*out_pos] {
256                            return None;
257                        }
258                        preserved.push((sym.clone(), in_pos));
259                    }
260                }
261            }
262            if projected.is_empty() {
263                // All axes preserved — fall through to standard remap.
264                let mut sub_map: HashMap<Symbol, TDim> = HashMap::new();
265                for (sym, in_pos) in &preserved {
266                    if crate::ops::logic::sym_to_coord_axis(sym) != Some(*in_pos) {
267                        let scope = sym.scope()?;
268                        sub_map.insert(sym.clone(), TDim::Sym(scope.coord_sym(*in_pos)));
269                    }
270                }
271                return if sub_map.is_empty() {
272                    Some(roi.clone())
273                } else {
274                    roi.substitute_all(&sub_map).ok()
275                };
276            }
277            // Try the chunked-band recogniser: one projected axis × one
278            // preserved axis at a time.
279            for p_sym in &projected {
280                for (k_sym, k_in_pos) in &preserved {
281                    if let Some(band) = crate::optim::propagate_roi::recognise_chunked_band_project(
282                        roi, p_sym, k_sym,
283                    ) {
284                        // Result mentions k_sym (output frame).  Remap to
285                        // input axis position.
286                        if crate::ops::logic::sym_to_coord_axis(k_sym) != Some(*k_in_pos) {
287                            let scope = k_sym.scope()?;
288                            let mut m: HashMap<Symbol, TDim> = HashMap::new();
289                            m.insert(k_sym.clone(), TDim::Sym(scope.coord_sym(*k_in_pos)));
290                            return band.substitute_all(&m).ok();
291                        }
292                        return Some(band);
293                    }
294                }
295            }
296            None
297        };
298        let result: TVec<Option<TDim>> =
299            (0..node.inputs.len()).map(|ix| project_for_input(ix)).collect();
300        Ok(Some(result))
301    }
302
303    fn axes_mapping(
304        &self,
305        inputs: &[&TypedFact],
306        _outputs: &[&TypedFact],
307    ) -> TractResult<AxesMapping> {
308        let mut axes = self.axes.clone();
309        for (slot, i) in inputs.iter().enumerate() {
310            if i.is_exotic()
311                && (i.exotic_fact().is_some_and(|of| {
312                    of.is::<PackedExoticFact>() || of.is::<PackedBlockQuantFact>()
313                }))
314            {
315                axes = axes
316                    .remove_axis_occurency(InOut::In(slot), i.rank())?
317                    .remove_axis_occurency(InOut::In(slot), i.rank())?;
318            }
319        }
320        Ok(axes)
321    }
322
323    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
324        let shapes = self.actual_input_shapes_from_facts(inputs)?;
325        let oshape = eval::output_shape(&self.axes, &shapes)?;
326        let ks = self
327            .axes
328            .iter_all_axes()
329            .filter(|axis| axis.outputs[0].len() == 0)
330            .map(|axis| {
331                axis.inputs
332                    .iter()
333                    .enumerate()
334                    .flat_map(|(ix, axes)| {
335                        axes.iter()
336                            .map(|axis| shapes[ix][*axis].clone())
337                            .collect::<TVec<_>>()
338                            .into_iter()
339                    })
340                    .find(|d| !d.is_one())
341                    .unwrap_or_else(|| 1.to_dim())
342            })
343            .product::<TDim>();
344        Ok(tvec!((Cost::FMA(self.operating_dt), oshape.iter().product::<TDim>() * ks)))
345    }
346
347    fn slice(
348        &self,
349        patch: &mut TypedModelPatch,
350        model: &TypedModel,
351        node: &TypedNode,
352        prefix: &str,
353        inputs: &[OutletId],
354        output_axis: usize,
355        _start: &TDim,
356        _end: &TDim,
357    ) -> TractResult<Option<TVec<OutletId>>> {
358        let facts = model.node_input_facts(node.id)?;
359        let axis = self.axes.axis((InOut::Out(0), output_axis))?;
360        if facts
361            .iter()
362            .enumerate()
363            .any(|(slot, fact)| axis.inputs[slot].len() > 0 && fact.is_exotic())
364        {
365            Ok(None)
366        } else {
367            patch.wire_node(prefix, self.clone(), inputs).map(Some)
368        }
369    }
370
371    #[allow(unused_variables)]
372    fn change_axes(
373        &self,
374        model: &TypedModel,
375        node: &TypedNode,
376        io: InOut,
377        change: &AxisOp,
378    ) -> TractResult<Option<AxisChangeConsequence>> {
379        let (mut inputs, mut outputs) = self.axes.to_strs();
380        let interface: &mut String = match io {
381            InOut::In(i) => &mut inputs[i],
382            InOut::Out(o) => &mut outputs[o],
383        };
384        let mut axes: Vec<char> = interface.chars().collect();
385        match change {
386            AxisOp::Rm(rm) => {
387                axes.remove(*rm);
388            }
389            AxisOp::Add(add) => axes.insert(*add, self.axes.available_label()),
390            AxisOp::Move(from, to) => {
391                let c = axes.remove(*from);
392                axes.insert(*to, c);
393            }
394            _ => {
395                return Ok(None);
396            }
397        };
398        *interface = axes.into_iter().collect();
399        let axes = AxesMapping::from_strs(&inputs, &outputs)?;
400        Ok(Some(AxisChangeConsequence {
401            substitute_op: Some(Box::new(EinSum { axes, ..self.clone() })),
402            wire_changes: tvec!((io, change.clone())),
403        }))
404    }
405
406    fn declutter_with_session(
407        &self,
408        session: &mut crate::optim::OptimizerSession,
409        model: &TypedModel,
410        node: &TypedNode,
411    ) -> TractResult<Option<TypedModelPatch>> {
412        if let Some(patch) = declutter_reshape_folding_input_axis(self, session, model, node)? {
413            return Ok(Some(patch));
414        }
415        if let Some(patch) = declutter_broadcast(self, session, model, node)? {
416            return Ok(Some(patch));
417        }
418        if let Some(patch) = unit_k_to_broadcast_mul(self, model, node)? {
419            return Ok(Some(patch));
420        }
421        Ok(None)
422    }
423
424    fn codegen(
425        &self,
426        model: &TypedModel,
427        node: &TypedNode,
428    ) -> TractResult<Option<TypedModelPatch>> {
429        rule_if!(
430            (self.q_params.is_none() && node.inputs.len() == 2)
431                || (self.q_params.is_some() && node.inputs.len() == 9)
432        );
433        // Some EinSums are introduced during codegen itself (e.g. ConvTranspose lowering
434        // emits an EinSum + DeconvSum pair). Those don't get a chance to go through declutter
435        // before being lowered, so we re-check the unit-K → broadcast-Mul rule here as a
436        // fast path. For EinSums that already existed at declutter time, this is a no-op
437        // (the declutter pass would already have rewritten them).
438        if let Some(patch) = unit_k_to_broadcast_mul(self, model, node)? {
439            return Ok(Some(patch));
440        }
441        einsum_matmul::detect_rule(&(), model, node, &node.name, self)
442    }
443
444    as_op!();
445}
446
447fn declutter_reshape_folding_input_axis(
448    op: &EinSum,
449    _session: &mut crate::optim::OptimizerSession,
450    model: &TypedModel,
451    node: &TypedNode,
452) -> TractResult<Option<TypedModelPatch>> {
453    for (slot, prec) in node.inputs.iter().map(|n| model.node(n.node)).enumerate() {
454        let Some(&AxisOp::Reshape(at, ref from, ref to)) = prec.op_as() else { continue };
455        if to.len() > 1 {
456            continue;
457        }
458        let mut axes = op.axes.clone();
459        let extra_labels = axes.available_labels().take(from.len() - 1).collect_vec();
460        // add a temporary input to axes to hold the extra axes
461        let extra_input = node.inputs.len();
462        axes = axes.with_extra_input(extra_input)?;
463        for label in &extra_labels {
464            axes = axes.with_extra_axis(*label, InOut::In(extra_input), 0)?;
465        }
466        let folded_axis = op.axes.axis((InOut::In(slot), at))?;
467        rule_if!(folded_axis.outputs[0].len() <= 1);
468        let mut patch = TypedModelPatch::default();
469        let mut taps = patch.taps(model, &node.inputs)?;
470        for (input, tap) in taps.iter_mut().enumerate() {
471            if folded_axis.inputs[input].len() == 0 {
472                continue;
473            };
474            rule_if!(folded_axis.inputs[input].len() <= 1);
475            let pos = folded_axis.inputs[input][0];
476            for label in &extra_labels {
477                axes = axes.with_extra_axis_occurency(*label, InOut::In(input), pos)?;
478            }
479            *tap = patch.wire_node(
480                format!("{}.reshape_folded_input_{}", node.name, input),
481                AxisOp::Reshape(pos, to.clone(), from.clone()),
482                &[*tap],
483            )?[0];
484        }
485        if folded_axis.outputs[0].len() == 1 {
486            let pos = folded_axis.outputs[0][0];
487            for label in &extra_labels {
488                axes = axes.with_extra_axis_occurency(*label, InOut::Out(0), pos)?;
489            }
490        }
491        axes = axes.remove_slot(InOut::In(extra_input))?;
492        let mut wire = patch.wire_node(&node.name, EinSum { axes, ..op.clone() }, &taps)?;
493        if folded_axis.outputs[0].len() == 1 {
494            let pos = folded_axis.outputs[0][0];
495            wire = patch.wire_node(
496                format!("{}.reshape_folded_output", node.name),
497                AxisOp::Reshape(pos, from.clone(), to.clone()),
498                &wire,
499            )?;
500        }
501        patch.shunt_outside(model, node.id.into(), wire[0])?;
502        return Ok(Some(patch));
503    }
504    Ok(None)
505}
506
507fn declutter_broadcast(
508    op: &EinSum,
509    _session: &mut crate::optim::OptimizerSession,
510    model: &TypedModel,
511    node: &TypedNode,
512) -> TractResult<Option<TypedModelPatch>> {
513    for (ix, outlet) in node.inputs.iter().enumerate() {
514        let prec = model.node(outlet.node);
515        if prec.op_is::<MultiBroadcastTo>() && prec.outputs[0].successors.len() == 1 {
516            let mut patch = TypedModelPatch::default();
517            let mut wires = patch.taps(model, &node.inputs)?;
518            wires[ix] = patch.tap_model(model, prec.inputs[0])?;
519            let wire = patch.wire_node(&node.name, op.clone(), &wires)?[0];
520            patch.shunt_outside(model, node.id.into(), wire)?;
521            return Ok(Some(patch));
522        }
523    }
524    Ok(None)
525}
526
527/// Rewrite an EinSum whose contraction product is statically 1 as a broadcast Mul.
528///
529/// Triggers when:
530/// - All "k-like" axes (present in both inputs, absent from output) have shape 1 in both inputs, OR
531/// - There are no k-like axes at all (Hadamard products like `mn,mn->mn`, outer products like
532///   `m,n->mn`, or any pure broadcast pattern).
533///
534/// In both cases the einsum has no real contraction work — it's a broadcast multiplication
535/// dressed up as an einsum. Lowering it as a matmul leaves the GEMM kernel running per-tile
536/// setup (clear, panel-load, store) for at most one FMA, so a direct broadcast Mul is much
537/// faster on Native (and a net semantic simplification regardless of perf).
538///
539/// Quantized einsums are left untouched: the existing `dequant` path in `EinSumMatMul::codegen`
540/// produces a non-q einsum that this rule then catches naturally on the next declutter pass.
541fn unit_k_to_broadcast_mul(
542    op: &EinSum,
543    model: &TypedModel,
544    node: &TypedNode,
545) -> TractResult<Option<TypedModelPatch>> {
546    if op.q_params.is_some() || node.inputs.len() != 2 {
547        return Ok(None);
548    }
549    let input_facts = model.node_input_facts(node.id)?;
550    let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
551    let k_axes: TVec<&Axis> = op
552        .axes
553        .iter_all_axes()
554        .filter(|a| a.inputs[0].len() == 1 && a.inputs[1].len() == 1 && a.outputs[0].is_empty())
555        .collect();
556    // Bail if any k-axis is non-trivial — that's a real contraction, leave it to matmul lowering.
557    let any_nontrivial_k = k_axes.iter().any(|a| {
558        !input_shapes[0][a.inputs[0][0]].is_one() || !input_shapes[1][a.inputs[1][0]].is_one()
559    });
560    if any_nontrivial_k {
561        return Ok(None);
562    }
563    // Scope: only fire when this einsum's output is consumed by a DeconvSum (i.e. it was
564    // emitted by the ConvTranspose lowering pipeline in `Deconv::wire_with_deconv_sum`).
565    // That's the original target case (DFN3 / GTCRN depthwise ConvTranspose with 1×N kernel
566    // collapsing to K=1 — see PR #2183). Other K=1 einsums (e.g. degenerate Q@K^T inside
567    // SDPA when head_dim=1, random-shape proptests with K=1) are intentionally left alone:
568    // backend-specific pipelines (Metal SDPA fusion, MetalMul rank-4 broadcast-segment limit,
569    // …) pattern-match on the matmul shape and break when we substitute a Mul.
570    let has_deconv_sum_consumer = node.outputs.first().map_or(false, |o| {
571        o.successors.iter().any(|inlet| model.node(inlet.node).op.name() == "DeconvSum")
572    });
573    if !has_deconv_sum_consumer {
574        return Ok(None);
575    }
576
577    let one = TDim::one();
578    // Reject "non-trivial single-side disappearing" axes — those need a real reduction.
579    for axis in op.axes.iter_all_axes() {
580        let in_left =
581            axis.inputs[0].first().map(|pos| &input_shapes[0][*pos]).unwrap_or(&one) != &one;
582        let in_right =
583            axis.inputs[1].first().map(|pos| &input_shapes[1][*pos]).unwrap_or(&one) != &one;
584        let in_out = !axis.outputs[0].is_empty();
585        if (in_left ^ in_right) && !in_out {
586            return Ok(None);
587        }
588    }
589
590    let c_axes: Vec<char> = op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
591    if c_axes.is_empty() {
592        return Ok(None);
593    }
594
595    let k_reprs: TVec<char> = k_axes.iter().map(|a| a.repr).collect();
596    let mut patch = TypedModelPatch::new("EinSum unit-K → broadcast Mul");
597    let mut wires: TVec<OutletId> = patch.taps(model, &node.inputs)?;
598    let name = &node.name;
599
600    for (slot, wire) in wires.iter_mut().enumerate() {
601        // Promote inputs to operating_dt so the result type matches EinSum::output_facts
602        // (e.g. i8 inputs with i32 operating_dt for an integer matmul that has been dequantized).
603        let cur_dt = patch.outlet_fact(*wire)?.datum_type;
604        if cur_dt != op.operating_dt {
605            *wire = patch.wire_node(
606                format!("{name}.cast_in{slot}"),
607                crate::ops::cast::cast(op.operating_dt),
608                &[*wire],
609            )?[0];
610        }
611
612        // Drop k axes (sorted descending so positions stay valid).
613        let mut k_positions: Vec<usize> = k_axes.iter().map(|a| a.inputs[slot][0]).collect();
614        k_positions.sort_by(|a, b| b.cmp(a));
615        for (i, pos) in k_positions.into_iter().enumerate() {
616            *wire =
617                patch.wire_node(format!("{name}.rm_k_in{slot}.{i}"), AxisOp::Rm(pos), &[*wire])?[0];
618        }
619
620        let mut current: Vec<char> = op
621            .axes
622            .axes(InOut::In(slot))
623            .map(|a| a.repr)
624            .filter(|c| !k_reprs.contains(c))
625            .collect();
626
627        // Drop any remaining axes not in output (must be size 1 by precondition above).
628        let mut to_drop: Vec<(usize, char)> = current
629            .iter()
630            .enumerate()
631            .filter(|(_, c)| !c_axes.contains(c))
632            .map(|(i, c)| (i, *c))
633            .collect();
634        to_drop.sort_by(|a, b| b.0.cmp(&a.0));
635        for (pos, c) in to_drop {
636            *wire = patch.wire_node(
637                format!("{name}.rm_extra_in{slot}_{c}"),
638                AxisOp::Rm(pos),
639                &[*wire],
640            )?[0];
641            current.remove(pos);
642        }
643
644        // Insert unit axes for output axes missing from this input.
645        for (target_pos, &t) in c_axes.iter().enumerate() {
646            if !current.contains(&t) {
647                *wire = patch.wire_node(
648                    format!("{name}.add_in{slot}_{t}"),
649                    AxisOp::Add(target_pos),
650                    &[*wire],
651                )?[0];
652                current.insert(target_pos, t);
653            }
654        }
655
656        // Permute to match output axis order.
657        for (target_pos, &t) in c_axes.iter().enumerate() {
658            let cur_pos = current.iter().position(|&c| c == t).unwrap();
659            if cur_pos != target_pos {
660                *wire = patch.wire_node(
661                    format!("{name}.move_in{slot}_{t}"),
662                    AxisOp::Move(cur_pos, target_pos),
663                    &[*wire],
664                )?[0];
665                let removed = current.remove(cur_pos);
666                current.insert(target_pos, removed);
667            }
668        }
669    }
670
671    let result = patch.wire_node(name, crate::ops::math::mul(), &wires)?;
672    patch.shunt_outside(model, node.id.into(), result[0])?;
673    Ok(Some(patch))
674}