tract_linalg/frame/mmm/
cost_model.rs1use tract_data::internal::*;
2use tract_data::itertools::{izip, Itertools};
3
4use super::MatMatMul;
5
6fn order_f<F: tract_num_traits::Float>(&a: &F, &b: &F) -> std::cmp::Ordering {
7 if a < b {
8 std::cmp::Ordering::Less
9 } else {
10 std::cmp::Ordering::Greater
11 }
12}
13
14#[derive(Debug)]
15pub struct CostModel<'a> {
16 pub big_product_mkn_threshold: f32,
17 pub big_product_kernel_choice: &'a str,
18 pub kernels: &'a [&'a str],
19 pub mrs: &'a [u32],
20 pub nrs: &'a [u32],
21 pub feat_norm_mean: &'a [f32],
22 pub feat_norm_stddev: &'a [f32],
23 pub w1: &'a [f32],
24 pub b1: &'a [f32],
25 pub w2: &'a [f32],
26 pub b2: &'a [f32],
27}
28
29impl CostModel<'_> {
30 pub fn features(&self, m: usize, k: usize, n: usize) -> Vec<f32> {
31 let mut feat = vec![
32 (m as f32).ln(),
33 (k as f32).ln(),
34 (n as f32).ln(),
35 (n as f32 * m as f32 * k as f32).ln(),
36 ];
37 for &mr in self.mrs {
38 let mr = mr as usize;
39 feat.push((m % mr) as f32);
40 feat.push((m % mr != 0) as usize as f32);
41 }
42 for &nr in self.nrs {
43 let nr = nr as usize;
44 feat.push((n % nr) as f32);
45 feat.push((n % nr != 0) as usize as f32);
46 }
47 feat
48 }
49
50 fn normalize(&self, feat: &mut [f32]) {
51 izip!(feat, self.feat_norm_mean, self.feat_norm_stddev)
52 .for_each(|(x, m, s)| *x = (*x - m) / s)
53 }
54
55 fn dnn(x: &[f32], w: &[f32], b: &[f32]) -> Vec<f32> {
56 let x = tract_ndarray::Array1::from_vec(x.to_vec());
57 let w = tract_ndarray::Array2::from_shape_vec([b.len(), x.len()], w.to_vec()).unwrap();
58 let b = tract_ndarray::Array1::from_vec(b.to_vec());
59 (w.dot(&x) + b).to_vec()
60 }
61
62 pub fn predict(&self, m: usize, k: usize, n: usize) -> &str {
63 let mut x = self.features(m, k, n);
64 self.normalize(&mut x);
65 let mut hidden = Self::dnn(&x, self.w1, self.b1);
66 (crate::generic().tanh_f32)().run(&mut hidden).unwrap();
67 let output = Self::dnn(&hidden, self.w2, self.b2);
68 let ix = output.iter().copied().position_max_by(order_f).unwrap();
69 self.kernels[ix]
70 }
71
72 pub fn pick(
73 &self,
74 impls: &[Box<dyn MatMatMul>],
75 m: Option<usize>,
76 k: Option<usize>,
77 n: Option<usize>,
78 ) -> Box<dyn MatMatMul> {
79 if let (Some(m), Some(k), Some(n)) = (m, k, n) {
80 let choice = self.predict(m, k, n);
81 impls.iter().find(|k| k.name() == choice).unwrap().clone()
82 } else {
83 impls.iter().find(|k| k.name() == self.big_product_kernel_choice).unwrap().clone()
84 }
85 }
86}