Skip to main content

tract_linalg/frame/mmm/
cost_model.rs

1use tract_data::internal::*;
2use tract_data::itertools::{Itertools, izip};
3
4use super::MatMatMul;
5
6fn order_f<F: tract_num_traits::Float>(&a: &F, &b: &F) -> std::cmp::Ordering {
7    if a < b { std::cmp::Ordering::Less } else { std::cmp::Ordering::Greater }
8}
9
10#[derive(Debug)]
11pub struct CostModel<'a> {
12    pub big_product_mkn_threshold: f32,
13    pub big_product_kernel_choice: &'a str,
14    pub kernels: &'a [&'a str],
15    pub mrs: &'a [u32],
16    pub nrs: &'a [u32],
17    pub feat_norm_mean: &'a [f32],
18    pub feat_norm_stddev: &'a [f32],
19    pub w1: &'a [f32],
20    pub b1: &'a [f32],
21    pub w2: &'a [f32],
22    pub b2: &'a [f32],
23}
24
25impl CostModel<'_> {
26    pub fn features(&self, m: usize, k: usize, n: usize) -> Vec<f32> {
27        let mut feat = vec![
28            (m as f32).ln(),
29            (k as f32).ln(),
30            (n as f32).ln(),
31            (n as f32 * m as f32 * k as f32).ln(),
32        ];
33        for &mr in self.mrs {
34            let mr = mr as usize;
35            feat.push((m % mr) as f32);
36            feat.push((m % mr != 0) as usize as f32);
37        }
38        for &nr in self.nrs {
39            let nr = nr as usize;
40            feat.push((n % nr) as f32);
41            feat.push((n % nr != 0) as usize as f32);
42        }
43        feat
44    }
45
46    fn normalize(&self, feat: &mut [f32]) {
47        izip!(feat, self.feat_norm_mean, self.feat_norm_stddev)
48            .for_each(|(x, m, s)| *x = (*x - m) / s)
49    }
50
51    fn dnn(x: &[f32], w: &[f32], b: &[f32]) -> Vec<f32> {
52        let x = tract_ndarray::Array1::from_vec(x.to_vec());
53        let w = tract_ndarray::Array2::from_shape_vec([b.len(), x.len()], w.to_vec()).unwrap();
54        let b = tract_ndarray::Array1::from_vec(b.to_vec());
55        (w.dot(&x) + b).to_vec()
56    }
57
58    pub fn predict(&self, m: usize, k: usize, n: usize) -> &str {
59        let mut x = self.features(m, k, n);
60        self.normalize(&mut x);
61        let mut hidden = Self::dnn(&x, self.w1, self.b1);
62        (crate::generic().tanh_f32)().run(&mut hidden).unwrap();
63        let output = Self::dnn(&hidden, self.w2, self.b2);
64        let ix = output.iter().copied().position_max_by(order_f).unwrap();
65        self.kernels[ix]
66    }
67
68    pub fn pick(
69        &self,
70        impls: &[Box<dyn MatMatMul>],
71        m: Option<usize>,
72        k: Option<usize>,
73        n: Option<usize>,
74    ) -> Box<dyn MatMatMul> {
75        if let (Some(m), Some(k), Some(n)) = (m, k, n) {
76            let choice = self.predict(m, k, n);
77            impls.iter().find(|k| k.name() == choice).unwrap().clone()
78        } else {
79            impls.iter().find(|k| k.name() == self.big_product_kernel_choice).unwrap().clone()
80        }
81    }
82}