tract_core/ops/einsum/
kernel_selection.rs

1#![allow(clippy::type_complexity)]
2
3use dyn_clone::clone_box;
4use tract_itertools::Itertools;
5use tract_linalg::block_quant::BlockQuantFact;
6use tract_linalg::mmm::{ImplementationQuality, MMMInputFormat, MatMatMul, PanelExtractor};
7use tract_linalg::WeightType;
8
9use crate::internal::*;
10use crate::ops::matmul::ModePicker;
11
12use super::einsum_matmul::EinSumMatMul;
13
14pub type Impl = (Box<dyn MatMatMul>, usize, Option<PanelExtractor>);
15pub type Strat = (ModePicker, Box<dyn MMMInputFormat>, Vec<Impl>);
16
17fn single_strat(it: Impl) -> Strat {
18    (ModePicker::Single, it.0.packings()[it.1].0.clone(), vec![it])
19}
20
21pub fn strategize(model: &TypedModel, node: &TypedNode, op: &EinSumMatMul) -> TractResult<Strat> {
22    let input_facts = model.node_input_facts(node.id)?;
23    if let (Some(m), Some(k), Some(n)) = (op.m.as_i64(), op.k.as_i64(), op.n.as_i64()) {
24        if op.op.operating_dt == input_facts[0].datum_type
25            && op.op.operating_dt == input_facts[1].datum_type
26        {
27            if let Some(mmm) = tract_linalg::ops().mmm(
28                op.operating_dt,
29                Some(m as usize),
30                Some(k as usize),
31                Some(n as usize),
32            ) {
33                if mmm.quality() == ImplementationQuality::ManuallyOptimized {
34                    return Ok((
35                        ModePicker::Single,
36                        mmm.packings()[0].0.clone(),
37                        vec![(mmm, 0, None)],
38                    ));
39                }
40            }
41        };
42    }
43
44    let mut impls = list_impls(model, node, op)?;
45    ensure!(impls.len() > 0);
46    fn score(mmm: &dyn MatMatMul) -> isize {
47        -(mmm.quality().cost() as isize * 1000) + mmm.dynamic_boost()
48    }
49    let wanted_quality = impls.iter().map(|(mmm, _, _)| score(&**mmm)).max().unwrap();
50    impls.retain(|(mmm, _, _)| score(&**mmm) == wanted_quality);
51    if impls.len() == 1 {
52        return Ok(single_strat(impls.remove(0)));
53    }
54    if op.n.is_one() {
55        let it =
56            impls.into_iter().max_by_key(|(m, _, pe)| (m.nr() == 1, pe.is_none(), m.mr())).unwrap();
57        return Ok(single_strat(it));
58    }
59    if op.n.as_i64().is_some_and(|n| n > 1) {
60        let it =
61            impls.into_iter().max_by_key(|(m, _, pe)| (pe.is_none(), m.nr() * m.mr())).unwrap();
62        return Ok(single_strat(it));
63    }
64    let mut grouped_by_left_packing = Vec::<(&dyn MMMInputFormat, Vec<_>)>::new();
65    'mmm: for (m, p, pe) in &impls {
66        let left_packing: &dyn MMMInputFormat =
67            pe.as_ref().map(|pe| &*pe.from).unwrap_or(&*m.packings()[*p].0);
68        for kit in &mut grouped_by_left_packing {
69            if let Some(merged) = kit.0.merge_with(left_packing) {
70                kit.0 = merged;
71                kit.1.push((m, p, pe));
72                continue 'mmm;
73            }
74        }
75        grouped_by_left_packing.push((left_packing, vec![(m, p, pe)]));
76    }
77    let (p, mmv, mmm) = grouped_by_left_packing
78        .iter()
79        .map(|(p, kit)| {
80            let best_for_mmv =
81                kit.iter().max_by_key(|(m, _, pe)| (m.nr() == 1, pe.is_none())).unwrap();
82            let best_for_mmm = kit.iter().max_by_key(|(m, _, _)| m.nr()).unwrap();
83            (p, best_for_mmv, best_for_mmm)
84        })
85        .max_by_key(|(_, mmv, mmm)| {
86            (mmv.0.nr() == 1 && mmm.0.nr() > 1, mmv.2.is_none(), mmm.0.mr(), mmm.0.nr())
87        })
88        .unwrap();
89
90    if mmm == mmv {
91        Ok((ModePicker::Single, clone_box(*p), vec![(mmv.0.clone(), *mmv.1, mmv.2.clone())]))
92    } else {
93        Ok((
94            ModePicker::VecVsMat,
95            clone_box(*p),
96            vec![(mmv.0.clone(), *mmv.1, mmv.2.clone()), (mmm.0.clone(), *mmm.1, mmm.2.clone())],
97        ))
98    }
99}
100
101pub fn list_impls(
102    model: &TypedModel,
103    node: &TypedNode,
104    op: &EinSumMatMul,
105) -> TractResult<Vec<Impl>> {
106    let (a_fact, b_fact) = model.node_input_facts(node.id)?.into_iter().collect_tuple().unwrap();
107    let a_dt = a_fact.datum_type;
108    let b_dt = b_fact.datum_type;
109
110    let a_weight: WeightType = if let Some(of) = a_fact.opaque_fact() {
111        if let Some(bqf) = of.downcast_ref::<BlockQuantFact>() {
112            WeightType::BlockQuant(bqf.format.clone())
113        } else {
114            bail!("Can not translate to matmul operand {a_fact:?}");
115        }
116    } else {
117        a_dt.into()
118    };
119
120    let impls = tract_linalg::ops()
121        .mmm_impls()
122        .iter()
123        .filter(|mmm| {
124            op.acceptable_accumulators().contains(&mmm.internal_type())
125                && mmm.stores().contains(&op.operating_dt.unquantized())
126        })
127        .flat_map(move |mmm| {
128            mmm.packings().iter().enumerate().map(|(ix, p)| (mmm.clone(), ix, &p.0, &p.1))
129        })
130        .filter_map(|(m, p, pa, pb)| {
131            if pb.precursor().as_dt().is_none_or(|dt| dt != b_dt.unquantized()) {
132                return None;
133            }
134            if pa.precursor() == a_weight {
135                Some((m, p, None))
136            } else {
137                tract_linalg::ops()
138                    .panel_extractors()
139                    .iter()
140                    .find(|pe| pe.from.precursor() == a_weight && pe.to.same_as(&**pa))
141                    .map(|pe| (m, p, Some(pe.clone())))
142            }
143        })
144        .collect_vec();
145    Ok(impls)
146}