1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
use num_traits::Zero;
use std::fmt;
use std::ops::{Add, Deref, Mul};

use crate::internal::*;
use crate::ops::quant::QParams;

use tract_linalg::mmm::{FusedSpec, MatMatMul, QMatMatMul};

#[derive(Clone, Debug, Educe)]
#[educe(Hash)]
pub enum MMMWrapper<TA, TB, TC, TI>
where
    TA: Datum + Copy + Zero,
    TB: Datum + Copy + Zero,
    TC: Datum + Copy,
    TI: Datum + Copy + Add + Mul + Zero + fmt::Debug,
{
    Plain(Box<dyn MatMatMul<TA, TB, TC, TI>>),
    Quant(Box<dyn QMatMatMul<TA, TB, TC, TI>>),
}

impl<TA, TB, TC, TI> MMMWrapper<TA, TB, TC, TI>
where
    TA: Datum + Copy + Zero,
    TB: Datum + Copy + Zero,
    TC: Datum + Copy,
    TI: Datum + Copy + Add + Mul + Zero + fmt::Debug,
{
    pub fn as_mmm(&self) -> &dyn MatMatMul<TA, TB, TC, TI> {
        match self {
            MMMWrapper::Plain(a) => a.as_ref(),
            MMMWrapper::Quant(a) => a.as_mmm(),
        }
    }

    pub fn as_mmm_mut(&mut self) -> &mut dyn MatMatMul<TA, TB, TC, TI> {
        match self {
            MMMWrapper::Plain(a) => a.as_mut(),
            MMMWrapper::Quant(a) => a.as_mmm_mut(),
        }
    }

    pub fn as_quant(&self) -> Option<&dyn QMatMatMul<TA, TB, TC, TI>> {
        match self {
            MMMWrapper::Plain(_) => None,
            MMMWrapper::Quant(a) => Some(a.deref()),
        }
    }

    pub fn as_quant_mut(&mut self) -> Option<&mut dyn QMatMatMul<TA, TB, TC, TI>> {
        match self {
            MMMWrapper::Plain(_) => None,
            MMMWrapper::Quant(ref mut a) => Some(a.as_mut()),
        }
    }

    pub unsafe fn run(&self, a: *const TA, b: *const TB, c: *mut TC, non_linear: &[FusedSpec<TI>]) {
        match self {
            MMMWrapper::Plain(p) => p.run(a, b, c, non_linear),
            MMMWrapper::Quant(q) => q.run(a, b, c, non_linear),
        }
    }

    pub fn set_quant_params(&mut self, params: &QParams) -> TractResult<()> {
        let q = self.as_quant_mut().ok_or("try to zero_point on a float mat mul")?;
        unsafe {
            if let Some(t) = params.zero_point_a.as_ref() {
                if t.rank() == 0 {
                    q.set_zero_point_a_scalar(*t.to_scalar()?)
                } else {
                    q.set_zero_point_a_vector(t.as_slice()?.to_vec())
                }
            }
            if let Some(t) = params.zero_point_b.as_ref() {
                if t.rank() == 0 {
                    q.set_zero_point_b_scalar(*t.to_scalar()?)
                } else {
                    q.set_zero_point_b_vector(t.as_slice()?.to_vec())
                }
            }
            if let Some(t) = params.zero_point_c.as_ref() {
                q.set_zero_point_c_scalar(t.cast_to_scalar()?)
            }
            if let Some(factor) = params.scale_factor {
                q.set_scale_factor(factor);
            }
        }
        Ok(())
    }
}

impl<TA, TB, TC, TI> fmt::Display for MMMWrapper<TA, TB, TC, TI>
where
    TA: Datum + Copy + Zero,
    TB: Datum + Copy + Zero,
    TC: Datum + Copy,
    TI: Datum + Copy + Add + Mul + Zero + fmt::Debug,
{
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        match self {
            MMMWrapper::Plain(a) => write!(fmt, "{}", a),
            MMMWrapper::Quant(a) => write!(fmt, "{}", a),
        }
    }
}