1#![allow(clippy::type_complexity)]
2
3use dyn_clone::clone_box;
4use tract_itertools::Itertools;
5use tract_linalg::block_quant::BlockQuantFact;
6use tract_linalg::mmm::{ImplementationQuality, MMMInputFormat, MatMatMul, PanelExtractor};
7use tract_linalg::WeightType;
8
9use crate::internal::*;
10use crate::ops::matmul::ModePicker;
11
12use super::einsum_matmul::EinSumMatMul;
13
14pub type Impl = (Box<dyn MatMatMul>, usize, Option<PanelExtractor>);
15pub type Strat = (ModePicker, Box<dyn MMMInputFormat>, Vec<Impl>);
16
17fn single_strat(it: Impl) -> Strat {
18 (ModePicker::Single, it.0.packings()[it.1].0.clone(), vec![it])
19}
20
21pub fn strategize(model: &TypedModel, node: &TypedNode, op: &EinSumMatMul) -> TractResult<Strat> {
22 let input_facts = model.node_input_facts(node.id)?;
23 if let (Some(m), Some(k), Some(n)) = (op.m.as_i64(), op.k.as_i64(), op.n.as_i64()) {
24 if op.op.operating_dt == input_facts[0].datum_type
25 && op.op.operating_dt == input_facts[1].datum_type
26 {
27 if let Some(mmm) = tract_linalg::ops().mmm(
28 op.operating_dt,
29 Some(m as usize),
30 Some(k as usize),
31 Some(n as usize),
32 ) {
33 if mmm.quality() == ImplementationQuality::ManuallyOptimized {
34 return Ok((
35 ModePicker::Single,
36 mmm.packings()[0].0.clone(),
37 vec![(mmm, 0, None)],
38 ));
39 }
40 }
41 };
42 }
43
44 let mut impls = list_impls(model, node, op)?;
45 ensure!(impls.len() > 0);
46 fn score(mmm: &dyn MatMatMul) -> isize {
47 -(mmm.quality().cost() as isize * 1000) + mmm.dynamic_boost()
48 }
49 let wanted_quality = impls.iter().map(|(mmm, _, _)| score(&**mmm)).max().unwrap();
50 impls.retain(|(mmm, _, _)| score(&**mmm) == wanted_quality);
51 if impls.len() == 1 {
52 return Ok(single_strat(impls.remove(0)));
53 }
54 if op.n.is_one() {
55 let it =
56 impls.into_iter().max_by_key(|(m, _, pe)| (m.nr() == 1, pe.is_none(), m.mr())).unwrap();
57 return Ok(single_strat(it));
58 }
59 if op.n.as_i64().is_some_and(|n| n > 1) {
60 let it =
61 impls.into_iter().max_by_key(|(m, _, pe)| (pe.is_none(), m.nr() * m.mr())).unwrap();
62 return Ok(single_strat(it));
63 }
64 let mut grouped_by_left_packing = Vec::<(&dyn MMMInputFormat, Vec<_>)>::new();
65 'mmm: for (m, p, pe) in &impls {
66 let left_packing: &dyn MMMInputFormat =
67 pe.as_ref().map(|pe| &*pe.from).unwrap_or(&*m.packings()[*p].0);
68 for kit in &mut grouped_by_left_packing {
69 if let Some(merged) = kit.0.merge_with(left_packing) {
70 kit.0 = merged;
71 kit.1.push((m, p, pe));
72 continue 'mmm;
73 }
74 }
75 grouped_by_left_packing.push((left_packing, vec![(m, p, pe)]));
76 }
77 let (p, mmv, mmm) = grouped_by_left_packing
78 .iter()
79 .map(|(p, kit)| {
80 let best_for_mmv =
81 kit.iter().max_by_key(|(m, _, pe)| (m.nr() == 1, pe.is_none())).unwrap();
82 let best_for_mmm = kit.iter().max_by_key(|(m, _, _)| m.nr()).unwrap();
83 (p, best_for_mmv, best_for_mmm)
84 })
85 .max_by_key(|(_, mmv, mmm)| {
86 (mmv.0.nr() == 1 && mmm.0.nr() > 1, mmv.2.is_none(), mmm.0.mr(), mmm.0.nr())
87 })
88 .unwrap();
89
90 if mmm == mmv {
91 Ok((ModePicker::Single, clone_box(*p), vec![(mmv.0.clone(), *mmv.1, mmv.2.clone())]))
92 } else {
93 Ok((
94 ModePicker::VecVsMat,
95 clone_box(*p),
96 vec![(mmv.0.clone(), *mmv.1, mmv.2.clone()), (mmm.0.clone(), *mmm.1, mmm.2.clone())],
97 ))
98 }
99}
100
101pub fn list_impls(
102 model: &TypedModel,
103 node: &TypedNode,
104 op: &EinSumMatMul,
105) -> TractResult<Vec<Impl>> {
106 let (a_fact, b_fact) = model.node_input_facts(node.id)?.into_iter().collect_tuple().unwrap();
107 let a_dt = a_fact.datum_type;
108 let b_dt = b_fact.datum_type;
109
110 let a_weight: WeightType = if let Some(of) = a_fact.opaque_fact() {
111 if let Some(bqf) = of.downcast_ref::<BlockQuantFact>() {
112 WeightType::BlockQuant(bqf.format.clone())
113 } else {
114 bail!("Can not translate to matmul operand {a_fact:?}");
115 }
116 } else {
117 a_dt.into()
118 };
119
120 let impls = tract_linalg::ops()
121 .mmm_impls()
122 .iter()
123 .filter(|mmm| {
124 op.acceptable_accumulators().contains(&mmm.internal_type())
125 && mmm.stores().contains(&op.operating_dt.unquantized())
126 })
127 .flat_map(move |mmm| {
128 mmm.packings().iter().enumerate().map(|(ix, p)| (mmm.clone(), ix, &p.0, &p.1))
129 })
130 .filter_map(|(m, p, pa, pb)| {
131 if pb.precursor().as_dt().is_none_or(|dt| dt != b_dt.unquantized()) {
132 return None;
133 }
134 if pa.precursor() == a_weight {
135 Some((m, p, None))
136 } else {
137 tract_linalg::ops()
138 .panel_extractors()
139 .iter()
140 .find(|pe| pe.from.precursor() == a_weight && pe.to.same_as(&**pa))
141 .map(|pe| (m, p, Some(pe.clone())))
142 }
143 })
144 .collect_vec();
145 Ok(impls)
146}