Skip to main content

tract_core/ops/matmul/
optimized.rs

1use crate::internal::*;
2use crate::ops::cast::{Cast, cast};
3use crate::ops::change_axes::wire_with_rank_broadcast;
4use crate::ops::nn::LeakyRelu;
5use ndarray::*;
6use tract_itertools::Itertools;
7
8use tract_linalg::mmm::{
9    AsInputValue, EagerPackedInput, FusedSpec, MMMInputValue, MatMatMul, OutputStoreSpec,
10    PanelExtractInput, PanelExtractor,
11};
12use tract_linalg::pack::PackedFormat;
13use tract_linalg::{BinOp, Scaler};
14use tract_smallvec::ToSmallVec;
15
16use super::ModePicker;
17
18#[derive(Clone, Debug)]
19pub enum ProtoFusedSpec {
20    AddMatMul {
21        geo: AddMatMulGeometry,
22        a: usize,
23        b: usize,
24        packings: Vec<(usize, Option<PanelExtractor>)>,
25    },
26    BinScalar(usize, BinOp),
27    LeakyRelu(usize),
28    BinPerRow(usize, BinOp, MapOutputAxisToInput),
29    BinPerCol(usize, BinOp, MapOutputAxisToInput),
30    AddRowColProducts(usize, usize),
31    AddUnicast(OutputStoreSpec, usize, MapOutputAxisToInput),
32    Scaler(Scaler),
33    Store(Vec<OutputStoreSpec>),
34}
35
36impl ProtoFusedSpec {
37    pub fn format(&self, mmm: &dyn MatMatMul, mode: usize) -> String {
38        use ProtoFusedSpec::*;
39        match self {
40            AddMatMul { geo, packings: packing, .. } => {
41                let (a, b) = &mmm.packings()[packing[mode].0];
42                format!("matmul(k={}, {a:?}•{b:?})", geo.k)
43            }
44            BinScalar(_, op) => format!("scalar{op:?}"),
45            LeakyRelu(alpha) => format!("leaky_relu({alpha:?})"),
46            BinPerRow(_, op, _) => format!("row{op:?}"),
47            BinPerCol(_, op, _) => format!("col{op:?}"),
48            AddRowColProducts(_, _) => "add_row_col_product".to_string(),
49            AddUnicast(_, _, _) => "add_to_matrix".to_string(),
50            Scaler(s) => format!("scale({})", 1f32 * *s),
51            Store(_oss) => "store".to_string(),
52        }
53    }
54
55    pub fn resolve<'t>(
56        &'t self,
57        inputs: &'t [TValue],
58        output_coords: &[usize],
59        output: &Tensor,
60        mmm: &dyn MatMatMul,
61        mode: usize,
62    ) -> FusedSpec<'t> {
63        #[allow(clippy::let_and_return)]
64        let fs = match self {
65            ProtoFusedSpec::AddMatMul { geo, a, b, packings } => {
66                let mut a = inputs[*a].view();
67                let mut b = inputs[*b].view();
68                unsafe {
69                    geo.c_to_a_axis_mapping.translate_view(output_coords, &mut a);
70                }
71                let a = a.as_slice::<Opaque>().unwrap()[0]
72                    .downcast_ref::<Box<dyn MMMInputValue>>()
73                    .unwrap();
74                unsafe {
75                    geo.c_to_b_axis_mapping.translate_view(output_coords, &mut b);
76                }
77                let b = b.as_slice::<Opaque>().unwrap()[0]
78                    .downcast_ref::<Box<dyn MMMInputValue>>()
79                    .unwrap();
80                let (_a_packing, b_packing) = &mmm.packings()[packings[mode].0];
81                let pa = if let Some(extractor) = &packings[mode].1 {
82                    let data = a.downcast_ref::<EagerPackedInput>().unwrap();
83                    AsInputValue::Owned(Box::new(PanelExtractInput {
84                        format: extractor.clone(),
85                        data: data.clone(),
86                    }))
87                } else {
88                    AsInputValue::Borrowed(&**a)
89                };
90                assert!(
91                    b_packing.same_as(b.format())
92                        || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
93                );
94                debug_assert!(pa.k().to_dim().compatible_with(&geo.k.to_dim()));
95                debug_assert!(b.k().to_dim().compatible_with(&geo.k.to_dim()));
96                FusedSpec::AddMatMul {
97                    a: pa,
98                    b: AsInputValue::Borrowed(&**b),
99                    packing: packings[mode].0,
100                }
101            }
102            ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
103            ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
104            ProtoFusedSpec::BinPerRow(v, op, map) => {
105                let mut v = inputs[*v].view();
106                unsafe { map.translate_view(output_coords, &mut v) }
107                FusedSpec::BinPerRow(v, *op)
108            }
109            ProtoFusedSpec::BinPerCol(v, op, map) => {
110                let mut v = inputs[*v].view();
111                unsafe { map.translate_view(output_coords, &mut v) }
112                FusedSpec::BinPerCol(v, *op)
113            }
114            ProtoFusedSpec::AddRowColProducts(row, col) => {
115                FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
116            }
117            ProtoFusedSpec::AddUnicast(store, v, map) => unsafe {
118                let mut view = inputs[*v].view();
119                map.translate_view(output_coords, &mut view);
120                FusedSpec::AddUnicast(store.wrap(&view))
121            },
122            ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
123            ProtoFusedSpec::Store(oss) => unsafe {
124                let view = output.view_offsetting_unchecked(output_coords);
125                FusedSpec::Store(oss[mode].wrap(&view))
126            },
127        };
128        fs
129    }
130
131    pub fn is_trivial(&self) -> bool {
132        match self {
133            ProtoFusedSpec::AddMatMul { geo, .. } => geo.k.as_i64().is_some(),
134            _ => true,
135        }
136    }
137
138    pub fn resolve_trivial<'t>(
139        &'t self,
140        inputs: &'t [TValue],
141        output: &mut Tensor,
142        _mmm: &dyn MatMatMul,
143        mode: usize,
144    ) -> FusedSpec<'t> {
145        #[allow(clippy::let_and_return)]
146        let fs = match self {
147            ProtoFusedSpec::AddMatMul { a, b, packings, .. } => unsafe {
148                debug_assert!(inputs.get(*a).is_some());
149                debug_assert!(inputs.get(*b).is_some());
150                let a = inputs.get_unchecked(*a);
151                let b = inputs.get_unchecked(*b);
152                debug_assert!(a.datum_type().is_opaque());
153                debug_assert!(a.len() == 1);
154                debug_assert!(b.datum_type().is_opaque());
155                debug_assert!(b.len() == 1);
156                let a = a.as_slice_unchecked::<Opaque>().get_unchecked(0);
157                let b = b.as_slice_unchecked::<Opaque>().get_unchecked(0);
158                debug_assert!(a.is::<Box<dyn MMMInputValue>>());
159                debug_assert!(b.is::<Box<dyn MMMInputValue>>());
160                let a = a.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
161                let b = b.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
162                debug_assert!(packings.len() == 1);
163                debug_assert!(packings[0].1.is_none()); // no panel extraction
164                #[cfg(debug_assertions)]
165                {
166                    let (a_packing, b_packing) = &_mmm.packings()[packings[mode].0];
167                    debug_assert!(
168                        a_packing.same_as(a.format())
169                            || (a_packing.is::<PackedFormat>() && a_packing.r() == a.format().r())
170                    );
171                    debug_assert!(
172                        b_packing.same_as(b.format())
173                            || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
174                    );
175                }
176                FusedSpec::AddMatMul {
177                    a: AsInputValue::Borrowed(&**a),
178                    b: AsInputValue::Borrowed(&**b),
179                    packing: packings[mode].0,
180                }
181            },
182            ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
183            ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
184            ProtoFusedSpec::BinPerRow(v, op, _) => {
185                let v = inputs[*v].view();
186                FusedSpec::BinPerRow(v, *op)
187            }
188            ProtoFusedSpec::BinPerCol(v, op, _) => {
189                let v = inputs[*v].view();
190                FusedSpec::BinPerCol(v, *op)
191            }
192            ProtoFusedSpec::AddRowColProducts(row, col) => {
193                FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
194            }
195            ProtoFusedSpec::AddUnicast(store, v, _) => unsafe {
196                let view = inputs[*v].view();
197                FusedSpec::AddUnicast(store.wrap(&view))
198            },
199            ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
200            ProtoFusedSpec::Store(oss) => unsafe {
201                FusedSpec::Store(oss[mode].wrap(&output.view_mut()))
202            },
203        };
204        fs
205    }
206
207    fn check_inputs(&self, inputs: &[&TypedFact]) -> TractResult<()> {
208        use ProtoFusedSpec::*;
209        match self {
210            AddMatMul { a, b, .. } => {
211                ensure!(inputs[*a].datum_type == Opaque::datum_type());
212                ensure!(inputs[*b].datum_type == Opaque::datum_type());
213            }
214            BinScalar(v, _)
215            | LeakyRelu(v)
216            | BinPerCol(v, _, _)
217            | BinPerRow(v, _, _)
218            | AddUnicast(_, v, _) => {
219                ensure!(inputs[*v].datum_type.is_number());
220            }
221            AddRowColProducts(row, col) => {
222                ensure!(inputs[*row].datum_type.is_number());
223                ensure!(inputs[*col].datum_type.is_number());
224            }
225            _ => (),
226        };
227        Ok(())
228    }
229
230    fn cost(&self, m: &TDim, n: &TDim, idt: DatumType) -> TVec<(Cost, TDim)> {
231        match self {
232            ProtoFusedSpec::AddMatMul { geo, .. } => {
233                tvec!((Cost::FMA(idt), m.clone() * n * &geo.k))
234            }
235            _ => tvec!(), /* FIXME maybe */
236        }
237    }
238
239    fn rm_c_axis(&mut self, axis: usize) {
240        use ProtoFusedSpec::*;
241        match self {
242            AddMatMul { geo, .. } => {
243                geo.c_to_a_axis_mapping.rm_c_axis(axis);
244                geo.c_to_b_axis_mapping.rm_c_axis(axis);
245            }
246            BinScalar(..) | Scaler(..) | AddRowColProducts(_, _) | LeakyRelu(_) => {}
247            BinPerRow(_, _, map) | BinPerCol(_, _, map) => map.rm_c_axis(axis),
248            AddUnicast(_, _, map) => {
249                map.rm_c_axis(axis);
250            }
251            Store(oss, ..) => {
252                for oss in oss {
253                    match oss {
254                        OutputStoreSpec::View { m_axis, n_axis, .. } => {
255                            if let Some(m) = m_axis {
256                                *m -= (*m > axis) as usize
257                            };
258                            if let Some(n) = n_axis {
259                                *n -= (*n > axis) as usize
260                            }
261                        }
262                        OutputStoreSpec::Strides { .. } => {}
263                    }
264                }
265            }
266        }
267    }
268}
269
270#[derive(Clone, Debug)]
271pub struct MapOutputAxisToInput(pub TVec<(usize, usize)>);
272
273impl MapOutputAxisToInput {
274    #[inline]
275    unsafe fn translate_view(&self, output_coords: &[usize], v: &mut TensorView) {
276        for &(out_axis, in_axis) in &self.0 {
277            unsafe { v.offset_axis(in_axis, output_coords[out_axis] as isize) }
278        }
279    }
280
281    #[inline]
282    fn rm_c_axis(&mut self, axis: usize) {
283        for (c, _) in &mut self.0 {
284            *c -= (*c > axis) as usize;
285        }
286    }
287}
288
289#[derive(Clone, Debug)]
290pub struct AddMatMulGeometry {
291    pub k: TDim,
292    pub c_to_a_axis_mapping: MapOutputAxisToInput,
293    pub c_to_b_axis_mapping: MapOutputAxisToInput,
294}
295
296#[derive(Clone, Debug)]
297pub struct OptMatMul {
298    pub c_fact: TypedFact,
299    pub micro_ops: Vec<ProtoFusedSpec>,
300    pub mmm: Vec<Box<dyn MatMatMul>>,
301    pub mode_picker: ModePicker,
302    pub c_m_axis: Option<usize>,
303    pub c_n_axis: Option<usize>,
304    pub trivial_packing: bool,
305    pub trivial_path: bool,
306}
307
308impl Op for OptMatMul {
309    fn name(&self) -> StaticName {
310        "OptMatMul".into()
311    }
312
313    fn info(&self) -> TractResult<Vec<String>> {
314        let m = self.c_m_axis.map(|ix| &self.c_fact.shape[ix]).unwrap_or(&TDim::Val(1));
315        let n = self.c_n_axis.map(|ix| &self.c_fact.shape[ix]).unwrap_or(&TDim::Val(1));
316        let mut infos = vec![format!(
317            "c_shape:{:?}, c_m_axis:{:?} c_n_axis:{:?} m:{} n:{}",
318            self.c_fact, self.c_m_axis, self.c_n_axis, m, n,
319        )];
320        if let Some(k) = self.guess_k() {
321            infos.push(format!("Mult: m:{} k:{} n:{} with {:?}", m, k, n, self.mmm));
322        } else {
323            infos.push(format!("Mult: {:?}", self.mmm));
324        }
325        for (mode, mmm) in self.mmm.iter().enumerate() {
326            infos.push(format!(
327                "Ops: {}",
328                self.micro_ops.iter().map(|o| o.format(&**mmm, mode)).join(" >>> ")
329            ));
330        }
331        Ok(infos)
332    }
333
334    op_as_typed_op!();
335}
336
337impl EvalOp for OptMatMul {
338    fn is_stateless(&self) -> bool {
339        true
340    }
341
342    fn eval_with_session(
343        &self,
344        _node_id: usize,
345        session: &TurnState,
346        inputs: TVec<TValue>,
347    ) -> TractResult<TVec<TValue>> {
348        unsafe {
349            let c_shape = self.c_fact.shape.eval_to_usize(&session.resolved_symbols)?;
350            let mut c = Tensor::uninitialized_dt(self.c_fact.datum_type, &c_shape)?;
351            let m = self.c_m_axis.map(|c_m| c.shape()[c_m]).unwrap_or(1);
352            let n = self.c_n_axis.map(|c_n| c.shape()[c_n]).unwrap_or(1);
353            let mode = self.mode_picker.pick(n)?;
354            let mmm = &*self.mmm[mode];
355            let mut cell = session.cached_mmm_scratch_space.borrow_mut();
356            if !cell.as_ref().is_some_and(|scratch| mmm.can_use_scratch_space(&**scratch)) {
357                *cell = None
358            }
359            let scratch = cell.get_or_insert_with(|| mmm.allocate_scratch_space());
360            if self.trivial_path {
361                let uops: Vec<FusedSpec> = self
362                    .micro_ops
363                    .iter()
364                    .map(|o| o.resolve_trivial(&inputs, &mut c, mmm, mode))
365                    .collect();
366                mmm.run_with_scratch_space(m, n, scratch.as_mut(), &uops)?;
367                Ok(tvec!(c.into_tvalue()))
368            } else {
369                let mut uops = vec![FusedSpec::ShiftLeft(0); self.micro_ops.len()];
370                let mut looping_shape: TVec<usize> = c_shape.to_smallvec();
371                if let Some(ax) = self.c_m_axis {
372                    looping_shape[ax] = 1;
373                }
374                if let Some(ax) = self.c_n_axis {
375                    looping_shape[ax] = 1;
376                }
377                for c_coords in indices(&*looping_shape) {
378                    for ix in 0..self.micro_ops.len() {
379                        *uops.get_unchecked_mut(ix) = self.micro_ops.get_unchecked(ix).resolve(
380                            &inputs,
381                            c_coords.slice(),
382                            &c,
383                            mmm,
384                            mode,
385                        );
386                    }
387                    mmm.run_with_scratch_space(m, n, scratch.as_mut(), &uops)
388                        .context("In mmm.run_with_scratch_space")?;
389                }
390                Ok(tvec!(c.into_tvalue()))
391            }
392        }
393    }
394}
395
396impl TypedOp for OptMatMul {
397    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
398        ensure!(self.c_m_axis.map(|ax| ax < self.c_fact.rank()).unwrap_or(true));
399        ensure!(self.c_n_axis.map(|ax| ax < self.c_fact.rank()).unwrap_or(true));
400        ensure!(self.trivial_path == self.can_use_trivial_path());
401        ensure!(self.mmm.iter().map(|mmm| mmm.internal_type()).all_equal());
402        for op in &self.micro_ops {
403            op.check_inputs(inputs)?;
404        }
405        Ok(tvec!(self.c_fact.clone()))
406    }
407
408    fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
409        let mut sums = HashMap::new();
410        for op in &self.micro_ops {
411            for (cost, count) in op.cost(self.m(), self.n(), self.mmm[0].internal_type()) {
412                *sums.entry(cost).or_default() += count;
413            }
414        }
415        let loops = self
416            .c_fact
417            .shape
418            .iter()
419            .enumerate()
420            .map(|(ix, d)| {
421                if Some(ix) == self.c_m_axis || Some(ix) == self.c_n_axis {
422                    1.to_dim()
423                } else {
424                    d.clone()
425                }
426            })
427            .product::<TDim>();
428        for s in &mut sums.values_mut() {
429            *s *= &loops;
430        }
431        Ok(sums.into_iter().collect())
432    }
433
434    fn fuse(&self, model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
435        use crate::ops;
436        rule_if!(node.outputs.len() == 1);
437        rule_if!(node.outputs[0].successors.len() == 1);
438        rule_if!(!model.output_outlets()?.contains(&node.id.into()));
439        let succ = model.node(node.outputs[0].successors[0].node);
440        let mut patch = TypedModelPatch::new(format!("fusing {succ}"));
441
442        if let Some(op) = succ.op_as::<ops::binary::TypedBinOp>() {
443            rule_if_some!(mut binop = op.0.as_linalg_binop());
444            let flipped = succ.inputs[0].node == node.id;
445            if flipped {
446                binop = binop.flip();
447            }
448            let other_outlet = succ.inputs[flipped as usize];
449            return self.fuse_binary(model, node, patch, other_outlet, binop);
450        }
451        if let Some(op) = succ.op_as::<ops::binary::OptBinByScalar>() {
452            rule_if_some!(mut binop = op.binop.as_linalg_binop());
453            let flipped = succ.inputs[0].node == node.id;
454            if flipped {
455                binop = binop.flip();
456            }
457            let other_outlet = succ.inputs[flipped as usize];
458            return self.fuse_binary(model, node, patch, other_outlet, binop);
459        }
460
461        if let Some(op) = succ.op_as::<ops::element_wise::ElementWiseOp>().map(|ew| ew.0.as_ref()) {
462            if let Some(op) = op.downcast_ref::<ops::math::QScale>() {
463                return self.fuse_op(
464                    model,
465                    node,
466                    patch,
467                    vec![ProtoFusedSpec::Scaler(op.scaler)],
468                    &[],
469                );
470            }
471            if let Some(op) = op.downcast_ref::<LeakyRelu>() {
472                rule_if!(
473                    self.mmm
474                        .iter()
475                        .all(|mmm| mmm.can_fuse(&FusedSpec::LeakyRelu(&tensor0(op.alpha))))
476                );
477                let alpha = patch.add_const(
478                    node.name.to_string() + ".alpha",
479                    tensor0(op.alpha).cast_to_dt(self.mmm[0].internal_type())?.into_owned(),
480                )?;
481                return self.fuse_op(
482                    model,
483                    node,
484                    patch,
485                    vec![ProtoFusedSpec::LeakyRelu(node.inputs.len())],
486                    &[alpha],
487                );
488            }
489        }
490        if let Some(cast_to) = succ.op_as::<ops::cast::Cast>().map(|cast| cast.to) {
491            if ((cast_to.unquantized() == i8::datum_type()
492                || cast_to.unquantized() == u8::datum_type())
493                && self.c_fact.datum_type == i32::datum_type())
494                || self.mmm.iter().all(|m| m.stores().contains(&cast_to))
495            {
496                if let Some(ProtoFusedSpec::Store(stores)) = self.micro_ops.last() {
497                    if stores.iter().any(|s| matches!(s, OutputStoreSpec::Strides { .. })) {
498                        return Ok(None);
499                    }
500                    let c_fact = cast_to.fact(self.c_fact.shape.clone());
501                    let mut patch = TypedModelPatch::fuse_with_next(
502                        model,
503                        node,
504                        Self { c_fact, ..self.clone() },
505                    )?;
506                    patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
507                    return Ok(Some(patch));
508                }
509            }
510        }
511        if let Some(AxisOp::Rm(axis)) = succ.op_as::<ops::AxisOp>() {
512            rule_if!(Some(*axis) != self.c_m_axis);
513            rule_if!(Some(*axis) != self.c_n_axis);
514            let mut new_op = self.clone();
515            new_op.c_fact.shape.remove_axis(*axis)?;
516            if let Some(c_m_axis) = &mut new_op.c_m_axis {
517                *c_m_axis -= (*c_m_axis > *axis) as usize;
518            }
519            if let Some(c_n_axis) = &mut new_op.c_n_axis {
520                *c_n_axis -= (*c_n_axis > *axis) as usize;
521            }
522            for uop in &mut new_op.micro_ops {
523                uop.rm_c_axis(*axis);
524            }
525            let mut patch = TypedModelPatch::fuse_with_next(model, node, new_op)?;
526            patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
527            return Ok(Some(patch));
528        }
529        if succ.op_is::<AxisOp>() || succ.op_is::<IntoShape>() {
530            if let &[next] = &*succ.outputs[0].successors {
531                let next_node = model.node(next.node);
532                if let Some(cast) = next_node.op_as::<Cast>() {
533                    let mut patch = TypedModelPatch::default();
534                    let mut wire = patch.tap_model(model, node.id.into())?;
535                    wire = patch.wire_node(&next_node.name, cast.clone(), &[wire])?[0];
536                    wire = patch.wire_node(&succ.name, succ.op.clone(), &[wire])?[0];
537                    patch.shunt_outside(model, next_node.id.into(), wire)?;
538                    return Ok(Some(patch));
539                } else if let Some(op) = next_node.op_as::<ops::binary::TypedBinOp>() {
540                    rule_if!(op.0.as_linalg_binop().is_some());
541                    let flipped = succ.inputs[0].node == node.id;
542                    let other_outlet = next_node.inputs[flipped as usize];
543                    if let Some(uni) = &model.outlet_fact(other_outlet)?.uniform {
544                        let mut patch = TypedModelPatch::default();
545                        let cst =
546                            patch.add_const(&model.node(other_outlet.node).name, uni.clone())?;
547                        let output = patch.tap_model(model, node.id.into())?;
548                        let wire = wire_with_rank_broadcast(
549                            &next_node.name,
550                            &mut patch,
551                            op.clone(),
552                            &if flipped { [output, cst] } else { [cst, output] },
553                        )?;
554                        let wire = patch.wire_node(&succ.name, succ.op.clone(), &wire)?[0];
555                        patch.shunt_outside(model, next_node.id.into(), wire)?;
556                        return Ok(Some(patch));
557                    }
558                }
559            }
560        }
561        if let Some(op) = succ.op_as::<ops::binary::OptBinUnicast>() {
562            let in_1_fact = model.outlet_fact(succ.inputs[0])?;
563            let in_2_fact = model.outlet_fact(succ.inputs[1])?;
564            if op.binop.is::<ops::math::Add>()
565                && self.mmm.len() == 1
566                && in_1_fact.without_value() == in_2_fact.without_value()
567            {
568                let other_slot = 1 - node.outputs[0].successors[0].slot;
569                let other_input = succ.inputs[other_slot];
570                let other_input = patch.tap_model(model, other_input)?;
571                let other_fact = patch.outlet_fact(other_input)?;
572
573                if other_fact.shape == self.c_fact.shape {
574                    let other_storage = unsafe { self.mmm[0].c_view(self.c_m_axis, self.c_n_axis) };
575                    let mapping =
576                        MapOutputAxisToInput((0..other_fact.rank()).map(|x| (x, x)).collect());
577                    return self.fuse_op(
578                        model,
579                        node,
580                        patch,
581                        vec![ProtoFusedSpec::AddUnicast(other_storage, node.inputs.len(), mapping)],
582                        &[other_input],
583                    );
584                }
585            } else {
586                rule_if_some!(mut binop = op.binop.as_linalg_binop());
587                let flipped = succ.inputs[0].node == node.id;
588                if flipped {
589                    binop = binop.flip();
590                }
591                let other_outlet = succ.inputs[flipped as usize];
592                return self.fuse_binary(model, node, patch, other_outlet, binop);
593            }
594        };
595        Ok(None)
596    }
597
598    as_op!();
599}
600
601impl OptMatMul {
602    pub fn new(
603        mmm: Vec<Box<dyn MatMatMul>>,
604        mode_picker: ModePicker,
605        c_fact: TypedFact,
606        c_m_axis: Option<usize>,
607        c_n_axis: Option<usize>,
608        micro_ops: Vec<ProtoFusedSpec>,
609        trivial_packing: bool,
610    ) -> TractResult<Self> {
611        if let Some(m) = c_m_axis {
612            ensure!(m < c_fact.rank());
613        }
614        if let Some(n) = c_n_axis {
615            ensure!(n < c_fact.rank());
616        }
617        let mut it = OptMatMul {
618            mmm,
619            mode_picker,
620            c_fact,
621            c_m_axis,
622            c_n_axis,
623            micro_ops,
624            trivial_path: false,
625            trivial_packing,
626        };
627        it.update_trivial_path();
628        Ok(it)
629    }
630
631    // for auditing only (may return None if no AddMatMul is found)
632    pub fn guess_k(&self) -> Option<TDim> {
633        self.micro_ops
634            .iter()
635            .find_map(
636                |o| {
637                    if let ProtoFusedSpec::AddMatMul { geo, .. } = o { Some(geo) } else { None }
638                },
639            )
640            .map(|geo| geo.k.clone())
641    }
642
643    #[inline]
644    pub fn m(&self) -> &TDim {
645        self.c_m_axis.map(|ax| &self.c_fact.shape[ax]).unwrap_or(&TDim::Val(1))
646    }
647
648    #[inline]
649    pub fn n(&self) -> &TDim {
650        self.c_n_axis.map(|ax| &self.c_fact.shape[ax]).unwrap_or(&TDim::Val(1))
651    }
652
653    fn update_trivial_path(&mut self) {
654        self.trivial_path = self.can_use_trivial_path();
655    }
656
657    fn can_use_trivial_path(&self) -> bool {
658        self.c_fact.shape.is_concrete()
659            && self.c_fact.shape.iter().enumerate().all(|(ax, dim)| {
660                Some(ax) == self.c_m_axis || Some(ax) == self.c_n_axis || dim.is_one()
661            })
662            && self.trivial_packing
663            && self.micro_ops.iter().all(|o| o.is_trivial())
664    }
665
666    fn fuse_op(
667        &self,
668        model: &TypedModel,
669        node: &TypedNode,
670        mut patch: TypedModelPatch,
671        fused_micro_op: Vec<ProtoFusedSpec>,
672        additional_inputs: &[OutletId],
673    ) -> TractResult<Option<TypedModelPatch>> {
674        let succ = model.node(node.outputs[0].successors[0].node);
675        let mut new_op = self.clone();
676        let before_last = new_op.micro_ops.len() - 1..new_op.micro_ops.len() - 1;
677        new_op.micro_ops.splice(before_last, fused_micro_op);
678        new_op.c_fact = succ.outputs[0].fact.clone();
679        new_op.update_trivial_path();
680        let mut inputs = patch.taps(model, &node.inputs)?;
681        inputs.extend(additional_inputs.iter().cloned());
682        let output = patch.wire_node(&succ.name, new_op, &inputs)?;
683        patch.shunt_outside(model, succ.id.into(), output[0])?;
684        Ok(Some(patch))
685    }
686
687    fn fuse_binary(
688        &self,
689        model: &TypedModel,
690        node: &TypedNode,
691        mut patch: TypedModelPatch,
692        value: OutletId,
693        binop: BinOp,
694    ) -> TractResult<Option<TypedModelPatch>> {
695        let fact = model.outlet_fact(value)?;
696        let mut v = patch.tap_model(model, value)?;
697        if fact.datum_type != self.mmm[0].internal_type() {
698            v = patch.wire_node(
699                format!("{}.cast-input-{}", node.name, node.inputs.len()),
700                cast(self.mmm[0].internal_type()),
701                &[v],
702            )?[0];
703        }
704        let value = node.inputs.len();
705        let additional_input = tvec!(v);
706        if fact.shape.volume() == 1.to_dim() {
707            return self.fuse_op(
708                model,
709                node,
710                patch,
711                vec![ProtoFusedSpec::BinScalar(value, binop)],
712                &additional_input,
713            );
714        }
715        let other_shape = fact.shape.to_owned();
716        if self.c_m_axis.is_some_and(|ax| {
717            other_shape[ax] == self.c_fact.shape[ax] && other_shape[ax] == other_shape.volume()
718        }) {
719            return self.fuse_op(
720                model,
721                node,
722                patch,
723                vec![ProtoFusedSpec::BinPerRow(
724                    value,
725                    binop,
726                    MapOutputAxisToInput(tvec!((self.c_m_axis.unwrap(), self.c_m_axis.unwrap()))),
727                )],
728                &additional_input,
729            );
730        }
731        if self.c_n_axis.is_some_and(|ax| {
732            other_shape[ax] == self.c_fact.shape[ax] && other_shape[ax] == other_shape.volume()
733        }) {
734            return self.fuse_op(
735                model,
736                node,
737                patch,
738                vec![ProtoFusedSpec::BinPerCol(
739                    value,
740                    binop,
741                    MapOutputAxisToInput(tvec!((self.c_n_axis.unwrap(), self.c_n_axis.unwrap()))),
742                )],
743                &additional_input,
744            );
745        }
746        Ok(None)
747    }
748}