Skip to main content

tract_core/ops/matmul/
optimized.rs

1use crate::internal::*;
2use crate::ops::cast::cast;
3use crate::ops::change_axes::wire_with_rank_broadcast;
4use crate::ops::nn::LeakyRelu;
5use ndarray::*;
6use tract_itertools::Itertools;
7
8use tract_linalg::mmm::{
9    AsInputValue, EagerPackedInput, FusedSpec, MMMInputValue, MatMatMul, OutputStoreSpec,
10    PanelExtractInput, PanelExtractor,
11};
12use tract_linalg::pack::PackedFormat;
13use tract_linalg::{BinOp, Scaler};
14use tract_smallvec::ToSmallVec;
15
16use super::ModePicker;
17
18#[derive(Clone, Debug)]
19pub enum ProtoFusedSpec {
20    AddMatMul {
21        geo: AddMatMulGeometry,
22        a: usize,
23        b: usize,
24        packings: Vec<(usize, Option<PanelExtractor>)>,
25    },
26    BinScalar(usize, BinOp),
27    LeakyRelu(usize),
28    BinPerRow(usize, BinOp, MapOutputAxisToInput),
29    BinPerCol(usize, BinOp, MapOutputAxisToInput),
30    AddRowColProducts(usize, usize),
31    AddUnicast(OutputStoreSpec, usize, MapOutputAxisToInput),
32    Scaler(Scaler),
33    Store(Vec<OutputStoreSpec>),
34}
35
36impl ProtoFusedSpec {
37    pub fn format(&self, mmm: &dyn MatMatMul, mode: usize) -> String {
38        use ProtoFusedSpec::*;
39        match self {
40            AddMatMul { geo, packings: packing, .. } => {
41                let (a, b) = &mmm.packings()[packing[mode].0];
42                format!("matmul(k={}, {a:?}•{b:?})", geo.k)
43            }
44            BinScalar(_, op) => format!("scalar{op:?}"),
45            LeakyRelu(alpha) => format!("leaky_relu({alpha:?})"),
46            BinPerRow(_, op, _) => format!("row{op:?}"),
47            BinPerCol(_, op, _) => format!("col{op:?}"),
48            AddRowColProducts(_, _) => "add_row_col_product".to_string(),
49            AddUnicast(_, _, _) => "add_to_matrix".to_string(),
50            Scaler(s) => format!("scale({})", 1f32 * *s),
51            Store(_oss) => "store".to_string(),
52        }
53    }
54
55    pub fn resolve<'t>(
56        &'t self,
57        inputs: &'t [TValue],
58        output_coords: &[usize],
59        output: &Tensor,
60        mmm: &dyn MatMatMul,
61        mode: usize,
62    ) -> FusedSpec<'t> {
63        #[allow(clippy::let_and_return)]
64        let fs = match self {
65            ProtoFusedSpec::AddMatMul { geo, a, b, packings } => {
66                let mut a = inputs[*a].view();
67                let mut b = inputs[*b].view();
68                unsafe {
69                    geo.c_to_a_axis_mapping.translate_view(output_coords, &mut a);
70                }
71                let a = a.as_slice::<Opaque>().unwrap()[0]
72                    .downcast_ref::<Box<dyn MMMInputValue>>()
73                    .unwrap();
74                unsafe {
75                    geo.c_to_b_axis_mapping.translate_view(output_coords, &mut b);
76                }
77                let b = b.as_slice::<Opaque>().unwrap()[0]
78                    .downcast_ref::<Box<dyn MMMInputValue>>()
79                    .unwrap();
80                let (_a_packing, b_packing) = &mmm.packings()[packings[mode].0];
81                let pa = if let Some(extractor) = &packings[mode].1 {
82                    let data = a.downcast_ref::<EagerPackedInput>().unwrap();
83                    AsInputValue::Owned(Box::new(PanelExtractInput {
84                        format: extractor.clone(),
85                        data: data.clone(),
86                    }))
87                } else {
88                    AsInputValue::Borrowed(&**a)
89                };
90                assert!(
91                    b_packing.same_as(b.format())
92                        || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
93                );
94                debug_assert!(pa.k().to_dim().compatible_with(&geo.k.to_dim()));
95                debug_assert!(b.k().to_dim().compatible_with(&geo.k.to_dim()));
96                FusedSpec::AddMatMul {
97                    a: pa,
98                    b: AsInputValue::Borrowed(&**b),
99                    packing: packings[mode].0,
100                }
101            }
102            ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
103            ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
104            ProtoFusedSpec::BinPerRow(v, op, map) => {
105                let mut v = inputs[*v].view();
106                unsafe { map.translate_view(output_coords, &mut v) }
107                FusedSpec::BinPerRow(v, *op)
108            }
109            ProtoFusedSpec::BinPerCol(v, op, map) => {
110                let mut v = inputs[*v].view();
111                unsafe { map.translate_view(output_coords, &mut v) }
112                FusedSpec::BinPerCol(v, *op)
113            }
114            ProtoFusedSpec::AddRowColProducts(row, col) => {
115                FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
116            }
117            ProtoFusedSpec::AddUnicast(store, v, map) => unsafe {
118                let mut view = inputs[*v].view();
119                map.translate_view(output_coords, &mut view);
120                FusedSpec::AddUnicast(store.wrap(&view))
121            },
122            ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
123            ProtoFusedSpec::Store(oss) => unsafe {
124                let view = output.view_offsetting_unchecked(output_coords);
125                FusedSpec::Store(oss[mode].wrap(&view))
126            },
127        };
128        fs
129    }
130
131    pub fn is_trivial(&self) -> bool {
132        match self {
133            ProtoFusedSpec::AddMatMul { geo, .. } => geo.k.as_i64().is_some(),
134            _ => true,
135        }
136    }
137
138    pub fn resolve_trivial<'t>(
139        &'t self,
140        inputs: &'t [TValue],
141        output: &mut Tensor,
142        _mmm: &dyn MatMatMul,
143        mode: usize,
144    ) -> FusedSpec<'t> {
145        #[allow(clippy::let_and_return)]
146        let fs = match self {
147            ProtoFusedSpec::AddMatMul { a, b, packings, .. } => unsafe {
148                debug_assert!(inputs.get(*a).is_some());
149                debug_assert!(inputs.get(*b).is_some());
150                let a = inputs.get_unchecked(*a);
151                let b = inputs.get_unchecked(*b);
152                debug_assert!(a.datum_type().is_opaque());
153                debug_assert!(a.len() == 1);
154                debug_assert!(b.datum_type().is_opaque());
155                debug_assert!(b.len() == 1);
156                let a = a.as_slice_unchecked::<Opaque>().get_unchecked(0);
157                let b = b.as_slice_unchecked::<Opaque>().get_unchecked(0);
158                debug_assert!(a.is::<Box<dyn MMMInputValue>>());
159                debug_assert!(b.is::<Box<dyn MMMInputValue>>());
160                let a = a.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
161                let b = b.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
162                debug_assert!(packings.len() == 1);
163                debug_assert!(packings[0].1.is_none()); // no panel extraction
164                #[cfg(debug_assertions)]
165                {
166                    let (a_packing, b_packing) = &_mmm.packings()[packings[mode].0];
167                    debug_assert!(
168                        a_packing.same_as(a.format())
169                            || (a_packing.is::<PackedFormat>() && a_packing.r() == a.format().r())
170                    );
171                    debug_assert!(
172                        b_packing.same_as(b.format())
173                            || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
174                    );
175                }
176                FusedSpec::AddMatMul {
177                    a: AsInputValue::Borrowed(&**a),
178                    b: AsInputValue::Borrowed(&**b),
179                    packing: packings[mode].0,
180                }
181            },
182            ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
183            ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
184            ProtoFusedSpec::BinPerRow(v, op, _) => {
185                let v = inputs[*v].view();
186                FusedSpec::BinPerRow(v, *op)
187            }
188            ProtoFusedSpec::BinPerCol(v, op, _) => {
189                let v = inputs[*v].view();
190                FusedSpec::BinPerCol(v, *op)
191            }
192            ProtoFusedSpec::AddRowColProducts(row, col) => {
193                FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
194            }
195            ProtoFusedSpec::AddUnicast(store, v, _) => unsafe {
196                let view = inputs[*v].view();
197                FusedSpec::AddUnicast(store.wrap(&view))
198            },
199            ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
200            ProtoFusedSpec::Store(oss) => unsafe {
201                FusedSpec::Store(oss[mode].wrap(&output.view_mut()))
202            },
203        };
204        fs
205    }
206
207    fn check_inputs(&self, inputs: &[&TypedFact]) -> TractResult<()> {
208        use ProtoFusedSpec::*;
209        match self {
210            AddMatMul { a, b, .. } => {
211                ensure!(inputs[*a].datum_type == Opaque::datum_type());
212                ensure!(inputs[*b].datum_type == Opaque::datum_type());
213            }
214            BinScalar(v, _)
215            | LeakyRelu(v)
216            | BinPerCol(v, _, _)
217            | BinPerRow(v, _, _)
218            | AddUnicast(_, v, _) => {
219                ensure!(inputs[*v].datum_type.is_number());
220            }
221            AddRowColProducts(row, col) => {
222                ensure!(inputs[*row].datum_type.is_number());
223                ensure!(inputs[*col].datum_type.is_number());
224            }
225            _ => (),
226        };
227        Ok(())
228    }
229
230    fn cost(&self, m: &TDim, n: &TDim, idt: DatumType) -> TVec<(Cost, TDim)> {
231        match self {
232            ProtoFusedSpec::AddMatMul { geo, .. } => {
233                tvec!((Cost::FMA(idt), m.clone() * n * &geo.k))
234            }
235            _ => tvec!(), /* FIXME maybe */
236        }
237    }
238
239    fn rm_c_axis(&mut self, axis: usize) {
240        use ProtoFusedSpec::*;
241        match self {
242            AddMatMul { geo, .. } => {
243                geo.c_to_a_axis_mapping.rm_c_axis(axis);
244                geo.c_to_b_axis_mapping.rm_c_axis(axis);
245            }
246            BinScalar(..) | Scaler(..) | AddRowColProducts(_, _) | LeakyRelu(_) => {}
247            BinPerRow(_, _, map) | BinPerCol(_, _, map) => map.rm_c_axis(axis),
248            AddUnicast(_, _, map) => {
249                map.rm_c_axis(axis);
250            }
251            Store(oss, ..) => {
252                for oss in oss {
253                    match oss {
254                        OutputStoreSpec::View { m_axis, n_axis, .. } => {
255                            if let Some(m) = m_axis {
256                                *m -= (*m > axis) as usize
257                            };
258                            if let Some(n) = n_axis {
259                                *n -= (*n > axis) as usize
260                            }
261                        }
262                        OutputStoreSpec::Strides { .. } => {}
263                    }
264                }
265            }
266        }
267    }
268}
269
270#[derive(Clone, Debug)]
271pub struct MapOutputAxisToInput(pub TVec<(usize, usize)>);
272
273impl MapOutputAxisToInput {
274    #[inline]
275    unsafe fn translate_view(&self, output_coords: &[usize], v: &mut TensorView) {
276        for &(out_axis, in_axis) in &self.0 {
277            unsafe { v.offset_axis(in_axis, output_coords[out_axis] as isize) }
278        }
279    }
280
281    #[inline]
282    fn rm_c_axis(&mut self, axis: usize) {
283        for (c, _) in &mut self.0 {
284            *c -= (*c > axis) as usize;
285        }
286    }
287}
288
289#[derive(Clone, Debug)]
290pub struct AddMatMulGeometry {
291    pub k: TDim,
292    pub c_to_a_axis_mapping: MapOutputAxisToInput,
293    pub c_to_b_axis_mapping: MapOutputAxisToInput,
294}
295
296#[derive(Clone, Debug)]
297pub struct OptMatMul {
298    pub c_fact: TypedFact,
299    pub micro_ops: Vec<ProtoFusedSpec>,
300    pub mmm: Vec<Box<dyn MatMatMul>>,
301    pub mode_picker: ModePicker,
302    pub c_m_axis: Option<usize>,
303    pub c_n_axis: Option<usize>,
304    pub trivial_packing: bool,
305    pub trivial_path: bool,
306}
307
308impl Op for OptMatMul {
309    fn name(&self) -> StaticName {
310        "OptMatMul".into()
311    }
312
313    fn info(&self) -> TractResult<Vec<String>> {
314        let m = self.c_m_axis.map(|ix| &self.c_fact.shape[ix]).unwrap_or(&TDim::Val(1));
315        let n = self.c_n_axis.map(|ix| &self.c_fact.shape[ix]).unwrap_or(&TDim::Val(1));
316        let mut infos = vec![format!(
317            "c_shape:{:?}, c_m_axis:{:?} c_n_axis:{:?} m:{} n:{}",
318            self.c_fact, self.c_m_axis, self.c_n_axis, m, n,
319        )];
320        if let Some(k) = self.guess_k() {
321            infos.push(format!("Mult: m:{} k:{} n:{} with {:?}", m, k, n, self.mmm));
322        } else {
323            infos.push(format!("Mult: {:?}", self.mmm));
324        }
325        for (mode, mmm) in self.mmm.iter().enumerate() {
326            infos.push(format!(
327                "Ops: {}",
328                self.micro_ops.iter().map(|o| o.format(&**mmm, mode)).join(" >>> ")
329            ));
330        }
331        Ok(infos)
332    }
333
334    op_as_typed_op!();
335}
336
337impl EvalOp for OptMatMul {
338    fn is_stateless(&self) -> bool {
339        true
340    }
341
342    fn eval_with_session(
343        &self,
344        _node_id: usize,
345        session: &SessionState,
346        inputs: TVec<TValue>,
347    ) -> TractResult<TVec<TValue>> {
348        unsafe {
349            let c_shape = self.c_fact.shape.eval_to_usize(&session.resolved_symbols)?;
350            let mut c = Tensor::uninitialized_dt(self.c_fact.datum_type, &c_shape)?;
351            let m = self.c_m_axis.map(|c_m| c.shape()[c_m]).unwrap_or(1);
352            let n = self.c_n_axis.map(|c_n| c.shape()[c_n]).unwrap_or(1);
353            let mode = self.mode_picker.pick(n)?;
354            let mmm = &*self.mmm[mode];
355            let mut cell = session.cached_mmm_scratch_space.borrow_mut();
356            if !cell.as_ref().is_some_and(|scratch| mmm.can_use_scratch_space(&**scratch)) {
357                *cell = None
358            }
359            let scratch = cell.get_or_insert_with(|| mmm.allocate_scratch_space());
360            if self.trivial_path {
361                let uops: Vec<FusedSpec> = self
362                    .micro_ops
363                    .iter()
364                    .map(|o| o.resolve_trivial(&inputs, &mut c, mmm, mode))
365                    .collect();
366                mmm.run_with_scratch_space(m, n, scratch.as_mut(), &uops)?;
367                Ok(tvec!(c.into_tvalue()))
368            } else {
369                let mut uops = vec![FusedSpec::ShiftLeft(0); self.micro_ops.len()];
370                let mut looping_shape: TVec<usize> = c_shape.to_smallvec();
371                if let Some(ax) = self.c_m_axis {
372                    looping_shape[ax] = 1;
373                }
374                if let Some(ax) = self.c_n_axis {
375                    looping_shape[ax] = 1;
376                }
377                for c_coords in indices(&*looping_shape) {
378                    for ix in 0..self.micro_ops.len() {
379                        *uops.get_unchecked_mut(ix) = self.micro_ops.get_unchecked(ix).resolve(
380                            &inputs,
381                            c_coords.slice(),
382                            &c,
383                            mmm,
384                            mode,
385                        );
386                    }
387                    mmm.run_with_scratch_space(m, n, scratch.as_mut(), &uops)
388                        .context("In mmm.run_with_scratch_space")?;
389                }
390                Ok(tvec!(c.into_tvalue()))
391            }
392        }
393    }
394}
395
396impl TypedOp for OptMatMul {
397    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
398        ensure!(self.c_m_axis.map(|ax| ax < self.c_fact.rank()).unwrap_or(true));
399        ensure!(self.c_n_axis.map(|ax| ax < self.c_fact.rank()).unwrap_or(true));
400        ensure!(self.trivial_path == self.can_use_trivial_path());
401        ensure!(self.mmm.iter().map(|mmm| mmm.internal_type()).all_equal());
402        for op in &self.micro_ops {
403            op.check_inputs(inputs)?;
404        }
405        Ok(tvec!(self.c_fact.clone()))
406    }
407
408    fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
409        let mut sums = HashMap::new();
410        for op in &self.micro_ops {
411            for (cost, count) in op.cost(self.m(), self.n(), self.mmm[0].internal_type()) {
412                *sums.entry(cost).or_default() += count;
413            }
414        }
415        let loops = self
416            .c_fact
417            .shape
418            .iter()
419            .enumerate()
420            .map(|(ix, d)| {
421                if Some(ix) == self.c_m_axis || Some(ix) == self.c_n_axis {
422                    1.to_dim()
423                } else {
424                    d.clone()
425                }
426            })
427            .product::<TDim>();
428        for s in &mut sums.values_mut() {
429            *s *= &loops;
430        }
431        Ok(sums.into_iter().collect())
432    }
433
434    fn fuse(&self, model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
435        use crate::ops;
436        if node.outputs.len() != 1
437            || node.outputs[0].successors.len() != 1
438            || model.output_outlets()?.contains(&node.id.into())
439        {
440            return Ok(None);
441        }
442        let succ = model.node(node.outputs[0].successors[0].node);
443        let mut patch = TypedModelPatch::new(format!("fusing {succ}"));
444
445        if let Some(op) = succ.op_as::<ops::binary::TypedBinOp>() {
446            let mut binop = if let Some(op) = op.0.as_linalg_binop() {
447                op
448            } else {
449                return Ok(None);
450            };
451            let flipped = succ.inputs[0].node == node.id;
452            if flipped {
453                binop = binop.flip();
454            }
455            let other_outlet = succ.inputs[flipped as usize];
456            return self.fuse_binary(model, node, patch, other_outlet, binop);
457        }
458        if let Some(op) = succ.op_as::<ops::binary::OptBinByScalar>() {
459            let mut binop = if let Some(op) = op.binop.as_linalg_binop() {
460                op
461            } else {
462                return Ok(None);
463            };
464            let flipped = succ.inputs[0].node == node.id;
465            if flipped {
466                binop = binop.flip();
467            }
468            let other_outlet = succ.inputs[flipped as usize];
469            return self.fuse_binary(model, node, patch, other_outlet, binop);
470        }
471
472        if let Some(op) = succ.op_as::<ops::element_wise::ElementWiseOp>().map(|ew| ew.0.as_ref()) {
473            if let Some(op) = op.downcast_ref::<ops::math::QScale>() {
474                return self.fuse_op(
475                    model,
476                    node,
477                    patch,
478                    vec![ProtoFusedSpec::Scaler(op.scaler)],
479                    &[],
480                );
481            }
482            if let Some(op) = op.downcast_ref::<LeakyRelu>() {
483                if !self
484                    .mmm
485                    .iter()
486                    .all(|mmm| mmm.can_fuse(&FusedSpec::LeakyRelu(&tensor0(op.alpha))))
487                {
488                    return Ok(None);
489                }
490                let alpha = patch.add_const(
491                    node.name.to_string() + ".alpha",
492                    tensor0(op.alpha).cast_to_dt(self.mmm[0].internal_type())?.into_owned(),
493                )?;
494                return self.fuse_op(
495                    model,
496                    node,
497                    patch,
498                    vec![ProtoFusedSpec::LeakyRelu(node.inputs.len())],
499                    &[alpha],
500                );
501            }
502        }
503        if let Some(cast_to) = succ.op_as::<ops::cast::Cast>().map(|cast| cast.to) {
504            if (cast_to.unquantized() == i8::datum_type()
505                || cast_to.unquantized() == u8::datum_type())
506                && self.c_fact.datum_type == i32::datum_type()
507            {
508                if let Some(ProtoFusedSpec::Store(stores)) = self.micro_ops.last() {
509                    if stores.iter().any(|s| matches!(s, OutputStoreSpec::Strides { .. })) {
510                        return Ok(None);
511                    }
512                    let c_fact = cast_to.fact(self.c_fact.shape.clone());
513                    let mut patch = TypedModelPatch::fuse_with_next(
514                        model,
515                        node,
516                        Self { c_fact, ..self.clone() },
517                    )?;
518                    patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
519                    return Ok(Some(patch));
520                }
521            }
522        }
523        if let Some(AxisOp::Rm(axis)) = succ.op_as::<ops::AxisOp>() {
524            if Some(*axis) == self.c_m_axis || Some(*axis) == self.c_n_axis {
525                return Ok(None);
526            }
527            let mut new_op = self.clone();
528            new_op.c_fact.shape.remove_axis(*axis)?;
529            if let Some(c_m_axis) = &mut new_op.c_m_axis {
530                *c_m_axis -= (*c_m_axis > *axis) as usize;
531            }
532            if let Some(c_n_axis) = &mut new_op.c_n_axis {
533                *c_n_axis -= (*c_n_axis > *axis) as usize;
534            }
535            for uop in &mut new_op.micro_ops {
536                uop.rm_c_axis(*axis);
537            }
538            let mut patch = TypedModelPatch::fuse_with_next(model, node, new_op)?;
539            patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
540            return Ok(Some(patch));
541        }
542        if succ.op_is::<AxisOp>() {
543            if let &[next] = &*succ.outputs[0].successors {
544                let bin = model.node(next.node);
545                if let Some(op) = bin.op_as::<ops::binary::TypedBinOp>() {
546                    if op.0.as_linalg_binop().is_none() {
547                        return Ok(None);
548                    };
549                    let flipped = succ.inputs[0].node == node.id;
550                    let other_outlet = bin.inputs[flipped as usize];
551                    if let Some(uni) = &model.outlet_fact(other_outlet)?.uniform {
552                        let mut patch = TypedModelPatch::default();
553                        let cst =
554                            patch.add_const(&model.node(other_outlet.node).name, uni.clone())?;
555                        let output = patch.tap_model(model, node.id.into())?;
556                        let wire = wire_with_rank_broadcast(
557                            &bin.name,
558                            &mut patch,
559                            op.clone(),
560                            &if flipped { [output, cst] } else { [cst, output] },
561                        )?;
562                        let wire = patch.wire_node(&succ.name, succ.op.clone(), &wire)?[0];
563                        patch.shunt_outside(model, bin.id.into(), wire)?;
564                        return Ok(Some(patch));
565                    }
566                }
567            }
568        }
569        if let Some(op) = succ.op_as::<ops::binary::OptBinUnicast>() {
570            let in_1_fact = model.outlet_fact(succ.inputs[0])?;
571            let in_2_fact = model.outlet_fact(succ.inputs[1])?;
572            if op.binop.is::<ops::math::Add>()
573                && self.mmm.len() == 1
574                && in_1_fact.without_value() == in_2_fact.without_value()
575            {
576                let other_slot = 1 - node.outputs[0].successors[0].slot;
577                let other_input = succ.inputs[other_slot];
578                let other_input = patch.tap_model(model, other_input)?;
579                let other_fact = patch.outlet_fact(other_input)?;
580
581                if other_fact.shape == self.c_fact.shape {
582                    let other_storage = unsafe { self.mmm[0].c_view(self.c_m_axis, self.c_n_axis) };
583                    let mapping =
584                        MapOutputAxisToInput((0..other_fact.rank()).map(|x| (x, x)).collect());
585                    return self.fuse_op(
586                        model,
587                        node,
588                        patch,
589                        vec![ProtoFusedSpec::AddUnicast(other_storage, node.inputs.len(), mapping)],
590                        &[other_input],
591                    );
592                }
593            } else {
594                let mut binop = if let Some(op) = op.binop.as_linalg_binop() {
595                    op
596                } else {
597                    return Ok(None);
598                };
599                let flipped = succ.inputs[0].node == node.id;
600                if flipped {
601                    binop = binop.flip();
602                }
603                let other_outlet = succ.inputs[flipped as usize];
604                return self.fuse_binary(model, node, patch, other_outlet, binop);
605            }
606        };
607        Ok(None)
608    }
609
610    as_op!();
611}
612
613impl OptMatMul {
614    pub fn new(
615        mmm: Vec<Box<dyn MatMatMul>>,
616        mode_picker: ModePicker,
617        c_fact: TypedFact,
618        c_m_axis: Option<usize>,
619        c_n_axis: Option<usize>,
620        micro_ops: Vec<ProtoFusedSpec>,
621        trivial_packing: bool,
622    ) -> TractResult<Self> {
623        if let Some(m) = c_m_axis {
624            ensure!(m < c_fact.rank());
625        }
626        if let Some(n) = c_n_axis {
627            ensure!(n < c_fact.rank());
628        }
629        let mut it = OptMatMul {
630            mmm,
631            mode_picker,
632            c_fact,
633            c_m_axis,
634            c_n_axis,
635            micro_ops,
636            trivial_path: false,
637            trivial_packing,
638        };
639        it.update_trivial_path();
640        Ok(it)
641    }
642
643    // for auditing only (may return None if no AddMatMul is found)
644    pub fn guess_k(&self) -> Option<TDim> {
645        self.micro_ops
646            .iter()
647            .find_map(
648                |o| {
649                    if let ProtoFusedSpec::AddMatMul { geo, .. } = o {
650                        Some(geo)
651                    } else {
652                        None
653                    }
654                },
655            )
656            .map(|geo| geo.k.clone())
657    }
658
659    #[inline]
660    pub fn m(&self) -> &TDim {
661        self.c_m_axis.map(|ax| &self.c_fact.shape[ax]).unwrap_or(&TDim::Val(1))
662    }
663
664    #[inline]
665    pub fn n(&self) -> &TDim {
666        self.c_n_axis.map(|ax| &self.c_fact.shape[ax]).unwrap_or(&TDim::Val(1))
667    }
668
669    fn update_trivial_path(&mut self) {
670        self.trivial_path = self.can_use_trivial_path();
671    }
672
673    fn can_use_trivial_path(&self) -> bool {
674        self.c_fact.shape.is_concrete()
675            && self.c_fact.shape.iter().enumerate().all(|(ax, dim)| {
676                Some(ax) == self.c_m_axis || Some(ax) == self.c_n_axis || dim.is_one()
677            })
678            && self.trivial_packing
679            && self.micro_ops.iter().all(|o| o.is_trivial())
680    }
681
682    fn fuse_op(
683        &self,
684        model: &TypedModel,
685        node: &TypedNode,
686        mut patch: TypedModelPatch,
687        fused_micro_op: Vec<ProtoFusedSpec>,
688        additional_inputs: &[OutletId],
689    ) -> TractResult<Option<TypedModelPatch>> {
690        let succ = model.node(node.outputs[0].successors[0].node);
691        let mut new_op = self.clone();
692        let before_last = new_op.micro_ops.len() - 1..new_op.micro_ops.len() - 1;
693        new_op.micro_ops.splice(before_last, fused_micro_op);
694        new_op.c_fact = succ.outputs[0].fact.clone();
695        new_op.update_trivial_path();
696        let mut inputs = patch.taps(model, &node.inputs)?;
697        inputs.extend(additional_inputs.iter().cloned());
698        let output = patch.wire_node(&succ.name, new_op, &inputs)?;
699        patch.shunt_outside(model, succ.id.into(), output[0])?;
700        Ok(Some(patch))
701    }
702
703    fn fuse_binary(
704        &self,
705        model: &TypedModel,
706        node: &TypedNode,
707        mut patch: TypedModelPatch,
708        value: OutletId,
709        binop: BinOp,
710    ) -> TractResult<Option<TypedModelPatch>> {
711        let fact = model.outlet_fact(value)?;
712        let mut v = patch.tap_model(model, value)?;
713        if fact.datum_type != self.mmm[0].internal_type() {
714            v = patch.wire_node(
715                format!("{}.cast-input-{}", node.name, node.inputs.len()),
716                cast(self.mmm[0].internal_type()),
717                &[v],
718            )?[0];
719        }
720        let value = node.inputs.len();
721        let additional_input = tvec!(v);
722        if fact.shape.volume() == 1.to_dim() {
723            return self.fuse_op(
724                model,
725                node,
726                patch,
727                vec![ProtoFusedSpec::BinScalar(value, binop)],
728                &additional_input,
729            );
730        }
731        let other_shape = fact.shape.to_owned();
732        if self.c_m_axis.is_some_and(|ax| {
733            other_shape[ax] == self.c_fact.shape[ax] && other_shape[ax] == other_shape.volume()
734        }) {
735            return self.fuse_op(
736                model,
737                node,
738                patch,
739                vec![ProtoFusedSpec::BinPerRow(
740                    value,
741                    binop,
742                    MapOutputAxisToInput(tvec!((self.c_m_axis.unwrap(), self.c_m_axis.unwrap()))),
743                )],
744                &additional_input,
745            );
746        }
747        if self.c_n_axis.is_some_and(|ax| {
748            other_shape[ax] == self.c_fact.shape[ax] && other_shape[ax] == other_shape.volume()
749        }) {
750            return self.fuse_op(
751                model,
752                node,
753                patch,
754                vec![ProtoFusedSpec::BinPerCol(
755                    value,
756                    binop,
757                    MapOutputAxisToInput(tvec!((self.c_n_axis.unwrap(), self.c_n_axis.unwrap()))),
758                )],
759                &additional_input,
760            );
761        }
762        Ok(None)
763    }
764}