tract_core/ops/matmul/
optimized.rs

1use crate::internal::*;
2use crate::ops::cast::cast;
3use crate::ops::change_axes::wire_with_rank_broadcast;
4use crate::ops::nn::LeakyRelu;
5use ndarray::*;
6use tract_itertools::Itertools;
7
8use tract_linalg::frame::PackedFormat;
9use tract_linalg::mmm::panel_extract::{PanelExtractInput, PanelExtractor};
10use tract_linalg::mmm::{
11    AsInputValue, EagerPackedInput, FusedSpec, MMMInputValue, MatMatMul, OutputStoreSpec,
12};
13use tract_linalg::{BinOp, Scaler};
14use tract_smallvec::ToSmallVec;
15
16use super::ModePicker;
17
18#[derive(Clone, Debug)]
19pub enum ProtoFusedSpec {
20    AddMatMul {
21        geo: AddMatMulGeometry,
22        a: usize,
23        b: usize,
24        packings: Vec<(usize, Option<PanelExtractor>)>,
25    },
26    BinScalar(usize, BinOp),
27    LeakyRelu(usize),
28    BinPerRow(usize, BinOp, MapOutputAxisToInput),
29    BinPerCol(usize, BinOp, MapOutputAxisToInput),
30    AddRowColProducts(usize, usize),
31    AddUnicast(OutputStoreSpec, usize, MapOutputAxisToInput),
32    Scaler(Scaler),
33    Store(Vec<OutputStoreSpec>),
34}
35
36impl ProtoFusedSpec {
37    pub fn format(&self, mmm: &dyn MatMatMul, mode: usize) -> String {
38        use ProtoFusedSpec::*;
39        match self {
40            AddMatMul { geo, packings: packing, .. } => {
41                let (a, b) = &mmm.packings()[packing[mode].0];
42                format!("matmul(k={}, {a:?}•{b:?})", geo.k)
43            }
44            BinScalar(_, op) => format!("scalar{op:?}"),
45            LeakyRelu(alpha) => format!("leaky_relu({alpha:?})"),
46            BinPerRow(_, op, _) => format!("row{op:?}"),
47            BinPerCol(_, op, _) => format!("col{op:?}"),
48            AddRowColProducts(_, _) => "add_row_col_product".to_string(),
49            AddUnicast(_, _, _) => "add_to_matrix".to_string(),
50            Scaler(s) => format!("scale({})", 1f32 * *s),
51            Store(_oss) => "store".to_string(),
52        }
53    }
54
55    pub fn resolve<'t>(
56        &'t self,
57        inputs: &'t [TValue],
58        output_coords: &[usize],
59        output: &Tensor,
60        mmm: &dyn MatMatMul,
61        mode: usize,
62    ) -> FusedSpec<'t> {
63        let fs = match self {
64            ProtoFusedSpec::AddMatMul { geo, a, b, packings } => {
65                let mut a = inputs[*a].view();
66                let mut b = inputs[*b].view();
67                unsafe {
68                    geo.c_to_a_axis_mapping.translate_view(output_coords, &mut a);
69                }
70                let a = a.as_slice::<Opaque>().unwrap()[0]
71                    .downcast_ref::<Box<dyn MMMInputValue>>()
72                    .unwrap();
73                unsafe {
74                    geo.c_to_b_axis_mapping.translate_view(output_coords, &mut b);
75                }
76                let b = b.as_slice::<Opaque>().unwrap()[0]
77                    .downcast_ref::<Box<dyn MMMInputValue>>()
78                    .unwrap();
79                let (_a_packing, b_packing) = &mmm.packings()[packings[mode].0];
80                let pa = if let Some(extractor) = &packings[mode].1 {
81                    let data = a.downcast_ref::<EagerPackedInput>().unwrap();
82                    AsInputValue::Owned(Box::new(PanelExtractInput {
83                        format: extractor.clone(),
84                        data: data.clone(),
85                    }))
86                } else {
87                    AsInputValue::Borrowed(&**a)
88                };
89                assert!(
90                    b_packing.same_as(b.format())
91                        || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
92                );
93                FusedSpec::AddMatMul {
94                    a: pa,
95                    b: AsInputValue::Borrowed(&**b),
96                    packing: packings[mode].0,
97                }
98            }
99            ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
100            ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
101            ProtoFusedSpec::BinPerRow(v, op, map) => {
102                let mut v = inputs[*v].view();
103                unsafe { map.translate_view(output_coords, &mut v) }
104                FusedSpec::BinPerRow(v, *op)
105            }
106            ProtoFusedSpec::BinPerCol(v, op, map) => {
107                let mut v = inputs[*v].view();
108                unsafe { map.translate_view(output_coords, &mut v) }
109                FusedSpec::BinPerCol(v, *op)
110            }
111            ProtoFusedSpec::AddRowColProducts(row, col) => {
112                FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
113            }
114            ProtoFusedSpec::AddUnicast(store, v, map) => unsafe {
115                let mut view = inputs[*v].view();
116                map.translate_view(output_coords, &mut view);
117                FusedSpec::AddUnicast(store.wrap(&view))
118            },
119            ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
120            ProtoFusedSpec::Store(oss) => unsafe {
121                let view = output.view_offsetting_unchecked(output_coords);
122                FusedSpec::Store(oss[mode].wrap(&view))
123            },
124        };
125        fs
126    }
127
128    pub fn is_trivial(&self) -> bool {
129        match self {
130            ProtoFusedSpec::AddMatMul { geo, .. } => geo.k.as_i64().is_some(),
131            _ => true,
132        }
133    }
134
135    pub fn resolve_trivial<'t>(
136        &'t self,
137        inputs: &'t [TValue],
138        output: &mut Tensor,
139        _mmm: &dyn MatMatMul,
140        mode: usize,
141    ) -> FusedSpec<'t> {
142        let fs = match self {
143            ProtoFusedSpec::AddMatMul { a, b, packings, .. } => unsafe {
144                debug_assert!(inputs.get(*a).is_some());
145                debug_assert!(inputs.get(*b).is_some());
146                let a = inputs.get_unchecked(*a);
147                let b = inputs.get_unchecked(*b);
148                debug_assert!(a.datum_type().is_opaque());
149                debug_assert!(a.len() == 1);
150                debug_assert!(b.datum_type().is_opaque());
151                debug_assert!(b.len() == 1);
152                let a = a.as_slice_unchecked::<Opaque>().get_unchecked(0);
153                let b = b.as_slice_unchecked::<Opaque>().get_unchecked(0);
154                debug_assert!(a.is::<Box<dyn MMMInputValue>>());
155                debug_assert!(b.is::<Box<dyn MMMInputValue>>());
156                let a = a.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
157                let b = b.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
158                #[cfg(debug_assertions)]
159                {
160                    let (a_packing, b_packing) = &_mmm.packings()[packings[mode].0];
161                    debug_assert!(
162                        a_packing.same_as(a.format())
163                            || (a_packing.is::<PackedFormat>() && a_packing.r() == a.format().r())
164                    );
165                    debug_assert!(
166                        b_packing.same_as(b.format())
167                            || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
168                    );
169                }
170                FusedSpec::AddMatMul {
171                    a: AsInputValue::Borrowed(&**a),
172                    b: AsInputValue::Borrowed(&**b),
173                    packing: packings[mode].0,
174                }
175            },
176            ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
177            ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
178            ProtoFusedSpec::BinPerRow(v, op, _) => {
179                let v = inputs[*v].view();
180                FusedSpec::BinPerRow(v, *op)
181            }
182            ProtoFusedSpec::BinPerCol(v, op, _) => {
183                let v = inputs[*v].view();
184                FusedSpec::BinPerCol(v, *op)
185            }
186            ProtoFusedSpec::AddRowColProducts(row, col) => {
187                FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
188            }
189            ProtoFusedSpec::AddUnicast(store, v, _) => unsafe {
190                let view = inputs[*v].view();
191                FusedSpec::AddUnicast(store.wrap(&view))
192            },
193            ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
194            ProtoFusedSpec::Store(oss) => unsafe {
195                FusedSpec::Store(oss[mode].wrap(&output.view_mut()))
196            },
197        };
198        fs
199    }
200
201    fn check_inputs(&self, inputs: &[&TypedFact]) -> TractResult<()> {
202        use ProtoFusedSpec::*;
203        match self {
204            AddMatMul { a, b, .. } => {
205                ensure!(inputs[*a].datum_type == Opaque::datum_type());
206                ensure!(inputs[*b].datum_type == Opaque::datum_type());
207            }
208            BinScalar(v, _)
209            | LeakyRelu(v)
210            | BinPerCol(v, _, _)
211            | BinPerRow(v, _, _)
212            | AddUnicast(_, v, _) => {
213                ensure!(inputs[*v].datum_type.is_number());
214            }
215            AddRowColProducts(row, col) => {
216                ensure!(inputs[*row].datum_type.is_number());
217                ensure!(inputs[*col].datum_type.is_number());
218            }
219            _ => (),
220        };
221        Ok(())
222    }
223
224    fn cost(&self, m: &TDim, n: &TDim, idt: DatumType) -> TVec<(Cost, TDim)> {
225        match self {
226            ProtoFusedSpec::AddMatMul { geo, .. } => {
227                tvec!((Cost::FMA(idt), m.clone() * n * &geo.k))
228            }
229            _ => tvec!(), /* FIXME maybe */
230        }
231    }
232
233    fn rm_c_axis(&mut self, axis: usize) {
234        use ProtoFusedSpec::*;
235        match self {
236            AddMatMul { geo, .. } => {
237                geo.c_to_a_axis_mapping.rm_c_axis(axis);
238                geo.c_to_b_axis_mapping.rm_c_axis(axis);
239            }
240            BinScalar(..) | Scaler(..) | AddRowColProducts(_, _) | LeakyRelu(_) => {}
241            BinPerRow(_, _, map) | BinPerCol(_, _, map) => map.rm_c_axis(axis),
242            AddUnicast(_, _, map) => {
243                map.rm_c_axis(axis);
244            }
245            Store(oss, ..) => {
246                for oss in oss {
247                    match oss {
248                        OutputStoreSpec::View { m_axis, n_axis, .. } => {
249                            *m_axis -= (*m_axis > axis) as usize;
250                            *n_axis -= (*n_axis > axis) as usize;
251                        }
252                        OutputStoreSpec::Strides { .. } => {}
253                    }
254                }
255            }
256        }
257    }
258}
259
260#[derive(Clone, Debug)]
261pub struct MapOutputAxisToInput(pub TVec<(usize, usize)>);
262
263impl MapOutputAxisToInput {
264    #[inline]
265    unsafe fn translate_view(&self, output_coords: &[usize], v: &mut TensorView) {
266        for &(out_axis, in_axis) in &self.0 {
267            v.offset_axis(in_axis, output_coords[out_axis] as isize)
268        }
269    }
270
271    #[inline]
272    fn rm_c_axis(&mut self, axis: usize) {
273        for (c, _) in &mut self.0 {
274            *c -= (*c > axis) as usize;
275        }
276    }
277}
278
279#[derive(Clone, Debug)]
280pub struct AddMatMulGeometry {
281    pub k: TDim,
282    pub c_to_a_axis_mapping: MapOutputAxisToInput,
283    pub c_to_b_axis_mapping: MapOutputAxisToInput,
284}
285
286#[derive(Clone, Debug)]
287pub struct OptMatMul {
288    pub c_fact: TypedFact,
289    pub micro_ops: Vec<ProtoFusedSpec>,
290    pub mmm: Vec<Box<dyn MatMatMul>>,
291    pub mode_picker: ModePicker,
292    pub c_m_axis: usize,
293    pub c_n_axis: usize,
294    pub trivial_packing: bool,
295    pub trivial_path: bool,
296}
297
298impl Op for OptMatMul {
299    fn name(&self) -> Cow<str> {
300        "OptMatMul".into()
301    }
302
303    fn info(&self) -> TractResult<Vec<String>> {
304        let m = &self.c_fact.shape[self.c_m_axis];
305        let n = &self.c_fact.shape[self.c_n_axis];
306        let mut infos = vec![format!(
307            "c_shape:{:?}, c_m_axis:{} c_n_axis:{} m:{} n:{}",
308            self.c_fact, self.c_m_axis, self.c_n_axis, m, n,
309        )];
310        if let Some(k) = self.guess_k() {
311            infos.push(format!("Mult: m:{} k:{} n:{} with {:?}", m, k, n, self.mmm));
312        } else {
313            infos.push(format!("Mult: {:?}", self.mmm));
314        }
315        for (mode, mmm) in self.mmm.iter().enumerate() {
316            infos.push(format!(
317                "Ops: {}",
318                self.micro_ops.iter().map(|o| o.format(&**mmm, mode)).join(" >>> ")
319            ));
320        }
321        Ok(infos)
322    }
323
324    op_as_typed_op!();
325}
326
327impl EvalOp for OptMatMul {
328    fn is_stateless(&self) -> bool {
329        true
330    }
331
332    fn eval_with_session(
333        &self,
334        session: &SessionState,
335        inputs: TVec<TValue>,
336    ) -> TractResult<TVec<TValue>> {
337        unsafe {
338            let c_shape = self.c_fact.shape.eval_to_usize(&session.resolved_symbols)?;
339            let mut c = Tensor::uninitialized_dt(self.c_fact.datum_type, &c_shape)?;
340            let mode = self.mode_picker.pick(c_shape[self.c_n_axis])?;
341            let mmm = &*self.mmm[mode];
342            let mut cell = session.cached_mmm_scratch_space.borrow_mut();
343            if !cell.as_ref().is_some_and(|scratch| mmm.can_use_scratch_space(&**scratch)) {
344                *cell = None
345            }
346            let scratch = cell.get_or_insert_with(|| mmm.allocate_scratch_space());
347
348            if self.trivial_path {
349                let uops: Vec<FusedSpec> = self
350                    .micro_ops
351                    .iter()
352                    .map(|o| o.resolve_trivial(&inputs, &mut c, mmm, mode))
353                    .collect();
354                mmm.run_with_scratch_space(
355                    *c_shape.get_unchecked(self.c_m_axis),
356                    *c_shape.get_unchecked(self.c_n_axis),
357                    scratch.as_mut(),
358                    &uops,
359                )?;
360                Ok(tvec!(c.into_tvalue()))
361            } else {
362                let mut uops = vec![FusedSpec::ShiftLeft(0); self.micro_ops.len()];
363                let mut looping_shape: TVec<usize> = c_shape.to_smallvec();
364                looping_shape[self.c_m_axis] = 1;
365                looping_shape[self.c_n_axis] = 1;
366                let m = c_shape[self.c_m_axis];
367                let n = c_shape[self.c_n_axis];
368                for c_coords in indices(&*looping_shape) {
369                    for ix in 0..self.micro_ops.len() {
370                        *uops.get_unchecked_mut(ix) = self.micro_ops.get_unchecked(ix).resolve(
371                            &inputs,
372                            c_coords.slice(),
373                            &c,
374                            mmm,
375                            mode,
376                        );
377                    }
378                    mmm.run_with_scratch_space(m, n, scratch.as_mut(), &uops)
379                        .context("In mmm.run_with_scratch_space")?;
380                }
381                Ok(tvec!(c.into_tvalue()))
382            }
383        }
384    }
385}
386
387impl TypedOp for OptMatMul {
388    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
389        ensure!(self.c_m_axis < self.c_fact.rank());
390        ensure!(self.c_n_axis < self.c_fact.rank());
391        ensure!(self.trivial_path == self.can_use_trivial_path());
392        ensure!(self.mmm.iter().map(|mmm| mmm.internal_type()).all_equal());
393        for op in &self.micro_ops {
394            op.check_inputs(inputs)?;
395        }
396        Ok(tvec!(self.c_fact.clone()))
397    }
398
399    fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
400        let mut sums = HashMap::new();
401        let m = &self.c_fact.shape[self.c_m_axis];
402        let n = &self.c_fact.shape[self.c_n_axis];
403        for op in &self.micro_ops {
404            for (cost, count) in op.cost(m, n, self.mmm[0].internal_type()) {
405                *sums.entry(cost).or_default() += count;
406            }
407        }
408        let loops =
409            self.c_fact
410                .shape
411                .iter()
412                .enumerate()
413                .map(|(ix, d)| {
414                    if ix == self.c_m_axis || ix == self.c_n_axis {
415                        1.to_dim()
416                    } else {
417                        d.clone()
418                    }
419                })
420                .product::<TDim>();
421        for s in &mut sums.values_mut() {
422            *s *= &loops;
423        }
424        Ok(sums.into_iter().collect())
425    }
426
427    fn fuse(&self, model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
428        use crate::ops;
429        if node.outputs.len() != 1
430            || node.outputs[0].successors.len() != 1
431            || model.output_outlets()?.contains(&node.id.into())
432        {
433            return Ok(None);
434        }
435        let succ = model.node(node.outputs[0].successors[0].node);
436        let mut patch = TypedModelPatch::new(format!("fusing {succ}"));
437
438        if let Some(op) = succ.op_as::<ops::binary::TypedBinOp>() {
439            let mut binop =
440                if let Some(op) = op.0.as_linalg_binop() { op } else { return Ok(None) };
441            let flipped = succ.inputs[0].node == node.id;
442            if flipped {
443                binop = binop.flip();
444            }
445            let other_outlet = succ.inputs[flipped as usize];
446            return self.fuse_binary(model, node, patch, other_outlet, binop);
447        }
448        if let Some(op) = succ.op_as::<ops::binary::OptBinByScalar>() {
449            let mut binop =
450                if let Some(op) = op.binop.as_linalg_binop() { op } else { return Ok(None) };
451            let flipped = succ.inputs[0].node == node.id;
452            if flipped {
453                binop = binop.flip();
454            }
455            let other_outlet = succ.inputs[flipped as usize];
456            return self.fuse_binary(model, node, patch, other_outlet, binop);
457        }
458
459        if let Some(op) = succ.op_as::<ops::element_wise::ElementWiseOp>().map(|ew| ew.0.as_ref()) {
460            if let Some(op) = op.downcast_ref::<ops::math::QScale>() {
461                return self.fuse_op(
462                    model,
463                    node,
464                    patch,
465                    vec![ProtoFusedSpec::Scaler(op.scaler)],
466                    &[],
467                );
468            }
469            if let Some(op) = op.downcast_ref::<LeakyRelu>() {
470                if !self
471                    .mmm
472                    .iter()
473                    .all(|mmm| mmm.can_fuse(&FusedSpec::LeakyRelu(&tensor0(op.alpha))))
474                {
475                    return Ok(None);
476                }
477                let alpha = patch.add_const(
478                    node.name.to_string() + ".alpha",
479                    tensor0(op.alpha).cast_to_dt(self.mmm[0].internal_type())?.into_owned(),
480                )?;
481                return self.fuse_op(
482                    model,
483                    node,
484                    patch,
485                    vec![ProtoFusedSpec::LeakyRelu(node.inputs.len())],
486                    &[alpha],
487                );
488            }
489        }
490        if let Some(cast_to) = succ.op_as::<ops::cast::Cast>().map(|cast| cast.to) {
491            if (cast_to.unquantized() == i8::datum_type()
492                || cast_to.unquantized() == u8::datum_type())
493                && self.c_fact.datum_type == i32::datum_type()
494            {
495                if let Some(ProtoFusedSpec::Store(stores)) = self.micro_ops.last() {
496                    if stores.iter().any(|s| matches!(s, OutputStoreSpec::Strides { .. })) {
497                        return Ok(None);
498                    }
499                    let c_fact = cast_to.fact(self.c_fact.shape.clone());
500                    let mut patch = TypedModelPatch::fuse_with_next(
501                        model,
502                        node,
503                        Self { c_fact, ..self.clone() },
504                    )?;
505                    patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
506                    return Ok(Some(patch));
507                }
508            }
509        }
510        if let Some(AxisOp::Rm(axis)) = succ.op_as::<ops::AxisOp>() {
511            if *axis == self.c_m_axis || *axis == self.c_n_axis {
512                return Ok(None);
513            }
514            let mut new_op = self.clone();
515            new_op.c_fact.shape.remove_axis(*axis)?;
516            new_op.c_m_axis -= (new_op.c_m_axis > *axis) as usize;
517            new_op.c_n_axis -= (new_op.c_n_axis > *axis) as usize;
518            for uop in &mut new_op.micro_ops {
519                uop.rm_c_axis(*axis);
520            }
521            let mut patch = TypedModelPatch::fuse_with_next(model, node, new_op)?;
522            patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
523            return Ok(Some(patch));
524        }
525        if succ.op_is::<AxisOp>() {
526            if let &[next] = &*succ.outputs[0].successors {
527                let bin = model.node(next.node);
528                if let Some(op) = bin.op_as::<ops::binary::TypedBinOp>() {
529                    if op.0.as_linalg_binop().is_none() {
530                        return Ok(None);
531                    };
532                    let flipped = succ.inputs[0].node == node.id;
533                    let other_outlet = bin.inputs[flipped as usize];
534                    if let Some(uni) = &model.outlet_fact(other_outlet)?.uniform {
535                        let mut patch = TypedModelPatch::default();
536                        let cst =
537                            patch.add_const(&model.node(other_outlet.node).name, uni.clone())?;
538                        let output = patch.tap_model(model, node.id.into())?;
539                        let wire = wire_with_rank_broadcast(
540                            &bin.name,
541                            &mut patch,
542                            op.clone(),
543                            &if flipped { [output, cst] } else { [cst, output] },
544                        )?;
545                        let wire = patch.wire_node(&succ.name, succ.op.clone(), &wire)?[0];
546                        patch.shunt_outside(model, bin.id.into(), wire)?;
547                        return Ok(Some(patch));
548                    }
549                }
550            }
551        }
552        if let Some(op) = succ.op_as::<ops::binary::OptBinUnicast>() {
553            let in_1_fact = model.outlet_fact(succ.inputs[0])?;
554            let in_2_fact = model.outlet_fact(succ.inputs[1])?;
555            if op.binop.is::<ops::math::Add>()
556                && self.mmm.len() == 1
557                && in_1_fact.without_value() == in_2_fact.without_value()
558            {
559                let other_slot = 1 - node.outputs[0].successors[0].slot;
560                let other_input = succ.inputs[other_slot];
561                let other_input = patch.tap_model(model, other_input)?;
562                let other_fact = patch.outlet_fact(other_input)?;
563
564                if other_fact.shape == self.c_fact.shape {
565                    let other_storage = unsafe { self.mmm[0].c_view(self.c_m_axis, self.c_n_axis) };
566                    let mapping =
567                        MapOutputAxisToInput((0..other_fact.rank()).map(|x| (x, x)).collect());
568                    return self.fuse_op(
569                        model,
570                        node,
571                        patch,
572                        vec![ProtoFusedSpec::AddUnicast(other_storage, node.inputs.len(), mapping)],
573                        &[other_input],
574                    );
575                }
576            } else {
577                let mut binop =
578                    if let Some(op) = op.binop.as_linalg_binop() { op } else { return Ok(None) };
579                let flipped = succ.inputs[0].node == node.id;
580                if flipped {
581                    binop = binop.flip();
582                }
583                let other_outlet = succ.inputs[flipped as usize];
584                return self.fuse_binary(model, node, patch, other_outlet, binop);
585            }
586        };
587        Ok(None)
588    }
589
590    as_op!();
591}
592
593impl OptMatMul {
594    pub fn new(
595        mmm: Vec<Box<dyn MatMatMul>>,
596        mode_picker: ModePicker,
597        c_fact: TypedFact,
598        c_m_axis: usize,
599        c_n_axis: usize,
600        micro_ops: Vec<ProtoFusedSpec>,
601        trivial_packing: bool,
602    ) -> TractResult<Self> {
603        ensure!(c_m_axis < c_fact.rank());
604        ensure!(c_n_axis < c_fact.rank());
605        let mut it = OptMatMul {
606            mmm,
607            mode_picker,
608            c_fact,
609            c_m_axis,
610            c_n_axis,
611            micro_ops,
612            trivial_path: false,
613            trivial_packing,
614        };
615        it.update_trivial_path();
616        Ok(it)
617    }
618    // for cost and info
619    fn guess_k(&self) -> Option<TDim> {
620        self.micro_ops
621            .iter()
622            .find_map(
623                |o| {
624                    if let ProtoFusedSpec::AddMatMul { geo, .. } = o {
625                        Some(geo)
626                    } else {
627                        None
628                    }
629                },
630            )
631            .map(|geo| geo.k.clone())
632    }
633
634    fn update_trivial_path(&mut self) {
635        self.trivial_path = self.can_use_trivial_path();
636    }
637
638    fn can_use_trivial_path(&self) -> bool {
639        self.c_fact.shape.is_concrete()
640            && self
641                .c_fact
642                .shape
643                .iter()
644                .enumerate()
645                .all(|(ax, dim)| ax == self.c_m_axis || ax == self.c_n_axis || dim.is_one())
646            && self.trivial_packing
647            && self.micro_ops.iter().all(|o| o.is_trivial())
648    }
649
650    fn fuse_op(
651        &self,
652        model: &TypedModel,
653        node: &TypedNode,
654        mut patch: TypedModelPatch,
655        fused_micro_op: Vec<ProtoFusedSpec>,
656        additional_inputs: &[OutletId],
657    ) -> TractResult<Option<TypedModelPatch>> {
658        let succ = model.node(node.outputs[0].successors[0].node);
659        let mut new_op = self.clone();
660        let before_last = new_op.micro_ops.len() - 1..new_op.micro_ops.len() - 1;
661        new_op.micro_ops.splice(before_last, fused_micro_op);
662        new_op.c_fact = succ.outputs[0].fact.clone();
663        new_op.update_trivial_path();
664        let mut inputs = patch.taps(model, &node.inputs)?;
665        inputs.extend(additional_inputs.iter().cloned());
666        let output = patch.wire_node(&succ.name, new_op, &inputs)?;
667        patch.shunt_outside(model, succ.id.into(), output[0])?;
668        Ok(Some(patch))
669    }
670
671    fn fuse_binary(
672        &self,
673        model: &TypedModel,
674        node: &TypedNode,
675        mut patch: TypedModelPatch,
676        value: OutletId,
677        binop: BinOp,
678    ) -> TractResult<Option<TypedModelPatch>> {
679        let fact = model.outlet_fact(value)?;
680        let mut v = patch.tap_model(model, value)?;
681        if fact.datum_type != self.mmm[0].internal_type() {
682            v = patch.wire_node(
683                format!("{}.cast-input-{}", node.name, node.inputs.len()),
684                cast(self.mmm[0].internal_type()),
685                &[v],
686            )?[0];
687        }
688        let value = node.inputs.len();
689        let additional_input = tvec!(v);
690        if fact.shape.volume() == 1.to_dim() {
691            return self.fuse_op(
692                model,
693                node,
694                patch,
695                vec![ProtoFusedSpec::BinScalar(value, binop)],
696                &additional_input,
697            );
698        }
699        let other_shape = fact.shape.to_owned();
700        if other_shape[self.c_m_axis] == self.c_fact.shape[self.c_m_axis]
701            && other_shape[self.c_m_axis] == other_shape.volume()
702        {
703            return self.fuse_op(
704                model,
705                node,
706                patch,
707                vec![ProtoFusedSpec::BinPerRow(
708                    value,
709                    binop,
710                    MapOutputAxisToInput(tvec!((self.c_m_axis, self.c_m_axis))),
711                )],
712                &additional_input,
713            );
714        }
715        if other_shape[self.c_n_axis] == self.c_fact.shape[self.c_n_axis]
716            && other_shape[self.c_n_axis] == other_shape.volume()
717        {
718            return self.fuse_op(
719                model,
720                node,
721                patch,
722                vec![ProtoFusedSpec::BinPerCol(
723                    value,
724                    binop,
725                    MapOutputAxisToInput(tvec!((self.c_n_axis, self.c_n_axis))),
726                )],
727                &additional_input,
728            );
729        }
730        Ok(None)
731    }
732}