Skip to main content

tract_core/ops/einsum/
einsum_matmul.rs

1use std::fmt::Formatter;
2use std::ops::Deref;
3
4use tract_itertools::{izip, multiunzip};
5use tract_linalg::block_quant::PackedBlockQuantFormat;
6use tract_linalg::pack::PackedFormat;
7
8use super::*;
9use crate::ops::cast::cast;
10use crate::ops::math::add;
11use crate::ops::matmul::ModePicker;
12use crate::ops::matmul::optimized::{
13    AddMatMulGeometry, MapOutputAxisToInput, OptMatMul, ProtoFusedSpec,
14};
15use crate::ops::matmul::pack::{OptMatMulPack, OptSimpleMatMulPack};
16use crate::ops::matmul::quant::{
17    combine_scales, compensate_zero_points, requant, wire_ensure_q8_flavour,
18};
19use crate::ops::nn::{Reduce, Reducer};
20
21pub fn merge_consecutive_same_role_axes(model: &mut TypedModel) -> TractResult<()> {
22    Rewriter::default()
23        .with_rule_for("merge-same-role-axes", merge_same_role_axes_rule)
24        .rewrite(&(), model)
25}
26
27fn merge_same_role_axes_rule(
28    _ctx: &(),
29    model: &TypedModel,
30    node: &TypedNode,
31    node_name: &str,
32    op: &EinSum,
33) -> TractResult<Option<TypedModelPatch>> {
34    // Only handle 2-input EinSums (matmul-like)
35    rule_if!(node.inputs.len() == 2);
36
37    // Compute role signature for each axis: (in_input_0, in_input_1, in_output)
38    type Role = (bool, bool, bool);
39    let axes: Vec<(char, Role)> = op
40        .axes
41        .iter_all_axes()
42        .map(|a| {
43            (a.repr, (!a.inputs[0].is_empty(), !a.inputs[1].is_empty(), !a.outputs[0].is_empty()))
44        })
45        .collect();
46
47    // For each input/output slot, get the axis order
48    let a_order: Vec<char> = op.axes.axes(InOut::In(0)).map(|a| a.repr).collect();
49    let b_order: Vec<char> = op.axes.axes(InOut::In(1)).map(|a| a.repr).collect();
50    let c_order: Vec<char> = op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
51
52    // Find first group of 2+ same-role axes that are consecutive in all inputs.
53    // Scan each input's axis order for runs of same-role axes.
54    let role_map: std::collections::HashMap<char, Role> = axes.iter().cloned().collect();
55    let mut best_group: Option<Vec<char>> = None;
56
57    // Try each input order as the primary scan order
58    let all_orders = [&a_order, &b_order];
59    for (primary_idx, primary_order) in all_orders.iter().enumerate() {
60        let mut i = 0;
61        while i < primary_order.len() {
62            let first = primary_order[i];
63            let first_role = role_map[&first];
64            let mut group = vec![first];
65            let mut j = i + 1;
66            while j < primary_order.len() {
67                let candidate = primary_order[j];
68                if role_map[&candidate] != first_role {
69                    break;
70                }
71                // Check consecutive in the OTHER input too
72                let consecutive_in_others = all_orders
73                    .iter()
74                    .enumerate()
75                    .filter(|(idx, _)| *idx != primary_idx)
76                    .all(|(_, order)| {
77                        let positions: Vec<usize> = group
78                            .iter()
79                            .chain(std::iter::once(&candidate))
80                            .filter_map(|c| order.iter().position(|x| x == c))
81                            .collect();
82                        if positions.len() <= 1 {
83                            return true;
84                        }
85                        let mut sorted = positions.clone();
86                        sorted.sort();
87                        sorted == positions
88                            && sorted.last().unwrap() - sorted.first().unwrap() == sorted.len() - 1
89                    });
90                if !consecutive_in_others {
91                    break;
92                }
93                group.push(candidate);
94                j += 1;
95            }
96            if group.len() >= 2 && best_group.as_ref().map_or(true, |bg| group.len() > bg.len()) {
97                best_group = Some(group);
98            }
99            i = j;
100        }
101    }
102
103    if let Some(group) = best_group {
104        // Found a mergeable group. Emit the patch.
105        let input_facts = model.node_input_facts(node.id)?;
106        let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
107        let output_shape = super::eval::output_shape(&op.axes, &input_shapes)?;
108
109        let drop_set: Vec<char> = group[1..].to_vec();
110
111        let mut patch = TypedModelPatch::default();
112        let mut wires: TVec<OutletId> = patch.taps(model, &node.inputs)?;
113
114        // Reshape each input to merge the group
115        for (slot, order) in [(0, &a_order), (1, &b_order)] {
116            let positions: Vec<usize> =
117                group.iter().filter_map(|c| order.iter().position(|x| x == c)).collect();
118            if positions.len() < 2 {
119                continue;
120            }
121            let start = positions[0];
122            let from_dims: TVec<TDim> =
123                positions.iter().map(|&p| input_shapes[slot][p].clone()).collect();
124            let merged: TDim = from_dims.iter().product();
125            wires[slot] = patch.wire_node(
126                format!("{node_name}.merge_in{slot}"),
127                AxisOp::Reshape(start, from_dims, tvec![merged]),
128                &[wires[slot]],
129            )?[0];
130        }
131
132        // If group axes aren't consecutive in C, reorder the EinSum output
133        let c_positions: Vec<usize> =
134            group.iter().filter_map(|c| c_order.iter().position(|x| x == c)).collect();
135        let c_needs_reorder = c_positions.len() >= 2 && {
136            let mut sorted = c_positions.clone();
137            sorted.sort();
138            sorted.last().unwrap() - sorted.first().unwrap() != sorted.len() - 1
139                || sorted != c_positions
140        };
141        let mut adjusted_c_order = c_order.clone();
142        if c_needs_reorder {
143            // Move group axes together (put second next to first)
144            for k in 1..c_positions.len() {
145                let cur_pos = adjusted_c_order.iter().position(|&c| c == group[k]).unwrap();
146                let target_pos =
147                    adjusted_c_order.iter().position(|&c| c == group[k - 1]).unwrap() + 1;
148                if cur_pos != target_pos {
149                    let removed = adjusted_c_order.remove(cur_pos);
150                    let insert_at = if cur_pos < target_pos { target_pos - 1 } else { target_pos };
151                    adjusted_c_order.insert(insert_at, removed);
152                }
153            }
154        }
155
156        // Rebuild EinSum formula with adjusted output and dropped axes
157        let in0: String = a_order.iter().collect();
158        let in1: String = b_order.iter().collect();
159        let out: String = adjusted_c_order.iter().collect();
160        let expr = format!("{in0},{in1}->{out}");
161        let mut new_axes: AxesMapping = expr.parse()?;
162        for &drop in &drop_set {
163            new_axes = new_axes.remove_axis(drop)?;
164        }
165        let new_op =
166            EinSum { axes: new_axes, operating_dt: op.operating_dt, q_params: op.q_params };
167        let mut result = patch.wire_node(node_name, new_op, &wires)?;
168
169        // Reshape output to split the merged axis back
170        let merged_c_positions: Vec<usize> =
171            group.iter().filter_map(|c| adjusted_c_order.iter().position(|x| x == c)).collect();
172        if merged_c_positions.len() >= 2 {
173            let start = merged_c_positions[0];
174            // Use original output dims for the group axes
175            let original_c_positions: Vec<usize> =
176                group.iter().filter_map(|c| c_order.iter().position(|x| x == c)).collect();
177            let original_dims: TVec<TDim> =
178                original_c_positions.iter().map(|&p| output_shape[p].clone()).collect();
179            let merged: TDim = original_dims.iter().product();
180            result[0] = patch.wire_node(
181                format!("{node_name}.unmerge_out"),
182                AxisOp::Reshape(start, tvec![merged], original_dims),
183                &[result[0]],
184            )?[0];
185        }
186
187        // Restore original output order if we reordered
188        if c_needs_reorder {
189            // After unmerge, axes are in adjusted_c_order (but with group expanded).
190            // Need to permute back to c_order.
191            // Build the unmerged adjusted order
192            let mut unmerged_adj: Vec<char> = Vec::new();
193            for &c in &adjusted_c_order {
194                if c == group[0] {
195                    unmerged_adj.extend(&group);
196                } else if !group.contains(&c) {
197                    unmerged_adj.push(c);
198                }
199            }
200            // Find what moves are needed to get from unmerged_adj to c_order
201            for target_pos in 0..c_order.len() {
202                let cur_pos = unmerged_adj.iter().position(|&c| c == c_order[target_pos]).unwrap();
203                if cur_pos != target_pos {
204                    result[0] = patch.wire_node(
205                        format!("{node_name}.restore_out_{target_pos}"),
206                        AxisOp::Move(cur_pos, target_pos),
207                        &[result[0]],
208                    )?[0];
209                    let removed = unmerged_adj.remove(cur_pos);
210                    unmerged_adj.insert(target_pos, removed);
211                }
212            }
213        }
214
215        patch.shunt_outside(model, node.id.into(), result[0])?;
216        return Ok(Some(patch));
217    }
218
219    // Second pass: look for same-role pairs separated by exactly one k-like axis
220    // in a single input. Insert a MoveAxis to push the separator to the end.
221    let k_role: Role = (true, true, false); // present in both inputs, absent from output
222    let role_of = |c: char| axes.iter().find(|(ch, _)| *ch == c).map(|(_, r)| *r);
223
224    for (slot, order) in [(0usize, &a_order), (1, &b_order)] {
225        // Find three consecutive axes in this input where the outer two share a role
226        // and the middle one is a k-axis
227        for w in order.windows(3) {
228            let (left, mid, right) = (w[0], w[1], w[2]);
229            let left_role = role_of(left);
230            let mid_role = role_of(mid);
231            let right_role = role_of(right);
232            if left_role != right_role || mid_role != Some(k_role) {
233                continue;
234            }
235            // left and right must also be consecutive in other inputs
236            // (output order is handled by the EinSum formula)
237            let other_input_orders: Vec<&Vec<char>> = [(0, &a_order), (1, &b_order)]
238                .iter()
239                .filter(|(s, _)| *s != slot)
240                .map(|(_, o)| *o)
241                .collect();
242            let consecutive_elsewhere = other_input_orders.iter().all(|order| {
243                let lp = order.iter().position(|&c| c == left);
244                let rp = order.iter().position(|&c| c == right);
245                match (lp, rp) {
246                    (Some(l), Some(r)) => r == l + 1,
247                    _ => true, // one or both absent — no constraint
248                }
249            });
250            if !consecutive_elsewhere {
251                continue;
252            }
253
254            // Move the k-axis to the inner (last) position in inputs and
255            // make left,right adjacent in the output too.
256            let mid_pos = order.iter().position(|&c| c == mid).unwrap();
257            let end_pos = order.len() - 1;
258            if mid_pos == end_pos {
259                continue;
260            }
261
262            // Use change_axes to update the EinSum formula for the input move
263            let move_op = AxisOp::Move(mid_pos, end_pos);
264            let Some(AxisChangeConsequence { substitute_op, .. }) =
265                op.change_axes(model, node, InOut::In(slot), &move_op)?
266            else {
267                continue;
268            };
269            let mut current_op = *substitute_op
270                .unwrap()
271                .downcast::<EinSum>()
272                .map_err(|_| anyhow!("expected EinSum"))?;
273
274            // Also make left,right adjacent in the output if needed
275            let new_c: Vec<char> = current_op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
276            let left_c = new_c.iter().position(|&c| c == left);
277            let right_c = new_c.iter().position(|&c| c == right);
278            let need_output_fix = matches!((left_c, right_c), (Some(l), Some(r)) if r != l + 1);
279            if need_output_fix {
280                let r_pos = right_c.unwrap();
281                let l_pos = left_c.unwrap();
282                let target = if r_pos < l_pos { l_pos } else { l_pos + 1 };
283                if let Some(AxisChangeConsequence { substitute_op, .. }) = current_op.change_axes(
284                    model,
285                    node,
286                    InOut::Out(0),
287                    &AxisOp::Move(r_pos, target),
288                )? {
289                    current_op = *substitute_op
290                        .unwrap()
291                        .downcast::<EinSum>()
292                        .map_err(|_| anyhow!("expected EinSum"))?;
293                }
294            }
295
296            let mut patch = TypedModelPatch::default();
297            let mut wires: TVec<OutletId> = patch.taps(model, &node.inputs)?;
298
299            wires[slot] =
300                patch.wire_node(format!("{node_name}.move_k_in{slot}"), move_op, &[wires[slot]])?
301                    [0];
302
303            let final_c: Vec<char> = current_op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
304            let mut result = patch.wire_node(node_name, current_op, &wires)?;
305
306            // Restore original output order
307            if need_output_fix {
308                let r_cur = final_c.iter().position(|&c| c == right).unwrap();
309                let r_orig = c_order.iter().position(|&c| c == right).unwrap();
310                if r_cur != r_orig {
311                    result[0] = patch.wire_node(
312                        format!("{node_name}.restore_out"),
313                        AxisOp::Move(r_cur, r_orig),
314                        &[result[0]],
315                    )?[0];
316                }
317            }
318
319            patch.shunt_outside(model, node.id.into(), result[0])?;
320            return Ok(Some(patch));
321        }
322    }
323
324    Ok(None)
325}
326
327pub fn detect_all(model: &mut TypedModel) -> TractResult<()> {
328    Rewriter::default().with_rule_for("detect-matmul-einsum", detect_rule).rewrite(&(), model)
329}
330
331pub fn flatten_all(model: &mut TypedModel) -> TractResult<()> {
332    Rewriter::default().with_rule_for("flatten-matmul-einsum", flatten_rule).rewrite(&(), model)
333}
334
335#[derive(Clone, Hash, PartialEq, Eq)]
336pub struct EinSumMatMul {
337    pub op: EinSum,
338    pub m_axis: char,
339    pub k_axis: char,
340    pub n_axis: char,
341    pub m: TDim,
342    pub k: TDim,
343    pub n: TDim,
344}
345
346impl EinSumMatMul {
347    pub fn m_axis(&self) -> &Axis {
348        self.op.axes.axis(self.m_axis).unwrap()
349    }
350    pub fn k_axis(&self) -> &Axis {
351        self.op.axes.axis(self.k_axis).unwrap()
352    }
353    pub fn n_axis(&self) -> &Axis {
354        self.op.axes.axis(self.n_axis).unwrap()
355    }
356    pub fn a_m(&self) -> usize {
357        self.m_axis().inputs[0][0]
358    }
359    pub fn a_k(&self) -> usize {
360        self.k_axis().inputs[0][0]
361    }
362    pub fn b_k(&self) -> usize {
363        self.k_axis().inputs[1][0]
364    }
365    pub fn b_n(&self) -> usize {
366        self.n_axis().inputs[1][0]
367    }
368    pub fn c_m(&self) -> Option<usize> {
369        self.m_axis().outputs[0].first().cloned()
370    }
371    pub fn c_n(&self) -> Option<usize> {
372        self.n_axis().outputs[0].first().cloned()
373    }
374
375    fn new(
376        op: EinSum,
377        m_axis: char,
378        k_axis: char,
379        n_axis: char,
380        m: TDim,
381        k: TDim,
382        n: TDim,
383    ) -> Self {
384        Self { op, m_axis, k_axis, n_axis, m, k, n }
385    }
386}
387
388impl Debug for EinSumMatMul {
389    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
390        write!(
391            f,
392            "EinsumMatMul: {} {:?} m: {}={}; k: {}={}; n: {}={}",
393            self.op.axes,
394            self.op.operating_dt,
395            self.m_axis,
396            self.m,
397            self.k_axis,
398            self.k,
399            self.n_axis,
400            self.n
401        )
402    }
403}
404
405impl Deref for EinSumMatMul {
406    type Target = EinSum;
407    fn deref(&self) -> &Self::Target {
408        &self.op
409    }
410}
411
412impl Op for EinSumMatMul {
413    fn name(&self) -> StaticName {
414        "EinSumMatMul".into()
415    }
416
417    op_as_typed_op!();
418}
419
420impl EvalOp for EinSumMatMul {
421    fn is_stateless(&self) -> bool {
422        true
423    }
424    fn eval_with_session(
425        &self,
426        node_id: usize,
427        session: &TurnState,
428        inputs: TVec<TValue>,
429    ) -> TractResult<TVec<TValue>> {
430        self.op.eval_with_session(node_id, session, inputs)
431    }
432}
433
434impl TypedOp for EinSumMatMul {
435    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
436        self.op.output_facts(inputs)
437    }
438
439    fn codegen(
440        &self,
441        model: &TypedModel,
442        node: &TypedNode,
443    ) -> TractResult<Option<TypedModelPatch>> {
444        // deal with parametric quantization extra inputs
445        if node.inputs.len() == 9 {
446            ensure!(self.op.q_params.is_some());
447            return dequant(model, node, self).map(Some);
448        }
449        ensure!(node.inputs.len() == 2);
450        let (a, b) = model.node_input_facts(node.id)?.into_iter().collect_tuple().unwrap();
451        // at this stage a and b must NOT be packed yet. if they are exotic, we can assume it's just compression
452        let must_transpose = if let Some(of) = a.exotic_fact() {
453            ensure!(of.is::<BlockQuantFact>());
454            false
455        } else if let Some(of) = b.exotic_fact() {
456            ensure!(of.is::<BlockQuantFact>());
457            true
458        } else if self.m == self.n {
459            false
460        } else {
461            match (self.m.as_i64(), self.n.as_i64()) {
462                (Some(m), Some(n)) => m < n,
463                (None, Some(n)) => n >= 8,
464                (Some(_), _) => false,
465                _ => (self.n.clone() - &self.m).prove_positive_or_zero(),
466            }
467        };
468        if must_transpose {
469            let mut op = self.clone();
470            op.op.axes.iter_all_axes_mut().for_each(|axis| axis.inputs.swap(0, 1));
471            std::mem::swap(&mut op.m_axis, &mut op.n_axis);
472            std::mem::swap(&mut op.m, &mut op.n);
473            return TypedModelPatch::replace_single_op(
474                model,
475                node,
476                &[node.inputs[1], node.inputs[0]],
477                op,
478            )
479            .map(|p| Some(p.with_context("transposing")));
480        }
481        // opt mat mul assumes we have at least one m or n
482        if self.c_m().is_some() || self.c_n().is_some() {
483            return optimized_mat_mul(model, node, self)
484                .map(|opt| opt.map(|p| p.with_context("optimizing")));
485        }
486        Ok(None)
487    }
488
489    as_op!();
490}
491
492pub(crate) fn detect_rule(
493    _ctx: &(),
494    model: &TypedModel,
495    node: &TypedNode,
496    _name: &str,
497    op: &EinSum,
498) -> TractResult<Option<TypedModelPatch>> {
499    rule_if!(node.inputs.len() == (2 + op.q_params.is_some() as usize * 7));
500    let input_facts = model.node_input_facts(node.id)?;
501    let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
502    let output_shape = super::eval::output_shape(&op.axes, &input_shapes)?;
503    let k_axes: TVec<&Axis> = op
504        .axes
505        .iter_all_axes()
506        // Filter possible candidates (should be one time in each inputs but not in output)
507        .filter(|a| a.inputs[0].len() == 1 && a.inputs[1].len() == 1 && a.outputs[0].is_empty())
508        .collect();
509
510    let non_trivial_k_axis = k_axes
511        .iter()
512        .filter(|a| {
513            !input_shapes[0][a.inputs[0][0]].is_one() || !input_shapes[1][a.inputs[1][0]].is_one()
514        })
515        .copied()
516        .collect::<TVec<_>>();
517
518    let k_axis = if non_trivial_k_axis.len() > 1 {
519        return regroup_k_axes(op, model, node, non_trivial_k_axis);
520    } else {
521        non_trivial_k_axis.first().or_else(|| k_axes.first()).copied()
522    };
523    let Some(k_axis) = k_axis else { return inject_k_axis(op, model, node).map(Some) };
524
525    let mut possible_m_axes: Vec<_> = op
526        .axes
527        .iter_all_axes()
528        .filter(|a| {
529            a.inputs[0].len() == 1
530                && (a.inputs[1].is_empty() || input_shapes[1][a.inputs[1][0]].is_one())
531                && (a.outputs[0].len() == 1
532                    || (input_shapes[0][a.inputs[0][0]].is_one() && a.inputs[1].is_empty()))
533        })
534        .collect();
535
536    // Prioritize obvious m-axes
537    if possible_m_axes.iter().any(|a| !a.outputs[0].is_empty()) {
538        possible_m_axes.retain(|a| !a.outputs[0].is_empty());
539    }
540
541    let m_axis = possible_m_axes
542        .into_iter()
543        .max_by_key(|a| input_shapes[0][a.inputs[0][0]].as_i64().unwrap_or(i64::MAX));
544
545    let Some(m_axis) = m_axis else {
546        return inject_m_or_n_axis(op, model, node, false).map(Some);
547    };
548
549    let n_axis = op
550        .axes
551        .iter_all_axes()
552        .filter(|a| {
553            (a.inputs[0].is_empty() || input_shapes[0][a.inputs[0][0]].is_one())
554                && a.inputs[1].len() == 1
555                && a.outputs[0].len() == 1
556                && *a != m_axis
557        })
558        .max_by_key(|a| input_shapes[1][a.inputs[1][0]].as_i64().unwrap_or(i64::MAX));
559    let Some(n_axis) = n_axis else {
560        return inject_m_or_n_axis(op, model, node, true).map(Some);
561    };
562    for axis in op.axes.iter_all_axes() {
563        let one = TDim::one();
564        let in_left =
565            axis.inputs[0].first().map(|pos| &input_shapes[0][*pos]).unwrap_or(&one) != &one;
566        let in_right =
567            axis.inputs[1].first().map(|pos| &input_shapes[1][*pos]).unwrap_or(&one) != &one;
568        let in_out = axis.outputs[0].first().map(|pos| &output_shape[*pos]).unwrap_or(&one) != &one;
569        if (in_left ^ in_right) && !in_out {
570            return Ok(None);
571            // return Ok(AxesOrPatch::NotAMatMul(
572            //     "non trivial single-side disappearing axis",
573            //     vec![axis],
574            // ));
575        }
576    }
577    let m = input_shapes[0][m_axis.inputs[0][0]].clone();
578    let k = input_shapes[0][k_axis.inputs[0][0]].clone();
579    let n = input_shapes[1][n_axis.inputs[1][0]].clone();
580    TypedModelPatch::replace_single_op(
581        model,
582        node,
583        &node.inputs,
584        EinSumMatMul::new(op.clone(), m_axis.repr, k_axis.repr, n_axis.repr, m, k, n),
585    )
586    .map(Some)
587}
588
589pub(super) fn inject_k_axis(
590    op: &EinSum,
591    model: &TypedModel,
592    node: &TypedNode,
593) -> TractResult<TypedModelPatch> {
594    let mut new_axes = op.axes.clone();
595    let name = &node.name;
596    let mut patch = TypedModelPatch::new("inject k axis");
597    let mut wire = patch.taps(model, &node.inputs)?;
598    let repr = new_axes.available_label();
599    new_axes = new_axes.with_extra_axis(repr, InOut::In(0), 0)?.with_extra_axis_occurency(
600        repr,
601        InOut::In(1),
602        0,
603    )?;
604    wire[0] = patch.wire_node(format!("{name}.add_k.0"), AxisOp::Add(0), &[wire[0]])?[0];
605    wire[1] = patch.wire_node(format!("{name}.add_k.1"), AxisOp::Add(0), &[wire[1]])?[0];
606    wire = patch.wire_node(&node.name, EinSum { axes: new_axes, ..op.clone() }, &wire)?;
607    patch.shunt_outside(model, node.id.into(), wire[0])?;
608    Ok(patch)
609}
610
611pub(super) fn regroup_k_axes(
612    op: &EinSum,
613    model: &TypedModel,
614    node: &TypedNode,
615    mut k_axes: TVec<&Axis>,
616) -> TractResult<Option<TypedModelPatch>> {
617    let input_facts = model.node_input_facts(node.id)?;
618    let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
619    let contig_in_a = k_axes
620        .iter()
621        .map(|axis| axis.inputs[0][0])
622        .sorted()
623        .tuple_windows()
624        .all(|(a, b)| a + 1 == b);
625    if contig_in_a {
626        k_axes.sort_by_key(|ax| ax.inputs[0][0]);
627    } else {
628        k_axes.sort_by_key(|ax| ax.inputs[1][0]);
629    }
630    let k_dims: TVec<_> =
631        k_axes.iter().map(|ax| input_shapes[0][ax.inputs[0][0]].clone()).collect();
632    let k: TDim = k_dims.iter().product();
633    let mut patch = TypedModelPatch::default();
634    let mut wires = patch.taps(model, &node.inputs)?;
635    let mut exprs: Vec<String> =
636        (0..2).map(|slot| op.axes.axes(InOut::In(slot)).map(|ax| ax.repr).join("")).collect();
637    for slot in 0..2 {
638        if k_axes.iter().map(|ax| ax.inputs[slot][0]).tuple_windows().any(|(a, b)| a + 1 != b) {
639            let after = op
640                .axes
641                .axes(InOut::In(slot))
642                .filter(|ax| !k_axes.contains(ax))
643                .chain(k_axes.iter().copied())
644                .map(|ax| ax.repr)
645                .join("");
646            let transpose =
647                AxesMapping::from_strs(&[&exprs[slot]], &[&after])?.translate_to_axis_ops()?;
648            for (ix, op) in transpose.into_iter().enumerate() {
649                wires[slot] = patch.wire_node(
650                    format!("{}.transpose_input_{}.{}", &node.name, slot, ix),
651                    op,
652                    &[wires[slot]],
653                )?[0];
654            }
655            exprs[slot] = after;
656        }
657        let pos = exprs[slot].chars().position(|c| k_axes[0].repr == c).unwrap();
658        wires[slot] = patch.wire_node(
659            format!("{}.fold_k_in_input_{}", &node.name, slot),
660            AxisOp::Reshape(pos, k_dims.clone(), tvec!(k.clone())),
661            &[wires[slot]],
662        )?[0];
663        exprs[slot] =
664            exprs[slot].chars().filter(|c| !k_axes.iter().any(|k| k.repr == *c)).collect();
665        exprs[slot].insert(pos, k_axes[0].repr);
666    }
667    let old = op.axes.to_string();
668    let (iexpr, oexpr) = old.split_once("->").unwrap();
669    let mut expr: String = exprs.iter().join(",");
670    if node.inputs.len() > 2 {
671        expr = expr + "," + &iexpr.split(",").skip(2).join(",");
672    }
673    expr = expr + "->" + oexpr;
674    let wire = patch.wire_node(
675        &node.name,
676        EinSum { axes: expr.parse().unwrap(), ..op.clone() },
677        &wires,
678    )?[0];
679    patch.shunt_outside(model, node.id.into(), wire)?;
680    Ok(Some(patch))
681}
682
683pub(super) fn inject_m_or_n_axis(
684    op: &EinSum,
685    model: &TypedModel,
686    node: &TypedNode,
687    is_n: bool,
688) -> TractResult<TypedModelPatch> {
689    let input_to_fix = is_n as usize;
690    let label = if is_n { "n" } else { "m" };
691    let name = &node.name;
692    let mut patch = TypedModelPatch::new("Injecting m or n axis");
693    let mut wire = patch.taps(model, &node.inputs)?;
694    let repr = op.axes.available_label();
695    let new_axes = op
696        .axes
697        .clone()
698        .with_extra_axis(repr, InOut::In(input_to_fix), 0)?
699        .with_extra_axis_occurency(repr, InOut::Out(0), 0)?;
700    wire[input_to_fix] =
701        patch.wire_node(format!("{name}.add_{label}"), AxisOp::Add(0), &[wire[input_to_fix]])?[0];
702    wire = patch.wire_node(name, EinSum { axes: new_axes, ..op.clone() }, &wire)?;
703    wire = patch.wire_node(&node.name, AxisOp::Rm(0), &wire)?;
704    patch.shunt_outside(model, node.id.into(), wire[0])?;
705    Ok(patch)
706}
707
708fn wire_axes_fix(
709    patch: &mut TypedModelPatch,
710    name: &str,
711    var: &str,
712    mapping: &AxesMapping,
713    mut outlet: TVec<OutletId>,
714) -> TractResult<TVec<OutletId>> {
715    for (ix, axis_op) in mapping.translate_to_axis_ops()?.into_iter().enumerate() {
716        outlet = patch.wire_node(format!("{name}.fix_{var}.{ix})"), axis_op, &outlet)?;
717    }
718    Ok(outlet)
719}
720
721fn dequant(
722    model: &TypedModel,
723    node: &TypedNode,
724    op: &EinSumMatMul,
725) -> TractResult<TypedModelPatch> {
726    let name = &node.name;
727    let mut patch = TypedModelPatch::new("Dequantizing einsum");
728
729    let k_axis = op.k_axis();
730
731    let mut taps = patch.taps(model, &node.inputs)?;
732    for ab in [0, 1] {
733        let scale_input = 4 + ab * 2;
734        if !patch.outlet_fact(taps[scale_input])?.shape.volume().is_one() {
735            let q_axis_in_output = op.axes.axis((InOut::In(scale_input), 0))?.outputs[0][0];
736            let output_rank = node.outputs[0].fact.rank();
737            for i in 1..(output_rank - q_axis_in_output) {
738                taps[scale_input] = patch.wire_node(
739                    format!("{name}.scale_input{ab}_axis_fix_{i}"),
740                    AxisOp::Add(i),
741                    &[taps[scale_input]],
742                )?[0];
743            }
744        }
745    }
746
747    let [mut a, mut b, bias, mut a0, a_scale, mut b0, b_scale, c0, c_scale] = *taps else {
748        bail!("Expect exactly 9 inputs")
749    };
750
751    wire_ensure_q8_flavour(&mut patch, &node.name, &mut a, "a", &mut a0, i8::datum_type())?;
752    wire_ensure_q8_flavour(&mut patch, &node.name, &mut b, "b", &mut b0, i8::datum_type())?;
753
754    let mut output = patch.wire_node(
755        &node.name,
756        EinSum {
757            q_params: None,
758            axes: op.axes.extract_sub_mapping(&[0, 1], &[0])?,
759            operating_dt: op.operating_dt,
760        },
761        &[a, b],
762    )?;
763
764    let a_i32 = patch.wire_node(format!("{name}.a_as_i32"), cast(i32::datum_type()), &[a])?[0];
765    let b_i32 = patch.wire_node(format!("{name}.b_as_i32"), cast(i32::datum_type()), &[b])?[0];
766    let sum_a = patch.wire_node(
767        format!("{name}.sum_a"),
768        Reduce::new(tvec!(k_axis.inputs[0][0]), Reducer::Sum),
769        &[a_i32],
770    )?;
771    let sum_b = patch.wire_node(
772        format!("{name}.sum_b"),
773        Reduce::new(tvec!(k_axis.inputs[1][0]), Reducer::Sum),
774        &[b_i32],
775    )?;
776
777    let sum_a =
778        wire_axes_fix(&mut patch, name, "sum_a", &op.axes.extract_sub_mapping(&[0], &[0])?, sum_a)?;
779    let sum_b =
780        wire_axes_fix(&mut patch, name, "sum_b", &op.axes.extract_sub_mapping(&[1], &[0])?, sum_b)?;
781    let bias = tvec!(bias);
782    let bias =
783        wire_axes_fix(&mut patch, name, "bias", &op.axes.extract_sub_mapping(&[2], &[0])?, bias)?;
784
785    let abc_scale = combine_scales(&mut patch, name, a_scale, b_scale, c_scale)?;
786
787    output = patch.wire_node(format!("{name}.add_bias"), add(), &[output[0], bias[0]])?;
788
789    let k = model.outlet_fact(node.inputs[0])?.shape[k_axis.inputs[0][0]].clone();
790    let output = compensate_zero_points(&mut patch, name, output[0], k, a0, b0, sum_a[0], sum_b[0])
791        .context("Zero point compensation")?;
792    let output = requant(&mut patch, name, output, op.q_params.unwrap(), abc_scale, c0)?;
793    patch.shunt_outside(model, node.id.into(), output)?;
794    Ok(patch)
795}
796
797fn flatten_rule(
798    _ctx: &(),
799    model: &TypedModel,
800    node: &TypedNode,
801    _name: &str,
802    op: &EinSumMatMul,
803) -> TractResult<Option<TypedModelPatch>> {
804    TypedModelPatch::replace_single_op(model, node, &node.inputs, op.op.clone()).map(Some)
805}
806
807fn optimized_mat_mul(
808    model: &TypedModel,
809    node: &TypedNode,
810    op: &EinSumMatMul,
811) -> TractResult<Option<TypedModelPatch>> {
812    let (mode_picker, left_pack, impls) = kernel_selection::strategize(model, node, op)?;
813    let input_facts = model.node_input_facts(node.id)?;
814    let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
815    let prefix = &node.name;
816
817    let mut patch = TypedModelPatch::new("Einsum to OptMatMul");
818    let taps = patch.taps(model, &node.inputs)?;
819    let name = &node.name;
820
821    let pack_a: Box<dyn TypedOp> = if input_facts[0].konst.is_some() {
822        if let Some(pf) = left_pack.downcast_ref::<PackedFormat>() {
823            Box::new(OptMatMulPack {
824                packers: vec![pf.clone()],
825                mode_picker: ModePicker::Single,
826                k_axis: op.a_k(),
827                mn_axis: op.a_m(),
828            })
829        } else if let Some(packed_format) =
830            left_pack.downcast_ref::<PackedBlockQuantFormat>().cloned()
831        {
832            Box::new(OptSimpleMatMulPack {
833                packed_format,
834                k: input_shapes[0][op.a_k()].to_usize().unwrap(),
835                m: input_shapes[0][op.a_m()].to_usize().unwrap(),
836            })
837        } else {
838            bail!("Unexpected static input format {left_pack:?}");
839        }
840    } else {
841        Box::new(OptMatMulPack {
842            packers: impls
843                .iter()
844                .map(|(mmm, p, pe)| {
845                    pe.as_ref()
846                        .map(|pe| &pe.from)
847                        .unwrap_or(&mmm.packings()[*p].0)
848                        .downcast_ref::<PackedFormat>()
849                        .unwrap()
850                        .clone()
851                })
852                .collect(),
853            mode_picker: mode_picker.clone(),
854            k_axis: op.a_k(),
855            mn_axis: op.a_m(),
856        })
857    };
858    let pa = patch.wire_node(format!("{prefix}.pack_a"), pack_a, &[taps[0]])?[0];
859
860    let pb = patch.wire_node(
861        format!("{prefix}.pack_b"),
862        OptMatMulPack {
863            k_axis: op.b_k(),
864            mn_axis: op.b_n(),
865            packers: impls
866                .iter()
867                .map(|(mmm, p, _)| {
868                    mmm.packings()[*p].1.downcast_ref::<PackedFormat>().unwrap().clone()
869                })
870                .collect(),
871            mode_picker: mode_picker.clone(),
872        },
873        &[taps[1]],
874    )?[0];
875
876    let mut c_to_a_axis_mapping = tvec!();
877    let mut c_to_b_axis_mapping = tvec!();
878    for axis in op
879        .op
880        .axes
881        .iter_all_axes()
882        .filter(|&axis| ![op.m_axis, op.k_axis, op.n_axis].contains(&axis.repr))
883    {
884        if let (&[c], &[a]) = (&*axis.outputs[0], &*axis.inputs[0])
885            && input_shapes[0][a] != 1.to_dim()
886        {
887            let a = a - (a > op.a_m()) as usize - (a > op.a_k()) as usize;
888            c_to_a_axis_mapping.push((c, a));
889        }
890        if let (&[c], &[b]) = (&*axis.outputs[0], &*axis.inputs[1])
891            && input_shapes[1][b] != 1.to_dim()
892        {
893            let b = b - (b > op.b_n()) as usize - (b > op.b_k()) as usize;
894            c_to_b_axis_mapping.push((c, b));
895        }
896    }
897
898    let c_fact = op.output_facts(&input_facts)?.remove(0);
899    let geo = AddMatMulGeometry {
900        k: op.k.clone(),
901        c_to_a_axis_mapping: MapOutputAxisToInput(c_to_a_axis_mapping),
902        c_to_b_axis_mapping: MapOutputAxisToInput(c_to_b_axis_mapping),
903    };
904    let (mmms, packings, extractor): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(impls);
905    let outputs = mmms.iter().map(|mmm| unsafe { mmm.c_view(op.c_m(), op.c_n()) }).collect();
906    let trivial_packing = mmms.len() == 1
907        && packings[0] == 0
908        && extractor[0].is_none()
909        && input_facts[0].exotic_fact.is_none();
910    let opt = OptMatMul::new(
911        mmms,
912        mode_picker,
913        c_fact,
914        op.c_m(),
915        op.c_n(),
916        vec![
917            ProtoFusedSpec::AddMatMul {
918                geo,
919                a: 0,
920                b: 1,
921                packings: izip!(packings, extractor).collect_vec(),
922            },
923            ProtoFusedSpec::Store(outputs),
924        ],
925        trivial_packing,
926    )
927    .context("Creating OptMatMul")?;
928    let output = patch.wire_node(name, opt, &[pa, pb])?[0];
929    patch.shunt_outside(model, node.id.into(), output)?;
930    Ok(Some(patch))
931}