1use std::fmt::Formatter;
2use std::ops::Deref;
3
4use tract_itertools::{izip, multiunzip};
5use tract_linalg::block_quant::PackedBlockQuantFormat;
6use tract_linalg::pack::PackedFormat;
7
8use super::*;
9use crate::ops::cast::cast;
10use crate::ops::math::add;
11use crate::ops::matmul::ModePicker;
12use crate::ops::matmul::optimized::{
13 AddMatMulGeometry, MapOutputAxisToInput, OptMatMul, ProtoFusedSpec,
14};
15use crate::ops::matmul::pack::{OptMatMulPack, OptSimpleMatMulPack};
16use crate::ops::matmul::quant::{
17 combine_scales, compensate_zero_points, requant, wire_ensure_q8_flavour,
18};
19use crate::ops::nn::{Reduce, Reducer};
20
21pub fn merge_consecutive_same_role_axes(model: &mut TypedModel) -> TractResult<()> {
22 Rewriter::default()
23 .with_rule_for("merge-same-role-axes", merge_same_role_axes_rule)
24 .rewrite(&(), model)
25}
26
27fn merge_same_role_axes_rule(
28 _ctx: &(),
29 model: &TypedModel,
30 node: &TypedNode,
31 node_name: &str,
32 op: &EinSum,
33) -> TractResult<Option<TypedModelPatch>> {
34 rule_if!(node.inputs.len() == 2);
36
37 type Role = (bool, bool, bool);
39 let axes: Vec<(char, Role)> = op
40 .axes
41 .iter_all_axes()
42 .map(|a| {
43 (a.repr, (!a.inputs[0].is_empty(), !a.inputs[1].is_empty(), !a.outputs[0].is_empty()))
44 })
45 .collect();
46
47 let a_order: Vec<char> = op.axes.axes(InOut::In(0)).map(|a| a.repr).collect();
49 let b_order: Vec<char> = op.axes.axes(InOut::In(1)).map(|a| a.repr).collect();
50 let c_order: Vec<char> = op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
51
52 let role_map: std::collections::HashMap<char, Role> = axes.iter().cloned().collect();
55 let mut best_group: Option<Vec<char>> = None;
56
57 let all_orders = [&a_order, &b_order];
59 for (primary_idx, primary_order) in all_orders.iter().enumerate() {
60 let mut i = 0;
61 while i < primary_order.len() {
62 let first = primary_order[i];
63 let first_role = role_map[&first];
64 let mut group = vec![first];
65 let mut j = i + 1;
66 while j < primary_order.len() {
67 let candidate = primary_order[j];
68 if role_map[&candidate] != first_role {
69 break;
70 }
71 let consecutive_in_others = all_orders
73 .iter()
74 .enumerate()
75 .filter(|(idx, _)| *idx != primary_idx)
76 .all(|(_, order)| {
77 let positions: Vec<usize> = group
78 .iter()
79 .chain(std::iter::once(&candidate))
80 .filter_map(|c| order.iter().position(|x| x == c))
81 .collect();
82 if positions.len() <= 1 {
83 return true;
84 }
85 let mut sorted = positions.clone();
86 sorted.sort();
87 sorted == positions
88 && sorted.last().unwrap() - sorted.first().unwrap() == sorted.len() - 1
89 });
90 if !consecutive_in_others {
91 break;
92 }
93 group.push(candidate);
94 j += 1;
95 }
96 if group.len() >= 2 && best_group.as_ref().map_or(true, |bg| group.len() > bg.len()) {
97 best_group = Some(group);
98 }
99 i = j;
100 }
101 }
102
103 if let Some(group) = best_group {
104 let input_facts = model.node_input_facts(node.id)?;
106 let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
107 let output_shape = super::eval::output_shape(&op.axes, &input_shapes)?;
108
109 let drop_set: Vec<char> = group[1..].to_vec();
110
111 let mut patch = TypedModelPatch::default();
112 let mut wires: TVec<OutletId> = patch.taps(model, &node.inputs)?;
113
114 for (slot, order) in [(0, &a_order), (1, &b_order)] {
116 let positions: Vec<usize> =
117 group.iter().filter_map(|c| order.iter().position(|x| x == c)).collect();
118 if positions.len() < 2 {
119 continue;
120 }
121 let start = positions[0];
122 let from_dims: TVec<TDim> =
123 positions.iter().map(|&p| input_shapes[slot][p].clone()).collect();
124 let merged: TDim = from_dims.iter().product();
125 wires[slot] = patch.wire_node(
126 format!("{node_name}.merge_in{slot}"),
127 AxisOp::Reshape(start, from_dims, tvec![merged]),
128 &[wires[slot]],
129 )?[0];
130 }
131
132 let c_positions: Vec<usize> =
134 group.iter().filter_map(|c| c_order.iter().position(|x| x == c)).collect();
135 let c_needs_reorder = c_positions.len() >= 2 && {
136 let mut sorted = c_positions.clone();
137 sorted.sort();
138 sorted.last().unwrap() - sorted.first().unwrap() != sorted.len() - 1
139 || sorted != c_positions
140 };
141 let mut adjusted_c_order = c_order.clone();
142 if c_needs_reorder {
143 for k in 1..c_positions.len() {
145 let cur_pos = adjusted_c_order.iter().position(|&c| c == group[k]).unwrap();
146 let target_pos =
147 adjusted_c_order.iter().position(|&c| c == group[k - 1]).unwrap() + 1;
148 if cur_pos != target_pos {
149 let removed = adjusted_c_order.remove(cur_pos);
150 let insert_at = if cur_pos < target_pos { target_pos - 1 } else { target_pos };
151 adjusted_c_order.insert(insert_at, removed);
152 }
153 }
154 }
155
156 let in0: String = a_order.iter().collect();
158 let in1: String = b_order.iter().collect();
159 let out: String = adjusted_c_order.iter().collect();
160 let expr = format!("{in0},{in1}->{out}");
161 let mut new_axes: AxesMapping = expr.parse()?;
162 for &drop in &drop_set {
163 new_axes = new_axes.remove_axis(drop)?;
164 }
165 let new_op =
166 EinSum { axes: new_axes, operating_dt: op.operating_dt, q_params: op.q_params };
167 let mut result = patch.wire_node(node_name, new_op, &wires)?;
168
169 let merged_c_positions: Vec<usize> =
171 group.iter().filter_map(|c| adjusted_c_order.iter().position(|x| x == c)).collect();
172 if merged_c_positions.len() >= 2 {
173 let start = merged_c_positions[0];
174 let original_c_positions: Vec<usize> =
176 group.iter().filter_map(|c| c_order.iter().position(|x| x == c)).collect();
177 let original_dims: TVec<TDim> =
178 original_c_positions.iter().map(|&p| output_shape[p].clone()).collect();
179 let merged: TDim = original_dims.iter().product();
180 result[0] = patch.wire_node(
181 format!("{node_name}.unmerge_out"),
182 AxisOp::Reshape(start, tvec![merged], original_dims),
183 &[result[0]],
184 )?[0];
185 }
186
187 if c_needs_reorder {
189 let mut unmerged_adj: Vec<char> = Vec::new();
193 for &c in &adjusted_c_order {
194 if c == group[0] {
195 unmerged_adj.extend(&group);
196 } else if !group.contains(&c) {
197 unmerged_adj.push(c);
198 }
199 }
200 for target_pos in 0..c_order.len() {
202 let cur_pos = unmerged_adj.iter().position(|&c| c == c_order[target_pos]).unwrap();
203 if cur_pos != target_pos {
204 result[0] = patch.wire_node(
205 format!("{node_name}.restore_out_{target_pos}"),
206 AxisOp::Move(cur_pos, target_pos),
207 &[result[0]],
208 )?[0];
209 let removed = unmerged_adj.remove(cur_pos);
210 unmerged_adj.insert(target_pos, removed);
211 }
212 }
213 }
214
215 patch.shunt_outside(model, node.id.into(), result[0])?;
216 return Ok(Some(patch));
217 }
218
219 let k_role: Role = (true, true, false); let role_of = |c: char| axes.iter().find(|(ch, _)| *ch == c).map(|(_, r)| *r);
223
224 for (slot, order) in [(0usize, &a_order), (1, &b_order)] {
225 for w in order.windows(3) {
228 let (left, mid, right) = (w[0], w[1], w[2]);
229 let left_role = role_of(left);
230 let mid_role = role_of(mid);
231 let right_role = role_of(right);
232 if left_role != right_role || mid_role != Some(k_role) {
233 continue;
234 }
235 let other_input_orders: Vec<&Vec<char>> = [(0, &a_order), (1, &b_order)]
238 .iter()
239 .filter(|(s, _)| *s != slot)
240 .map(|(_, o)| *o)
241 .collect();
242 let consecutive_elsewhere = other_input_orders.iter().all(|order| {
243 let lp = order.iter().position(|&c| c == left);
244 let rp = order.iter().position(|&c| c == right);
245 match (lp, rp) {
246 (Some(l), Some(r)) => r == l + 1,
247 _ => true, }
249 });
250 if !consecutive_elsewhere {
251 continue;
252 }
253
254 let mid_pos = order.iter().position(|&c| c == mid).unwrap();
257 let end_pos = order.len() - 1;
258 if mid_pos == end_pos {
259 continue;
260 }
261
262 let move_op = AxisOp::Move(mid_pos, end_pos);
264 let Some(AxisChangeConsequence { substitute_op, .. }) =
265 op.change_axes(model, node, InOut::In(slot), &move_op)?
266 else {
267 continue;
268 };
269 let mut current_op = *substitute_op
270 .unwrap()
271 .downcast::<EinSum>()
272 .map_err(|_| anyhow!("expected EinSum"))?;
273
274 let new_c: Vec<char> = current_op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
276 let left_c = new_c.iter().position(|&c| c == left);
277 let right_c = new_c.iter().position(|&c| c == right);
278 let need_output_fix = matches!((left_c, right_c), (Some(l), Some(r)) if r != l + 1);
279 if need_output_fix {
280 let r_pos = right_c.unwrap();
281 let l_pos = left_c.unwrap();
282 let target = if r_pos < l_pos { l_pos } else { l_pos + 1 };
283 if let Some(AxisChangeConsequence { substitute_op, .. }) = current_op.change_axes(
284 model,
285 node,
286 InOut::Out(0),
287 &AxisOp::Move(r_pos, target),
288 )? {
289 current_op = *substitute_op
290 .unwrap()
291 .downcast::<EinSum>()
292 .map_err(|_| anyhow!("expected EinSum"))?;
293 }
294 }
295
296 let mut patch = TypedModelPatch::default();
297 let mut wires: TVec<OutletId> = patch.taps(model, &node.inputs)?;
298
299 wires[slot] =
300 patch.wire_node(format!("{node_name}.move_k_in{slot}"), move_op, &[wires[slot]])?
301 [0];
302
303 let final_c: Vec<char> = current_op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
304 let mut result = patch.wire_node(node_name, current_op, &wires)?;
305
306 if need_output_fix {
308 let r_cur = final_c.iter().position(|&c| c == right).unwrap();
309 let r_orig = c_order.iter().position(|&c| c == right).unwrap();
310 if r_cur != r_orig {
311 result[0] = patch.wire_node(
312 format!("{node_name}.restore_out"),
313 AxisOp::Move(r_cur, r_orig),
314 &[result[0]],
315 )?[0];
316 }
317 }
318
319 patch.shunt_outside(model, node.id.into(), result[0])?;
320 return Ok(Some(patch));
321 }
322 }
323
324 Ok(None)
325}
326
327pub fn detect_all(model: &mut TypedModel) -> TractResult<()> {
328 Rewriter::default().with_rule_for("detect-matmul-einsum", detect_rule).rewrite(&(), model)
329}
330
331pub fn flatten_all(model: &mut TypedModel) -> TractResult<()> {
332 Rewriter::default().with_rule_for("flatten-matmul-einsum", flatten_rule).rewrite(&(), model)
333}
334
335#[derive(Clone, Hash, PartialEq, Eq)]
336pub struct EinSumMatMul {
337 pub op: EinSum,
338 pub m_axis: char,
339 pub k_axis: char,
340 pub n_axis: char,
341 pub m: TDim,
342 pub k: TDim,
343 pub n: TDim,
344}
345
346impl EinSumMatMul {
347 pub fn m_axis(&self) -> &Axis {
348 self.op.axes.axis(self.m_axis).unwrap()
349 }
350 pub fn k_axis(&self) -> &Axis {
351 self.op.axes.axis(self.k_axis).unwrap()
352 }
353 pub fn n_axis(&self) -> &Axis {
354 self.op.axes.axis(self.n_axis).unwrap()
355 }
356 pub fn a_m(&self) -> usize {
357 self.m_axis().inputs[0][0]
358 }
359 pub fn a_k(&self) -> usize {
360 self.k_axis().inputs[0][0]
361 }
362 pub fn b_k(&self) -> usize {
363 self.k_axis().inputs[1][0]
364 }
365 pub fn b_n(&self) -> usize {
366 self.n_axis().inputs[1][0]
367 }
368 pub fn c_m(&self) -> Option<usize> {
369 self.m_axis().outputs[0].first().cloned()
370 }
371 pub fn c_n(&self) -> Option<usize> {
372 self.n_axis().outputs[0].first().cloned()
373 }
374
375 fn new(
376 op: EinSum,
377 m_axis: char,
378 k_axis: char,
379 n_axis: char,
380 m: TDim,
381 k: TDim,
382 n: TDim,
383 ) -> Self {
384 Self { op, m_axis, k_axis, n_axis, m, k, n }
385 }
386}
387
388impl Debug for EinSumMatMul {
389 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
390 write!(
391 f,
392 "EinsumMatMul: {} {:?} m: {}={}; k: {}={}; n: {}={}",
393 self.op.axes,
394 self.op.operating_dt,
395 self.m_axis,
396 self.m,
397 self.k_axis,
398 self.k,
399 self.n_axis,
400 self.n
401 )
402 }
403}
404
405impl Deref for EinSumMatMul {
406 type Target = EinSum;
407 fn deref(&self) -> &Self::Target {
408 &self.op
409 }
410}
411
412impl Op for EinSumMatMul {
413 fn name(&self) -> StaticName {
414 "EinSumMatMul".into()
415 }
416
417 op_as_typed_op!();
418}
419
420impl EvalOp for EinSumMatMul {
421 fn is_stateless(&self) -> bool {
422 true
423 }
424 fn eval_with_session(
425 &self,
426 node_id: usize,
427 session: &TurnState,
428 inputs: TVec<TValue>,
429 ) -> TractResult<TVec<TValue>> {
430 self.op.eval_with_session(node_id, session, inputs)
431 }
432}
433
434impl TypedOp for EinSumMatMul {
435 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
436 self.op.output_facts(inputs)
437 }
438
439 fn codegen(
440 &self,
441 model: &TypedModel,
442 node: &TypedNode,
443 ) -> TractResult<Option<TypedModelPatch>> {
444 if node.inputs.len() == 9 {
446 ensure!(self.op.q_params.is_some());
447 return dequant(model, node, self).map(Some);
448 }
449 ensure!(node.inputs.len() == 2);
450 let (a, b) = model.node_input_facts(node.id)?.into_iter().collect_tuple().unwrap();
451 let must_transpose = if let Some(of) = a.exotic_fact() {
453 ensure!(of.is::<BlockQuantFact>());
454 false
455 } else if let Some(of) = b.exotic_fact() {
456 ensure!(of.is::<BlockQuantFact>());
457 true
458 } else if self.m == self.n {
459 false
460 } else {
461 match (self.m.as_i64(), self.n.as_i64()) {
462 (Some(m), Some(n)) => m < n,
463 (None, Some(n)) => n >= 8,
464 (Some(_), _) => false,
465 _ => (self.n.clone() - &self.m).prove_positive_or_zero(),
466 }
467 };
468 if must_transpose {
469 let mut op = self.clone();
470 op.op.axes.iter_all_axes_mut().for_each(|axis| axis.inputs.swap(0, 1));
471 std::mem::swap(&mut op.m_axis, &mut op.n_axis);
472 std::mem::swap(&mut op.m, &mut op.n);
473 return TypedModelPatch::replace_single_op(
474 model,
475 node,
476 &[node.inputs[1], node.inputs[0]],
477 op,
478 )
479 .map(|p| Some(p.with_context("transposing")));
480 }
481 if self.c_m().is_some() || self.c_n().is_some() {
483 return optimized_mat_mul(model, node, self)
484 .map(|opt| opt.map(|p| p.with_context("optimizing")));
485 }
486 Ok(None)
487 }
488
489 as_op!();
490}
491
492pub(crate) fn detect_rule(
493 _ctx: &(),
494 model: &TypedModel,
495 node: &TypedNode,
496 _name: &str,
497 op: &EinSum,
498) -> TractResult<Option<TypedModelPatch>> {
499 rule_if!(node.inputs.len() == (2 + op.q_params.is_some() as usize * 7));
500 let input_facts = model.node_input_facts(node.id)?;
501 let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
502 let output_shape = super::eval::output_shape(&op.axes, &input_shapes)?;
503 let k_axes: TVec<&Axis> = op
504 .axes
505 .iter_all_axes()
506 .filter(|a| a.inputs[0].len() == 1 && a.inputs[1].len() == 1 && a.outputs[0].is_empty())
508 .collect();
509
510 let non_trivial_k_axis = k_axes
511 .iter()
512 .filter(|a| {
513 !input_shapes[0][a.inputs[0][0]].is_one() || !input_shapes[1][a.inputs[1][0]].is_one()
514 })
515 .copied()
516 .collect::<TVec<_>>();
517
518 let k_axis = if non_trivial_k_axis.len() > 1 {
519 return regroup_k_axes(op, model, node, non_trivial_k_axis);
520 } else {
521 non_trivial_k_axis.first().or_else(|| k_axes.first()).copied()
522 };
523 let Some(k_axis) = k_axis else { return inject_k_axis(op, model, node).map(Some) };
524
525 let mut possible_m_axes: Vec<_> = op
526 .axes
527 .iter_all_axes()
528 .filter(|a| {
529 a.inputs[0].len() == 1
530 && (a.inputs[1].is_empty() || input_shapes[1][a.inputs[1][0]].is_one())
531 && (a.outputs[0].len() == 1
532 || (input_shapes[0][a.inputs[0][0]].is_one() && a.inputs[1].is_empty()))
533 })
534 .collect();
535
536 if possible_m_axes.iter().any(|a| !a.outputs[0].is_empty()) {
538 possible_m_axes.retain(|a| !a.outputs[0].is_empty());
539 }
540
541 let m_axis = possible_m_axes
542 .into_iter()
543 .max_by_key(|a| input_shapes[0][a.inputs[0][0]].as_i64().unwrap_or(i64::MAX));
544
545 let Some(m_axis) = m_axis else {
546 return inject_m_or_n_axis(op, model, node, false).map(Some);
547 };
548
549 let n_axis = op
550 .axes
551 .iter_all_axes()
552 .filter(|a| {
553 (a.inputs[0].is_empty() || input_shapes[0][a.inputs[0][0]].is_one())
554 && a.inputs[1].len() == 1
555 && a.outputs[0].len() == 1
556 && *a != m_axis
557 })
558 .max_by_key(|a| input_shapes[1][a.inputs[1][0]].as_i64().unwrap_or(i64::MAX));
559 let Some(n_axis) = n_axis else {
560 return inject_m_or_n_axis(op, model, node, true).map(Some);
561 };
562 for axis in op.axes.iter_all_axes() {
563 let one = TDim::one();
564 let in_left =
565 axis.inputs[0].first().map(|pos| &input_shapes[0][*pos]).unwrap_or(&one) != &one;
566 let in_right =
567 axis.inputs[1].first().map(|pos| &input_shapes[1][*pos]).unwrap_or(&one) != &one;
568 let in_out = axis.outputs[0].first().map(|pos| &output_shape[*pos]).unwrap_or(&one) != &one;
569 if (in_left ^ in_right) && !in_out {
570 return Ok(None);
571 }
576 }
577 let m = input_shapes[0][m_axis.inputs[0][0]].clone();
578 let k = input_shapes[0][k_axis.inputs[0][0]].clone();
579 let n = input_shapes[1][n_axis.inputs[1][0]].clone();
580 TypedModelPatch::replace_single_op(
581 model,
582 node,
583 &node.inputs,
584 EinSumMatMul::new(op.clone(), m_axis.repr, k_axis.repr, n_axis.repr, m, k, n),
585 )
586 .map(Some)
587}
588
589pub(super) fn inject_k_axis(
590 op: &EinSum,
591 model: &TypedModel,
592 node: &TypedNode,
593) -> TractResult<TypedModelPatch> {
594 let mut new_axes = op.axes.clone();
595 let name = &node.name;
596 let mut patch = TypedModelPatch::new("inject k axis");
597 let mut wire = patch.taps(model, &node.inputs)?;
598 let repr = new_axes.available_label();
599 new_axes = new_axes.with_extra_axis(repr, InOut::In(0), 0)?.with_extra_axis_occurency(
600 repr,
601 InOut::In(1),
602 0,
603 )?;
604 wire[0] = patch.wire_node(format!("{name}.add_k.0"), AxisOp::Add(0), &[wire[0]])?[0];
605 wire[1] = patch.wire_node(format!("{name}.add_k.1"), AxisOp::Add(0), &[wire[1]])?[0];
606 wire = patch.wire_node(&node.name, EinSum { axes: new_axes, ..op.clone() }, &wire)?;
607 patch.shunt_outside(model, node.id.into(), wire[0])?;
608 Ok(patch)
609}
610
611pub(super) fn regroup_k_axes(
612 op: &EinSum,
613 model: &TypedModel,
614 node: &TypedNode,
615 mut k_axes: TVec<&Axis>,
616) -> TractResult<Option<TypedModelPatch>> {
617 let input_facts = model.node_input_facts(node.id)?;
618 let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
619 let contig_in_a = k_axes
620 .iter()
621 .map(|axis| axis.inputs[0][0])
622 .sorted()
623 .tuple_windows()
624 .all(|(a, b)| a + 1 == b);
625 if contig_in_a {
626 k_axes.sort_by_key(|ax| ax.inputs[0][0]);
627 } else {
628 k_axes.sort_by_key(|ax| ax.inputs[1][0]);
629 }
630 let k_dims: TVec<_> =
631 k_axes.iter().map(|ax| input_shapes[0][ax.inputs[0][0]].clone()).collect();
632 let k: TDim = k_dims.iter().product();
633 let mut patch = TypedModelPatch::default();
634 let mut wires = patch.taps(model, &node.inputs)?;
635 let mut exprs: Vec<String> =
636 (0..2).map(|slot| op.axes.axes(InOut::In(slot)).map(|ax| ax.repr).join("")).collect();
637 for slot in 0..2 {
638 if k_axes.iter().map(|ax| ax.inputs[slot][0]).tuple_windows().any(|(a, b)| a + 1 != b) {
639 let after = op
640 .axes
641 .axes(InOut::In(slot))
642 .filter(|ax| !k_axes.contains(ax))
643 .chain(k_axes.iter().copied())
644 .map(|ax| ax.repr)
645 .join("");
646 let transpose =
647 AxesMapping::from_strs(&[&exprs[slot]], &[&after])?.translate_to_axis_ops()?;
648 for (ix, op) in transpose.into_iter().enumerate() {
649 wires[slot] = patch.wire_node(
650 format!("{}.transpose_input_{}.{}", &node.name, slot, ix),
651 op,
652 &[wires[slot]],
653 )?[0];
654 }
655 exprs[slot] = after;
656 }
657 let pos = exprs[slot].chars().position(|c| k_axes[0].repr == c).unwrap();
658 wires[slot] = patch.wire_node(
659 format!("{}.fold_k_in_input_{}", &node.name, slot),
660 AxisOp::Reshape(pos, k_dims.clone(), tvec!(k.clone())),
661 &[wires[slot]],
662 )?[0];
663 exprs[slot] =
664 exprs[slot].chars().filter(|c| !k_axes.iter().any(|k| k.repr == *c)).collect();
665 exprs[slot].insert(pos, k_axes[0].repr);
666 }
667 let old = op.axes.to_string();
668 let (iexpr, oexpr) = old.split_once("->").unwrap();
669 let mut expr: String = exprs.iter().join(",");
670 if node.inputs.len() > 2 {
671 expr = expr + "," + &iexpr.split(",").skip(2).join(",");
672 }
673 expr = expr + "->" + oexpr;
674 let wire = patch.wire_node(
675 &node.name,
676 EinSum { axes: expr.parse().unwrap(), ..op.clone() },
677 &wires,
678 )?[0];
679 patch.shunt_outside(model, node.id.into(), wire)?;
680 Ok(Some(patch))
681}
682
683pub(super) fn inject_m_or_n_axis(
684 op: &EinSum,
685 model: &TypedModel,
686 node: &TypedNode,
687 is_n: bool,
688) -> TractResult<TypedModelPatch> {
689 let input_to_fix = is_n as usize;
690 let label = if is_n { "n" } else { "m" };
691 let name = &node.name;
692 let mut patch = TypedModelPatch::new("Injecting m or n axis");
693 let mut wire = patch.taps(model, &node.inputs)?;
694 let repr = op.axes.available_label();
695 let new_axes = op
696 .axes
697 .clone()
698 .with_extra_axis(repr, InOut::In(input_to_fix), 0)?
699 .with_extra_axis_occurency(repr, InOut::Out(0), 0)?;
700 wire[input_to_fix] =
701 patch.wire_node(format!("{name}.add_{label}"), AxisOp::Add(0), &[wire[input_to_fix]])?[0];
702 wire = patch.wire_node(name, EinSum { axes: new_axes, ..op.clone() }, &wire)?;
703 wire = patch.wire_node(&node.name, AxisOp::Rm(0), &wire)?;
704 patch.shunt_outside(model, node.id.into(), wire[0])?;
705 Ok(patch)
706}
707
708fn wire_axes_fix(
709 patch: &mut TypedModelPatch,
710 name: &str,
711 var: &str,
712 mapping: &AxesMapping,
713 mut outlet: TVec<OutletId>,
714) -> TractResult<TVec<OutletId>> {
715 for (ix, axis_op) in mapping.translate_to_axis_ops()?.into_iter().enumerate() {
716 outlet = patch.wire_node(format!("{name}.fix_{var}.{ix})"), axis_op, &outlet)?;
717 }
718 Ok(outlet)
719}
720
721fn dequant(
722 model: &TypedModel,
723 node: &TypedNode,
724 op: &EinSumMatMul,
725) -> TractResult<TypedModelPatch> {
726 let name = &node.name;
727 let mut patch = TypedModelPatch::new("Dequantizing einsum");
728
729 let k_axis = op.k_axis();
730
731 let mut taps = patch.taps(model, &node.inputs)?;
732 for ab in [0, 1] {
733 let scale_input = 4 + ab * 2;
734 if !patch.outlet_fact(taps[scale_input])?.shape.volume().is_one() {
735 let q_axis_in_output = op.axes.axis((InOut::In(scale_input), 0))?.outputs[0][0];
736 let output_rank = node.outputs[0].fact.rank();
737 for i in 1..(output_rank - q_axis_in_output) {
738 taps[scale_input] = patch.wire_node(
739 format!("{name}.scale_input{ab}_axis_fix_{i}"),
740 AxisOp::Add(i),
741 &[taps[scale_input]],
742 )?[0];
743 }
744 }
745 }
746
747 let [mut a, mut b, bias, mut a0, a_scale, mut b0, b_scale, c0, c_scale] = *taps else {
748 bail!("Expect exactly 9 inputs")
749 };
750
751 wire_ensure_q8_flavour(&mut patch, &node.name, &mut a, "a", &mut a0, i8::datum_type())?;
752 wire_ensure_q8_flavour(&mut patch, &node.name, &mut b, "b", &mut b0, i8::datum_type())?;
753
754 let mut output = patch.wire_node(
755 &node.name,
756 EinSum {
757 q_params: None,
758 axes: op.axes.extract_sub_mapping(&[0, 1], &[0])?,
759 operating_dt: op.operating_dt,
760 },
761 &[a, b],
762 )?;
763
764 let a_i32 = patch.wire_node(format!("{name}.a_as_i32"), cast(i32::datum_type()), &[a])?[0];
765 let b_i32 = patch.wire_node(format!("{name}.b_as_i32"), cast(i32::datum_type()), &[b])?[0];
766 let sum_a = patch.wire_node(
767 format!("{name}.sum_a"),
768 Reduce::new(tvec!(k_axis.inputs[0][0]), Reducer::Sum),
769 &[a_i32],
770 )?;
771 let sum_b = patch.wire_node(
772 format!("{name}.sum_b"),
773 Reduce::new(tvec!(k_axis.inputs[1][0]), Reducer::Sum),
774 &[b_i32],
775 )?;
776
777 let sum_a =
778 wire_axes_fix(&mut patch, name, "sum_a", &op.axes.extract_sub_mapping(&[0], &[0])?, sum_a)?;
779 let sum_b =
780 wire_axes_fix(&mut patch, name, "sum_b", &op.axes.extract_sub_mapping(&[1], &[0])?, sum_b)?;
781 let bias = tvec!(bias);
782 let bias =
783 wire_axes_fix(&mut patch, name, "bias", &op.axes.extract_sub_mapping(&[2], &[0])?, bias)?;
784
785 let abc_scale = combine_scales(&mut patch, name, a_scale, b_scale, c_scale)?;
786
787 output = patch.wire_node(format!("{name}.add_bias"), add(), &[output[0], bias[0]])?;
788
789 let k = model.outlet_fact(node.inputs[0])?.shape[k_axis.inputs[0][0]].clone();
790 let output = compensate_zero_points(&mut patch, name, output[0], k, a0, b0, sum_a[0], sum_b[0])
791 .context("Zero point compensation")?;
792 let output = requant(&mut patch, name, output, op.q_params.unwrap(), abc_scale, c0)?;
793 patch.shunt_outside(model, node.id.into(), output)?;
794 Ok(patch)
795}
796
797fn flatten_rule(
798 _ctx: &(),
799 model: &TypedModel,
800 node: &TypedNode,
801 _name: &str,
802 op: &EinSumMatMul,
803) -> TractResult<Option<TypedModelPatch>> {
804 TypedModelPatch::replace_single_op(model, node, &node.inputs, op.op.clone()).map(Some)
805}
806
807fn optimized_mat_mul(
808 model: &TypedModel,
809 node: &TypedNode,
810 op: &EinSumMatMul,
811) -> TractResult<Option<TypedModelPatch>> {
812 let (mode_picker, left_pack, impls) = kernel_selection::strategize(model, node, op)?;
813 let input_facts = model.node_input_facts(node.id)?;
814 let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
815 let prefix = &node.name;
816
817 let mut patch = TypedModelPatch::new("Einsum to OptMatMul");
818 let taps = patch.taps(model, &node.inputs)?;
819 let name = &node.name;
820
821 let pack_a: Box<dyn TypedOp> = if input_facts[0].konst.is_some() {
822 if let Some(pf) = left_pack.downcast_ref::<PackedFormat>() {
823 Box::new(OptMatMulPack {
824 packers: vec![pf.clone()],
825 mode_picker: ModePicker::Single,
826 k_axis: op.a_k(),
827 mn_axis: op.a_m(),
828 })
829 } else if let Some(packed_format) =
830 left_pack.downcast_ref::<PackedBlockQuantFormat>().cloned()
831 {
832 Box::new(OptSimpleMatMulPack {
833 packed_format,
834 k: input_shapes[0][op.a_k()].to_usize().unwrap(),
835 m: input_shapes[0][op.a_m()].to_usize().unwrap(),
836 })
837 } else {
838 bail!("Unexpected static input format {left_pack:?}");
839 }
840 } else {
841 Box::new(OptMatMulPack {
842 packers: impls
843 .iter()
844 .map(|(mmm, p, pe)| {
845 pe.as_ref()
846 .map(|pe| &pe.from)
847 .unwrap_or(&mmm.packings()[*p].0)
848 .downcast_ref::<PackedFormat>()
849 .unwrap()
850 .clone()
851 })
852 .collect(),
853 mode_picker: mode_picker.clone(),
854 k_axis: op.a_k(),
855 mn_axis: op.a_m(),
856 })
857 };
858 let pa = patch.wire_node(format!("{prefix}.pack_a"), pack_a, &[taps[0]])?[0];
859
860 let pb = patch.wire_node(
861 format!("{prefix}.pack_b"),
862 OptMatMulPack {
863 k_axis: op.b_k(),
864 mn_axis: op.b_n(),
865 packers: impls
866 .iter()
867 .map(|(mmm, p, _)| {
868 mmm.packings()[*p].1.downcast_ref::<PackedFormat>().unwrap().clone()
869 })
870 .collect(),
871 mode_picker: mode_picker.clone(),
872 },
873 &[taps[1]],
874 )?[0];
875
876 let mut c_to_a_axis_mapping = tvec!();
877 let mut c_to_b_axis_mapping = tvec!();
878 for axis in op
879 .op
880 .axes
881 .iter_all_axes()
882 .filter(|&axis| ![op.m_axis, op.k_axis, op.n_axis].contains(&axis.repr))
883 {
884 if let (&[c], &[a]) = (&*axis.outputs[0], &*axis.inputs[0])
885 && input_shapes[0][a] != 1.to_dim()
886 {
887 let a = a - (a > op.a_m()) as usize - (a > op.a_k()) as usize;
888 c_to_a_axis_mapping.push((c, a));
889 }
890 if let (&[c], &[b]) = (&*axis.outputs[0], &*axis.inputs[1])
891 && input_shapes[1][b] != 1.to_dim()
892 {
893 let b = b - (b > op.b_n()) as usize - (b > op.b_k()) as usize;
894 c_to_b_axis_mapping.push((c, b));
895 }
896 }
897
898 let c_fact = op.output_facts(&input_facts)?.remove(0);
899 let geo = AddMatMulGeometry {
900 k: op.k.clone(),
901 c_to_a_axis_mapping: MapOutputAxisToInput(c_to_a_axis_mapping),
902 c_to_b_axis_mapping: MapOutputAxisToInput(c_to_b_axis_mapping),
903 };
904 let (mmms, packings, extractor): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(impls);
905 let outputs = mmms.iter().map(|mmm| unsafe { mmm.c_view(op.c_m(), op.c_n()) }).collect();
906 let trivial_packing = mmms.len() == 1
907 && packings[0] == 0
908 && extractor[0].is_none()
909 && input_facts[0].exotic_fact.is_none();
910 let opt = OptMatMul::new(
911 mmms,
912 mode_picker,
913 c_fact,
914 op.c_m(),
915 op.c_n(),
916 vec![
917 ProtoFusedSpec::AddMatMul {
918 geo,
919 a: 0,
920 b: 1,
921 packings: izip!(packings, extractor).collect_vec(),
922 },
923 ProtoFusedSpec::Store(outputs),
924 ],
925 trivial_packing,
926 )
927 .context("Creating OptMatMul")?;
928 let output = patch.wire_node(name, opt, &[pa, pb])?[0];
929 patch.shunt_outside(model, node.id.into(), output)?;
930 Ok(Some(patch))
931}