1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
use crate::internal::*; use crate::ops::nn::DataShape; use ndarray::prelude::*; use tract_linalg::mmm::*; #[derive(Debug, Clone, new)] pub struct Direct { tile: Box<dyn MatMatMul<f32>>, data_offsets: Vec<isize>, kernel_offsets: Vec<isize>, input_shape: DataShape, output_shape: DataShape, packed_filters: Tensor, fused_ops: Vec<FusedSpec<f32>>, } impl Direct { pub fn output_shape(&self) -> &[usize] { &self.output_shape.shape } } impl Op for Direct { fn name(&self) -> Cow<str> { "ConvDirect".into() } fn info(&self) -> TractResult<Vec<String>> { let mut info = vec![format!("{:?}", self.tile)]; for op in &self.fused_ops { info.push(format!(" + {:?}", op)); } Ok(info) } fn fuse(&self, model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> { if let Some(succ) = model.single_succ(node.id)? { let fused_micro_op = (|| -> TractResult<Option<TVec<FusedSpec<f32>>>> { if let Some(op) = succ.op_as::<crate::ops::binary::UnaryOp>() { if op.a.shape() == &[*self.output_shape.c()] { if op.mini_op.is::<crate::ops::math::Mul>() { return Ok(Some(tvec!(FusedSpec::PerRowMul( op.a.as_slice::<f32>()?.to_vec(), )))); } else if op.mini_op.is::<crate::ops::math::Add>() { return Ok(Some(tvec!(FusedSpec::PerRowAdd( op.a.as_slice::<f32>()?.to_vec(), )))); } } } else if let Some(op) = succ.op_as::<crate::ops::math::ScalarMax>() { return Ok(Some(tvec!(FusedSpec::Max(op.max)))); } else if let Some(op) = succ.op_as::<crate::ops::math::ScalarMin>() { return Ok(Some(tvec!(FusedSpec::Min(op.min)))); } else if let Some(op) = succ.op_as::<crate::ops::math::ScalarMinMax>() { return Ok(Some(tvec!(FusedSpec::Min(op.min), FusedSpec::Max(op.max),))); } Ok(None) })()?; if let Some(op) = fused_micro_op { let mut ops = self.fused_ops.clone(); ops.extend(op.into_iter()); return Ok(Some(TypedModelPatch::fuse_with_next( model, node, Direct { fused_ops: ops, ..self.clone() }, )?)); } } Ok(None) } fn cost(&self, inputs: &[&TypedTensorInfo]) -> TractResult<TVec<(Cost, TDim)>> { let batch = inputs[0].shape.dim(0); Ok(tvec!(( Cost::FMA(f32::datum_type()), batch * self.tile.n() * self.tile.m() * self.tile.k() ))) } fn validation(&self) -> Validation { Validation::Rounding } op_as_typed_op!(); } impl StatelessOp for Direct { fn eval(&self, mut inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> { let input = args_1!(inputs); unsafe { let input = input.to_array_view::<f32>()?; let mut output = ArrayD::<f32>::uninitialized(&*self.output_shape.shape); let filters = self.packed_filters.as_ptr::<f32>()?; for n in 0..*self.input_shape.n() { let input = input.slice_axis(Axis(0), (n..=n).into()); let mut output = output.slice_axis_mut(Axis(0), (n..=n).into()); self.tile.run( &self.tile.a_from_packed(filters), &self.tile.b_from_data_and_offsets( input.as_ptr(), &self.kernel_offsets, &self.data_offsets, ), &mut self.tile.c_from_data_and_strides( output.as_mut_ptr(), *self.output_shape.c_stride() as isize, *self.output_shape.w_stride() as isize, ), &*self.fused_ops, ); } Ok(tvec!(output.into_arc_tensor())) } } } impl TypedOp for Direct { typed_op_as_op!(); fn output_facts(&self, inputs: &[&TypedTensorInfo]) -> TractResult<TVec<TypedTensorInfo>> { Ok(tvec!(TypedTensorInfo::dt_shape(inputs[0].datum_type, &*self.output_shape.shape)?)) } }