tract_core/ops/matmul/
optimized.rs

1use crate::internal::*;
2use crate::ops::cast::cast;
3use crate::ops::change_axes::wire_with_rank_broadcast;
4use crate::ops::nn::LeakyRelu;
5use ndarray::*;
6use tract_itertools::Itertools;
7
8use tract_linalg::mmm::{
9    AsInputValue, EagerPackedInput, FusedSpec, MMMInputValue, MatMatMul, OutputStoreSpec,
10    PanelExtractInput, PanelExtractor,
11};
12use tract_linalg::pack::PackedFormat;
13use tract_linalg::{BinOp, Scaler};
14use tract_smallvec::ToSmallVec;
15
16use super::ModePicker;
17
18#[derive(Clone, Debug)]
19pub enum ProtoFusedSpec {
20    AddMatMul {
21        geo: AddMatMulGeometry,
22        a: usize,
23        b: usize,
24        packings: Vec<(usize, Option<PanelExtractor>)>,
25    },
26    BinScalar(usize, BinOp),
27    LeakyRelu(usize),
28    BinPerRow(usize, BinOp, MapOutputAxisToInput),
29    BinPerCol(usize, BinOp, MapOutputAxisToInput),
30    AddRowColProducts(usize, usize),
31    AddUnicast(OutputStoreSpec, usize, MapOutputAxisToInput),
32    Scaler(Scaler),
33    Store(Vec<OutputStoreSpec>),
34}
35
36impl ProtoFusedSpec {
37    pub fn format(&self, mmm: &dyn MatMatMul, mode: usize) -> String {
38        use ProtoFusedSpec::*;
39        match self {
40            AddMatMul { geo, packings: packing, .. } => {
41                let (a, b) = &mmm.packings()[packing[mode].0];
42                format!("matmul(k={}, {a:?}•{b:?})", geo.k)
43            }
44            BinScalar(_, op) => format!("scalar{op:?}"),
45            LeakyRelu(alpha) => format!("leaky_relu({alpha:?})"),
46            BinPerRow(_, op, _) => format!("row{op:?}"),
47            BinPerCol(_, op, _) => format!("col{op:?}"),
48            AddRowColProducts(_, _) => "add_row_col_product".to_string(),
49            AddUnicast(_, _, _) => "add_to_matrix".to_string(),
50            Scaler(s) => format!("scale({})", 1f32 * *s),
51            Store(_oss) => "store".to_string(),
52        }
53    }
54
55    pub fn resolve<'t>(
56        &'t self,
57        inputs: &'t [TValue],
58        output_coords: &[usize],
59        output: &Tensor,
60        mmm: &dyn MatMatMul,
61        mode: usize,
62    ) -> FusedSpec<'t> {
63        let fs = match self {
64            ProtoFusedSpec::AddMatMul { geo, a, b, packings } => {
65                let mut a = inputs[*a].view();
66                let mut b = inputs[*b].view();
67                unsafe {
68                    geo.c_to_a_axis_mapping.translate_view(output_coords, &mut a);
69                }
70                let a = a.as_slice::<Opaque>().unwrap()[0]
71                    .downcast_ref::<Box<dyn MMMInputValue>>()
72                    .unwrap();
73                unsafe {
74                    geo.c_to_b_axis_mapping.translate_view(output_coords, &mut b);
75                }
76                let b = b.as_slice::<Opaque>().unwrap()[0]
77                    .downcast_ref::<Box<dyn MMMInputValue>>()
78                    .unwrap();
79                let (_a_packing, b_packing) = &mmm.packings()[packings[mode].0];
80                let pa = if let Some(extractor) = &packings[mode].1 {
81                    let data = a.downcast_ref::<EagerPackedInput>().unwrap();
82                    AsInputValue::Owned(Box::new(PanelExtractInput {
83                        format: extractor.clone(),
84                        data: data.clone(),
85                    }))
86                } else {
87                    AsInputValue::Borrowed(&**a)
88                };
89                assert!(
90                    b_packing.same_as(b.format())
91                        || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
92                );
93                debug_assert!(pa.k().to_dim().compatible_with(&geo.k.to_dim()));
94                debug_assert!(b.k().to_dim().compatible_with(&geo.k.to_dim()));
95                FusedSpec::AddMatMul {
96                    a: pa,
97                    b: AsInputValue::Borrowed(&**b),
98                    packing: packings[mode].0,
99                }
100            }
101            ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
102            ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
103            ProtoFusedSpec::BinPerRow(v, op, map) => {
104                let mut v = inputs[*v].view();
105                unsafe { map.translate_view(output_coords, &mut v) }
106                FusedSpec::BinPerRow(v, *op)
107            }
108            ProtoFusedSpec::BinPerCol(v, op, map) => {
109                let mut v = inputs[*v].view();
110                unsafe { map.translate_view(output_coords, &mut v) }
111                FusedSpec::BinPerCol(v, *op)
112            }
113            ProtoFusedSpec::AddRowColProducts(row, col) => {
114                FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
115            }
116            ProtoFusedSpec::AddUnicast(store, v, map) => unsafe {
117                let mut view = inputs[*v].view();
118                map.translate_view(output_coords, &mut view);
119                FusedSpec::AddUnicast(store.wrap(&view))
120            },
121            ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
122            ProtoFusedSpec::Store(oss) => unsafe {
123                let view = output.view_offsetting_unchecked(output_coords);
124                FusedSpec::Store(oss[mode].wrap(&view))
125            },
126        };
127        fs
128    }
129
130    pub fn is_trivial(&self) -> bool {
131        match self {
132            ProtoFusedSpec::AddMatMul { geo, .. } => geo.k.as_i64().is_some(),
133            _ => true,
134        }
135    }
136
137    pub fn resolve_trivial<'t>(
138        &'t self,
139        inputs: &'t [TValue],
140        output: &mut Tensor,
141        _mmm: &dyn MatMatMul,
142        mode: usize,
143    ) -> FusedSpec<'t> {
144        let fs = match self {
145            ProtoFusedSpec::AddMatMul { a, b, packings, .. } => unsafe {
146                debug_assert!(inputs.get(*a).is_some());
147                debug_assert!(inputs.get(*b).is_some());
148                let a = inputs.get_unchecked(*a);
149                let b = inputs.get_unchecked(*b);
150                debug_assert!(a.datum_type().is_opaque());
151                debug_assert!(a.len() == 1);
152                debug_assert!(b.datum_type().is_opaque());
153                debug_assert!(b.len() == 1);
154                let a = a.as_slice_unchecked::<Opaque>().get_unchecked(0);
155                let b = b.as_slice_unchecked::<Opaque>().get_unchecked(0);
156                debug_assert!(a.is::<Box<dyn MMMInputValue>>());
157                debug_assert!(b.is::<Box<dyn MMMInputValue>>());
158                let a = a.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
159                let b = b.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
160                #[cfg(debug_assertions)]
161                {
162                    let (a_packing, b_packing) = &_mmm.packings()[packings[mode].0];
163                    debug_assert!(
164                        a_packing.same_as(a.format())
165                            || (a_packing.is::<PackedFormat>() && a_packing.r() == a.format().r())
166                    );
167                    debug_assert!(
168                        b_packing.same_as(b.format())
169                            || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
170                    );
171                }
172                FusedSpec::AddMatMul {
173                    a: AsInputValue::Borrowed(&**a),
174                    b: AsInputValue::Borrowed(&**b),
175                    packing: packings[mode].0,
176                }
177            },
178            ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
179            ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
180            ProtoFusedSpec::BinPerRow(v, op, _) => {
181                let v = inputs[*v].view();
182                FusedSpec::BinPerRow(v, *op)
183            }
184            ProtoFusedSpec::BinPerCol(v, op, _) => {
185                let v = inputs[*v].view();
186                FusedSpec::BinPerCol(v, *op)
187            }
188            ProtoFusedSpec::AddRowColProducts(row, col) => {
189                FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
190            }
191            ProtoFusedSpec::AddUnicast(store, v, _) => unsafe {
192                let view = inputs[*v].view();
193                FusedSpec::AddUnicast(store.wrap(&view))
194            },
195            ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
196            ProtoFusedSpec::Store(oss) => unsafe {
197                FusedSpec::Store(oss[mode].wrap(&output.view_mut()))
198            },
199        };
200        fs
201    }
202
203    fn check_inputs(&self, inputs: &[&TypedFact]) -> TractResult<()> {
204        use ProtoFusedSpec::*;
205        match self {
206            AddMatMul { a, b, .. } => {
207                ensure!(inputs[*a].datum_type == Opaque::datum_type());
208                ensure!(inputs[*b].datum_type == Opaque::datum_type());
209            }
210            BinScalar(v, _)
211            | LeakyRelu(v)
212            | BinPerCol(v, _, _)
213            | BinPerRow(v, _, _)
214            | AddUnicast(_, v, _) => {
215                ensure!(inputs[*v].datum_type.is_number());
216            }
217            AddRowColProducts(row, col) => {
218                ensure!(inputs[*row].datum_type.is_number());
219                ensure!(inputs[*col].datum_type.is_number());
220            }
221            _ => (),
222        };
223        Ok(())
224    }
225
226    fn cost(&self, m: &TDim, n: &TDim, idt: DatumType) -> TVec<(Cost, TDim)> {
227        match self {
228            ProtoFusedSpec::AddMatMul { geo, .. } => {
229                tvec!((Cost::FMA(idt), m.clone() * n * &geo.k))
230            }
231            _ => tvec!(), /* FIXME maybe */
232        }
233    }
234
235    fn rm_c_axis(&mut self, axis: usize) {
236        use ProtoFusedSpec::*;
237        match self {
238            AddMatMul { geo, .. } => {
239                geo.c_to_a_axis_mapping.rm_c_axis(axis);
240                geo.c_to_b_axis_mapping.rm_c_axis(axis);
241            }
242            BinScalar(..) | Scaler(..) | AddRowColProducts(_, _) | LeakyRelu(_) => {}
243            BinPerRow(_, _, map) | BinPerCol(_, _, map) => map.rm_c_axis(axis),
244            AddUnicast(_, _, map) => {
245                map.rm_c_axis(axis);
246            }
247            Store(oss, ..) => {
248                for oss in oss {
249                    match oss {
250                        OutputStoreSpec::View { m_axis, n_axis, .. } => {
251                            *m_axis -= (*m_axis > axis) as usize;
252                            *n_axis -= (*n_axis > axis) as usize;
253                        }
254                        OutputStoreSpec::Strides { .. } => {}
255                    }
256                }
257            }
258        }
259    }
260}
261
262#[derive(Clone, Debug)]
263pub struct MapOutputAxisToInput(pub TVec<(usize, usize)>);
264
265impl MapOutputAxisToInput {
266    #[inline]
267    unsafe fn translate_view(&self, output_coords: &[usize], v: &mut TensorView) {
268        for &(out_axis, in_axis) in &self.0 {
269            v.offset_axis(in_axis, output_coords[out_axis] as isize)
270        }
271    }
272
273    #[inline]
274    fn rm_c_axis(&mut self, axis: usize) {
275        for (c, _) in &mut self.0 {
276            *c -= (*c > axis) as usize;
277        }
278    }
279}
280
281#[derive(Clone, Debug)]
282pub struct AddMatMulGeometry {
283    pub k: TDim,
284    pub c_to_a_axis_mapping: MapOutputAxisToInput,
285    pub c_to_b_axis_mapping: MapOutputAxisToInput,
286}
287
288#[derive(Clone, Debug)]
289pub struct OptMatMul {
290    pub c_fact: TypedFact,
291    pub micro_ops: Vec<ProtoFusedSpec>,
292    pub mmm: Vec<Box<dyn MatMatMul>>,
293    pub mode_picker: ModePicker,
294    pub c_m_axis: usize,
295    pub c_n_axis: usize,
296    pub trivial_packing: bool,
297    pub trivial_path: bool,
298}
299
300impl Op for OptMatMul {
301    fn name(&self) -> Cow<str> {
302        "OptMatMul".into()
303    }
304
305    fn info(&self) -> TractResult<Vec<String>> {
306        let m = &self.c_fact.shape[self.c_m_axis];
307        let n = &self.c_fact.shape[self.c_n_axis];
308        let mut infos = vec![format!(
309            "c_shape:{:?}, c_m_axis:{} c_n_axis:{} m:{} n:{}",
310            self.c_fact, self.c_m_axis, self.c_n_axis, m, n,
311        )];
312        if let Some(k) = self.guess_k() {
313            infos.push(format!("Mult: m:{} k:{} n:{} with {:?}", m, k, n, self.mmm));
314        } else {
315            infos.push(format!("Mult: {:?}", self.mmm));
316        }
317        for (mode, mmm) in self.mmm.iter().enumerate() {
318            infos.push(format!(
319                "Ops: {}",
320                self.micro_ops.iter().map(|o| o.format(&**mmm, mode)).join(" >>> ")
321            ));
322        }
323        Ok(infos)
324    }
325
326    op_as_typed_op!();
327}
328
329impl EvalOp for OptMatMul {
330    fn is_stateless(&self) -> bool {
331        true
332    }
333
334    fn eval_with_session(
335        &self,
336        session: &SessionState,
337        inputs: TVec<TValue>,
338    ) -> TractResult<TVec<TValue>> {
339        unsafe {
340            let c_shape = self.c_fact.shape.eval_to_usize(&session.resolved_symbols)?;
341            let mut c = Tensor::uninitialized_dt(self.c_fact.datum_type, &c_shape)?;
342            let mode = self.mode_picker.pick(c_shape[self.c_n_axis])?;
343            let mmm = &*self.mmm[mode];
344            let mut cell = session.cached_mmm_scratch_space.borrow_mut();
345            if !cell.as_ref().is_some_and(|scratch| mmm.can_use_scratch_space(&**scratch)) {
346                *cell = None
347            }
348            let scratch = cell.get_or_insert_with(|| mmm.allocate_scratch_space());
349
350            if self.trivial_path {
351                let uops: Vec<FusedSpec> = self
352                    .micro_ops
353                    .iter()
354                    .map(|o| o.resolve_trivial(&inputs, &mut c, mmm, mode))
355                    .collect();
356                mmm.run_with_scratch_space(
357                    *c_shape.get_unchecked(self.c_m_axis),
358                    *c_shape.get_unchecked(self.c_n_axis),
359                    scratch.as_mut(),
360                    &uops,
361                )?;
362                Ok(tvec!(c.into_tvalue()))
363            } else {
364                let mut uops = vec![FusedSpec::ShiftLeft(0); self.micro_ops.len()];
365                let mut looping_shape: TVec<usize> = c_shape.to_smallvec();
366                looping_shape[self.c_m_axis] = 1;
367                looping_shape[self.c_n_axis] = 1;
368                let m = c_shape[self.c_m_axis];
369                let n = c_shape[self.c_n_axis];
370                for c_coords in indices(&*looping_shape) {
371                    for ix in 0..self.micro_ops.len() {
372                        *uops.get_unchecked_mut(ix) = self.micro_ops.get_unchecked(ix).resolve(
373                            &inputs,
374                            c_coords.slice(),
375                            &c,
376                            mmm,
377                            mode,
378                        );
379                    }
380                    mmm.run_with_scratch_space(m, n, scratch.as_mut(), &uops)
381                        .context("In mmm.run_with_scratch_space")?;
382                }
383                Ok(tvec!(c.into_tvalue()))
384            }
385        }
386    }
387}
388
389impl TypedOp for OptMatMul {
390    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
391        ensure!(self.c_m_axis < self.c_fact.rank());
392        ensure!(self.c_n_axis < self.c_fact.rank());
393        ensure!(self.trivial_path == self.can_use_trivial_path());
394        ensure!(self.mmm.iter().map(|mmm| mmm.internal_type()).all_equal());
395        for op in &self.micro_ops {
396            op.check_inputs(inputs)?;
397        }
398        Ok(tvec!(self.c_fact.clone()))
399    }
400
401    fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
402        let mut sums = HashMap::new();
403        let m = &self.c_fact.shape[self.c_m_axis];
404        let n = &self.c_fact.shape[self.c_n_axis];
405        for op in &self.micro_ops {
406            for (cost, count) in op.cost(m, n, self.mmm[0].internal_type()) {
407                *sums.entry(cost).or_default() += count;
408            }
409        }
410        let loops =
411            self.c_fact
412                .shape
413                .iter()
414                .enumerate()
415                .map(|(ix, d)| {
416                    if ix == self.c_m_axis || ix == self.c_n_axis {
417                        1.to_dim()
418                    } else {
419                        d.clone()
420                    }
421                })
422                .product::<TDim>();
423        for s in &mut sums.values_mut() {
424            *s *= &loops;
425        }
426        Ok(sums.into_iter().collect())
427    }
428
429    fn fuse(&self, model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
430        use crate::ops;
431        if node.outputs.len() != 1
432            || node.outputs[0].successors.len() != 1
433            || model.output_outlets()?.contains(&node.id.into())
434        {
435            return Ok(None);
436        }
437        let succ = model.node(node.outputs[0].successors[0].node);
438        let mut patch = TypedModelPatch::new(format!("fusing {succ}"));
439
440        if let Some(op) = succ.op_as::<ops::binary::TypedBinOp>() {
441            let mut binop = if let Some(op) = op.0.as_linalg_binop() {
442                op
443            } else {
444                return Ok(None);
445            };
446            let flipped = succ.inputs[0].node == node.id;
447            if flipped {
448                binop = binop.flip();
449            }
450            let other_outlet = succ.inputs[flipped as usize];
451            return self.fuse_binary(model, node, patch, other_outlet, binop);
452        }
453        if let Some(op) = succ.op_as::<ops::binary::OptBinByScalar>() {
454            let mut binop = if let Some(op) = op.binop.as_linalg_binop() {
455                op
456            } else {
457                return Ok(None);
458            };
459            let flipped = succ.inputs[0].node == node.id;
460            if flipped {
461                binop = binop.flip();
462            }
463            let other_outlet = succ.inputs[flipped as usize];
464            return self.fuse_binary(model, node, patch, other_outlet, binop);
465        }
466
467        if let Some(op) = succ.op_as::<ops::element_wise::ElementWiseOp>().map(|ew| ew.0.as_ref()) {
468            if let Some(op) = op.downcast_ref::<ops::math::QScale>() {
469                return self.fuse_op(
470                    model,
471                    node,
472                    patch,
473                    vec![ProtoFusedSpec::Scaler(op.scaler)],
474                    &[],
475                );
476            }
477            if let Some(op) = op.downcast_ref::<LeakyRelu>() {
478                if !self
479                    .mmm
480                    .iter()
481                    .all(|mmm| mmm.can_fuse(&FusedSpec::LeakyRelu(&tensor0(op.alpha))))
482                {
483                    return Ok(None);
484                }
485                let alpha = patch.add_const(
486                    node.name.to_string() + ".alpha",
487                    tensor0(op.alpha).cast_to_dt(self.mmm[0].internal_type())?.into_owned(),
488                )?;
489                return self.fuse_op(
490                    model,
491                    node,
492                    patch,
493                    vec![ProtoFusedSpec::LeakyRelu(node.inputs.len())],
494                    &[alpha],
495                );
496            }
497        }
498        if let Some(cast_to) = succ.op_as::<ops::cast::Cast>().map(|cast| cast.to) {
499            if (cast_to.unquantized() == i8::datum_type()
500                || cast_to.unquantized() == u8::datum_type())
501                && self.c_fact.datum_type == i32::datum_type()
502            {
503                if let Some(ProtoFusedSpec::Store(stores)) = self.micro_ops.last() {
504                    if stores.iter().any(|s| matches!(s, OutputStoreSpec::Strides { .. })) {
505                        return Ok(None);
506                    }
507                    let c_fact = cast_to.fact(self.c_fact.shape.clone());
508                    let mut patch = TypedModelPatch::fuse_with_next(
509                        model,
510                        node,
511                        Self { c_fact, ..self.clone() },
512                    )?;
513                    patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
514                    return Ok(Some(patch));
515                }
516            }
517        }
518        if let Some(AxisOp::Rm(axis)) = succ.op_as::<ops::AxisOp>() {
519            if *axis == self.c_m_axis || *axis == self.c_n_axis {
520                return Ok(None);
521            }
522            let mut new_op = self.clone();
523            new_op.c_fact.shape.remove_axis(*axis)?;
524            new_op.c_m_axis -= (new_op.c_m_axis > *axis) as usize;
525            new_op.c_n_axis -= (new_op.c_n_axis > *axis) as usize;
526            for uop in &mut new_op.micro_ops {
527                uop.rm_c_axis(*axis);
528            }
529            let mut patch = TypedModelPatch::fuse_with_next(model, node, new_op)?;
530            patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
531            return Ok(Some(patch));
532        }
533        if succ.op_is::<AxisOp>() {
534            if let &[next] = &*succ.outputs[0].successors {
535                let bin = model.node(next.node);
536                if let Some(op) = bin.op_as::<ops::binary::TypedBinOp>() {
537                    if op.0.as_linalg_binop().is_none() {
538                        return Ok(None);
539                    };
540                    let flipped = succ.inputs[0].node == node.id;
541                    let other_outlet = bin.inputs[flipped as usize];
542                    if let Some(uni) = &model.outlet_fact(other_outlet)?.uniform {
543                        let mut patch = TypedModelPatch::default();
544                        let cst =
545                            patch.add_const(&model.node(other_outlet.node).name, uni.clone())?;
546                        let output = patch.tap_model(model, node.id.into())?;
547                        let wire = wire_with_rank_broadcast(
548                            &bin.name,
549                            &mut patch,
550                            op.clone(),
551                            &if flipped { [output, cst] } else { [cst, output] },
552                        )?;
553                        let wire = patch.wire_node(&succ.name, succ.op.clone(), &wire)?[0];
554                        patch.shunt_outside(model, bin.id.into(), wire)?;
555                        return Ok(Some(patch));
556                    }
557                }
558            }
559        }
560        if let Some(op) = succ.op_as::<ops::binary::OptBinUnicast>() {
561            let in_1_fact = model.outlet_fact(succ.inputs[0])?;
562            let in_2_fact = model.outlet_fact(succ.inputs[1])?;
563            if op.binop.is::<ops::math::Add>()
564                && self.mmm.len() == 1
565                && in_1_fact.without_value() == in_2_fact.without_value()
566            {
567                let other_slot = 1 - node.outputs[0].successors[0].slot;
568                let other_input = succ.inputs[other_slot];
569                let other_input = patch.tap_model(model, other_input)?;
570                let other_fact = patch.outlet_fact(other_input)?;
571
572                if other_fact.shape == self.c_fact.shape {
573                    let other_storage = unsafe { self.mmm[0].c_view(self.c_m_axis, self.c_n_axis) };
574                    let mapping =
575                        MapOutputAxisToInput((0..other_fact.rank()).map(|x| (x, x)).collect());
576                    return self.fuse_op(
577                        model,
578                        node,
579                        patch,
580                        vec![ProtoFusedSpec::AddUnicast(other_storage, node.inputs.len(), mapping)],
581                        &[other_input],
582                    );
583                }
584            } else {
585                let mut binop = if let Some(op) = op.binop.as_linalg_binop() {
586                    op
587                } else {
588                    return Ok(None);
589                };
590                let flipped = succ.inputs[0].node == node.id;
591                if flipped {
592                    binop = binop.flip();
593                }
594                let other_outlet = succ.inputs[flipped as usize];
595                return self.fuse_binary(model, node, patch, other_outlet, binop);
596            }
597        };
598        Ok(None)
599    }
600
601    as_op!();
602}
603
604impl OptMatMul {
605    pub fn new(
606        mmm: Vec<Box<dyn MatMatMul>>,
607        mode_picker: ModePicker,
608        c_fact: TypedFact,
609        c_m_axis: usize,
610        c_n_axis: usize,
611        micro_ops: Vec<ProtoFusedSpec>,
612        trivial_packing: bool,
613    ) -> TractResult<Self> {
614        ensure!(c_m_axis < c_fact.rank());
615        ensure!(c_n_axis < c_fact.rank());
616        let mut it = OptMatMul {
617            mmm,
618            mode_picker,
619            c_fact,
620            c_m_axis,
621            c_n_axis,
622            micro_ops,
623            trivial_path: false,
624            trivial_packing,
625        };
626        it.update_trivial_path();
627        Ok(it)
628    }
629
630    // for auditing only (may return None if no AddMatMul is found)
631    pub fn guess_k(&self) -> Option<TDim> {
632        self.micro_ops
633            .iter()
634            .find_map(
635                |o| {
636                    if let ProtoFusedSpec::AddMatMul { geo, .. } = o {
637                        Some(geo)
638                    } else {
639                        None
640                    }
641                },
642            )
643            .map(|geo| geo.k.clone())
644    }
645
646    pub fn m(&self) -> &TDim {
647        &self.c_fact.shape[self.c_m_axis]
648    }
649
650    pub fn n(&self) -> &TDim {
651        &self.c_fact.shape[self.c_n_axis]
652    }
653
654    fn update_trivial_path(&mut self) {
655        self.trivial_path = self.can_use_trivial_path();
656    }
657
658    fn can_use_trivial_path(&self) -> bool {
659        self.c_fact.shape.is_concrete()
660            && self
661                .c_fact
662                .shape
663                .iter()
664                .enumerate()
665                .all(|(ax, dim)| ax == self.c_m_axis || ax == self.c_n_axis || dim.is_one())
666            && self.trivial_packing
667            && self.micro_ops.iter().all(|o| o.is_trivial())
668    }
669
670    fn fuse_op(
671        &self,
672        model: &TypedModel,
673        node: &TypedNode,
674        mut patch: TypedModelPatch,
675        fused_micro_op: Vec<ProtoFusedSpec>,
676        additional_inputs: &[OutletId],
677    ) -> TractResult<Option<TypedModelPatch>> {
678        let succ = model.node(node.outputs[0].successors[0].node);
679        let mut new_op = self.clone();
680        let before_last = new_op.micro_ops.len() - 1..new_op.micro_ops.len() - 1;
681        new_op.micro_ops.splice(before_last, fused_micro_op);
682        new_op.c_fact = succ.outputs[0].fact.clone();
683        new_op.update_trivial_path();
684        let mut inputs = patch.taps(model, &node.inputs)?;
685        inputs.extend(additional_inputs.iter().cloned());
686        let output = patch.wire_node(&succ.name, new_op, &inputs)?;
687        patch.shunt_outside(model, succ.id.into(), output[0])?;
688        Ok(Some(patch))
689    }
690
691    fn fuse_binary(
692        &self,
693        model: &TypedModel,
694        node: &TypedNode,
695        mut patch: TypedModelPatch,
696        value: OutletId,
697        binop: BinOp,
698    ) -> TractResult<Option<TypedModelPatch>> {
699        let fact = model.outlet_fact(value)?;
700        let mut v = patch.tap_model(model, value)?;
701        if fact.datum_type != self.mmm[0].internal_type() {
702            v = patch.wire_node(
703                format!("{}.cast-input-{}", node.name, node.inputs.len()),
704                cast(self.mmm[0].internal_type()),
705                &[v],
706            )?[0];
707        }
708        let value = node.inputs.len();
709        let additional_input = tvec!(v);
710        if fact.shape.volume() == 1.to_dim() {
711            return self.fuse_op(
712                model,
713                node,
714                patch,
715                vec![ProtoFusedSpec::BinScalar(value, binop)],
716                &additional_input,
717            );
718        }
719        let other_shape = fact.shape.to_owned();
720        if other_shape[self.c_m_axis] == self.c_fact.shape[self.c_m_axis]
721            && other_shape[self.c_m_axis] == other_shape.volume()
722        {
723            return self.fuse_op(
724                model,
725                node,
726                patch,
727                vec![ProtoFusedSpec::BinPerRow(
728                    value,
729                    binop,
730                    MapOutputAxisToInput(tvec!((self.c_m_axis, self.c_m_axis))),
731                )],
732                &additional_input,
733            );
734        }
735        if other_shape[self.c_n_axis] == self.c_fact.shape[self.c_n_axis]
736            && other_shape[self.c_n_axis] == other_shape.volume()
737        {
738            return self.fuse_op(
739                model,
740                node,
741                patch,
742                vec![ProtoFusedSpec::BinPerCol(
743                    value,
744                    binop,
745                    MapOutputAxisToInput(tvec!((self.c_n_axis, self.c_n_axis))),
746                )],
747                &additional_input,
748            );
749        }
750        Ok(None)
751    }
752}