use std::fmt::Debug;
use std::ops::Deref;
use super::pack::PackedFormat;
use crate::BinOp;
use super::{MMMInputValue, OutputStore, OutputStoreKer};
use tract_data::internal::*;
#[repr(usize)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum RoundingPolicy {
Native,
Zero,
Away,
MinusInf,
PlusInf,
Even,
Odd,
}
#[derive(Clone, Debug)]
pub enum AsInputValue<'t> {
Owned(Box<dyn MMMInputValue>),
Borrowed(&'t dyn MMMInputValue),
}
impl Deref for AsInputValue<'_> {
type Target = dyn MMMInputValue;
fn deref(&self) -> &Self::Target {
match self {
AsInputValue::Owned(b) => &**b,
AsInputValue::Borrowed(r) => *r,
}
}
}
#[derive(Clone, Debug)]
pub enum FusedSpec<'t> {
BinScalar(&'t Tensor, BinOp),
BinPerRow(TensorView<'t>, BinOp),
BinPerCol(TensorView<'t>, BinOp),
AddRowColProducts(&'t Tensor, &'t Tensor),
AddUnicast(OutputStore),
LeakyRelu(&'t Tensor),
QScale(isize, RoundingPolicy, i32),
RoundingShiftRight(usize, RoundingPolicy),
ShiftLeft(usize),
Store(OutputStore),
AddMatMul { a: AsInputValue<'t>, b: AsInputValue<'t>, packing: usize },
}
impl FusedSpec<'_> {
pub fn prefer_col_outer(&self) -> Option<bool> {
if let FusedSpec::AddMatMul { a, b, .. } = self {
let a_is_eager = a.format().is::<PackedFormat>();
let b_is_eager = b.format().is::<PackedFormat>();
if a_is_eager == b_is_eager {
None
} else {
Some(a_is_eager)
}
} else {
None
}
}
}
#[repr(C, usize)]
#[derive(PartialEq, Eq, Copy, Clone, Debug)]
#[rustfmt::skip]
pub enum FusedKerSpec<TI: Copy> {
Done, Clear, LoadTile(*const TI, *const TI), ScalarMin(TI), ScalarMax(TI), ScalarAdd(TI), ScalarMul(TI), ScalarSub(TI), ScalarSubF(TI), LeakyRelu(TI), PerRowMin(*const TI), PerRowMax(*const TI), PerRowAdd(*const TI), PerRowMul(*const TI), PerRowSub(*const TI), PerRowSubF(*const TI), PerColMin(*const TI), PerColMax(*const TI), PerColAdd(*const TI), PerColMul(*const TI), PerColSub(*const TI), PerColSubF(*const TI), QScale(isize, RoundingPolicy, i32), RoundingShiftRight(usize, RoundingPolicy), ShiftLeft(usize), AddUnicast(OutputStoreKer), AddRowColProducts(*const TI, *const TI), Store(OutputStoreKer), AddMatMul { k: usize, pa: *const u8, pb: *const u8, packing: usize },
}
unsafe impl<TI: Copy> Send for FusedKerSpec<TI> {}
unsafe impl<TI: Copy> Sync for FusedKerSpec<TI> {}
#[cfg(test)]
#[test]
fn check_non_linear_enum_size() {
assert_eq!(std::mem::size_of::<RoundingPolicy>(), std::mem::size_of::<usize>());
assert_eq!(
std::mem::size_of::<FusedKerSpec<f32>>(),
std::mem::size_of::<usize>() + std::mem::size_of::<OutputStoreKer>()
);
assert_eq!(std::mem::size_of::<FusedKerSpec<f32>>(), 5 * std::mem::size_of::<usize>());
}