tract_linalg/frame/mmm/
fuse.rs

1use std::fmt::Debug;
2use std::ops::Deref;
3
4use super::pack::PackedFormat;
5use crate::BinOp;
6
7use super::{MMMInputValue, OutputStore, OutputStoreKer};
8use tract_data::internal::*;
9
10#[repr(usize)]
11#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
12pub enum RoundingPolicy {
13    Native,
14    Zero,
15    Away,
16    MinusInf,
17    PlusInf,
18    Even,
19    Odd,
20}
21
22#[derive(Clone, Debug)]
23pub enum AsInputValue<'t> {
24    Owned(Box<dyn MMMInputValue>),
25    Borrowed(&'t dyn MMMInputValue),
26}
27
28impl Deref for AsInputValue<'_> {
29    type Target = dyn MMMInputValue;
30    fn deref(&self) -> &Self::Target {
31        match self {
32            AsInputValue::Owned(b) => &**b,
33            AsInputValue::Borrowed(r) => *r,
34        }
35    }
36}
37
38#[derive(Clone, Debug)]
39pub enum FusedSpec<'t> {
40    BinScalar(&'t Tensor, BinOp),
41    BinPerRow(TensorView<'t>, BinOp),
42    BinPerCol(TensorView<'t>, BinOp),
43    AddRowColProducts(&'t Tensor, &'t Tensor),
44    AddUnicast(OutputStore),
45    LeakyRelu(&'t Tensor),
46    QScale(isize, RoundingPolicy, i32),
47    RoundingShiftRight(usize, RoundingPolicy),
48    ShiftLeft(usize),
49    Store(OutputStore),
50    AddMatMul { a: AsInputValue<'t>, b: AsInputValue<'t>, packing: usize },
51}
52
53impl FusedSpec<'_> {
54    pub fn prefer_col_outer(&self) -> Option<bool> {
55        if let FusedSpec::AddMatMul { a, b, .. } = self {
56            let a_is_eager = a.format().is::<PackedFormat>();
57            let b_is_eager = b.format().is::<PackedFormat>();
58            if a_is_eager == b_is_eager {
59                None
60            } else {
61                Some(a_is_eager)
62            }
63        } else {
64            None
65        }
66    }
67}
68
69// Careful here, the jump_to comments are used by the build script.
70#[repr(C, usize)]
71#[derive(PartialEq, Eq, Copy, Clone, Debug)]
72#[rustfmt::skip]
73pub enum FusedKerSpec<TI: Copy> {
74    Done,                                       // jump_to:done
75    Clear,                                      // jump_to:clear
76                                                //
77    LoadTile(*const TI, *const TI),             // jump_to:load_tile
78
79    ScalarMin(TI),                              // jump_to:scalar_min
80    ScalarMax(TI),                              // jump_to:scalar_max
81    ScalarAdd(TI),                              // jump_to:scalar_add
82    ScalarMul(TI),                              // jump_to:scalar_mul
83    ScalarSub(TI),                              // jump_to:scalar_sub
84    ScalarSubF(TI),                             // jump_to:scalar_sub_flipped
85
86    LeakyRelu(TI),                              // jump_to:leaky_relu
87
88    PerRowMin(*const TI),                       // jump_to:per_row_min
89    PerRowMax(*const TI),                       // jump_to:per_row_max
90    PerRowAdd(*const TI),                       // jump_to:per_row_add
91    PerRowMul(*const TI),                       // jump_to:per_row_mul
92    PerRowSub(*const TI),                       // jump_to:per_row_sub
93    PerRowSubF(*const TI),                      // jump_to:per_row_sub_flipped
94
95    PerColMin(*const TI),                       // jump_to:per_col_min
96    PerColMax(*const TI),                       // jump_to:per_col_max
97    PerColAdd(*const TI),                       // jump_to:per_col_add
98    PerColMul(*const TI),                       // jump_to:per_col_mul
99    PerColSub(*const TI),                       // jump_to:per_col_sub
100    PerColSubF(*const TI),                      // jump_to:per_col_sub_flipped
101
102    QScale(isize, RoundingPolicy, i32),         // jump_to:q_scale
103    RoundingShiftRight(usize, RoundingPolicy),  // jump_to:q_shr
104    ShiftLeft(usize),                           // jump_to:q_shl
105    AddUnicast(OutputStoreKer),                 // jump_to:add_unicast
106    AddRowColProducts(*const TI, *const TI),    // jump_to:add_row_col_products
107    Store(OutputStoreKer),                      // jump_to:store
108
109    // jump_to:add_mat_mul
110    AddMatMul { k: usize, pa: *const u8, pb: *const u8, packing: usize },
111}
112
113unsafe impl<TI: Copy> Send for FusedKerSpec<TI> {}
114unsafe impl<TI: Copy> Sync for FusedKerSpec<TI> {}
115
116#[cfg(test)]
117#[test]
118fn check_non_linear_enum_size() {
119    assert_eq!(std::mem::size_of::<RoundingPolicy>(), std::mem::size_of::<usize>());
120    assert_eq!(
121        std::mem::size_of::<FusedKerSpec<f32>>(),
122        std::mem::size_of::<usize>() + std::mem::size_of::<OutputStoreKer>()
123    );
124    assert_eq!(std::mem::size_of::<FusedKerSpec<f32>>(), 5 * std::mem::size_of::<usize>());
125}