1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use crate::internal::*;
use crate::ops::nn::DataShape;
use ndarray::prelude::*;
use tract_linalg::mmm::*;

#[derive(Debug, Clone, new)]
pub struct Direct {
    tile: Box<dyn MatMatMul<f32>>,
    data_offsets: Vec<isize>,
    kernel_offsets: Vec<isize>,
    input_shape: DataShape,
    output_shape: DataShape,
    packed_filters: Tensor,
    fused_ops: Vec<FusedSpec<f32>>,
}

impl Direct {
    pub fn output_shape(&self) -> &[usize] {
        &self.output_shape.shape
    }
}

impl Op for Direct {
    fn name(&self) -> Cow<str> {
        "ConvDirect".into()
    }

    fn info(&self) -> TractResult<Vec<String>> {
        let mut info = vec![format!("{:?}", self.tile)];
        for op in &self.fused_ops {
            info.push(format!(" + {:?}", op));
        }
        Ok(info)
    }

    fn fuse(&self, model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
        if let Some(succ) = model.single_succ(node.id)? {
            let fused_micro_op = (|| -> TractResult<Option<TVec<FusedSpec<f32>>>> {
                if let Some(op) = succ.op_as::<crate::ops::binary::UnaryOp>() {
                    if op.a.shape() == &[*self.output_shape.c()] {
                        if op.mini_op.is::<crate::ops::math::Mul>() {
                            return Ok(Some(tvec!(FusedSpec::PerRowMul(
                                op.a.as_slice::<f32>()?.to_vec(),
                            ))));
                        } else if op.mini_op.is::<crate::ops::math::Add>() {
                            return Ok(Some(tvec!(FusedSpec::PerRowAdd(
                                op.a.as_slice::<f32>()?.to_vec(),
                            ))));
                        }
                    }
                } else if let Some(op) = succ.op_as::<crate::ops::math::ScalarMax>() {
                    return Ok(Some(tvec!(FusedSpec::Max(op.max))));
                } else if let Some(op) = succ.op_as::<crate::ops::math::ScalarMin>() {
                    return Ok(Some(tvec!(FusedSpec::Min(op.min))));
                } else if let Some(op) = succ.op_as::<crate::ops::math::ScalarMinMax>() {
                    return Ok(Some(tvec!(FusedSpec::Min(op.min), FusedSpec::Max(op.max),)));
                }
                Ok(None)
            })()?;
            if let Some(op) = fused_micro_op {
                let mut ops = self.fused_ops.clone();
                ops.extend(op.into_iter());
                return Ok(Some(TypedModelPatch::fuse_with_next(
                    model,
                    node,
                    Direct { fused_ops: ops, ..self.clone() },
                )?));
            }
        }
        Ok(None)
    }

    fn cost(&self, inputs: &[&TypedTensorInfo]) -> TractResult<TVec<(Cost, TDim)>> {
        let batch = inputs[0].shape.dim(0);
        Ok(tvec!((
            Cost::FMA(f32::datum_type()),
            batch * self.tile.n() * self.tile.m() * self.tile.k()
        )))
    }

    fn validation(&self) -> Validation {
        Validation::Rounding
    }

    op_as_typed_op!();
}

impl StatelessOp for Direct {
    fn eval(&self, mut inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
        let input = args_1!(inputs);
        unsafe {
            let input = input.to_array_view::<f32>()?;
            let mut output = ArrayD::<f32>::uninitialized(&*self.output_shape.shape);
            let filters = self.packed_filters.as_ptr::<f32>()?;
            for n in 0..*self.input_shape.n() {
                let input = input.slice_axis(Axis(0), (n..=n).into());
                let mut output = output.slice_axis_mut(Axis(0), (n..=n).into());
                self.tile.run(
                    &self.tile.a_from_packed(filters),
                    &self.tile.b_from_data_and_offsets(
                        input.as_ptr(),
                        &self.kernel_offsets,
                        &self.data_offsets,
                    ),
                    &mut self.tile.c_from_data_and_strides(
                        output.as_mut_ptr(),
                        *self.output_shape.c_stride() as isize,
                        *self.output_shape.w_stride() as isize,
                    ),
                    &*self.fused_ops,
                );
            }
            Ok(tvec!(output.into_arc_tensor()))
        }
    }
}

impl TypedOp for Direct {
    typed_op_as_op!();

    fn output_facts(&self, inputs: &[&TypedTensorInfo]) -> TractResult<TVec<TypedTensorInfo>> {
        Ok(tvec!(TypedTensorInfo::dt_shape(inputs[0].datum_type, &*self.output_shape.shape)?))
    }
}