1use std::fmt::Debug;
2use std::ops::Deref;
3
4use super::pack::PackedFormat;
5use crate::BinOp;
6
7use super::{MMMInputValue, OutputStore, OutputStoreKer};
8use tract_data::internal::*;
9
10#[repr(usize)]
11#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
12pub enum RoundingPolicy {
13 Native,
14 Zero,
15 Away,
16 MinusInf,
17 PlusInf,
18 Even,
19 Odd,
20}
21
22#[derive(Clone, Debug)]
23pub enum AsInputValue<'t> {
24 Owned(Box<dyn MMMInputValue>),
25 Borrowed(&'t dyn MMMInputValue),
26}
27
28impl Deref for AsInputValue<'_> {
29 type Target = dyn MMMInputValue;
30 fn deref(&self) -> &Self::Target {
31 match self {
32 AsInputValue::Owned(b) => &**b,
33 AsInputValue::Borrowed(r) => *r,
34 }
35 }
36}
37
38#[derive(Clone, Debug)]
39pub enum FusedSpec<'t> {
40 BinScalar(&'t Tensor, BinOp),
41 BinPerRow(TensorView<'t>, BinOp),
42 BinPerCol(TensorView<'t>, BinOp),
43 AddRowColProducts(&'t Tensor, &'t Tensor),
44 AddUnicast(OutputStore),
45 LeakyRelu(&'t Tensor),
46 QScale(isize, RoundingPolicy, i32),
47 RoundingShiftRight(usize, RoundingPolicy),
48 ShiftLeft(usize),
49 Store(OutputStore),
50 AddMatMul { a: AsInputValue<'t>, b: AsInputValue<'t>, packing: usize },
51}
52
53impl FusedSpec<'_> {
54 pub fn prefer_col_outer(&self) -> Option<bool> {
55 if let FusedSpec::AddMatMul { a, b, .. } = self {
56 let a_is_eager = a.format().is::<PackedFormat>();
57 let b_is_eager = b.format().is::<PackedFormat>();
58 if a_is_eager == b_is_eager {
59 None
60 } else {
61 Some(a_is_eager)
62 }
63 } else {
64 None
65 }
66 }
67}
68
69#[repr(C, usize)]
71#[derive(PartialEq, Eq, Copy, Clone, Debug)]
72#[rustfmt::skip]
73pub enum FusedKerSpec<TI: Copy> {
74 Done, Clear, LoadTile(*const TI, *const TI), ScalarMin(TI), ScalarMax(TI), ScalarAdd(TI), ScalarMul(TI), ScalarSub(TI), ScalarSubF(TI), LeakyRelu(TI), PerRowMin(*const TI), PerRowMax(*const TI), PerRowAdd(*const TI), PerRowMul(*const TI), PerRowSub(*const TI), PerRowSubF(*const TI), PerColMin(*const TI), PerColMax(*const TI), PerColAdd(*const TI), PerColMul(*const TI), PerColSub(*const TI), PerColSubF(*const TI), QScale(isize, RoundingPolicy, i32), RoundingShiftRight(usize, RoundingPolicy), ShiftLeft(usize), AddUnicast(OutputStoreKer), AddRowColProducts(*const TI, *const TI), Store(OutputStoreKer), AddMatMul { k: usize, pa: *const u8, pb: *const u8, packing: usize },
111}
112
113unsafe impl<TI: Copy> Send for FusedKerSpec<TI> {}
114unsafe impl<TI: Copy> Sync for FusedKerSpec<TI> {}
115
116#[cfg(test)]
117#[test]
118fn check_non_linear_enum_size() {
119 assert_eq!(std::mem::size_of::<RoundingPolicy>(), std::mem::size_of::<usize>());
120 assert_eq!(
121 std::mem::size_of::<FusedKerSpec<f32>>(),
122 std::mem::size_of::<usize>() + std::mem::size_of::<OutputStoreKer>()
123 );
124 assert_eq!(std::mem::size_of::<FusedKerSpec<f32>>(), 5 * std::mem::size_of::<usize>());
125}