use num_traits::Zero;
use std::fmt;
use std::ops::{Add, Mul};
use crate::internal::*;
use ndarray::*;
use super::MMMWrapper;
use tract_linalg::mmm::FusedSpec;
use tract_linalg::frame::PackB;
#[derive(Debug, Clone, PartialEq, Educe)]
#[educe(Hash)]
pub struct MatMatMulPackB<T>
where
T: Copy + Datum + Zero,
{
pub(crate) pack_b: PackB<T>,
pub(crate) row_stride: isize,
pub(crate) col_stride: isize,
pub(crate) output_shape: TVec<usize>,
}
impl<T> DynHash for MatMatMulPackB<T>
where
T: Copy + Datum + Zero,
{
fn dyn_hash(&self, hasher: &mut dyn std::hash::Hasher) {
tract_linalg::hash::dyn_hash(&self, hasher)
}
}
impl<T> Op for MatMatMulPackB<T>
where
T: Copy + Datum + Zero,
{
fn name(&self) -> Cow<str> {
"MatMatMulPackB".into()
}
fn same_as(&self, other: &dyn Op) -> bool {
other.downcast_ref::<Self>().map(|other| other == self).unwrap_or(false)
}
op_core_lir!();
op_as_typed_op!();
not_a_pulsed_op!();
}
impl<T> StatelessOp for MatMatMulPackB<T>
where
T: Copy + Datum + Zero,
{
fn eval(&self, mut inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
let b = args_1!(inputs);
let mut packed = unsafe {
Tensor::uninitialized_aligned::<T>(&*self.output_shape, self.pack_b.alignment())
.unwrap()
};
if b.shape()[..b.shape().len() - 2].iter().any(|d| *d > 1) {
let b = b.to_array_view::<T>()?;
let b_prefix = &b.shape()[..b.shape().len() - 2];
for prefix in indices(b_prefix).into_iter() {
let mut b = b.view();
let mut p = packed.to_array_view_mut()?;
for &dim in prefix.slice() {
b.index_axis_inplace(Axis(0), dim);
p.index_axis_inplace(Axis(0), dim);
}
self.pack_b.pack(p.as_mut_ptr(), b.as_ptr(), self.row_stride, self.col_stride)
}
} else {
self.pack_b.pack(packed.as_ptr_mut()?, b.as_ptr()?, self.row_stride, self.col_stride)
}
Ok(tvec!(packed.into_arc_tensor()))
}
}
impl<T> TypedOp for MatMatMulPackB<T>
where
T: Copy + Datum + Zero,
{
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
Ok(tvec!(TypedFact::dt_shape(inputs[0].datum_type, &*self.output_shape)?))
}
as_op!();
}
#[derive(Debug, Clone, Educe)]
#[educe(Hash)]
pub(crate) struct MatMatMulUnaryFinite<TA, TB, TC, TI>
where
TA: Datum + Copy + Zero,
TB: Datum + Copy + Zero,
TC: Datum + Copy,
TI: Datum + Copy + Add + Mul + Zero + fmt::Debug,
{
pub(crate) c_trans: bool,
pub(crate) bc_c_shape: TVec<usize>,
pub(crate) c_fact: TypedFact,
pub(crate) c_prefix_dim_and_stride: Option<(TVec<usize>, TVec<isize>)>,
pub(crate) packed_as: ArrayD<Arc<Tensor>>,
pub(crate) fused_ops: Option<ArrayD<Vec<FusedSpec<TI>>>>,
pub(crate) mmm: MMMWrapper<TA, TB, TC, TI>,
}
impl<TA, TB, TC, TI> DynHash for MatMatMulUnaryFinite<TA, TB, TC, TI>
where
TA: Datum + Copy + Zero,
TB: Datum + Copy + Zero,
TC: Datum + Copy,
TI: Datum + Copy + Add + Mul + Zero + fmt::Debug,
{
fn dyn_hash(&self, hasher: &mut dyn std::hash::Hasher) {
tract_linalg::hash::dyn_hash(&self, hasher)
}
}
impl<TA, TB, TC, TI> Op for MatMatMulUnaryFinite<TA, TB, TC, TI>
where
TA: Datum + Copy + Zero,
TB: Datum + Copy + Zero,
TC: Datum + Copy,
TI: Datum + Copy + Add + Mul + Zero + fmt::Debug,
{
fn name(&self) -> Cow<str> {
if self.mmm.as_mmm().n() == 1 { "MatVecMul" } else { "MatMatMul" }.into()
}
fn info(&self) -> TractResult<Vec<String>> {
let mut infos = vec![format!(
"c_prefix: {:?} m:{} k:{} n:{} c_trans:{:?}",
self.c_prefix_dim_and_stride,
self.mmm.as_mmm().m(),
self.mmm.as_mmm().k(),
self.mmm.as_mmm().n(),
self.c_trans
)];
infos.push(format!("Mult: {}", self.mmm));
if let Some(f) = &self.fused_ops {
infos.push(format!("{:?}", f));
}
Ok(infos)
}
op_core_lir!();
op_as_typed_op!();
not_a_pulsed_op!();
}
impl<TA, TB, TC, TI> StatelessOp for MatMatMulUnaryFinite<TA, TB, TC, TI>
where
TA: Datum + Copy + Zero,
TB: Datum + Copy + Zero,
TC: Datum + Copy,
TI: Datum + Copy + Add + Mul + Zero + fmt::Debug,
{
fn eval(&self, mut inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
unsafe {
let b = args_1!(inputs);
let mut c = Tensor::uninitialized::<TC>(&*self.c_fact.shape.as_finite().unwrap())?;
if let Some((prefix_dim, prefix_strides)) = &self.c_prefix_dim_and_stride {
let b = b.to_array_view::<TB>()?;
let mut c = c.to_array_view_mut::<TC>()?;
for prefix in indices(&**prefix_dim).into_iter() {
let mut a = self.packed_as.view();
let mut b = b.view();
let mut c: *mut TC = c.as_mut_ptr();
for (ix, &dim) in prefix.slice().iter().enumerate() {
let d = dim.min(a.shape()[0] - 1);
a.index_axis_inplace(Axis(0), d);
let d = dim.min(b.shape()[0] - 1);
b.index_axis_inplace(Axis(0), d);
c = c.offset(prefix_strides[ix] * dim as isize);
}
let pa: &Tensor = a.iter().next().unwrap();
if let Some(fused) = &self.fused_ops {
let mut fused = fused.view();
for &dim in prefix.slice() {
let d = dim.min(fused.shape()[0] - 1);
fused.index_axis_inplace(Axis(0), d);
}
self.mmm.run(pa.as_ptr()?, b.as_ptr(), c, &fused.as_slice().unwrap()[0]);
} else {
self.mmm.run(pa.as_ptr()?, b.as_ptr(), c, &[]);
}
}
} else {
if let Some(fused) = &self.fused_ops {
self.mmm.run(
self.packed_as.as_slice().unwrap()[0].as_ptr()?,
b.as_ptr()?,
c.as_ptr_mut()?,
&fused.as_slice().unwrap()[0],
);
} else {
self.mmm.run(
self.packed_as.as_slice().unwrap()[0].as_ptr()?,
b.as_ptr()?,
c.as_ptr_mut()?,
&[],
);
}
}
Ok(tvec!(c.into_arc_tensor()))
}
}
}
impl<TA, TB, TC, TI> TypedOp for MatMatMulUnaryFinite<TA, TB, TC, TI>
where
TA: Datum + Copy + Zero,
TB: Datum + Copy + Zero,
TC: Datum + Copy,
TI: Datum + Copy + Add + Mul + Zero + fmt::Debug,
{
fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
Ok(tvec!(self.c_fact.clone()))
}
fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
let mmm = self.mmm.as_mmm();
let mul = self.c_prefix_dim_and_stride.as_ref().map(|c| c.0.iter().product()).unwrap_or(1);
Ok(tvec!(
(Cost::FMA(TI::datum_type()), (mul * mmm.m() * mmm.n() * mmm.k()).to_dim()),
(
Cost::Params(TA::datum_type()),
self.packed_as.iter().fold(0.to_dim(), |sum, a| sum + a.len())
)
))
}
fn fuse(&self, model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
use crate::ops;
if let Some(succ) = model.single_succ(node.id)? {
if let Some(op) = succ.op_as::<ops::AxisOp>() {
if op.only_shape() {
return Ok(Some(TypedModelPatch::fuse_with_next(
model,
&node,
Self { c_fact: succ.outputs[0].fact.clone(), ..self.clone() },
)?));
}
}
let fused_micro_op = if let Some(op) = succ.op_as::<ops::binary::UnaryOp>() {
let m = self.mmm.as_mmm().m();
if op.a.len() == m
&& op.a.shape()[op.a.rank() - 1 - ((!self.c_trans) as usize)] == m
{
if op.mini_op.is::<ops::math::Mul>() {
Some(tvec!(FusedSpec::PerRowMul(op.a.as_slice::<TI>()?.to_vec(),)))
} else if op.mini_op.is::<ops::math::Add>() {
Some(tvec!(FusedSpec::PerRowAdd(op.a.as_slice::<TI>()?.to_vec(),)))
} else {
None
}
} else if op.a.len() == 1 {
if op.mini_op.is::<ops::math::Max>() {
Some(tvec!(FusedSpec::Max(op.a.cast_to_scalar()?)))
} else if op.mini_op.is::<ops::math::Min>() {
Some(tvec!(FusedSpec::Min(op.a.cast_to_scalar()?)))
} else if op.mini_op.is::<ops::math::Mul>() {
Some(tvec!(FusedSpec::ScalarMul(op.a.cast_to_scalar()?)))
} else {
None
}
} else {
None
}
} else {
None
};
if let Some(op) = fused_micro_op {
let mut new_op = self.clone();
new_op
.fused_ops
.get_or_insert_with(|| {
let shape = vec![
1;
self.c_prefix_dim_and_stride
.as_ref()
.map(|c| c.0.len())
.unwrap_or(0)
];
ArrayD::from_shape_fn(shape, |_| vec![])
})
.map_inplace(|v| v.extend(op.iter().cloned()));
return Ok(Some(TypedModelPatch::fuse_with_next(model, &node, new_op)?));
}
}
Ok(None)
}
as_op!();
}