1use std::fmt::Formatter;
2use std::ops::Deref;
3
4use tract_itertools::{izip, multiunzip};
5use tract_linalg::block_quant::PackedBlockQuantFormat;
6
7use super::*;
8use crate::ops::cast::cast;
9use crate::ops::math::add;
10use crate::ops::matmul::ModePicker;
11use crate::ops::matmul::optimized::{
12 AddMatMulGeometry, MapOutputAxisToInput, OptMatMul, ProtoFusedSpec,
13};
14use crate::ops::matmul::pack::{OptMatMulPack, OptSimpleMatMulPack};
15use crate::ops::matmul::quant::{
16 combine_scales, compensate_zero_points, requant, wire_ensure_q8_flavour,
17};
18use crate::ops::nn::{Reduce, Reducer};
19
20pub fn merge_consecutive_same_role_axes(model: &mut TypedModel) -> TractResult<()> {
21 Rewriter::default()
22 .with_rule_for("merge-same-role-axes", merge_same_role_axes_rule)
23 .rewrite(&(), model)
24}
25
26fn merge_same_role_axes_rule(
27 _ctx: &(),
28 model: &TypedModel,
29 node: &TypedNode,
30 node_name: &str,
31 op: &EinSum,
32) -> TractResult<Option<TypedModelPatch>> {
33 rule_if!(node.inputs.len() == 2);
35
36 type Role = (bool, bool, bool);
38 let axes: Vec<(char, Role)> = op
39 .axes
40 .iter_all_axes()
41 .map(|a| {
42 (a.repr, (!a.inputs[0].is_empty(), !a.inputs[1].is_empty(), !a.outputs[0].is_empty()))
43 })
44 .collect();
45
46 let a_order: Vec<char> = op.axes.axes(InOut::In(0)).map(|a| a.repr).collect();
48 let b_order: Vec<char> = op.axes.axes(InOut::In(1)).map(|a| a.repr).collect();
49 let c_order: Vec<char> = op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
50
51 let input_facts = model.node_input_facts(node.id)?;
54 let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
55 let is_non_unit = |c: &char| -> bool {
62 let dim = a_order
63 .iter()
64 .position(|x| x == c)
65 .map(|p| &input_shapes[0][p])
66 .or_else(|| b_order.iter().position(|x| x == c).map(|p| &input_shapes[1][p]));
67 dim.map_or(true, |d| d.as_i64() != Some(1))
68 };
69
70 let role_map: std::collections::HashMap<char, Role> = axes.iter().cloned().collect();
73 let mut best_group: Option<Vec<char>> = None;
74
75 let all_orders = [&a_order, &b_order];
77 for (primary_idx, primary_order) in all_orders.iter().enumerate() {
78 let mut i = 0;
79 while i < primary_order.len() {
80 let first = primary_order[i];
81 let first_role = role_map[&first];
82 let mut group = vec![first];
83 let mut j = i + 1;
84 while j < primary_order.len() {
85 let candidate = primary_order[j];
86 if role_map[&candidate] != first_role {
87 break;
88 }
89 let consecutive_in_others = all_orders
91 .iter()
92 .enumerate()
93 .filter(|(idx, _)| *idx != primary_idx)
94 .all(|(_, order)| {
95 let positions: Vec<usize> = group
96 .iter()
97 .chain(std::iter::once(&candidate))
98 .filter_map(|c| order.iter().position(|x| x == c))
99 .collect();
100 if positions.len() <= 1 {
101 return true;
102 }
103 let mut sorted = positions.clone();
104 sorted.sort();
105 sorted == positions
106 && sorted.last().unwrap() - sorted.first().unwrap() == sorted.len() - 1
107 });
108 if !consecutive_in_others {
109 break;
110 }
111 group.push(candidate);
112 j += 1;
113 }
114 if group.len() >= 2 && best_group.as_ref().map_or(true, |bg| group.len() > bg.len()) {
115 best_group = Some(group);
116 }
117 i = j;
118 }
119 }
120
121 if let Some(ref group) = best_group {
122 let dims_match = group.iter().all(|c| {
128 match (a_order.iter().position(|x| x == c), b_order.iter().position(|x| x == c)) {
129 (Some(p0), Some(p1)) => input_shapes[0][p0] == input_shapes[1][p1],
130 _ => true,
131 }
132 });
133 let worth_merging = group.iter().filter(|c| is_non_unit(c)).count() >= 2;
136 if !dims_match || !worth_merging {
137 best_group = None;
138 }
139 }
140
141 if let Some(group) = best_group {
142 let output_shape = super::eval::output_shape(&op.axes, &input_shapes)?;
144
145 let drop_set: Vec<char> = group[1..].to_vec();
146
147 let mut patch = TypedModelPatch::default();
148 let mut wires: TVec<OutletId> = patch.taps(model, &node.inputs)?;
149
150 for (slot, order) in [(0, &a_order), (1, &b_order)] {
152 let positions: Vec<usize> =
153 group.iter().filter_map(|c| order.iter().position(|x| x == c)).collect();
154 if positions.len() < 2 {
155 continue;
156 }
157 let start = positions[0];
158 let from_dims: TVec<TDim> =
159 positions.iter().map(|&p| input_shapes[slot][p].clone()).collect();
160 let merged: TDim = from_dims.iter().product();
161 wires[slot] = patch.wire_node(
162 format!("{node_name}.merge_in{slot}"),
163 AxisOp::Reshape(start, from_dims, tvec![merged]),
164 &[wires[slot]],
165 )?[0];
166 }
167
168 let c_positions: Vec<usize> =
170 group.iter().filter_map(|c| c_order.iter().position(|x| x == c)).collect();
171 let c_needs_reorder = c_positions.len() >= 2 && {
172 let mut sorted = c_positions.clone();
173 sorted.sort();
174 sorted.last().unwrap() - sorted.first().unwrap() != sorted.len() - 1
175 || sorted != c_positions
176 };
177 let mut adjusted_c_order = c_order.clone();
178 if c_needs_reorder {
179 for k in 1..c_positions.len() {
181 let cur_pos = adjusted_c_order.iter().position(|&c| c == group[k]).unwrap();
182 let target_pos =
183 adjusted_c_order.iter().position(|&c| c == group[k - 1]).unwrap() + 1;
184 if cur_pos != target_pos {
185 let removed = adjusted_c_order.remove(cur_pos);
186 let insert_at = if cur_pos < target_pos { target_pos - 1 } else { target_pos };
187 adjusted_c_order.insert(insert_at, removed);
188 }
189 }
190 }
191
192 let in0: String = a_order.iter().collect();
194 let in1: String = b_order.iter().collect();
195 let out: String = adjusted_c_order.iter().collect();
196 let expr = format!("{in0},{in1}->{out}");
197 let mut new_axes: AxesMapping = expr.parse()?;
198 for &drop in &drop_set {
199 new_axes = new_axes.remove_axis(drop)?;
200 }
201 let new_op =
202 EinSum { axes: new_axes, operating_dt: op.operating_dt, q_params: op.q_params };
203 let mut result = patch.wire_node(node_name, new_op, &wires)?;
204
205 let merged_c_positions: Vec<usize> =
207 group.iter().filter_map(|c| adjusted_c_order.iter().position(|x| x == c)).collect();
208 if merged_c_positions.len() >= 2 {
209 let start = merged_c_positions[0];
210 let original_c_positions: Vec<usize> =
212 group.iter().filter_map(|c| c_order.iter().position(|x| x == c)).collect();
213 let original_dims: TVec<TDim> =
214 original_c_positions.iter().map(|&p| output_shape[p].clone()).collect();
215 let merged: TDim = original_dims.iter().product();
216 result[0] = patch.wire_node(
217 format!("{node_name}.unmerge_out"),
218 AxisOp::Reshape(start, tvec![merged], original_dims),
219 &[result[0]],
220 )?[0];
221 }
222
223 if c_needs_reorder {
225 let mut unmerged_adj: Vec<char> = Vec::new();
229 for &c in &adjusted_c_order {
230 if c == group[0] {
231 unmerged_adj.extend(&group);
232 } else if !group.contains(&c) {
233 unmerged_adj.push(c);
234 }
235 }
236 for target_pos in 0..c_order.len() {
238 let cur_pos = unmerged_adj.iter().position(|&c| c == c_order[target_pos]).unwrap();
239 if cur_pos != target_pos {
240 result[0] = patch.wire_node(
241 format!("{node_name}.restore_out_{target_pos}"),
242 AxisOp::Move(cur_pos, target_pos),
243 &[result[0]],
244 )?[0];
245 let removed = unmerged_adj.remove(cur_pos);
246 unmerged_adj.insert(target_pos, removed);
247 }
248 }
249 }
250
251 patch.shunt_outside(model, node.id.into(), result[0])?;
252 return Ok(Some(patch));
253 }
254
255 let k_role: Role = (true, true, false); let role_of = |c: char| axes.iter().find(|(ch, _)| *ch == c).map(|(_, r)| *r);
259
260 for (slot, order) in [(0usize, &a_order), (1, &b_order)] {
261 for w in order.windows(3) {
264 let (left, mid, right) = (w[0], w[1], w[2]);
265 let left_role = role_of(left);
266 let mid_role = role_of(mid);
267 let right_role = role_of(right);
268 if left_role != right_role || mid_role != Some(k_role) {
269 continue;
270 }
271 if !is_non_unit(&left) || !is_non_unit(&right) {
277 continue;
278 }
279 let other_input_orders: Vec<&Vec<char>> = [(0, &a_order), (1, &b_order)]
282 .iter()
283 .filter(|(s, _)| *s != slot)
284 .map(|(_, o)| *o)
285 .collect();
286 let consecutive_elsewhere = other_input_orders.iter().all(|order| {
287 let lp = order.iter().position(|&c| c == left);
288 let rp = order.iter().position(|&c| c == right);
289 match (lp, rp) {
290 (Some(l), Some(r)) => r == l + 1,
291 _ => true, }
293 });
294 if !consecutive_elsewhere {
295 continue;
296 }
297
298 let mid_pos = order.iter().position(|&c| c == mid).unwrap();
301 let end_pos = order.len() - 1;
302 if mid_pos == end_pos {
303 continue;
304 }
305
306 let move_op = AxisOp::Move(mid_pos, end_pos);
308 let Some(AxisChangeConsequence { substitute_op, .. }) =
309 op.change_axes(model, node, InOut::In(slot), &move_op)?
310 else {
311 continue;
312 };
313 let mut current_op = *substitute_op
314 .unwrap()
315 .downcast::<EinSum>()
316 .map_err(|_| anyhow!("expected EinSum"))?;
317
318 let new_c: Vec<char> = current_op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
320 let left_c = new_c.iter().position(|&c| c == left);
321 let right_c = new_c.iter().position(|&c| c == right);
322 let need_output_fix = matches!((left_c, right_c), (Some(l), Some(r)) if r != l + 1);
323 if need_output_fix {
324 let r_pos = right_c.unwrap();
325 let l_pos = left_c.unwrap();
326 let target = if r_pos < l_pos { l_pos } else { l_pos + 1 };
327 if let Some(AxisChangeConsequence { substitute_op, .. }) = current_op.change_axes(
328 model,
329 node,
330 InOut::Out(0),
331 &AxisOp::Move(r_pos, target),
332 )? {
333 current_op = *substitute_op
334 .unwrap()
335 .downcast::<EinSum>()
336 .map_err(|_| anyhow!("expected EinSum"))?;
337 }
338 }
339
340 let mut patch = TypedModelPatch::default();
341 let mut wires: TVec<OutletId> = patch.taps(model, &node.inputs)?;
342
343 wires[slot] =
344 patch.wire_node(format!("{node_name}.move_k_in{slot}"), move_op, &[wires[slot]])?
345 [0];
346
347 let final_c: Vec<char> = current_op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
348 let mut result = patch.wire_node(node_name, current_op, &wires)?;
349
350 if need_output_fix {
352 let r_cur = final_c.iter().position(|&c| c == right).unwrap();
353 let r_orig = c_order.iter().position(|&c| c == right).unwrap();
354 if r_cur != r_orig {
355 result[0] = patch.wire_node(
356 format!("{node_name}.restore_out"),
357 AxisOp::Move(r_cur, r_orig),
358 &[result[0]],
359 )?[0];
360 }
361 }
362
363 patch.shunt_outside(model, node.id.into(), result[0])?;
364 return Ok(Some(patch));
365 }
366 }
367
368 Ok(None)
369}
370
371pub fn detect_all(model: &mut TypedModel) -> TractResult<()> {
372 Rewriter::default().with_rule_for("detect-matmul-einsum", detect_rule).rewrite(&(), model)
373}
374
375pub fn flatten_all(model: &mut TypedModel) -> TractResult<()> {
376 Rewriter::default().with_rule_for("flatten-matmul-einsum", flatten_rule).rewrite(&(), model)
377}
378
379#[derive(Clone, Hash, PartialEq, Eq)]
380pub struct EinSumMatMul {
381 pub op: EinSum,
382 pub m_axis: char,
383 pub k_axis: char,
384 pub n_axis: char,
385 pub m: TDim,
386 pub k: TDim,
387 pub n: TDim,
388}
389
390impl EinSumMatMul {
391 pub fn m_axis(&self) -> &Axis {
392 self.op.axes.axis(self.m_axis).unwrap()
393 }
394 pub fn k_axis(&self) -> &Axis {
395 self.op.axes.axis(self.k_axis).unwrap()
396 }
397 pub fn n_axis(&self) -> &Axis {
398 self.op.axes.axis(self.n_axis).unwrap()
399 }
400 pub fn a_m(&self) -> usize {
401 self.m_axis().inputs[0][0]
402 }
403 pub fn a_k(&self) -> usize {
404 self.k_axis().inputs[0][0]
405 }
406 pub fn b_k(&self) -> usize {
407 self.k_axis().inputs[1][0]
408 }
409 pub fn b_n(&self) -> usize {
410 self.n_axis().inputs[1][0]
411 }
412 pub fn c_m(&self) -> Option<usize> {
413 self.m_axis().outputs[0].first().cloned()
414 }
415 pub fn c_n(&self) -> Option<usize> {
416 self.n_axis().outputs[0].first().cloned()
417 }
418
419 fn new(
420 op: EinSum,
421 m_axis: char,
422 k_axis: char,
423 n_axis: char,
424 m: TDim,
425 k: TDim,
426 n: TDim,
427 ) -> Self {
428 Self { op, m_axis, k_axis, n_axis, m, k, n }
429 }
430}
431
432impl Debug for EinSumMatMul {
433 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
434 write!(
435 f,
436 "EinsumMatMul: {} {:?} m: {}={}; k: {}={}; n: {}={}",
437 self.op.axes,
438 self.op.operating_dt,
439 self.m_axis,
440 self.m,
441 self.k_axis,
442 self.k,
443 self.n_axis,
444 self.n
445 )
446 }
447}
448
449impl Deref for EinSumMatMul {
450 type Target = EinSum;
451 fn deref(&self) -> &Self::Target {
452 &self.op
453 }
454}
455
456impl Op for EinSumMatMul {
457 fn name(&self) -> StaticName {
458 "EinSumMatMul".into()
459 }
460
461 op_as_typed_op!();
462}
463
464impl EvalOp for EinSumMatMul {
465 fn is_stateless(&self) -> bool {
466 true
467 }
468 fn eval_with_session(
469 &self,
470 node_id: usize,
471 session: &TurnState,
472 inputs: TVec<TValue>,
473 ) -> TractResult<TVec<TValue>> {
474 self.op.eval_with_session(node_id, session, inputs)
475 }
476}
477
478impl TypedOp for EinSumMatMul {
479 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
480 self.op.output_facts(inputs)
481 }
482
483 fn codegen(
484 &self,
485 model: &TypedModel,
486 node: &TypedNode,
487 ) -> TractResult<Option<TypedModelPatch>> {
488 if node.inputs.len() == 9 {
490 ensure!(self.op.q_params.is_some());
491 return dequant(model, node, self).map(Some);
492 }
493 ensure!(node.inputs.len() == 2);
494 let (a, b) = model.node_input_facts(node.id)?.into_iter().collect_tuple().unwrap();
495 let must_transpose = if let Some(of) = a.exotic_fact() {
497 ensure!(of.is::<BlockQuantFact>());
498 false
499 } else if let Some(of) = b.exotic_fact() {
500 ensure!(of.is::<BlockQuantFact>());
501 true
502 } else if self.m == self.n {
503 false
504 } else {
505 match (self.m.as_i64(), self.n.as_i64()) {
506 (Some(m), Some(n)) => m < n,
507 (None, Some(n)) => n >= 8,
508 (Some(_), _) => false,
509 _ => (self.n.clone() - &self.m).prove_positive_or_zero(),
510 }
511 };
512 if must_transpose {
513 let mut op = self.clone();
514 op.op.axes.iter_all_axes_mut().for_each(|axis| axis.inputs.swap(0, 1));
515 std::mem::swap(&mut op.m_axis, &mut op.n_axis);
516 std::mem::swap(&mut op.m, &mut op.n);
517 return TypedModelPatch::replace_single_op(
518 model,
519 node,
520 &[node.inputs[1], node.inputs[0]],
521 op,
522 )
523 .map(|p| Some(p.with_context("transposing")));
524 }
525 if self.c_m().is_some() || self.c_n().is_some() {
527 return optimized_mat_mul(model, node, self)
528 .map(|opt| opt.map(|p| p.with_context("optimizing")));
529 }
530 Ok(None)
531 }
532
533 as_op!();
534}
535
536pub(crate) fn detect_rule(
537 _ctx: &(),
538 model: &TypedModel,
539 node: &TypedNode,
540 _name: &str,
541 op: &EinSum,
542) -> TractResult<Option<TypedModelPatch>> {
543 rule_if!(node.inputs.len() == (2 + op.q_params.is_some() as usize * 7));
544 let input_facts = model.node_input_facts(node.id)?;
545 let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
546 let output_shape = super::eval::output_shape(&op.axes, &input_shapes)?;
547 let k_axes: TVec<&Axis> = op
548 .axes
549 .iter_all_axes()
550 .filter(|a| a.inputs[0].len() == 1 && a.inputs[1].len() == 1 && a.outputs[0].is_empty())
552 .collect();
553
554 let non_trivial_k_axis = k_axes
555 .iter()
556 .filter(|a| {
557 !input_shapes[0][a.inputs[0][0]].is_one() || !input_shapes[1][a.inputs[1][0]].is_one()
558 })
559 .copied()
560 .collect::<TVec<_>>();
561
562 let k_axis = if non_trivial_k_axis.len() > 1 {
563 return regroup_k_axes(op, model, node, non_trivial_k_axis);
564 } else {
565 non_trivial_k_axis.first().or_else(|| k_axes.first()).copied()
566 };
567 let Some(k_axis) = k_axis else { return inject_k_axis(op, model, node).map(Some) };
568
569 let mut possible_m_axes: Vec<_> = op
570 .axes
571 .iter_all_axes()
572 .filter(|a| {
573 a.inputs[0].len() == 1
574 && (a.inputs[1].is_empty() || input_shapes[1][a.inputs[1][0]].is_one())
575 && (a.outputs[0].len() == 1
576 || (input_shapes[0][a.inputs[0][0]].is_one() && a.inputs[1].is_empty()))
577 })
578 .collect();
579
580 if possible_m_axes.iter().any(|a| !a.outputs[0].is_empty()) {
582 possible_m_axes.retain(|a| !a.outputs[0].is_empty());
583 }
584
585 let m_axis = possible_m_axes
586 .into_iter()
587 .max_by_key(|a| input_shapes[0][a.inputs[0][0]].as_i64().unwrap_or(i64::MAX));
588
589 let Some(m_axis) = m_axis else {
590 return inject_m_or_n_axis(op, model, node, false).map(Some);
591 };
592
593 let n_axis = op
594 .axes
595 .iter_all_axes()
596 .filter(|a| {
597 (a.inputs[0].is_empty() || input_shapes[0][a.inputs[0][0]].is_one())
598 && a.inputs[1].len() == 1
599 && a.outputs[0].len() == 1
600 && *a != m_axis
601 })
602 .max_by_key(|a| input_shapes[1][a.inputs[1][0]].as_i64().unwrap_or(i64::MAX));
603 let Some(n_axis) = n_axis else {
604 return inject_m_or_n_axis(op, model, node, true).map(Some);
605 };
606 for axis in op.axes.iter_all_axes() {
607 let one = TDim::one();
608 let in_left =
609 axis.inputs[0].first().map(|pos| &input_shapes[0][*pos]).unwrap_or(&one) != &one;
610 let in_right =
611 axis.inputs[1].first().map(|pos| &input_shapes[1][*pos]).unwrap_or(&one) != &one;
612 let in_out = axis.outputs[0].first().map(|pos| &output_shape[*pos]).unwrap_or(&one) != &one;
613 if (in_left ^ in_right) && !in_out {
614 return Ok(None);
615 }
620 }
621 let m = input_shapes[0][m_axis.inputs[0][0]].clone();
622 let k = input_shapes[0][k_axis.inputs[0][0]].clone();
623 let n = input_shapes[1][n_axis.inputs[1][0]].clone();
624 TypedModelPatch::replace_single_op(
625 model,
626 node,
627 &node.inputs,
628 EinSumMatMul::new(op.clone(), m_axis.repr, k_axis.repr, n_axis.repr, m, k, n),
629 )
630 .map(Some)
631}
632
633pub(super) fn inject_k_axis(
634 op: &EinSum,
635 model: &TypedModel,
636 node: &TypedNode,
637) -> TractResult<TypedModelPatch> {
638 let mut new_axes = op.axes.clone();
639 let name = &node.name;
640 let mut patch = TypedModelPatch::new("inject k axis");
641 let mut wire = patch.taps(model, &node.inputs)?;
642 let repr = new_axes.available_label();
643 new_axes = new_axes.with_extra_axis(repr, InOut::In(0), 0)?.with_extra_axis_occurency(
644 repr,
645 InOut::In(1),
646 0,
647 )?;
648 wire[0] = patch.wire_node(format!("{name}.add_k.0"), AxisOp::Add(0), &[wire[0]])?[0];
649 wire[1] = patch.wire_node(format!("{name}.add_k.1"), AxisOp::Add(0), &[wire[1]])?[0];
650 wire = patch.wire_node(&node.name, EinSum { axes: new_axes, ..op.clone() }, &wire)?;
651 patch.shunt_outside(model, node.id.into(), wire[0])?;
652 Ok(patch)
653}
654
655pub(super) fn regroup_k_axes(
656 op: &EinSum,
657 model: &TypedModel,
658 node: &TypedNode,
659 mut k_axes: TVec<&Axis>,
660) -> TractResult<Option<TypedModelPatch>> {
661 let input_facts = model.node_input_facts(node.id)?;
662 let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
663 let contig_in_a = k_axes
664 .iter()
665 .map(|axis| axis.inputs[0][0])
666 .sorted()
667 .tuple_windows()
668 .all(|(a, b)| a + 1 == b);
669 if contig_in_a {
670 k_axes.sort_by_key(|ax| ax.inputs[0][0]);
671 } else {
672 k_axes.sort_by_key(|ax| ax.inputs[1][0]);
673 }
674 let k_dims: TVec<_> =
675 k_axes.iter().map(|ax| input_shapes[0][ax.inputs[0][0]].clone()).collect();
676 let k: TDim = k_dims.iter().product();
677 let mut patch = TypedModelPatch::default();
678 let mut wires = patch.taps(model, &node.inputs)?;
679 let mut exprs: Vec<String> =
680 (0..2).map(|slot| op.axes.axes(InOut::In(slot)).map(|ax| ax.repr).join("")).collect();
681 for slot in 0..2 {
682 if k_axes.iter().map(|ax| ax.inputs[slot][0]).tuple_windows().any(|(a, b)| a + 1 != b) {
683 let after = op
684 .axes
685 .axes(InOut::In(slot))
686 .filter(|ax| !k_axes.contains(ax))
687 .chain(k_axes.iter().copied())
688 .map(|ax| ax.repr)
689 .join("");
690 let transpose =
691 AxesMapping::from_strs(&[&exprs[slot]], &[&after])?.translate_to_axis_ops()?;
692 for (ix, op) in transpose.into_iter().enumerate() {
693 wires[slot] = patch.wire_node(
694 format!("{}.transpose_input_{}.{}", &node.name, slot, ix),
695 op,
696 &[wires[slot]],
697 )?[0];
698 }
699 exprs[slot] = after;
700 }
701 let pos = exprs[slot].chars().position(|c| k_axes[0].repr == c).unwrap();
702 wires[slot] = patch.wire_node(
703 format!("{}.fold_k_in_input_{}", &node.name, slot),
704 AxisOp::Reshape(pos, k_dims.clone(), tvec!(k.clone())),
705 &[wires[slot]],
706 )?[0];
707 exprs[slot] =
708 exprs[slot].chars().filter(|c| !k_axes.iter().any(|k| k.repr == *c)).collect();
709 exprs[slot].insert(pos, k_axes[0].repr);
710 }
711 let old = op.axes.to_string();
712 let (iexpr, oexpr) = old.split_once("->").unwrap();
713 let mut expr: String = exprs.iter().join(",");
714 if node.inputs.len() > 2 {
715 expr = expr + "," + &iexpr.split(",").skip(2).join(",");
716 }
717 expr = expr + "->" + oexpr;
718 let wire = patch.wire_node(
719 &node.name,
720 EinSum { axes: expr.parse().unwrap(), ..op.clone() },
721 &wires,
722 )?[0];
723 patch.shunt_outside(model, node.id.into(), wire)?;
724 Ok(Some(patch))
725}
726
727pub(super) fn inject_m_or_n_axis(
728 op: &EinSum,
729 model: &TypedModel,
730 node: &TypedNode,
731 is_n: bool,
732) -> TractResult<TypedModelPatch> {
733 let input_to_fix = is_n as usize;
734 let label = if is_n { "n" } else { "m" };
735 let name = &node.name;
736 let mut patch = TypedModelPatch::new("Injecting m or n axis");
737 let mut wire = patch.taps(model, &node.inputs)?;
738 let repr = op.axes.available_label();
739 let new_axes = op
740 .axes
741 .clone()
742 .with_extra_axis(repr, InOut::In(input_to_fix), 0)?
743 .with_extra_axis_occurency(repr, InOut::Out(0), 0)?;
744 wire[input_to_fix] =
745 patch.wire_node(format!("{name}.add_{label}"), AxisOp::Add(0), &[wire[input_to_fix]])?[0];
746 wire = patch.wire_node(name, EinSum { axes: new_axes, ..op.clone() }, &wire)?;
747 wire = patch.wire_node(&node.name, AxisOp::Rm(0), &wire)?;
748 patch.shunt_outside(model, node.id.into(), wire[0])?;
749 Ok(patch)
750}
751
752fn wire_axes_fix(
753 patch: &mut TypedModelPatch,
754 name: &str,
755 var: &str,
756 mapping: &AxesMapping,
757 mut outlet: TVec<OutletId>,
758) -> TractResult<TVec<OutletId>> {
759 for (ix, axis_op) in mapping.translate_to_axis_ops()?.into_iter().enumerate() {
760 outlet = patch.wire_node(format!("{name}.fix_{var}.{ix})"), axis_op, &outlet)?;
761 }
762 Ok(outlet)
763}
764
765fn dequant(
766 model: &TypedModel,
767 node: &TypedNode,
768 op: &EinSumMatMul,
769) -> TractResult<TypedModelPatch> {
770 let name = &node.name;
771 let mut patch = TypedModelPatch::new("Dequantizing einsum");
772
773 let k_axis = op.k_axis();
774
775 let mut taps = patch.taps(model, &node.inputs)?;
776 for ab in [0, 1] {
777 let scale_input = 4 + ab * 2;
778 if !patch.outlet_fact(taps[scale_input])?.shape.volume().is_one() {
779 let q_axis_in_output = op.axes.axis((InOut::In(scale_input), 0))?.outputs[0][0];
780 let output_rank = node.outputs[0].fact.rank();
781 for i in 1..(output_rank - q_axis_in_output) {
782 taps[scale_input] = patch.wire_node(
783 format!("{name}.scale_input{ab}_axis_fix_{i}"),
784 AxisOp::Add(i),
785 &[taps[scale_input]],
786 )?[0];
787 }
788 }
789 }
790
791 let [mut a, mut b, bias, mut a0, a_scale, mut b0, b_scale, c0, c_scale] = *taps else {
792 bail!("Expect exactly 9 inputs")
793 };
794
795 wire_ensure_q8_flavour(&mut patch, &node.name, &mut a, "a", &mut a0, i8::datum_type())?;
796 wire_ensure_q8_flavour(&mut patch, &node.name, &mut b, "b", &mut b0, i8::datum_type())?;
797
798 let mut output = patch.wire_node(
799 &node.name,
800 EinSum {
801 q_params: None,
802 axes: op.axes.extract_sub_mapping(&[0, 1], &[0])?,
803 operating_dt: op.operating_dt,
804 },
805 &[a, b],
806 )?;
807
808 let a_i32 = patch.wire_node(format!("{name}.a_as_i32"), cast(i32::datum_type()), &[a])?[0];
809 let b_i32 = patch.wire_node(format!("{name}.b_as_i32"), cast(i32::datum_type()), &[b])?[0];
810 let sum_a = patch.wire_node(
811 format!("{name}.sum_a"),
812 Reduce::new(tvec!(k_axis.inputs[0][0]), Reducer::Sum),
813 &[a_i32],
814 )?;
815 let sum_b = patch.wire_node(
816 format!("{name}.sum_b"),
817 Reduce::new(tvec!(k_axis.inputs[1][0]), Reducer::Sum),
818 &[b_i32],
819 )?;
820
821 let sum_a =
822 wire_axes_fix(&mut patch, name, "sum_a", &op.axes.extract_sub_mapping(&[0], &[0])?, sum_a)?;
823 let sum_b =
824 wire_axes_fix(&mut patch, name, "sum_b", &op.axes.extract_sub_mapping(&[1], &[0])?, sum_b)?;
825 let bias = tvec!(bias);
826 let bias =
827 wire_axes_fix(&mut patch, name, "bias", &op.axes.extract_sub_mapping(&[2], &[0])?, bias)?;
828
829 let abc_scale = combine_scales(&mut patch, name, a_scale, b_scale, c_scale)?;
830
831 output = patch.wire_node(format!("{name}.add_bias"), add(), &[output[0], bias[0]])?;
832
833 let k = model.outlet_fact(node.inputs[0])?.shape[k_axis.inputs[0][0]].clone();
834 let output = compensate_zero_points(&mut patch, name, output[0], k, a0, b0, sum_a[0], sum_b[0])
835 .context("Zero point compensation")?;
836 let output = requant(&mut patch, name, output, op.q_params.unwrap(), abc_scale, c0)?;
837 patch.shunt_outside(model, node.id.into(), output)?;
838 Ok(patch)
839}
840
841fn flatten_rule(
842 _ctx: &(),
843 model: &TypedModel,
844 node: &TypedNode,
845 _name: &str,
846 op: &EinSumMatMul,
847) -> TractResult<Option<TypedModelPatch>> {
848 TypedModelPatch::replace_single_op(model, node, &node.inputs, op.op.clone()).map(Some)
849}
850
851fn optimized_mat_mul(
852 model: &TypedModel,
853 node: &TypedNode,
854 op: &EinSumMatMul,
855) -> TractResult<Option<TypedModelPatch>> {
856 let (mode_picker, left_pack, impls) = kernel_selection::strategize(model, node, op)?;
857 let input_facts = model.node_input_facts(node.id)?;
858 let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
859 let prefix = &node.name;
860
861 let mut patch = TypedModelPatch::new("Einsum to OptMatMul");
862 let taps = patch.taps(model, &node.inputs)?;
863 let name = &node.name;
864
865 let pack_a: Box<dyn TypedOp> = if input_facts[0].konst.is_some() {
866 if let Some(packed_format) = left_pack.downcast_ref::<PackedBlockQuantFormat>().cloned() {
867 Box::new(OptSimpleMatMulPack {
868 packed_format,
869 k: input_shapes[0][op.a_k()].to_usize().unwrap(),
870 m: input_shapes[0][op.a_m()].to_usize().unwrap(),
871 })
872 } else {
873 Box::new(OptMatMulPack {
876 packers: vec![left_pack],
877 mode_picker: ModePicker::Single,
878 k_axis: op.a_k(),
879 mn_axis: op.a_m(),
880 })
881 }
882 } else {
883 Box::new(OptMatMulPack {
884 packers: impls
885 .iter()
886 .map(|(mmm, p, pe)| {
887 pe.as_ref()
888 .map(|pe| pe.from.clone())
889 .unwrap_or_else(|| mmm.packings()[*p].0.clone())
890 })
891 .collect(),
892 mode_picker: mode_picker.clone(),
893 k_axis: op.a_k(),
894 mn_axis: op.a_m(),
895 })
896 };
897 let pa = patch.wire_node(format!("{prefix}.pack_a"), pack_a, &[taps[0]])?[0];
898
899 let pb = patch.wire_node(
900 format!("{prefix}.pack_b"),
901 OptMatMulPack {
902 k_axis: op.b_k(),
903 mn_axis: op.b_n(),
904 packers: impls.iter().map(|(mmm, p, _)| mmm.packings()[*p].1.clone()).collect(),
905 mode_picker: mode_picker.clone(),
906 },
907 &[taps[1]],
908 )?[0];
909
910 let mut c_to_a_axis_mapping = tvec!();
911 let mut c_to_b_axis_mapping = tvec!();
912 for axis in op
913 .op
914 .axes
915 .iter_all_axes()
916 .filter(|&axis| ![op.m_axis, op.k_axis, op.n_axis].contains(&axis.repr))
917 {
918 if let (&[c], &[a]) = (&*axis.outputs[0], &*axis.inputs[0])
919 && input_shapes[0][a] != 1.to_dim()
920 {
921 let a = a - (a > op.a_m()) as usize - (a > op.a_k()) as usize;
922 c_to_a_axis_mapping.push((c, a));
923 }
924 if let (&[c], &[b]) = (&*axis.outputs[0], &*axis.inputs[1])
925 && input_shapes[1][b] != 1.to_dim()
926 {
927 let b = b - (b > op.b_n()) as usize - (b > op.b_k()) as usize;
928 c_to_b_axis_mapping.push((c, b));
929 }
930 }
931
932 let c_fact = op.output_facts(&input_facts)?.remove(0);
933 let geo = AddMatMulGeometry {
934 k: op.k.clone(),
935 c_to_a_axis_mapping: MapOutputAxisToInput(c_to_a_axis_mapping),
936 c_to_b_axis_mapping: MapOutputAxisToInput(c_to_b_axis_mapping),
937 };
938 let (mmms, packings, extractor): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(impls);
939 let outputs = mmms.iter().map(|mmm| unsafe { mmm.c_view(op.c_m(), op.c_n()) }).collect();
940 let trivial_packing = mmms.len() == 1
941 && packings[0] == 0
942 && extractor[0].is_none()
943 && input_facts[0].exotic_fact.is_none();
944 let opt = OptMatMul::new(
945 mmms,
946 mode_picker,
947 c_fact,
948 op.c_m(),
949 op.c_n(),
950 vec![
951 ProtoFusedSpec::AddMatMul {
952 geo,
953 a: 0,
954 b: 1,
955 packings: izip!(packings, extractor).collect_vec(),
956 },
957 ProtoFusedSpec::Store(outputs),
958 ],
959 trivial_packing,
960 )
961 .context("Creating OptMatMul")?;
962 let output = patch.wire_node(name, opt, &[pa, pb])?[0];
963 patch.shunt_outside(model, node.id.into(), output)?;
964 Ok(Some(patch))
965}