Skip to main content

tract_core/ops/matmul/
optimized.rs

1use crate::internal::*;
2use crate::ops::cast::{Cast, cast};
3use crate::ops::change_axes::wire_with_rank_broadcast;
4use crate::ops::nn::LeakyRelu;
5use ndarray::*;
6use tract_itertools::Itertools;
7
8use tract_linalg::mmm::{
9    AsInputValue, EagerPackedInput, FusedSpec, MatMatMul, OutputStoreSpec, PackedMatrixStorage,
10    PanelExtractInput, PanelExtractor,
11};
12use tract_linalg::pack::PackedFormat;
13use tract_linalg::{BinOp, Scaler};
14use tract_smallvec::ToSmallVec;
15
16use super::ModePicker;
17
18#[derive(Clone, Debug, PartialEq, Eq)]
19pub enum ProtoFusedSpec {
20    AddMatMul {
21        geo: AddMatMulGeometry,
22        a: usize,
23        b: usize,
24        packings: Vec<(usize, Option<PanelExtractor>)>,
25    },
26    BinScalar(usize, BinOp),
27    LeakyRelu(usize),
28    BinPerRow(usize, BinOp, MapOutputAxisToInput),
29    BinPerCol(usize, BinOp, MapOutputAxisToInput),
30    AddRowColProducts(usize, usize),
31    AddUnicast(OutputStoreSpec, usize, MapOutputAxisToInput),
32    Scaler(Scaler),
33    Store(Vec<OutputStoreSpec>),
34}
35
36impl ProtoFusedSpec {
37    pub fn format(&self, mmm: &dyn MatMatMul, mode: usize) -> String {
38        use ProtoFusedSpec::*;
39        match self {
40            AddMatMul { geo, packings: packing, .. } => {
41                let (a, b) = &mmm.packings()[packing[mode].0];
42                format!("matmul(k={}, {a:?}•{b:?})", geo.k)
43            }
44            BinScalar(_, op) => format!("scalar{op:?}"),
45            LeakyRelu(alpha) => format!("leaky_relu({alpha:?})"),
46            BinPerRow(_, op, _) => format!("row{op:?}"),
47            BinPerCol(_, op, _) => format!("col{op:?}"),
48            AddRowColProducts(_, _) => "add_row_col_product".to_string(),
49            AddUnicast(_, _, _) => "add_to_matrix".to_string(),
50            Scaler(s) => format!("scale({})", 1f32 * *s),
51            Store(_oss) => "store".to_string(),
52        }
53    }
54
55    pub fn resolve<'t>(
56        &'t self,
57        inputs: &'t [TValue],
58        output_coords: &[usize],
59        output: &Tensor,
60        mmm: &dyn MatMatMul,
61        mode: usize,
62    ) -> FusedSpec<'t> {
63        #[allow(clippy::let_and_return)]
64        let fs = match self {
65            ProtoFusedSpec::AddMatMul { geo, a, b, packings } => {
66                let a_tensor = &inputs[*a];
67                let a_storage = a_tensor.try_storage_as::<PackedMatrixStorage>().unwrap();
68                let a_idx =
69                    geo.c_to_a_axis_mapping.flat_index(output_coords, a_storage.batch_strides());
70                let a = a_storage.value_at_flat(a_idx);
71
72                let b_tensor = &inputs[*b];
73                let b_storage = b_tensor.try_storage_as::<PackedMatrixStorage>().unwrap();
74                let b_idx =
75                    geo.c_to_b_axis_mapping.flat_index(output_coords, b_storage.batch_strides());
76                let b = b_storage.value_at_flat(b_idx);
77
78                let (_a_packing, b_packing) = &mmm.packings()[packings[mode].0];
79                let pa = if let Some(extractor) = &packings[mode].1 {
80                    let data = a.downcast_ref::<EagerPackedInput>().unwrap();
81                    AsInputValue::Owned(Box::new(PanelExtractInput {
82                        format: extractor.clone(),
83                        data: data.clone(),
84                    }))
85                } else {
86                    AsInputValue::Borrowed(a)
87                };
88                assert!(
89                    b_packing.dyn_eq(b.format())
90                        || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
91                );
92                debug_assert!(pa.k().to_dim().compatible_with(&geo.k.to_dim()));
93                debug_assert!(b.k().to_dim().compatible_with(&geo.k.to_dim()));
94                FusedSpec::AddMatMul {
95                    a: pa,
96                    b: AsInputValue::Borrowed(b),
97                    packing: packings[mode].0,
98                }
99            }
100            ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
101            ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
102            ProtoFusedSpec::BinPerRow(v, op, map) => {
103                let mut v = inputs[*v].view();
104                unsafe { map.translate_view(output_coords, &mut v) }
105                FusedSpec::BinPerRow(v, *op)
106            }
107            ProtoFusedSpec::BinPerCol(v, op, map) => {
108                let mut v = inputs[*v].view();
109                unsafe { map.translate_view(output_coords, &mut v) }
110                FusedSpec::BinPerCol(v, *op)
111            }
112            ProtoFusedSpec::AddRowColProducts(row, col) => {
113                FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
114            }
115            ProtoFusedSpec::AddUnicast(store, v, map) => unsafe {
116                let mut view = inputs[*v].view();
117                map.translate_view(output_coords, &mut view);
118                FusedSpec::AddUnicast(store.wrap(&view))
119            },
120            ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
121            ProtoFusedSpec::Store(oss) => unsafe {
122                let view = output.view_offsetting_unchecked(output_coords);
123                FusedSpec::Store(oss[mode].wrap(&view))
124            },
125        };
126        fs
127    }
128
129    pub fn is_trivial(&self) -> bool {
130        match self {
131            ProtoFusedSpec::AddMatMul { geo, .. } => geo.k.as_i64().is_some(),
132            _ => true,
133        }
134    }
135
136    pub fn resolve_trivial<'t>(
137        &'t self,
138        inputs: &'t [TValue],
139        output: &mut Tensor,
140        _mmm: &dyn MatMatMul,
141        mode: usize,
142    ) -> FusedSpec<'t> {
143        #[allow(clippy::let_and_return)]
144        let fs = match self {
145            ProtoFusedSpec::AddMatMul { a, b, packings, .. } => unsafe {
146                debug_assert!(inputs.get(*a).is_some());
147                debug_assert!(inputs.get(*b).is_some());
148                let a = inputs.get_unchecked(*a);
149                let b = inputs.get_unchecked(*b);
150                debug_assert!(a.is_exotic());
151                debug_assert!(a.len() == 1);
152                debug_assert!(b.is_exotic());
153                debug_assert!(b.len() == 1);
154                let a_storage = a.try_storage_as::<PackedMatrixStorage>().unwrap_unchecked();
155                let b_storage = b.try_storage_as::<PackedMatrixStorage>().unwrap_unchecked();
156                let a = a_storage.value();
157                let b = b_storage.value();
158                debug_assert!(packings.len() == 1);
159                debug_assert!(packings[0].1.is_none()); // no panel extraction
160                #[cfg(debug_assertions)]
161                {
162                    let (a_packing, b_packing) = &_mmm.packings()[packings[mode].0];
163                    debug_assert!(
164                        a_packing.dyn_eq(a.format())
165                            || (a_packing.is::<PackedFormat>() && a_packing.r() == a.format().r())
166                    );
167                    debug_assert!(
168                        b_packing.dyn_eq(b.format())
169                            || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
170                    );
171                }
172                FusedSpec::AddMatMul {
173                    a: AsInputValue::Borrowed(a),
174                    b: AsInputValue::Borrowed(b),
175                    packing: packings[mode].0,
176                }
177            },
178            ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
179            ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
180            ProtoFusedSpec::BinPerRow(v, op, _) => {
181                let v = inputs[*v].view();
182                FusedSpec::BinPerRow(v, *op)
183            }
184            ProtoFusedSpec::BinPerCol(v, op, _) => {
185                let v = inputs[*v].view();
186                FusedSpec::BinPerCol(v, *op)
187            }
188            ProtoFusedSpec::AddRowColProducts(row, col) => {
189                FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
190            }
191            ProtoFusedSpec::AddUnicast(store, v, _) => unsafe {
192                let view = inputs[*v].view();
193                FusedSpec::AddUnicast(store.wrap(&view))
194            },
195            ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
196            ProtoFusedSpec::Store(oss) => unsafe {
197                FusedSpec::Store(oss[mode].wrap(&output.view_mut()))
198            },
199        };
200        fs
201    }
202
203    fn check_inputs(&self, inputs: &[&TypedFact]) -> TractResult<()> {
204        use ProtoFusedSpec::*;
205        match self {
206            AddMatMul { a, b, .. } => {
207                ensure!(inputs[*a].is_exotic());
208                ensure!(inputs[*b].is_exotic());
209            }
210            BinScalar(v, _)
211            | LeakyRelu(v)
212            | BinPerCol(v, _, _)
213            | BinPerRow(v, _, _)
214            | AddUnicast(_, v, _) => {
215                ensure!(inputs[*v].datum_type.is_number());
216            }
217            AddRowColProducts(row, col) => {
218                ensure!(inputs[*row].datum_type.is_number());
219                ensure!(inputs[*col].datum_type.is_number());
220            }
221            _ => (),
222        };
223        Ok(())
224    }
225
226    fn cost(&self, m: &TDim, n: &TDim, idt: DatumType) -> TVec<(Cost, TDim)> {
227        match self {
228            ProtoFusedSpec::AddMatMul { geo, .. } => {
229                tvec!((Cost::FMA(idt), m.clone() * n * &geo.k))
230            }
231            _ => tvec!(), /* FIXME maybe */
232        }
233    }
234
235    fn rm_c_axis(&mut self, axis: usize) {
236        use ProtoFusedSpec::*;
237        match self {
238            AddMatMul { geo, .. } => {
239                geo.c_to_a_axis_mapping.rm_c_axis(axis);
240                geo.c_to_b_axis_mapping.rm_c_axis(axis);
241            }
242            BinScalar(..) | Scaler(..) | AddRowColProducts(_, _) | LeakyRelu(_) => {}
243            BinPerRow(_, _, map) | BinPerCol(_, _, map) => map.rm_c_axis(axis),
244            AddUnicast(_, _, map) => {
245                map.rm_c_axis(axis);
246            }
247            Store(oss, ..) => {
248                for oss in oss {
249                    match oss {
250                        OutputStoreSpec::View { m_axis, n_axis, .. } => {
251                            if let Some(m) = m_axis {
252                                *m -= (*m > axis) as usize
253                            };
254                            if let Some(n) = n_axis {
255                                *n -= (*n > axis) as usize
256                            }
257                        }
258                        OutputStoreSpec::Strides { .. } => {}
259                    }
260                }
261            }
262        }
263    }
264}
265
266#[derive(Clone, Debug, PartialEq, Eq)]
267pub struct MapOutputAxisToInput(pub TVec<(usize, usize)>);
268
269impl MapOutputAxisToInput {
270    #[inline]
271    unsafe fn translate_view(&self, output_coords: &[usize], v: &mut TensorView) {
272        for &(out_axis, in_axis) in &self.0 {
273            unsafe { v.offset_axis(in_axis, output_coords[out_axis] as isize) }
274        }
275    }
276
277    #[inline]
278    fn rm_c_axis(&mut self, axis: usize) {
279        for (c, _) in &mut self.0 {
280            *c -= (*c > axis) as usize;
281        }
282    }
283
284    /// Compute a flat index into a PackedMatrixStorage from output coordinates and batch strides.
285    #[inline]
286    pub fn flat_index(&self, output_coords: &[usize], batch_strides: &[isize]) -> usize {
287        self.0
288            .iter()
289            .map(|&(out_axis, in_axis)| output_coords[out_axis] * batch_strides[in_axis] as usize)
290            .sum()
291    }
292}
293
294#[derive(Clone, Debug, PartialEq, Eq)]
295pub struct AddMatMulGeometry {
296    pub k: TDim,
297    pub c_to_a_axis_mapping: MapOutputAxisToInput,
298    pub c_to_b_axis_mapping: MapOutputAxisToInput,
299}
300
301#[derive(Clone, Debug, PartialEq, Eq)]
302pub struct OptMatMul {
303    pub c_fact: TypedFact,
304    pub micro_ops: Vec<ProtoFusedSpec>,
305    pub mmm: Vec<Box<dyn MatMatMul>>,
306    pub mode_picker: ModePicker,
307    pub c_m_axis: Option<usize>,
308    pub c_n_axis: Option<usize>,
309    pub trivial_packing: bool,
310    pub trivial_path: bool,
311}
312
313impl Op for OptMatMul {
314    fn name(&self) -> StaticName {
315        "OptMatMul".into()
316    }
317
318    fn info(&self) -> TractResult<Vec<String>> {
319        let m = self.c_m_axis.map(|ix| &self.c_fact.shape[ix]).unwrap_or(&TDim::Val(1));
320        let n = self.c_n_axis.map(|ix| &self.c_fact.shape[ix]).unwrap_or(&TDim::Val(1));
321        let mut infos = vec![format!(
322            "c_shape:{:?}, c_m_axis:{:?} c_n_axis:{:?} m:{} n:{}",
323            self.c_fact, self.c_m_axis, self.c_n_axis, m, n,
324        )];
325        if let Some(k) = self.guess_k() {
326            infos.push(format!("Mult: m:{} k:{} n:{} with {:?}", m, k, n, self.mmm));
327        } else {
328            infos.push(format!("Mult: {:?}", self.mmm));
329        }
330        for (mode, mmm) in self.mmm.iter().enumerate() {
331            infos.push(format!(
332                "Ops: {}",
333                self.micro_ops.iter().map(|o| o.format(&**mmm, mode)).join(" >>> ")
334            ));
335        }
336        Ok(infos)
337    }
338
339    op_as_typed_op!();
340}
341
342impl EvalOp for OptMatMul {
343    fn is_stateless(&self) -> bool {
344        true
345    }
346
347    fn eval_with_session(
348        &self,
349        _node_id: usize,
350        session: &TurnState,
351        inputs: TVec<TValue>,
352    ) -> TractResult<TVec<TValue>> {
353        unsafe {
354            let c_shape = self.c_fact.shape.eval_to_usize(&session.resolved_symbols)?;
355            let mut c = Tensor::uninitialized_dt(self.c_fact.datum_type, &c_shape)?;
356            let m = self.c_m_axis.map(|c_m| c.shape()[c_m]).unwrap_or(1);
357            let n = self.c_n_axis.map(|c_n| c.shape()[c_n]).unwrap_or(1);
358            let mode = self.mode_picker.pick(n)?;
359            let mmm = &*self.mmm[mode];
360            let mut cell = session.cached_mmm_scratch_space.borrow_mut();
361            if !cell.as_ref().is_some_and(|scratch| mmm.can_use_scratch_space(&**scratch)) {
362                *cell = None
363            }
364            let scratch = cell.get_or_insert_with(|| mmm.allocate_scratch_space());
365            if self.trivial_path {
366                let uops: Vec<FusedSpec> = self
367                    .micro_ops
368                    .iter()
369                    .map(|o| o.resolve_trivial(&inputs, &mut c, mmm, mode))
370                    .collect();
371                mmm.run_with_scratch_space(m, n, scratch.as_mut(), &uops)?;
372                Ok(tvec!(c.into_tvalue()))
373            } else {
374                let mut uops = vec![FusedSpec::ShiftLeft(0); self.micro_ops.len()];
375                let mut looping_shape: TVec<usize> = c_shape.to_smallvec();
376                if let Some(ax) = self.c_m_axis {
377                    looping_shape[ax] = 1;
378                }
379                if let Some(ax) = self.c_n_axis {
380                    looping_shape[ax] = 1;
381                }
382                for c_coords in indices(&*looping_shape) {
383                    for ix in 0..self.micro_ops.len() {
384                        *uops.get_unchecked_mut(ix) = self.micro_ops.get_unchecked(ix).resolve(
385                            &inputs,
386                            c_coords.slice(),
387                            &c,
388                            mmm,
389                            mode,
390                        );
391                    }
392                    mmm.run_with_scratch_space(m, n, scratch.as_mut(), &uops)
393                        .context("In mmm.run_with_scratch_space")?;
394                }
395                Ok(tvec!(c.into_tvalue()))
396            }
397        }
398    }
399}
400
401impl TypedOp for OptMatMul {
402    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
403        ensure!(self.c_m_axis.map(|ax| ax < self.c_fact.rank()).unwrap_or(true));
404        ensure!(self.c_n_axis.map(|ax| ax < self.c_fact.rank()).unwrap_or(true));
405        ensure!(self.trivial_path == self.can_use_trivial_path());
406        ensure!(self.mmm.iter().map(|mmm| mmm.internal_type()).all_equal());
407        for op in &self.micro_ops {
408            op.check_inputs(inputs)?;
409        }
410        Ok(tvec!(self.c_fact.clone()))
411    }
412
413    fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
414        let mut sums = HashMap::new();
415        for op in &self.micro_ops {
416            for (cost, count) in op.cost(self.m(), self.n(), self.mmm[0].internal_type()) {
417                *sums.entry(cost).or_default() += count;
418            }
419        }
420        let loops = self
421            .c_fact
422            .shape
423            .iter()
424            .enumerate()
425            .map(|(ix, d)| {
426                if Some(ix) == self.c_m_axis || Some(ix) == self.c_n_axis {
427                    1.to_dim()
428                } else {
429                    d.clone()
430                }
431            })
432            .product::<TDim>();
433        for s in &mut sums.values_mut() {
434            *s *= &loops;
435        }
436        Ok(sums.into_iter().collect())
437    }
438
439    fn fuse(&self, model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
440        use crate::ops;
441        rule_if!(node.outputs.len() == 1);
442        rule_if!(node.outputs[0].successors.len() == 1);
443        rule_if!(!model.output_outlets()?.contains(&node.id.into()));
444        let succ = model.node(node.outputs[0].successors[0].node);
445        let mut patch = TypedModelPatch::new(format!("fusing {succ}"));
446
447        if let Some(op) = succ.op_as::<ops::binary::TypedBinOp>() {
448            rule_if_some!(mut binop = op.0.as_linalg_binop());
449            let flipped = succ.inputs[0].node == node.id;
450            if flipped {
451                binop = binop.flip();
452            }
453            let other_outlet = succ.inputs[flipped as usize];
454            return self.fuse_binary(model, node, patch, other_outlet, binop);
455        }
456        if let Some(op) = succ.op_as::<ops::binary::OptBinByScalar>() {
457            rule_if_some!(mut binop = op.binop.as_linalg_binop());
458            let flipped = succ.inputs[0].node == node.id;
459            if flipped {
460                binop = binop.flip();
461            }
462            let other_outlet = succ.inputs[flipped as usize];
463            return self.fuse_binary(model, node, patch, other_outlet, binop);
464        }
465
466        if let Some(op) = succ.op_as::<ops::element_wise::ElementWiseOp>().map(|ew| ew.0.as_ref()) {
467            if let Some(op) = op.downcast_ref::<ops::math::QScale>() {
468                return self.fuse_op(
469                    model,
470                    node,
471                    patch,
472                    vec![ProtoFusedSpec::Scaler(op.scaler)],
473                    &[],
474                );
475            }
476            if let Some(op) = op.downcast_ref::<LeakyRelu>() {
477                rule_if!(
478                    self.mmm
479                        .iter()
480                        .all(|mmm| mmm.can_fuse(&FusedSpec::LeakyRelu(&tensor0(op.alpha))))
481                );
482                let alpha = patch.add_const(
483                    node.name.to_string() + ".alpha",
484                    tensor0(op.alpha).cast_to_dt(self.mmm[0].internal_type())?.into_owned(),
485                )?;
486                return self.fuse_op(
487                    model,
488                    node,
489                    patch,
490                    vec![ProtoFusedSpec::LeakyRelu(node.inputs.len())],
491                    &[alpha],
492                );
493            }
494        }
495        if let Some(cast_to) = succ.op_as::<ops::cast::Cast>().map(|cast| cast.to)
496            && (((cast_to.unquantized() == i8::datum_type()
497                || cast_to.unquantized() == u8::datum_type())
498                && self.c_fact.datum_type == i32::datum_type())
499                || self.mmm.iter().all(|m| m.stores().contains(&cast_to)))
500            && let Some(ProtoFusedSpec::Store(stores)) = self.micro_ops.last()
501        {
502            if stores.iter().any(|s| matches!(s, OutputStoreSpec::Strides { .. })) {
503                return Ok(None);
504            }
505            let c_fact = cast_to.fact(self.c_fact.shape.clone());
506            let mut patch =
507                TypedModelPatch::fuse_with_next(model, node, Self { c_fact, ..self.clone() })?;
508            patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
509            return Ok(Some(patch));
510        }
511        if let Some(AxisOp::Rm(axis)) = succ.op_as::<ops::AxisOp>() {
512            rule_if!(Some(*axis) != self.c_m_axis);
513            rule_if!(Some(*axis) != self.c_n_axis);
514            let mut new_op = self.clone();
515            new_op.c_fact.shape.remove_axis(*axis)?;
516            if let Some(c_m_axis) = &mut new_op.c_m_axis {
517                *c_m_axis -= (*c_m_axis > *axis) as usize;
518            }
519            if let Some(c_n_axis) = &mut new_op.c_n_axis {
520                *c_n_axis -= (*c_n_axis > *axis) as usize;
521            }
522            for uop in &mut new_op.micro_ops {
523                uop.rm_c_axis(*axis);
524            }
525            let mut patch = TypedModelPatch::fuse_with_next(model, node, new_op)?;
526            patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
527            return Ok(Some(patch));
528        }
529        if (succ.op_is::<AxisOp>() || succ.op_is::<IntoShape>())
530            && let &[next] = &*succ.outputs[0].successors
531        {
532            let next_node = model.node(next.node);
533            if let Some(cast) = next_node.op_as::<Cast>() {
534                let mut patch = TypedModelPatch::default();
535                let mut wire = patch.tap_model(model, node.id.into())?;
536                wire = patch.wire_node(&next_node.name, cast.clone(), &[wire])?[0];
537                wire = patch.wire_node(&succ.name, succ.op.clone(), &[wire])?[0];
538                patch.shunt_outside(model, next_node.id.into(), wire)?;
539                return Ok(Some(patch));
540            } else if let Some(op) = next_node.op_as::<ops::binary::TypedBinOp>() {
541                rule_if!(op.0.as_linalg_binop().is_some());
542                let flipped = succ.inputs[0].node == node.id;
543                let other_outlet = next_node.inputs[flipped as usize];
544                if let Some(uni) = &model.outlet_fact(other_outlet)?.uniform {
545                    let mut patch = TypedModelPatch::default();
546                    let cst = patch.add_const(&model.node(other_outlet.node).name, uni.clone())?;
547                    let output = patch.tap_model(model, node.id.into())?;
548                    let wire = wire_with_rank_broadcast(
549                        &next_node.name,
550                        &mut patch,
551                        op.clone(),
552                        &if flipped { [output, cst] } else { [cst, output] },
553                    )?;
554                    let wire = patch.wire_node(&succ.name, succ.op.clone(), &wire)?[0];
555                    patch.shunt_outside(model, next_node.id.into(), wire)?;
556                    return Ok(Some(patch));
557                }
558            }
559        }
560        if let Some(op) = succ.op_as::<ops::binary::OptBinUnicast>() {
561            let in_1_fact = model.outlet_fact(succ.inputs[0])?;
562            let in_2_fact = model.outlet_fact(succ.inputs[1])?;
563            if op.binop.is::<ops::math::Add>()
564                && self.mmm.len() == 1
565                && in_1_fact.without_value() == in_2_fact.without_value()
566            {
567                let other_slot = 1 - node.outputs[0].successors[0].slot;
568                let other_input = succ.inputs[other_slot];
569                let other_input = patch.tap_model(model, other_input)?;
570                let other_fact = patch.outlet_fact(other_input)?;
571
572                if other_fact.shape == self.c_fact.shape {
573                    let other_storage = unsafe { self.mmm[0].c_view(self.c_m_axis, self.c_n_axis) };
574                    let mapping =
575                        MapOutputAxisToInput((0..other_fact.rank()).map(|x| (x, x)).collect());
576                    return self.fuse_op(
577                        model,
578                        node,
579                        patch,
580                        vec![ProtoFusedSpec::AddUnicast(other_storage, node.inputs.len(), mapping)],
581                        &[other_input],
582                    );
583                }
584            } else {
585                rule_if_some!(mut binop = op.binop.as_linalg_binop());
586                let flipped = succ.inputs[0].node == node.id;
587                if flipped {
588                    binop = binop.flip();
589                }
590                let other_outlet = succ.inputs[flipped as usize];
591                return self.fuse_binary(model, node, patch, other_outlet, binop);
592            }
593        };
594        Ok(None)
595    }
596
597    as_op!();
598}
599
600impl OptMatMul {
601    pub fn new(
602        mmm: Vec<Box<dyn MatMatMul>>,
603        mode_picker: ModePicker,
604        c_fact: TypedFact,
605        c_m_axis: Option<usize>,
606        c_n_axis: Option<usize>,
607        micro_ops: Vec<ProtoFusedSpec>,
608        trivial_packing: bool,
609    ) -> TractResult<Self> {
610        if let Some(m) = c_m_axis {
611            ensure!(m < c_fact.rank());
612        }
613        if let Some(n) = c_n_axis {
614            ensure!(n < c_fact.rank());
615        }
616        let mut it = OptMatMul {
617            mmm,
618            mode_picker,
619            c_fact,
620            c_m_axis,
621            c_n_axis,
622            micro_ops,
623            trivial_path: false,
624            trivial_packing,
625        };
626        it.update_trivial_path();
627        Ok(it)
628    }
629
630    // for auditing only (may return None if no AddMatMul is found)
631    pub fn guess_k(&self) -> Option<TDim> {
632        self.micro_ops
633            .iter()
634            .find_map(
635                |o| {
636                    if let ProtoFusedSpec::AddMatMul { geo, .. } = o { Some(geo) } else { None }
637                },
638            )
639            .map(|geo| geo.k.clone())
640    }
641
642    #[inline]
643    pub fn m(&self) -> &TDim {
644        self.c_m_axis.map(|ax| &self.c_fact.shape[ax]).unwrap_or(&TDim::Val(1))
645    }
646
647    #[inline]
648    pub fn n(&self) -> &TDim {
649        self.c_n_axis.map(|ax| &self.c_fact.shape[ax]).unwrap_or(&TDim::Val(1))
650    }
651
652    fn update_trivial_path(&mut self) {
653        self.trivial_path = self.can_use_trivial_path();
654    }
655
656    fn can_use_trivial_path(&self) -> bool {
657        self.c_fact.shape.is_concrete()
658            && self.c_fact.shape.iter().enumerate().all(|(ax, dim)| {
659                Some(ax) == self.c_m_axis || Some(ax) == self.c_n_axis || dim.is_one()
660            })
661            && self.trivial_packing
662            && self.micro_ops.iter().all(|o| o.is_trivial())
663    }
664
665    fn fuse_op(
666        &self,
667        model: &TypedModel,
668        node: &TypedNode,
669        mut patch: TypedModelPatch,
670        fused_micro_op: Vec<ProtoFusedSpec>,
671        additional_inputs: &[OutletId],
672    ) -> TractResult<Option<TypedModelPatch>> {
673        let succ = model.node(node.outputs[0].successors[0].node);
674        let mut new_op = self.clone();
675        let before_last = new_op.micro_ops.len() - 1..new_op.micro_ops.len() - 1;
676        new_op.micro_ops.splice(before_last, fused_micro_op);
677        new_op.c_fact = succ.outputs[0].fact.clone();
678        new_op.update_trivial_path();
679        let mut inputs = patch.taps(model, &node.inputs)?;
680        inputs.extend(additional_inputs.iter().cloned());
681        let output = patch.wire_node(&succ.name, new_op, &inputs)?;
682        patch.shunt_outside(model, succ.id.into(), output[0])?;
683        Ok(Some(patch))
684    }
685
686    fn fuse_binary(
687        &self,
688        model: &TypedModel,
689        node: &TypedNode,
690        mut patch: TypedModelPatch,
691        value: OutletId,
692        binop: BinOp,
693    ) -> TractResult<Option<TypedModelPatch>> {
694        let fact = model.outlet_fact(value)?;
695        let mut v = patch.tap_model(model, value)?;
696        if fact.datum_type != self.mmm[0].internal_type() {
697            v = patch.wire_node(
698                format!("{}.cast-input-{}", node.name, node.inputs.len()),
699                cast(self.mmm[0].internal_type()),
700                &[v],
701            )?[0];
702        }
703        let value = node.inputs.len();
704        let additional_input = tvec!(v);
705        if fact.shape.volume() == 1.to_dim() {
706            return self.fuse_op(
707                model,
708                node,
709                patch,
710                vec![ProtoFusedSpec::BinScalar(value, binop)],
711                &additional_input,
712            );
713        }
714        let other_shape = fact.shape.to_owned();
715        if self.c_m_axis.is_some_and(|ax| {
716            other_shape[ax] == self.c_fact.shape[ax] && other_shape[ax] == other_shape.volume()
717        }) {
718            return self.fuse_op(
719                model,
720                node,
721                patch,
722                vec![ProtoFusedSpec::BinPerRow(
723                    value,
724                    binop,
725                    MapOutputAxisToInput(tvec!((self.c_m_axis.unwrap(), self.c_m_axis.unwrap()))),
726                )],
727                &additional_input,
728            );
729        }
730        if self.c_n_axis.is_some_and(|ax| {
731            other_shape[ax] == self.c_fact.shape[ax] && other_shape[ax] == other_shape.volume()
732        }) {
733            return self.fuse_op(
734                model,
735                node,
736                patch,
737                vec![ProtoFusedSpec::BinPerCol(
738                    value,
739                    binop,
740                    MapOutputAxisToInput(tvec!((self.c_n_axis.unwrap(), self.c_n_axis.unwrap()))),
741                )],
742                &additional_input,
743            );
744        }
745        Ok(None)
746    }
747}