Skip to main content

tract_core/ops/einsum/
einsum_matmul.rs

1use std::fmt::Formatter;
2use std::ops::Deref;
3
4use tract_itertools::{izip, multiunzip};
5use tract_linalg::block_quant::PackedBlockQuantFormat;
6
7use super::*;
8use crate::ops::cast::cast;
9use crate::ops::math::add;
10use crate::ops::matmul::ModePicker;
11use crate::ops::matmul::optimized::{
12    AddMatMulGeometry, MapOutputAxisToInput, OptMatMul, ProtoFusedSpec,
13};
14use crate::ops::matmul::pack::{OptMatMulPack, OptSimpleMatMulPack};
15use crate::ops::matmul::quant::{
16    combine_scales, compensate_zero_points, requant, wire_ensure_q8_flavour,
17};
18use crate::ops::nn::{Reduce, Reducer};
19
20pub fn merge_consecutive_same_role_axes(model: &mut TypedModel) -> TractResult<()> {
21    Rewriter::default()
22        .with_rule_for("merge-same-role-axes", merge_same_role_axes_rule)
23        .rewrite(&(), model)
24}
25
26fn merge_same_role_axes_rule(
27    _ctx: &(),
28    model: &TypedModel,
29    node: &TypedNode,
30    node_name: &str,
31    op: &EinSum,
32) -> TractResult<Option<TypedModelPatch>> {
33    // Only handle 2-input EinSums (matmul-like)
34    rule_if!(node.inputs.len() == 2);
35
36    // Compute role signature for each axis: (in_input_0, in_input_1, in_output)
37    type Role = (bool, bool, bool);
38    let axes: Vec<(char, Role)> = op
39        .axes
40        .iter_all_axes()
41        .map(|a| {
42            (a.repr, (!a.inputs[0].is_empty(), !a.inputs[1].is_empty(), !a.outputs[0].is_empty()))
43        })
44        .collect();
45
46    // For each input/output slot, get the axis order
47    let a_order: Vec<char> = op.axes.axes(InOut::In(0)).map(|a| a.repr).collect();
48    let b_order: Vec<char> = op.axes.axes(InOut::In(1)).map(|a| a.repr).collect();
49    let c_order: Vec<char> = op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
50
51    // Per-input shapes, used both to reject broadcast merges and to gauge
52    // whether a merge earns its reshape / axis-permute cost.
53    let input_facts = model.node_input_facts(node.id)?;
54    let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
55    // An axis is "non-unit" if its extent is not statically 1 (symbolic extents
56    // count as potentially large). A fold only reduces the number of matmul
57    // invocations when it combines at least two non-unit axes; folding a unit
58    // axis — the streaming axis at pulse=1, or a batch axis of 1 — leaves the
59    // matmul geometry untouched and only adds reshapes (and, in the k-axis
60    // branch, a MoveAxis), so we decline those.
61    let is_non_unit = |c: &char| -> bool {
62        let dim = a_order
63            .iter()
64            .position(|x| x == c)
65            .map(|p| &input_shapes[0][p])
66            .or_else(|| b_order.iter().position(|x| x == c).map(|p| &input_shapes[1][p]));
67        dim.map_or(true, |d| d.as_i64() != Some(1))
68    };
69
70    // Find first group of 2+ same-role axes that are consecutive in all inputs.
71    // Scan each input's axis order for runs of same-role axes.
72    let role_map: std::collections::HashMap<char, Role> = axes.iter().cloned().collect();
73    let mut best_group: Option<Vec<char>> = None;
74
75    // Try each input order as the primary scan order
76    let all_orders = [&a_order, &b_order];
77    for (primary_idx, primary_order) in all_orders.iter().enumerate() {
78        let mut i = 0;
79        while i < primary_order.len() {
80            let first = primary_order[i];
81            let first_role = role_map[&first];
82            let mut group = vec![first];
83            let mut j = i + 1;
84            while j < primary_order.len() {
85                let candidate = primary_order[j];
86                if role_map[&candidate] != first_role {
87                    break;
88                }
89                // Check consecutive in the OTHER input too
90                let consecutive_in_others = all_orders
91                    .iter()
92                    .enumerate()
93                    .filter(|(idx, _)| *idx != primary_idx)
94                    .all(|(_, order)| {
95                        let positions: Vec<usize> = group
96                            .iter()
97                            .chain(std::iter::once(&candidate))
98                            .filter_map(|c| order.iter().position(|x| x == c))
99                            .collect();
100                        if positions.len() <= 1 {
101                            return true;
102                        }
103                        let mut sorted = positions.clone();
104                        sorted.sort();
105                        sorted == positions
106                            && sorted.last().unwrap() - sorted.first().unwrap() == sorted.len() - 1
107                    });
108                if !consecutive_in_others {
109                    break;
110                }
111                group.push(candidate);
112                j += 1;
113            }
114            if group.len() >= 2 && best_group.as_ref().map_or(true, |bg| group.len() > bg.len()) {
115                best_group = Some(group);
116            }
117            i = j;
118        }
119    }
120
121    if let Some(ref group) = best_group {
122        // Reject the group if any axis has mismatched per-input dims. This
123        // catches broadcasting cases — e.g. GQA `bhgmk,bhgnk->bhgmn` where g has
124        // dim 2 in input[0] and dim 1 in input[1]. Merging would collapse the
125        // broadcast structure into a non-broadcast dim mismatch that downstream
126        // OptMatMul codegen / kernels cannot handle.
127        let dims_match = group.iter().all(|c| {
128            match (a_order.iter().position(|x| x == c), b_order.iter().position(|x| x == c)) {
129                (Some(p0), Some(p1)) => input_shapes[0][p0] == input_shapes[1][p1],
130                _ => true,
131            }
132        });
133        // Decline merges that combine fewer than two non-unit axes: they don't
134        // shrink the matmul loop count, they only add reshapes / a MoveAxis.
135        let worth_merging = group.iter().filter(|c| is_non_unit(c)).count() >= 2;
136        if !dims_match || !worth_merging {
137            best_group = None;
138        }
139    }
140
141    if let Some(group) = best_group {
142        // Found a mergeable group. Emit the patch.
143        let output_shape = super::eval::output_shape(&op.axes, &input_shapes)?;
144
145        let drop_set: Vec<char> = group[1..].to_vec();
146
147        let mut patch = TypedModelPatch::default();
148        let mut wires: TVec<OutletId> = patch.taps(model, &node.inputs)?;
149
150        // Reshape each input to merge the group
151        for (slot, order) in [(0, &a_order), (1, &b_order)] {
152            let positions: Vec<usize> =
153                group.iter().filter_map(|c| order.iter().position(|x| x == c)).collect();
154            if positions.len() < 2 {
155                continue;
156            }
157            let start = positions[0];
158            let from_dims: TVec<TDim> =
159                positions.iter().map(|&p| input_shapes[slot][p].clone()).collect();
160            let merged: TDim = from_dims.iter().product();
161            wires[slot] = patch.wire_node(
162                format!("{node_name}.merge_in{slot}"),
163                AxisOp::Reshape(start, from_dims, tvec![merged]),
164                &[wires[slot]],
165            )?[0];
166        }
167
168        // If group axes aren't consecutive in C, reorder the EinSum output
169        let c_positions: Vec<usize> =
170            group.iter().filter_map(|c| c_order.iter().position(|x| x == c)).collect();
171        let c_needs_reorder = c_positions.len() >= 2 && {
172            let mut sorted = c_positions.clone();
173            sorted.sort();
174            sorted.last().unwrap() - sorted.first().unwrap() != sorted.len() - 1
175                || sorted != c_positions
176        };
177        let mut adjusted_c_order = c_order.clone();
178        if c_needs_reorder {
179            // Move group axes together (put second next to first)
180            for k in 1..c_positions.len() {
181                let cur_pos = adjusted_c_order.iter().position(|&c| c == group[k]).unwrap();
182                let target_pos =
183                    adjusted_c_order.iter().position(|&c| c == group[k - 1]).unwrap() + 1;
184                if cur_pos != target_pos {
185                    let removed = adjusted_c_order.remove(cur_pos);
186                    let insert_at = if cur_pos < target_pos { target_pos - 1 } else { target_pos };
187                    adjusted_c_order.insert(insert_at, removed);
188                }
189            }
190        }
191
192        // Rebuild EinSum formula with adjusted output and dropped axes
193        let in0: String = a_order.iter().collect();
194        let in1: String = b_order.iter().collect();
195        let out: String = adjusted_c_order.iter().collect();
196        let expr = format!("{in0},{in1}->{out}");
197        let mut new_axes: AxesMapping = expr.parse()?;
198        for &drop in &drop_set {
199            new_axes = new_axes.remove_axis(drop)?;
200        }
201        let new_op =
202            EinSum { axes: new_axes, operating_dt: op.operating_dt, q_params: op.q_params };
203        let mut result = patch.wire_node(node_name, new_op, &wires)?;
204
205        // Reshape output to split the merged axis back
206        let merged_c_positions: Vec<usize> =
207            group.iter().filter_map(|c| adjusted_c_order.iter().position(|x| x == c)).collect();
208        if merged_c_positions.len() >= 2 {
209            let start = merged_c_positions[0];
210            // Use original output dims for the group axes
211            let original_c_positions: Vec<usize> =
212                group.iter().filter_map(|c| c_order.iter().position(|x| x == c)).collect();
213            let original_dims: TVec<TDim> =
214                original_c_positions.iter().map(|&p| output_shape[p].clone()).collect();
215            let merged: TDim = original_dims.iter().product();
216            result[0] = patch.wire_node(
217                format!("{node_name}.unmerge_out"),
218                AxisOp::Reshape(start, tvec![merged], original_dims),
219                &[result[0]],
220            )?[0];
221        }
222
223        // Restore original output order if we reordered
224        if c_needs_reorder {
225            // After unmerge, axes are in adjusted_c_order (but with group expanded).
226            // Need to permute back to c_order.
227            // Build the unmerged adjusted order
228            let mut unmerged_adj: Vec<char> = Vec::new();
229            for &c in &adjusted_c_order {
230                if c == group[0] {
231                    unmerged_adj.extend(&group);
232                } else if !group.contains(&c) {
233                    unmerged_adj.push(c);
234                }
235            }
236            // Find what moves are needed to get from unmerged_adj to c_order
237            for target_pos in 0..c_order.len() {
238                let cur_pos = unmerged_adj.iter().position(|&c| c == c_order[target_pos]).unwrap();
239                if cur_pos != target_pos {
240                    result[0] = patch.wire_node(
241                        format!("{node_name}.restore_out_{target_pos}"),
242                        AxisOp::Move(cur_pos, target_pos),
243                        &[result[0]],
244                    )?[0];
245                    let removed = unmerged_adj.remove(cur_pos);
246                    unmerged_adj.insert(target_pos, removed);
247                }
248            }
249        }
250
251        patch.shunt_outside(model, node.id.into(), result[0])?;
252        return Ok(Some(patch));
253    }
254
255    // Second pass: look for same-role pairs separated by exactly one k-like axis
256    // in a single input. Insert a MoveAxis to push the separator to the end.
257    let k_role: Role = (true, true, false); // present in both inputs, absent from output
258    let role_of = |c: char| axes.iter().find(|(ch, _)| *ch == c).map(|(_, r)| *r);
259
260    for (slot, order) in [(0usize, &a_order), (1, &b_order)] {
261        // Find three consecutive axes in this input where the outer two share a role
262        // and the middle one is a k-axis
263        for w in order.windows(3) {
264            let (left, mid, right) = (w[0], w[1], w[2]);
265            let left_role = role_of(left);
266            let mid_role = role_of(mid);
267            let right_role = role_of(right);
268            if left_role != right_role || mid_role != Some(k_role) {
269                continue;
270            }
271            // Only move the k-axis aside if the resulting merge is worth it:
272            // both axes we'd bring together must be non-unit. Otherwise the
273            // MoveAxis buys nothing (e.g. a 1×1 conv at pulse=1, where the
274            // batch / streaming axes are 1 and the only real axis was already
275            // the matmul m).
276            if !is_non_unit(&left) || !is_non_unit(&right) {
277                continue;
278            }
279            // left and right must also be consecutive in other inputs
280            // (output order is handled by the EinSum formula)
281            let other_input_orders: Vec<&Vec<char>> = [(0, &a_order), (1, &b_order)]
282                .iter()
283                .filter(|(s, _)| *s != slot)
284                .map(|(_, o)| *o)
285                .collect();
286            let consecutive_elsewhere = other_input_orders.iter().all(|order| {
287                let lp = order.iter().position(|&c| c == left);
288                let rp = order.iter().position(|&c| c == right);
289                match (lp, rp) {
290                    (Some(l), Some(r)) => r == l + 1,
291                    _ => true, // one or both absent — no constraint
292                }
293            });
294            if !consecutive_elsewhere {
295                continue;
296            }
297
298            // Move the k-axis to the inner (last) position in inputs and
299            // make left,right adjacent in the output too.
300            let mid_pos = order.iter().position(|&c| c == mid).unwrap();
301            let end_pos = order.len() - 1;
302            if mid_pos == end_pos {
303                continue;
304            }
305
306            // Use change_axes to update the EinSum formula for the input move
307            let move_op = AxisOp::Move(mid_pos, end_pos);
308            let Some(AxisChangeConsequence { substitute_op, .. }) =
309                op.change_axes(model, node, InOut::In(slot), &move_op)?
310            else {
311                continue;
312            };
313            let mut current_op = *substitute_op
314                .unwrap()
315                .downcast::<EinSum>()
316                .map_err(|_| anyhow!("expected EinSum"))?;
317
318            // Also make left,right adjacent in the output if needed
319            let new_c: Vec<char> = current_op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
320            let left_c = new_c.iter().position(|&c| c == left);
321            let right_c = new_c.iter().position(|&c| c == right);
322            let need_output_fix = matches!((left_c, right_c), (Some(l), Some(r)) if r != l + 1);
323            if need_output_fix {
324                let r_pos = right_c.unwrap();
325                let l_pos = left_c.unwrap();
326                let target = if r_pos < l_pos { l_pos } else { l_pos + 1 };
327                if let Some(AxisChangeConsequence { substitute_op, .. }) = current_op.change_axes(
328                    model,
329                    node,
330                    InOut::Out(0),
331                    &AxisOp::Move(r_pos, target),
332                )? {
333                    current_op = *substitute_op
334                        .unwrap()
335                        .downcast::<EinSum>()
336                        .map_err(|_| anyhow!("expected EinSum"))?;
337                }
338            }
339
340            let mut patch = TypedModelPatch::default();
341            let mut wires: TVec<OutletId> = patch.taps(model, &node.inputs)?;
342
343            wires[slot] =
344                patch.wire_node(format!("{node_name}.move_k_in{slot}"), move_op, &[wires[slot]])?
345                    [0];
346
347            let final_c: Vec<char> = current_op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
348            let mut result = patch.wire_node(node_name, current_op, &wires)?;
349
350            // Restore original output order
351            if need_output_fix {
352                let r_cur = final_c.iter().position(|&c| c == right).unwrap();
353                let r_orig = c_order.iter().position(|&c| c == right).unwrap();
354                if r_cur != r_orig {
355                    result[0] = patch.wire_node(
356                        format!("{node_name}.restore_out"),
357                        AxisOp::Move(r_cur, r_orig),
358                        &[result[0]],
359                    )?[0];
360                }
361            }
362
363            patch.shunt_outside(model, node.id.into(), result[0])?;
364            return Ok(Some(patch));
365        }
366    }
367
368    Ok(None)
369}
370
371pub fn detect_all(model: &mut TypedModel) -> TractResult<()> {
372    Rewriter::default().with_rule_for("detect-matmul-einsum", detect_rule).rewrite(&(), model)
373}
374
375pub fn flatten_all(model: &mut TypedModel) -> TractResult<()> {
376    Rewriter::default().with_rule_for("flatten-matmul-einsum", flatten_rule).rewrite(&(), model)
377}
378
379#[derive(Clone, Hash, PartialEq, Eq)]
380pub struct EinSumMatMul {
381    pub op: EinSum,
382    pub m_axis: char,
383    pub k_axis: char,
384    pub n_axis: char,
385    pub m: TDim,
386    pub k: TDim,
387    pub n: TDim,
388}
389
390impl EinSumMatMul {
391    pub fn m_axis(&self) -> &Axis {
392        self.op.axes.axis(self.m_axis).unwrap()
393    }
394    pub fn k_axis(&self) -> &Axis {
395        self.op.axes.axis(self.k_axis).unwrap()
396    }
397    pub fn n_axis(&self) -> &Axis {
398        self.op.axes.axis(self.n_axis).unwrap()
399    }
400    pub fn a_m(&self) -> usize {
401        self.m_axis().inputs[0][0]
402    }
403    pub fn a_k(&self) -> usize {
404        self.k_axis().inputs[0][0]
405    }
406    pub fn b_k(&self) -> usize {
407        self.k_axis().inputs[1][0]
408    }
409    pub fn b_n(&self) -> usize {
410        self.n_axis().inputs[1][0]
411    }
412    pub fn c_m(&self) -> Option<usize> {
413        self.m_axis().outputs[0].first().cloned()
414    }
415    pub fn c_n(&self) -> Option<usize> {
416        self.n_axis().outputs[0].first().cloned()
417    }
418
419    fn new(
420        op: EinSum,
421        m_axis: char,
422        k_axis: char,
423        n_axis: char,
424        m: TDim,
425        k: TDim,
426        n: TDim,
427    ) -> Self {
428        Self { op, m_axis, k_axis, n_axis, m, k, n }
429    }
430}
431
432impl Debug for EinSumMatMul {
433    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
434        write!(
435            f,
436            "EinsumMatMul: {} {:?} m: {}={}; k: {}={}; n: {}={}",
437            self.op.axes,
438            self.op.operating_dt,
439            self.m_axis,
440            self.m,
441            self.k_axis,
442            self.k,
443            self.n_axis,
444            self.n
445        )
446    }
447}
448
449impl Deref for EinSumMatMul {
450    type Target = EinSum;
451    fn deref(&self) -> &Self::Target {
452        &self.op
453    }
454}
455
456impl Op for EinSumMatMul {
457    fn name(&self) -> StaticName {
458        "EinSumMatMul".into()
459    }
460
461    op_as_typed_op!();
462}
463
464impl EvalOp for EinSumMatMul {
465    fn is_stateless(&self) -> bool {
466        true
467    }
468    fn eval_with_session(
469        &self,
470        node_id: usize,
471        session: &TurnState,
472        inputs: TVec<TValue>,
473    ) -> TractResult<TVec<TValue>> {
474        self.op.eval_with_session(node_id, session, inputs)
475    }
476}
477
478impl TypedOp for EinSumMatMul {
479    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
480        self.op.output_facts(inputs)
481    }
482
483    fn codegen(
484        &self,
485        model: &TypedModel,
486        node: &TypedNode,
487    ) -> TractResult<Option<TypedModelPatch>> {
488        // deal with parametric quantization extra inputs
489        if node.inputs.len() == 9 {
490            ensure!(self.op.q_params.is_some());
491            return dequant(model, node, self).map(Some);
492        }
493        ensure!(node.inputs.len() == 2);
494        let (a, b) = model.node_input_facts(node.id)?.into_iter().collect_tuple().unwrap();
495        // at this stage a and b must NOT be packed yet. if they are exotic, we can assume it's just compression
496        let must_transpose = if let Some(of) = a.exotic_fact() {
497            ensure!(of.is::<BlockQuantFact>());
498            false
499        } else if let Some(of) = b.exotic_fact() {
500            ensure!(of.is::<BlockQuantFact>());
501            true
502        } else if self.m == self.n {
503            false
504        } else {
505            match (self.m.as_i64(), self.n.as_i64()) {
506                (Some(m), Some(n)) => m < n,
507                (None, Some(n)) => n >= 8,
508                (Some(_), _) => false,
509                _ => (self.n.clone() - &self.m).prove_positive_or_zero(),
510            }
511        };
512        if must_transpose {
513            let mut op = self.clone();
514            op.op.axes.iter_all_axes_mut().for_each(|axis| axis.inputs.swap(0, 1));
515            std::mem::swap(&mut op.m_axis, &mut op.n_axis);
516            std::mem::swap(&mut op.m, &mut op.n);
517            return TypedModelPatch::replace_single_op(
518                model,
519                node,
520                &[node.inputs[1], node.inputs[0]],
521                op,
522            )
523            .map(|p| Some(p.with_context("transposing")));
524        }
525        // opt mat mul assumes we have at least one m or n
526        if self.c_m().is_some() || self.c_n().is_some() {
527            return optimized_mat_mul(model, node, self)
528                .map(|opt| opt.map(|p| p.with_context("optimizing")));
529        }
530        Ok(None)
531    }
532
533    as_op!();
534}
535
536pub(crate) fn detect_rule(
537    _ctx: &(),
538    model: &TypedModel,
539    node: &TypedNode,
540    _name: &str,
541    op: &EinSum,
542) -> TractResult<Option<TypedModelPatch>> {
543    rule_if!(node.inputs.len() == (2 + op.q_params.is_some() as usize * 7));
544    let input_facts = model.node_input_facts(node.id)?;
545    let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
546    let output_shape = super::eval::output_shape(&op.axes, &input_shapes)?;
547    let k_axes: TVec<&Axis> = op
548        .axes
549        .iter_all_axes()
550        // Filter possible candidates (should be one time in each inputs but not in output)
551        .filter(|a| a.inputs[0].len() == 1 && a.inputs[1].len() == 1 && a.outputs[0].is_empty())
552        .collect();
553
554    let non_trivial_k_axis = k_axes
555        .iter()
556        .filter(|a| {
557            !input_shapes[0][a.inputs[0][0]].is_one() || !input_shapes[1][a.inputs[1][0]].is_one()
558        })
559        .copied()
560        .collect::<TVec<_>>();
561
562    let k_axis = if non_trivial_k_axis.len() > 1 {
563        return regroup_k_axes(op, model, node, non_trivial_k_axis);
564    } else {
565        non_trivial_k_axis.first().or_else(|| k_axes.first()).copied()
566    };
567    let Some(k_axis) = k_axis else { return inject_k_axis(op, model, node).map(Some) };
568
569    let mut possible_m_axes: Vec<_> = op
570        .axes
571        .iter_all_axes()
572        .filter(|a| {
573            a.inputs[0].len() == 1
574                && (a.inputs[1].is_empty() || input_shapes[1][a.inputs[1][0]].is_one())
575                && (a.outputs[0].len() == 1
576                    || (input_shapes[0][a.inputs[0][0]].is_one() && a.inputs[1].is_empty()))
577        })
578        .collect();
579
580    // Prioritize obvious m-axes
581    if possible_m_axes.iter().any(|a| !a.outputs[0].is_empty()) {
582        possible_m_axes.retain(|a| !a.outputs[0].is_empty());
583    }
584
585    let m_axis = possible_m_axes
586        .into_iter()
587        .max_by_key(|a| input_shapes[0][a.inputs[0][0]].as_i64().unwrap_or(i64::MAX));
588
589    let Some(m_axis) = m_axis else {
590        return inject_m_or_n_axis(op, model, node, false).map(Some);
591    };
592
593    let n_axis = op
594        .axes
595        .iter_all_axes()
596        .filter(|a| {
597            (a.inputs[0].is_empty() || input_shapes[0][a.inputs[0][0]].is_one())
598                && a.inputs[1].len() == 1
599                && a.outputs[0].len() == 1
600                && *a != m_axis
601        })
602        .max_by_key(|a| input_shapes[1][a.inputs[1][0]].as_i64().unwrap_or(i64::MAX));
603    let Some(n_axis) = n_axis else {
604        return inject_m_or_n_axis(op, model, node, true).map(Some);
605    };
606    for axis in op.axes.iter_all_axes() {
607        let one = TDim::one();
608        let in_left =
609            axis.inputs[0].first().map(|pos| &input_shapes[0][*pos]).unwrap_or(&one) != &one;
610        let in_right =
611            axis.inputs[1].first().map(|pos| &input_shapes[1][*pos]).unwrap_or(&one) != &one;
612        let in_out = axis.outputs[0].first().map(|pos| &output_shape[*pos]).unwrap_or(&one) != &one;
613        if (in_left ^ in_right) && !in_out {
614            return Ok(None);
615            // return Ok(AxesOrPatch::NotAMatMul(
616            //     "non trivial single-side disappearing axis",
617            //     vec![axis],
618            // ));
619        }
620    }
621    let m = input_shapes[0][m_axis.inputs[0][0]].clone();
622    let k = input_shapes[0][k_axis.inputs[0][0]].clone();
623    let n = input_shapes[1][n_axis.inputs[1][0]].clone();
624    TypedModelPatch::replace_single_op(
625        model,
626        node,
627        &node.inputs,
628        EinSumMatMul::new(op.clone(), m_axis.repr, k_axis.repr, n_axis.repr, m, k, n),
629    )
630    .map(Some)
631}
632
633pub(super) fn inject_k_axis(
634    op: &EinSum,
635    model: &TypedModel,
636    node: &TypedNode,
637) -> TractResult<TypedModelPatch> {
638    let mut new_axes = op.axes.clone();
639    let name = &node.name;
640    let mut patch = TypedModelPatch::new("inject k axis");
641    let mut wire = patch.taps(model, &node.inputs)?;
642    let repr = new_axes.available_label();
643    new_axes = new_axes.with_extra_axis(repr, InOut::In(0), 0)?.with_extra_axis_occurency(
644        repr,
645        InOut::In(1),
646        0,
647    )?;
648    wire[0] = patch.wire_node(format!("{name}.add_k.0"), AxisOp::Add(0), &[wire[0]])?[0];
649    wire[1] = patch.wire_node(format!("{name}.add_k.1"), AxisOp::Add(0), &[wire[1]])?[0];
650    wire = patch.wire_node(&node.name, EinSum { axes: new_axes, ..op.clone() }, &wire)?;
651    patch.shunt_outside(model, node.id.into(), wire[0])?;
652    Ok(patch)
653}
654
655pub(super) fn regroup_k_axes(
656    op: &EinSum,
657    model: &TypedModel,
658    node: &TypedNode,
659    mut k_axes: TVec<&Axis>,
660) -> TractResult<Option<TypedModelPatch>> {
661    let input_facts = model.node_input_facts(node.id)?;
662    let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
663    let contig_in_a = k_axes
664        .iter()
665        .map(|axis| axis.inputs[0][0])
666        .sorted()
667        .tuple_windows()
668        .all(|(a, b)| a + 1 == b);
669    if contig_in_a {
670        k_axes.sort_by_key(|ax| ax.inputs[0][0]);
671    } else {
672        k_axes.sort_by_key(|ax| ax.inputs[1][0]);
673    }
674    let k_dims: TVec<_> =
675        k_axes.iter().map(|ax| input_shapes[0][ax.inputs[0][0]].clone()).collect();
676    let k: TDim = k_dims.iter().product();
677    let mut patch = TypedModelPatch::default();
678    let mut wires = patch.taps(model, &node.inputs)?;
679    let mut exprs: Vec<String> =
680        (0..2).map(|slot| op.axes.axes(InOut::In(slot)).map(|ax| ax.repr).join("")).collect();
681    for slot in 0..2 {
682        if k_axes.iter().map(|ax| ax.inputs[slot][0]).tuple_windows().any(|(a, b)| a + 1 != b) {
683            let after = op
684                .axes
685                .axes(InOut::In(slot))
686                .filter(|ax| !k_axes.contains(ax))
687                .chain(k_axes.iter().copied())
688                .map(|ax| ax.repr)
689                .join("");
690            let transpose =
691                AxesMapping::from_strs(&[&exprs[slot]], &[&after])?.translate_to_axis_ops()?;
692            for (ix, op) in transpose.into_iter().enumerate() {
693                wires[slot] = patch.wire_node(
694                    format!("{}.transpose_input_{}.{}", &node.name, slot, ix),
695                    op,
696                    &[wires[slot]],
697                )?[0];
698            }
699            exprs[slot] = after;
700        }
701        let pos = exprs[slot].chars().position(|c| k_axes[0].repr == c).unwrap();
702        wires[slot] = patch.wire_node(
703            format!("{}.fold_k_in_input_{}", &node.name, slot),
704            AxisOp::Reshape(pos, k_dims.clone(), tvec!(k.clone())),
705            &[wires[slot]],
706        )?[0];
707        exprs[slot] =
708            exprs[slot].chars().filter(|c| !k_axes.iter().any(|k| k.repr == *c)).collect();
709        exprs[slot].insert(pos, k_axes[0].repr);
710    }
711    let old = op.axes.to_string();
712    let (iexpr, oexpr) = old.split_once("->").unwrap();
713    let mut expr: String = exprs.iter().join(",");
714    if node.inputs.len() > 2 {
715        expr = expr + "," + &iexpr.split(",").skip(2).join(",");
716    }
717    expr = expr + "->" + oexpr;
718    let wire = patch.wire_node(
719        &node.name,
720        EinSum { axes: expr.parse().unwrap(), ..op.clone() },
721        &wires,
722    )?[0];
723    patch.shunt_outside(model, node.id.into(), wire)?;
724    Ok(Some(patch))
725}
726
727pub(super) fn inject_m_or_n_axis(
728    op: &EinSum,
729    model: &TypedModel,
730    node: &TypedNode,
731    is_n: bool,
732) -> TractResult<TypedModelPatch> {
733    let input_to_fix = is_n as usize;
734    let label = if is_n { "n" } else { "m" };
735    let name = &node.name;
736    let mut patch = TypedModelPatch::new("Injecting m or n axis");
737    let mut wire = patch.taps(model, &node.inputs)?;
738    let repr = op.axes.available_label();
739    let new_axes = op
740        .axes
741        .clone()
742        .with_extra_axis(repr, InOut::In(input_to_fix), 0)?
743        .with_extra_axis_occurency(repr, InOut::Out(0), 0)?;
744    wire[input_to_fix] =
745        patch.wire_node(format!("{name}.add_{label}"), AxisOp::Add(0), &[wire[input_to_fix]])?[0];
746    wire = patch.wire_node(name, EinSum { axes: new_axes, ..op.clone() }, &wire)?;
747    wire = patch.wire_node(&node.name, AxisOp::Rm(0), &wire)?;
748    patch.shunt_outside(model, node.id.into(), wire[0])?;
749    Ok(patch)
750}
751
752fn wire_axes_fix(
753    patch: &mut TypedModelPatch,
754    name: &str,
755    var: &str,
756    mapping: &AxesMapping,
757    mut outlet: TVec<OutletId>,
758) -> TractResult<TVec<OutletId>> {
759    for (ix, axis_op) in mapping.translate_to_axis_ops()?.into_iter().enumerate() {
760        outlet = patch.wire_node(format!("{name}.fix_{var}.{ix})"), axis_op, &outlet)?;
761    }
762    Ok(outlet)
763}
764
765fn dequant(
766    model: &TypedModel,
767    node: &TypedNode,
768    op: &EinSumMatMul,
769) -> TractResult<TypedModelPatch> {
770    let name = &node.name;
771    let mut patch = TypedModelPatch::new("Dequantizing einsum");
772
773    let k_axis = op.k_axis();
774
775    let mut taps = patch.taps(model, &node.inputs)?;
776    for ab in [0, 1] {
777        let scale_input = 4 + ab * 2;
778        if !patch.outlet_fact(taps[scale_input])?.shape.volume().is_one() {
779            let q_axis_in_output = op.axes.axis((InOut::In(scale_input), 0))?.outputs[0][0];
780            let output_rank = node.outputs[0].fact.rank();
781            for i in 1..(output_rank - q_axis_in_output) {
782                taps[scale_input] = patch.wire_node(
783                    format!("{name}.scale_input{ab}_axis_fix_{i}"),
784                    AxisOp::Add(i),
785                    &[taps[scale_input]],
786                )?[0];
787            }
788        }
789    }
790
791    let [mut a, mut b, bias, mut a0, a_scale, mut b0, b_scale, c0, c_scale] = *taps else {
792        bail!("Expect exactly 9 inputs")
793    };
794
795    wire_ensure_q8_flavour(&mut patch, &node.name, &mut a, "a", &mut a0, i8::datum_type())?;
796    wire_ensure_q8_flavour(&mut patch, &node.name, &mut b, "b", &mut b0, i8::datum_type())?;
797
798    let mut output = patch.wire_node(
799        &node.name,
800        EinSum {
801            q_params: None,
802            axes: op.axes.extract_sub_mapping(&[0, 1], &[0])?,
803            operating_dt: op.operating_dt,
804        },
805        &[a, b],
806    )?;
807
808    let a_i32 = patch.wire_node(format!("{name}.a_as_i32"), cast(i32::datum_type()), &[a])?[0];
809    let b_i32 = patch.wire_node(format!("{name}.b_as_i32"), cast(i32::datum_type()), &[b])?[0];
810    let sum_a = patch.wire_node(
811        format!("{name}.sum_a"),
812        Reduce::new(tvec!(k_axis.inputs[0][0]), Reducer::Sum),
813        &[a_i32],
814    )?;
815    let sum_b = patch.wire_node(
816        format!("{name}.sum_b"),
817        Reduce::new(tvec!(k_axis.inputs[1][0]), Reducer::Sum),
818        &[b_i32],
819    )?;
820
821    let sum_a =
822        wire_axes_fix(&mut patch, name, "sum_a", &op.axes.extract_sub_mapping(&[0], &[0])?, sum_a)?;
823    let sum_b =
824        wire_axes_fix(&mut patch, name, "sum_b", &op.axes.extract_sub_mapping(&[1], &[0])?, sum_b)?;
825    let bias = tvec!(bias);
826    let bias =
827        wire_axes_fix(&mut patch, name, "bias", &op.axes.extract_sub_mapping(&[2], &[0])?, bias)?;
828
829    let abc_scale = combine_scales(&mut patch, name, a_scale, b_scale, c_scale)?;
830
831    output = patch.wire_node(format!("{name}.add_bias"), add(), &[output[0], bias[0]])?;
832
833    let k = model.outlet_fact(node.inputs[0])?.shape[k_axis.inputs[0][0]].clone();
834    let output = compensate_zero_points(&mut patch, name, output[0], k, a0, b0, sum_a[0], sum_b[0])
835        .context("Zero point compensation")?;
836    let output = requant(&mut patch, name, output, op.q_params.unwrap(), abc_scale, c0)?;
837    patch.shunt_outside(model, node.id.into(), output)?;
838    Ok(patch)
839}
840
841fn flatten_rule(
842    _ctx: &(),
843    model: &TypedModel,
844    node: &TypedNode,
845    _name: &str,
846    op: &EinSumMatMul,
847) -> TractResult<Option<TypedModelPatch>> {
848    TypedModelPatch::replace_single_op(model, node, &node.inputs, op.op.clone()).map(Some)
849}
850
851fn optimized_mat_mul(
852    model: &TypedModel,
853    node: &TypedNode,
854    op: &EinSumMatMul,
855) -> TractResult<Option<TypedModelPatch>> {
856    let (mode_picker, left_pack, impls) = kernel_selection::strategize(model, node, op)?;
857    let input_facts = model.node_input_facts(node.id)?;
858    let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
859    let prefix = &node.name;
860
861    let mut patch = TypedModelPatch::new("Einsum to OptMatMul");
862    let taps = patch.taps(model, &node.inputs)?;
863    let name = &node.name;
864
865    let pack_a: Box<dyn TypedOp> = if input_facts[0].konst.is_some() {
866        if let Some(packed_format) = left_pack.downcast_ref::<PackedBlockQuantFormat>().cloned() {
867            Box::new(OptSimpleMatMulPack {
868                packed_format,
869                k: input_shapes[0][op.a_k()].to_usize().unwrap(),
870                m: input_shapes[0][op.a_m()].to_usize().unwrap(),
871            })
872        } else {
873            // PackedFormat or a custom packer (e.g. PackedI8K4); OptMatMulPack
874            // dispatches on the concrete format at pack time.
875            Box::new(OptMatMulPack {
876                packers: vec![left_pack],
877                mode_picker: ModePicker::Single,
878                k_axis: op.a_k(),
879                mn_axis: op.a_m(),
880            })
881        }
882    } else {
883        Box::new(OptMatMulPack {
884            packers: impls
885                .iter()
886                .map(|(mmm, p, pe)| {
887                    pe.as_ref()
888                        .map(|pe| pe.from.clone())
889                        .unwrap_or_else(|| mmm.packings()[*p].0.clone())
890                })
891                .collect(),
892            mode_picker: mode_picker.clone(),
893            k_axis: op.a_k(),
894            mn_axis: op.a_m(),
895        })
896    };
897    let pa = patch.wire_node(format!("{prefix}.pack_a"), pack_a, &[taps[0]])?[0];
898
899    let pb = patch.wire_node(
900        format!("{prefix}.pack_b"),
901        OptMatMulPack {
902            k_axis: op.b_k(),
903            mn_axis: op.b_n(),
904            packers: impls.iter().map(|(mmm, p, _)| mmm.packings()[*p].1.clone()).collect(),
905            mode_picker: mode_picker.clone(),
906        },
907        &[taps[1]],
908    )?[0];
909
910    let mut c_to_a_axis_mapping = tvec!();
911    let mut c_to_b_axis_mapping = tvec!();
912    for axis in op
913        .op
914        .axes
915        .iter_all_axes()
916        .filter(|&axis| ![op.m_axis, op.k_axis, op.n_axis].contains(&axis.repr))
917    {
918        if let (&[c], &[a]) = (&*axis.outputs[0], &*axis.inputs[0])
919            && input_shapes[0][a] != 1.to_dim()
920        {
921            let a = a - (a > op.a_m()) as usize - (a > op.a_k()) as usize;
922            c_to_a_axis_mapping.push((c, a));
923        }
924        if let (&[c], &[b]) = (&*axis.outputs[0], &*axis.inputs[1])
925            && input_shapes[1][b] != 1.to_dim()
926        {
927            let b = b - (b > op.b_n()) as usize - (b > op.b_k()) as usize;
928            c_to_b_axis_mapping.push((c, b));
929        }
930    }
931
932    let c_fact = op.output_facts(&input_facts)?.remove(0);
933    let geo = AddMatMulGeometry {
934        k: op.k.clone(),
935        c_to_a_axis_mapping: MapOutputAxisToInput(c_to_a_axis_mapping),
936        c_to_b_axis_mapping: MapOutputAxisToInput(c_to_b_axis_mapping),
937    };
938    let (mmms, packings, extractor): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(impls);
939    let outputs = mmms.iter().map(|mmm| unsafe { mmm.c_view(op.c_m(), op.c_n()) }).collect();
940    let trivial_packing = mmms.len() == 1
941        && packings[0] == 0
942        && extractor[0].is_none()
943        && input_facts[0].exotic_fact.is_none();
944    let opt = OptMatMul::new(
945        mmms,
946        mode_picker,
947        c_fact,
948        op.c_m(),
949        op.c_n(),
950        vec![
951            ProtoFusedSpec::AddMatMul {
952                geo,
953                a: 0,
954                b: 1,
955                packings: izip!(packings, extractor).collect_vec(),
956            },
957            ProtoFusedSpec::Store(outputs),
958        ],
959        trivial_packing,
960    )
961    .context("Creating OptMatMul")?;
962    let output = patch.wire_node(name, opt, &[pa, pb])?[0];
963    patch.shunt_outside(model, node.id.into(), output)?;
964    Ok(Some(patch))
965}