1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
use tract_data::internal::*;
use tract_data::itertools::{izip, Itertools};

use super::MatMatMul;

fn order_f<F: tract_num_traits::Float>(&a: &F, &b: &F) -> std::cmp::Ordering {
    if a < b {
        std::cmp::Ordering::Less
    } else {
        std::cmp::Ordering::Greater
    }
}

#[derive(Debug)]
pub struct CostModel<'a> {
    pub big_product_mkn_threshold: f32,
    pub big_product_kernel_choice: &'a str,
    pub kernels: &'a [&'a str],
    pub mrs: &'a [u32],
    pub nrs: &'a [u32],
    pub feat_norm_mean: &'a [f32],
    pub feat_norm_stddev: &'a [f32],
    pub w1: &'a [f32],
    pub b1: &'a [f32],
    pub w2: &'a [f32],
    pub b2: &'a [f32],
}

impl<'a> CostModel<'a> {
    pub fn features(&self, m: usize, k: usize, n: usize) -> Vec<f32> {
        let mut feat = vec![
            (m as f32).ln(),
            (k as f32).ln(),
            (n as f32).ln(),
            (n as f32 * m as f32 * k as f32).ln(),
        ];
        for &mr in self.mrs {
            let mr = mr as usize;
            feat.push((m % mr) as f32);
            feat.push((m % mr != 0) as usize as f32);
        }
        for &nr in self.nrs {
            let nr = nr as usize;
            feat.push((n % nr) as f32);
            feat.push((n % nr != 0) as usize as f32);
        }
        feat
    }

    fn normalize(&self, feat: &mut [f32]) {
        izip!(feat, self.feat_norm_mean, self.feat_norm_stddev)
            .for_each(|(x, m, s)| *x = (*x - m) / s)
    }

    fn dnn(x: &[f32], w: &[f32], b: &[f32]) -> Vec<f32> {
        let x = tract_ndarray::Array1::from_vec(x.to_vec());
        let w = tract_ndarray::Array2::from_shape_vec([b.len(), x.len()], w.to_vec()).unwrap();
        let b = tract_ndarray::Array1::from_vec(b.to_vec());
        (w.dot(&x) + b).to_vec()
    }

    pub fn predict(&self, m: usize, k: usize, n: usize) -> &str {
        let mut x = self.features(m, k, n);
        self.normalize(&mut x);
        let mut hidden = Self::dnn(&x, self.w1, self.b1);
        (crate::generic().tanh_f32)().run(&mut hidden).unwrap();
        let output = Self::dnn(&hidden, self.w2, self.b2);
        let ix = output.iter().copied().position_max_by(order_f).unwrap();
        self.kernels[ix]
    }

    pub fn pick(
        &self,
        impls: &[Box<dyn MatMatMul>],
        m: Option<usize>,
        k: Option<usize>,
        n: Option<usize>,
    ) -> Box<dyn MatMatMul> {
        if let (Some(m), Some(k), Some(n)) = (m, k, n) {
            let choice = self.predict(m, k, n);
            impls.iter().find(|k| k.kernel_name() == choice).unwrap().clone()
        } else {
            impls
                .iter()
                .find(|k| k.kernel_name() == self.big_product_kernel_choice)
                .unwrap()
                .clone()
        }
    }
}