use crate::hash::{DynHash, SloppyHash};
use std::fmt;
use std::fmt::Debug;
use std::hash::Hash;
use std::ops::{Add, Deref, Mul, Neg};
use num_traits::{AsPrimitive, Bounded, Zero};
use super::MatMatMul;
use super::*;
pub trait QMatMatMul<TA, TB, TC, TI>:
Debug + fmt::Display + dyn_clone::DynClone + Send + Sync + DynHash
where
TA: Copy + Zero + 'static,
TB: Copy + Zero + 'static,
TC: Copy + Debug + 'static,
TI: Copy + Add + Mul + Zero + Debug + 'static,
{
fn as_mmm(&self) -> &dyn MatMatMul<TA, TB, TC, TI>;
fn as_mmm_mut(&mut self) -> &mut dyn MatMatMul<TA, TB, TC, TI>;
unsafe fn set_zero_point_a_scalar(&mut self, value: TA);
unsafe fn set_zero_point_a_vector(&mut self, values: Vec<TA>);
unsafe fn set_zero_point_b_scalar(&mut self, value: TB);
unsafe fn set_zero_point_b_vector(&mut self, values: Vec<TB>);
unsafe fn set_zero_point_c_scalar(&mut self, value: TC);
unsafe fn set_scale_factor(&mut self, factor: f32);
unsafe fn run(&self, a: *const TA, b: *const TB, c: *mut TC, non_linear: &[FusedSpec<TI>]);
}
dyn_clone::clone_trait_object!(<TA, TB, TC, TI> QMatMatMul<TA, TB, TC, TI> where
TA: Copy + Zero + 'static,
TB: Copy + Zero + 'static,
TC: Copy + Debug + 'static,
TI: Copy + Add + Mul + Zero + Debug + SloppyHash + 'static,
);
impl<TA, TB, TC, TI> Hash for Box<dyn QMatMatMul<TA, TB, TC, TI>>
where
TA: Copy + Zero + 'static,
TB: Copy + Zero + 'static,
TC: Copy + Debug + 'static,
TI: Copy + Add + Mul + Zero + Debug + 'static,
{
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.dyn_hash(state)
}
}
#[derive(Debug, Clone)]
pub enum QuantizedParam<T: crate::hash::SloppyHash> {
Scalar(T),
Vector(Vec<T>),
}
impl<T: crate::hash::SloppyHash> std::hash::Hash for QuantizedParam<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match self {
Self::Scalar(t) => t.sloppy_hash(state),
Self::Vector(t) => {
t.len().hash(state);
t.iter().for_each(|v| v.sloppy_hash(state))
}
}
}
}
#[derive(Debug, Clone)]
pub struct QMatMatMulImpl<K, TA, TB, TC, TI>
where
TA: Copy + Zero + SloppyHash + 'static,
TB: Copy + Zero + SloppyHash + 'static,
TC: Copy + Debug + SloppyHash + 'static,
TI: Copy + Add + SloppyHash + Mul + Zero + Debug + 'static,
K: MatMatMulKer<TA, TB, TC, TI> + 'static,
{
pub mmm: MatMatMulImpl<K, TA, TB, TC, TI>,
pub zero_point_a: Option<QuantizedParam<TA>>,
pub zero_point_b: Option<QuantizedParam<TB>>,
pub zero_point_c: Option<TC>,
pub scale_factor: Option<(TI, usize)>,
}
impl<K, TA, TB, TC, TI> QMatMatMulImpl<K, TA, TB, TC, TI>
where
TA: Copy + Zero + SloppyHash + AsPrimitive<TI>,
TB: Copy + Zero + SloppyHash + AsPrimitive<TI> + Debug + 'static,
TC: Copy + Debug + SloppyHash + 'static,
TI: Copy + Add + Mul + Zero + Debug + SloppyHash + 'static,
K: MatMatMulKer<TA, TB, TC, TI> + 'static,
{
fn sum_a_over_k(&self, mut a: *const TA) -> Vec<TI> {
match &self.mmm.a_storage {
MatrixStoreSpec::Packed { .. } => {
let mr = K::mr();
let mut result = vec![TI::zero(); self.m];
unsafe {
for p in 0..(self.m / mr) {
for _k in 0..self.k {
for row in 0..mr {
result[p * mr + row] = result[p * mr + row] + (*a).as_();
a = a.offset(1);
}
}
}
if self.m % mr != 0 {
let p = self.m / mr;
for _k in 0..self.k {
for row in 0..mr {
if row < self.m % mr {
result[p * mr + row] = result[p * mr + row] + (*a).as_();
}
a = a.offset(1);
}
}
}
}
result
}
a => panic!("Storage for A {:?} not supported for quantized ops", a),
}
}
fn sum_b_over_k(&self, mut b: *const TB) -> Vec<TI> {
let mut result = vec![TI::zero(); self.n];
match &self.mmm.b_storage {
MatrixStoreSpec::Packed { .. } => unsafe {
let nr = K::nr();
for p in 0..(self.n / nr) {
for _k in 0..self.k {
for col in 0..nr {
result[p * nr + col] = result[p * nr + col] + (*b).as_();
b = b.offset(1);
}
}
}
if self.n % nr != 0 {
let p = self.n / nr;
for _k in 0..self.k {
for col in 0..nr {
if col < self.n % nr {
result[p * nr + col] = result[p * nr + col] + (*b).as_();
}
b = b.offset(1);
}
}
}
},
MatrixStoreSpec::OffsetsAndPtrs { row_byte_offsets, col_byte_offsets, .. } => unsafe {
for n in 0..self.n {
for k in 0..self.k {
let offset = (row_byte_offsets[k] + col_byte_offsets[n])
/ std::mem::size_of::<TB>() as isize;
result[n] = result[n] + (*b.offset(offset)).as_();
}
}
},
b => panic!("Storage {:?} for B not supported for quantized ops", b),
}
result
}
}
impl<K, TA, TB, TC, TI> From<MatMatMulImpl<K, TA, TB, TC, TI>> for QMatMatMulImpl<K, TA, TB, TC, TI>
where
TA: Copy + Zero + SloppyHash + 'static,
TB: Copy + Zero + SloppyHash + 'static,
TC: Copy + Debug + SloppyHash + 'static,
TI: Copy + Add + Mul + Zero + Debug + SloppyHash + 'static,
K: MatMatMulKer<TA, TB, TC, TI> + 'static,
{
fn from(mmm: MatMatMulImpl<K, TA, TB, TC, TI>) -> QMatMatMulImpl<K, TA, TB, TC, TI> {
QMatMatMulImpl {
mmm,
zero_point_a: None,
zero_point_b: None,
zero_point_c: None,
scale_factor: None,
}
}
}
impl<K, TA, TB, TC, TI> Deref for QMatMatMulImpl<K, TA, TB, TC, TI>
where
TA: Copy + Zero + SloppyHash + 'static,
TB: Copy + Zero + SloppyHash + 'static,
TC: Copy + Debug + SloppyHash + 'static,
TI: Copy + Add + Mul + Zero + Debug + SloppyHash + 'static,
K: MatMatMulKer<TA, TB, TC, TI> + 'static,
{
type Target = MatMatMulImpl<K, TA, TB, TC, TI>;
fn deref(&self) -> &Self::Target {
&self.mmm
}
}
unsafe impl<K, TA, TB, TC, TI> Send for QMatMatMulImpl<K, TA, TB, TC, TI>
where
TA: Copy + Zero + SloppyHash + 'static,
TB: Copy + Zero + SloppyHash + 'static,
TC: Copy + Debug + SloppyHash + 'static,
TI: Copy + Add + Mul + Zero + Debug + SloppyHash + 'static,
K: MatMatMulKer<TA, TB, TC, TI> + 'static,
{
}
unsafe impl<K, TA, TB, TC, TI> Sync for QMatMatMulImpl<K, TA, TB, TC, TI>
where
TA: Copy + Zero + SloppyHash + 'static,
TB: Copy + Zero + SloppyHash + 'static,
TC: Copy + Debug + SloppyHash + 'static,
TI: Copy + Add + Mul + Zero + Debug + SloppyHash + 'static,
K: MatMatMulKer<TA, TB, TC, TI> + 'static,
{
}
impl<K, TA, TB, TC, TI> QMatMatMul<TA, TB, TC, TI> for QMatMatMulImpl<K, TA, TB, TC, TI>
where
TA: Copy + Zero + Debug + SloppyHash + AsPrimitive<TI> + 'static,
TB: Copy + Zero + Debug + SloppyHash + AsPrimitive<TI> + 'static,
TC: Copy + Debug + Bounded + AsPrimitive<TI> + SloppyHash + 'static,
TI: Copy + Add + Mul<Output = TI> + Zero + Neg<Output = TI> + Debug + SloppyHash + 'static,
K: MatMatMulKer<TA, TB, TC, TI> + 'static,
usize: AsPrimitive<TI>,
i32: AsPrimitive<TI>,
{
fn as_mmm(&self) -> &dyn MatMatMul<TA, TB, TC, TI> {
&self.mmm
}
fn as_mmm_mut(&mut self) -> &mut dyn MatMatMul<TA, TB, TC, TI> {
&mut self.mmm
}
unsafe fn set_zero_point_a_scalar(&mut self, value: TA) {
self.zero_point_a = Some(QuantizedParam::Scalar(value))
}
unsafe fn set_zero_point_b_scalar(&mut self, value: TB) {
self.zero_point_b = Some(QuantizedParam::Scalar(value))
}
unsafe fn set_zero_point_c_scalar(&mut self, value: TC) {
self.zero_point_c = Some(value)
}
unsafe fn set_zero_point_a_vector(&mut self, mut values: Vec<TA>) {
let wanted = self.m() + K::mr() - 1 / K::mr() * K::mr();
while values.len() < wanted {
values.push(values[values.len() - 1])
}
self.zero_point_a = Some(QuantizedParam::Vector(values))
}
unsafe fn set_zero_point_b_vector(&mut self, mut values: Vec<TB>) {
let wanted = self.n() + K::nr() - 1 / K::nr() * K::nr();
while values.len() < wanted {
values.push(values[values.len() - 1])
}
self.zero_point_b = Some(QuantizedParam::Vector(values))
}
unsafe fn set_scale_factor(&mut self, factor: f32) {
let factor_bits = factor.to_bits();
let current_exponent = factor_bits >> 23;
let bumped_multi = f32::from_bits(factor_bits & 0x007fffff | 0x3f000000);
let int_multi = (bumped_multi * (1i64 << 31) as f32).round() as i32;
let shift = 126 - current_exponent;
self.scale_factor = Some((int_multi.as_(), shift as usize));
}
unsafe fn run(&self, a: *const TA, b: *const TB, c: *mut TC, non_linear: &[FusedSpec<TI>]) {
let mut non_linear = non_linear.to_vec();
if let Some(ref a0) = self.zero_point_a {
let mut sum_b_over_k = self.sum_b_over_k(b);
for n in 0..self.n {
sum_b_over_k[n] = sum_b_over_k[n].neg();
}
let term = match a0 {
QuantizedParam::Scalar(a0) => {
for n in 0..self.n {
sum_b_over_k[n] = sum_b_over_k[n] * a0.as_();
}
FusedSpec::PerColAdd(sum_b_over_k)
}
QuantizedParam::Vector(a0) => {
let a0 = a0.iter().map(|a| a.as_()).collect();
FusedSpec::AddRowColProducts(a0, sum_b_over_k)
}
};
non_linear.insert(0, term);
}
if let Some(ref b0) = self.zero_point_b {
let mut sum_a_over_k = self.sum_a_over_k(a);
for m in 0..self.m {
sum_a_over_k[m] = sum_a_over_k[m].neg();
if let Some(ref a0) = self.zero_point_a {
match a0 {
QuantizedParam::Scalar(a0) => {
sum_a_over_k[m] = a0.as_() * self.k.as_() + sum_a_over_k[m];
}
QuantizedParam::Vector(a0) => {
sum_a_over_k[m] = a0[m].as_() * self.k.as_() + sum_a_over_k[m];
}
}
}
}
let term = match b0 {
QuantizedParam::Scalar(b0) => {
for m in 0..self.m {
sum_a_over_k[m] = sum_a_over_k[m] * b0.as_();
}
FusedSpec::PerRowAdd(sum_a_over_k)
}
QuantizedParam::Vector(b0) => {
let b0 = b0.iter().map(|b| b.as_()).collect();
FusedSpec::AddRowColProducts(sum_a_over_k, b0)
}
};
non_linear.insert(0, term);
}
if let Some(scale) = self.scale_factor {
non_linear.push(FusedSpec::QTowardsPlusInf(scale.0, scale.1));
}
if let Some(c0) = self.zero_point_c {
non_linear.push(FusedSpec::ScalarAdd(c0.as_()));
}
non_linear.push(FusedSpec::Min(TC::max_value().as_()));
non_linear.push(FusedSpec::Max(TC::min_value().as_()));
self.mmm.run(a, b, c, &non_linear);
}
}
impl<K, TA, TB, TC, TI> fmt::Display for QMatMatMulImpl<K, TA, TB, TC, TI>
where
TA: Copy + Zero + SloppyHash + 'static,
TB: Copy + Zero + SloppyHash + 'static,
TC: Copy + Debug + SloppyHash + 'static,
TI: Copy + Add + Mul + Zero + Debug + SloppyHash + 'static,
K: MatMatMulKer<TA, TB, TC, TI>,
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "{}", self.mmm)
}
}
impl<TA, TB, TC, TI, K> std::hash::Hash for QMatMatMulImpl<K, TA, TB, TC, TI>
where
TA: Copy + Zero + SloppyHash + 'static,
TB: Copy + Zero + SloppyHash + 'static,
TC: Copy + Debug + SloppyHash + 'static,
TI: Copy + Add + Mul + Zero + Debug + SloppyHash + 'static,
K: MatMatMulKer<TA, TB, TC, TI>,
{
fn hash<S: std::hash::Hasher>(&self, state: &mut S) {
self.mmm.hash(state);
if let Some(a) = &self.zero_point_a {
a.hash(state);
}
if let Some(b) = &self.zero_point_b {
b.hash(state);
}
}
}
impl<TA, TB, TC, TI, K> crate::hash::DynHash for QMatMatMulImpl<K, TA, TB, TC, TI>
where
TA: Copy + Zero + SloppyHash + 'static,
TB: Copy + Zero + SloppyHash + 'static,
TC: Copy + Debug + SloppyHash + 'static,
TI: Copy + Add + Mul + Zero + Debug + SloppyHash + 'static,
K: MatMatMulKer<TA, TB, TC, TI>,
{
fn dyn_hash(&self, hasher: &mut dyn std::hash::Hasher) {
crate::hash::dyn_hash(self, hasher)
}
}
#[cfg(test)]
#[allow(dead_code)]
#[macro_use]
pub mod test {
use super::*;
use crate::align::Buffer;
use proptest::collection::vec;
use proptest::prelude::*;
use std::marker::PhantomData;
use std::ops::{AddAssign, Sub};
#[derive(Debug)]
pub struct QMatMulProblem<TA: SloppyHash, TB: SloppyHash, TC, TI: SloppyHash> {
pub m: usize,
pub k: usize,
pub n: usize,
pub a: Vec<TA>,
pub a0: QuantizedParam<TA>,
pub b: Vec<TB>,
pub b0: QuantizedParam<TB>,
pub boo: PhantomData<(TC, TI)>,
}
impl<TI: Arbitrary + 'static + SloppyHash> Arbitrary for QuantizedParam<TI> {
type Parameters = usize;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(n: usize) -> Self::Strategy {
prop_oneof![
any::<TI>().prop_map(QuantizedParam::Scalar),
vec(any::<TI>(), n..=n).prop_map(QuantizedParam::Vector),
]
.boxed()
}
}
impl<TA, TB, TC, TI> Arbitrary for QMatMulProblem<TA, TB, TC, TI>
where
TA: Arbitrary + 'static + Debug + 'static + SloppyHash,
TB: Arbitrary + 'static + Debug + 'static + SloppyHash,
TC: Arbitrary + 'static + Debug + 'static,
TI: Arbitrary + 'static + Debug + 'static + SloppyHash,
{
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
(1usize..10, 1usize..10, 1usize..10)
.prop_flat_map(|(m, k, n)| {
(
Just(m),
Just(k),
Just(n),
vec(any::<TA>(), m * k..=m * k),
any_with::<QuantizedParam<TA>>(m),
vec(any::<TB>(), k * n..=k * n),
any_with::<QuantizedParam<TB>>(n),
)
})
.prop_map(|(m, k, n, a, a0, b, b0)| QMatMulProblem {
m,
k,
n,
a,
a0,
b,
b0,
boo: PhantomData,
})
.boxed()
}
}
impl<TA, TB, TC, TI> QMatMulProblem<TA, TB, TC, TI>
where
TA: Arbitrary + SloppyHash + 'static + Debug + AsPrimitive<TI> + Zero + Copy,
TB: Arbitrary + SloppyHash + 'static + Debug + AsPrimitive<TI> + Zero + Copy,
TC: Arbitrary + SloppyHash + 'static + Debug + Copy + Bounded + AsPrimitive<TI> + Zero + 'static,
TI: Arbitrary
+ 'static
+ Debug
+ Copy
+ AsPrimitive<TC>
+ Add<Output = TI>
+ Mul<Output = TI>
+ Sub<Output = TI>
+ AddAssign
+ Neg<Output = TI>
+ Zero
+ SloppyHash
+ Ord,
usize: AsPrimitive<TI>,
i32: AsPrimitive<TI>,
{
pub fn reference(&self) -> Vec<TC> {
let mut i = vec![TI::zero(); self.m * self.n];
for m in 0..self.m {
for n in 0..self.n {
for k in 0..self.k {
let a: TI = self.a[k + self.k * m].as_();
let b: TI = self.b[n + self.n * k].as_();
let a0 = match &self.a0 {
QuantizedParam::Scalar(a0) => a0.as_(),
QuantizedParam::Vector(a0) => a0[m].as_(),
};
let b0 = match &self.b0 {
QuantizedParam::Scalar(b0) => b0.as_(),
QuantizedParam::Vector(b0) => b0[n].as_(),
};
i[n + self.n * m] += (a - a0) * (b - b0);
}
}
}
i.iter()
.map(|i| i.max(&TC::min_value().as_()).min(&TC::max_value().as_()).as_())
.collect()
}
pub fn run<K: MatMatMulKer<TA, TB, TC, TI>>(&self) -> Vec<TC> {
unsafe {
let mut c = vec![TC::zero(); self.m * self.n];
let mut mmm = QMatMatMulImpl::from(MatMatMulImpl::<K, TA, TB, TC, TI>::new(
self.m, self.k, self.n,
));
let mut packed_a =
Buffer::uninitialized(mmm.a_pack().len(), mmm.a_pack().alignment());
mmm.a_pack().pack(packed_a.as_mut_ptr(), self.a.as_ptr(), self.k as isize, 1);
let mut packed_b =
Buffer::uninitialized(mmm.b_pack().len(), mmm.b_pack().alignment());
mmm.b_pack().pack(packed_b.as_mut_ptr(), self.b.as_ptr(), self.n as isize, 1);
match &self.a0 {
QuantizedParam::Scalar(a0) => mmm.set_zero_point_a_scalar(*a0),
QuantizedParam::Vector(a0) => mmm.set_zero_point_a_vector(a0.clone()),
}
match &self.b0 {
QuantizedParam::Scalar(b0) => mmm.set_zero_point_b_scalar(*b0),
QuantizedParam::Vector(b0) => mmm.set_zero_point_b_vector(b0.clone()),
}
mmm.run(packed_a.as_ptr(), packed_b.as_ptr(), c.as_mut_ptr(), &[]);
c
}
}
}
#[macro_export]
macro_rules! qmmm_frame_tests {
($cond:expr, $ker:ty, $ta: ty, $tb: ty, $tc: ty, $ti: ty) => {
mod qframe {
use proptest::prelude::*;
use std::marker::PhantomData;
#[allow(unused_imports)]
use $crate::frame::mmm::qmmm::test::*;
use $crate::frame::mmm::qmmm::QuantizedParam;
proptest::proptest! {
#[test]
fn q_mat_mul_prop(pb in any::<QMatMulProblem<$ta, $tb, $tc, $ti>>()) {
if $cond {
prop_assert_eq!(pb.run::<$ker>(), pb.reference())
}
}
}
#[test]
fn q_mat_mul_1() {
if $cond {
let pb = QMatMulProblem {
m: 1,
k: 1,
n: 1,
a0: QuantizedParam::Vector(vec![1]),
a: vec![0],
b0: QuantizedParam::Vector(vec![1]),
b: vec![0],
boo: PhantomData,
};
assert_eq!(pb.run::<$ker>(), pb.reference());
}
}
#[test]
fn q_mat_mul_sat_1() {
if $cond {
let pb = QMatMulProblem {
m: 1,
k: 1,
n: 1,
a0: QuantizedParam::Vector(vec![0]),
a: vec![3],
b0: QuantizedParam::Vector(vec![43]),
b: vec![0],
boo: PhantomData,
};
assert_eq!(pb.run::<$ker>(), pb.reference());
}
}
#[test]
fn q_mat_mul_sat_2() {
if $cond {
let pb = QMatMulProblem {
m: 1,
k: 1,
n: 1,
a0: QuantizedParam::Vector(vec![0]),
a: vec![<$ta>::min_value()],
b0: QuantizedParam::Vector(vec![0]),
b: vec![1],
boo: PhantomData,
};
assert_eq!(pb.run::<$ker>(), pb.reference());
}
}
#[test]
fn q_mat_mul_n2() {
if $cond {
let pb = QMatMulProblem {
m: 1,
k: 1,
n: 2,
a: vec![0],
a0: QuantizedParam::Vector(vec![1]),
b: vec![0, 0],
b0: QuantizedParam::Vector(vec![0, 1]),
boo: PhantomData,
};
assert_eq!(pb.run::<$ker>(), pb.reference());
}
}
#[test]
fn q_mat_mul_k2() {
if $cond {
let pb = QMatMulProblem {
m: 1,
k: 2,
n: 1,
a: vec![0, 1],
a0: QuantizedParam::Vector(vec![0]),
b: vec![0, 1],
b0: QuantizedParam::Vector(vec![0]),
boo: PhantomData,
};
assert_eq!(pb.run::<$ker>(), pb.reference());
}
}
}
};
}
#[macro_export]
macro_rules! qmmm_s_frame_tests {
($cond:expr, $ker:ty, $ta: ty, $tb: ty, $tc: ty, $ti: ty) => {
mod qframe_s {
use std::marker::PhantomData;
#[allow(unused_imports)]
use $crate::frame::mmm::qmmm::test::*;
use $crate::frame::mmm::qmmm::QuantizedParam;
#[test]
fn q_mat_mul_1_1_5() {
if $cond {
let pb = QMatMulProblem {
m: 1,
k: 1,
n: 5,
a: vec![-1],
a0: QuantizedParam::Scalar(0),
b: vec![0, 0, 0, 0, -2],
b0: QuantizedParam::Scalar(0),
boo: PhantomData,
};
assert_eq!(pb.run::<$ker>(), pb.reference());
}
}
#[test]
fn q_mat_mul_1_1_1() {
if $cond {
let pb = QMatMulProblem {
m: 1,
k: 1,
n: 1,
a: vec![11],
a0: QuantizedParam::Scalar(10),
b: vec![-1],
b0: QuantizedParam::Scalar(0),
boo: PhantomData,
};
assert_eq!(pb.run::<$ker>(), pb.reference());
}
}
}
};
}
}