Skip to main content

tract_core/ops/matmul/
optimized.rs

1use crate::internal::*;
2use crate::ops::cast::cast;
3use crate::ops::change_axes::wire_with_rank_broadcast;
4use crate::ops::nn::LeakyRelu;
5use ndarray::*;
6use tract_itertools::Itertools;
7
8use tract_linalg::mmm::{
9    AsInputValue, EagerPackedInput, FusedSpec, MMMInputValue, MatMatMul, OutputStoreSpec,
10    PanelExtractInput, PanelExtractor,
11};
12use tract_linalg::pack::PackedFormat;
13use tract_linalg::{BinOp, Scaler};
14use tract_smallvec::ToSmallVec;
15
16use super::ModePicker;
17
18#[derive(Clone, Debug)]
19pub enum ProtoFusedSpec {
20    AddMatMul {
21        geo: AddMatMulGeometry,
22        a: usize,
23        b: usize,
24        packings: Vec<(usize, Option<PanelExtractor>)>,
25    },
26    BinScalar(usize, BinOp),
27    LeakyRelu(usize),
28    BinPerRow(usize, BinOp, MapOutputAxisToInput),
29    BinPerCol(usize, BinOp, MapOutputAxisToInput),
30    AddRowColProducts(usize, usize),
31    AddUnicast(OutputStoreSpec, usize, MapOutputAxisToInput),
32    Scaler(Scaler),
33    Store(Vec<OutputStoreSpec>),
34}
35
36impl ProtoFusedSpec {
37    pub fn format(&self, mmm: &dyn MatMatMul, mode: usize) -> String {
38        use ProtoFusedSpec::*;
39        match self {
40            AddMatMul { geo, packings: packing, .. } => {
41                let (a, b) = &mmm.packings()[packing[mode].0];
42                format!("matmul(k={}, {a:?}•{b:?})", geo.k)
43            }
44            BinScalar(_, op) => format!("scalar{op:?}"),
45            LeakyRelu(alpha) => format!("leaky_relu({alpha:?})"),
46            BinPerRow(_, op, _) => format!("row{op:?}"),
47            BinPerCol(_, op, _) => format!("col{op:?}"),
48            AddRowColProducts(_, _) => "add_row_col_product".to_string(),
49            AddUnicast(_, _, _) => "add_to_matrix".to_string(),
50            Scaler(s) => format!("scale({})", 1f32 * *s),
51            Store(_oss) => "store".to_string(),
52        }
53    }
54
55    pub fn resolve<'t>(
56        &'t self,
57        inputs: &'t [TValue],
58        output_coords: &[usize],
59        output: &Tensor,
60        mmm: &dyn MatMatMul,
61        mode: usize,
62    ) -> FusedSpec<'t> {
63        let fs = match self {
64            ProtoFusedSpec::AddMatMul { geo, a, b, packings } => {
65                let mut a = inputs[*a].view();
66                let mut b = inputs[*b].view();
67                unsafe {
68                    geo.c_to_a_axis_mapping.translate_view(output_coords, &mut a);
69                }
70                let a = a.as_slice::<Opaque>().unwrap()[0]
71                    .downcast_ref::<Box<dyn MMMInputValue>>()
72                    .unwrap();
73                unsafe {
74                    geo.c_to_b_axis_mapping.translate_view(output_coords, &mut b);
75                }
76                let b = b.as_slice::<Opaque>().unwrap()[0]
77                    .downcast_ref::<Box<dyn MMMInputValue>>()
78                    .unwrap();
79                let (_a_packing, b_packing) = &mmm.packings()[packings[mode].0];
80                let pa = if let Some(extractor) = &packings[mode].1 {
81                    let data = a.downcast_ref::<EagerPackedInput>().unwrap();
82                    AsInputValue::Owned(Box::new(PanelExtractInput {
83                        format: extractor.clone(),
84                        data: data.clone(),
85                    }))
86                } else {
87                    AsInputValue::Borrowed(&**a)
88                };
89                assert!(
90                    b_packing.same_as(b.format())
91                        || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
92                );
93                debug_assert!(pa.k().to_dim().compatible_with(&geo.k.to_dim()));
94                debug_assert!(b.k().to_dim().compatible_with(&geo.k.to_dim()));
95                FusedSpec::AddMatMul {
96                    a: pa,
97                    b: AsInputValue::Borrowed(&**b),
98                    packing: packings[mode].0,
99                }
100            }
101            ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
102            ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
103            ProtoFusedSpec::BinPerRow(v, op, map) => {
104                let mut v = inputs[*v].view();
105                unsafe { map.translate_view(output_coords, &mut v) }
106                FusedSpec::BinPerRow(v, *op)
107            }
108            ProtoFusedSpec::BinPerCol(v, op, map) => {
109                let mut v = inputs[*v].view();
110                unsafe { map.translate_view(output_coords, &mut v) }
111                FusedSpec::BinPerCol(v, *op)
112            }
113            ProtoFusedSpec::AddRowColProducts(row, col) => {
114                FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
115            }
116            ProtoFusedSpec::AddUnicast(store, v, map) => unsafe {
117                let mut view = inputs[*v].view();
118                map.translate_view(output_coords, &mut view);
119                FusedSpec::AddUnicast(store.wrap(&view))
120            },
121            ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
122            ProtoFusedSpec::Store(oss) => unsafe {
123                let view = output.view_offsetting_unchecked(output_coords);
124                FusedSpec::Store(oss[mode].wrap(&view))
125            },
126        };
127        fs
128    }
129
130    pub fn is_trivial(&self) -> bool {
131        match self {
132            ProtoFusedSpec::AddMatMul { geo, .. } => geo.k.as_i64().is_some(),
133            _ => true,
134        }
135    }
136
137    pub fn resolve_trivial<'t>(
138        &'t self,
139        inputs: &'t [TValue],
140        output: &mut Tensor,
141        _mmm: &dyn MatMatMul,
142        mode: usize,
143    ) -> FusedSpec<'t> {
144        let fs = match self {
145            ProtoFusedSpec::AddMatMul { a, b, packings, .. } => unsafe {
146                debug_assert!(inputs.get(*a).is_some());
147                debug_assert!(inputs.get(*b).is_some());
148                let a = inputs.get_unchecked(*a);
149                let b = inputs.get_unchecked(*b);
150                debug_assert!(a.datum_type().is_opaque());
151                debug_assert!(a.len() == 1);
152                debug_assert!(b.datum_type().is_opaque());
153                debug_assert!(b.len() == 1);
154                let a = a.as_slice_unchecked::<Opaque>().get_unchecked(0);
155                let b = b.as_slice_unchecked::<Opaque>().get_unchecked(0);
156                debug_assert!(a.is::<Box<dyn MMMInputValue>>());
157                debug_assert!(b.is::<Box<dyn MMMInputValue>>());
158                let a = a.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
159                let b = b.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
160                debug_assert!(packings.len() == 1);
161                debug_assert!(packings[0].1.is_none()); // no panel extraction
162                #[cfg(debug_assertions)]
163                {
164                    let (a_packing, b_packing) = &_mmm.packings()[packings[mode].0];
165                    debug_assert!(
166                        a_packing.same_as(a.format())
167                            || (a_packing.is::<PackedFormat>() && a_packing.r() == a.format().r())
168                    );
169                    debug_assert!(
170                        b_packing.same_as(b.format())
171                            || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
172                    );
173                }
174                FusedSpec::AddMatMul {
175                    a: AsInputValue::Borrowed(&**a),
176                    b: AsInputValue::Borrowed(&**b),
177                    packing: packings[mode].0,
178                }
179            },
180            ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
181            ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
182            ProtoFusedSpec::BinPerRow(v, op, _) => {
183                let v = inputs[*v].view();
184                FusedSpec::BinPerRow(v, *op)
185            }
186            ProtoFusedSpec::BinPerCol(v, op, _) => {
187                let v = inputs[*v].view();
188                FusedSpec::BinPerCol(v, *op)
189            }
190            ProtoFusedSpec::AddRowColProducts(row, col) => {
191                FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
192            }
193            ProtoFusedSpec::AddUnicast(store, v, _) => unsafe {
194                let view = inputs[*v].view();
195                FusedSpec::AddUnicast(store.wrap(&view))
196            },
197            ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
198            ProtoFusedSpec::Store(oss) => unsafe {
199                FusedSpec::Store(oss[mode].wrap(&output.view_mut()))
200            },
201        };
202        fs
203    }
204
205    fn check_inputs(&self, inputs: &[&TypedFact]) -> TractResult<()> {
206        use ProtoFusedSpec::*;
207        match self {
208            AddMatMul { a, b, .. } => {
209                ensure!(inputs[*a].datum_type == Opaque::datum_type());
210                ensure!(inputs[*b].datum_type == Opaque::datum_type());
211            }
212            BinScalar(v, _)
213            | LeakyRelu(v)
214            | BinPerCol(v, _, _)
215            | BinPerRow(v, _, _)
216            | AddUnicast(_, v, _) => {
217                ensure!(inputs[*v].datum_type.is_number());
218            }
219            AddRowColProducts(row, col) => {
220                ensure!(inputs[*row].datum_type.is_number());
221                ensure!(inputs[*col].datum_type.is_number());
222            }
223            _ => (),
224        };
225        Ok(())
226    }
227
228    fn cost(&self, m: &TDim, n: &TDim, idt: DatumType) -> TVec<(Cost, TDim)> {
229        match self {
230            ProtoFusedSpec::AddMatMul { geo, .. } => {
231                tvec!((Cost::FMA(idt), m.clone() * n * &geo.k))
232            }
233            _ => tvec!(), /* FIXME maybe */
234        }
235    }
236
237    fn rm_c_axis(&mut self, axis: usize) {
238        use ProtoFusedSpec::*;
239        match self {
240            AddMatMul { geo, .. } => {
241                geo.c_to_a_axis_mapping.rm_c_axis(axis);
242                geo.c_to_b_axis_mapping.rm_c_axis(axis);
243            }
244            BinScalar(..) | Scaler(..) | AddRowColProducts(_, _) | LeakyRelu(_) => {}
245            BinPerRow(_, _, map) | BinPerCol(_, _, map) => map.rm_c_axis(axis),
246            AddUnicast(_, _, map) => {
247                map.rm_c_axis(axis);
248            }
249            Store(oss, ..) => {
250                for oss in oss {
251                    match oss {
252                        OutputStoreSpec::View { m_axis, n_axis, .. } => {
253                            if let Some(m) = m_axis {
254                                *m -= (*m > axis) as usize
255                            };
256                            if let Some(n) = n_axis {
257                                *n -= (*n > axis) as usize
258                            }
259                        }
260                        OutputStoreSpec::Strides { .. } => {}
261                    }
262                }
263            }
264        }
265    }
266}
267
268#[derive(Clone, Debug)]
269pub struct MapOutputAxisToInput(pub TVec<(usize, usize)>);
270
271impl MapOutputAxisToInput {
272    #[inline]
273    unsafe fn translate_view(&self, output_coords: &[usize], v: &mut TensorView) {
274        for &(out_axis, in_axis) in &self.0 {
275            v.offset_axis(in_axis, output_coords[out_axis] as isize)
276        }
277    }
278
279    #[inline]
280    fn rm_c_axis(&mut self, axis: usize) {
281        for (c, _) in &mut self.0 {
282            *c -= (*c > axis) as usize;
283        }
284    }
285}
286
287#[derive(Clone, Debug)]
288pub struct AddMatMulGeometry {
289    pub k: TDim,
290    pub c_to_a_axis_mapping: MapOutputAxisToInput,
291    pub c_to_b_axis_mapping: MapOutputAxisToInput,
292}
293
294#[derive(Clone, Debug)]
295pub struct OptMatMul {
296    pub c_fact: TypedFact,
297    pub micro_ops: Vec<ProtoFusedSpec>,
298    pub mmm: Vec<Box<dyn MatMatMul>>,
299    pub mode_picker: ModePicker,
300    pub c_m_axis: Option<usize>,
301    pub c_n_axis: Option<usize>,
302    pub trivial_packing: bool,
303    pub trivial_path: bool,
304}
305
306impl Op for OptMatMul {
307    fn name(&self) -> Cow<str> {
308        "OptMatMul".into()
309    }
310
311    fn info(&self) -> TractResult<Vec<String>> {
312        let m = self.c_m_axis.map(|ix| &self.c_fact.shape[ix]).unwrap_or(&TDim::Val(1));
313        let n = self.c_n_axis.map(|ix| &self.c_fact.shape[ix]).unwrap_or(&TDim::Val(1));
314        let mut infos = vec![format!(
315            "c_shape:{:?}, c_m_axis:{:?} c_n_axis:{:?} m:{} n:{}",
316            self.c_fact, self.c_m_axis, self.c_n_axis, m, n,
317        )];
318        if let Some(k) = self.guess_k() {
319            infos.push(format!("Mult: m:{} k:{} n:{} with {:?}", m, k, n, self.mmm));
320        } else {
321            infos.push(format!("Mult: {:?}", self.mmm));
322        }
323        for (mode, mmm) in self.mmm.iter().enumerate() {
324            infos.push(format!(
325                "Ops: {}",
326                self.micro_ops.iter().map(|o| o.format(&**mmm, mode)).join(" >>> ")
327            ));
328        }
329        Ok(infos)
330    }
331
332    op_as_typed_op!();
333}
334
335impl EvalOp for OptMatMul {
336    fn is_stateless(&self) -> bool {
337        true
338    }
339
340    fn eval_with_session(
341        &self,
342        session: &SessionState,
343        inputs: TVec<TValue>,
344    ) -> TractResult<TVec<TValue>> {
345        unsafe {
346            let c_shape = self.c_fact.shape.eval_to_usize(&session.resolved_symbols)?;
347            let mut c = Tensor::uninitialized_dt(self.c_fact.datum_type, &c_shape)?;
348            let m = self.c_m_axis.map(|c_m| c.shape()[c_m]).unwrap_or(1);
349            let n = self.c_n_axis.map(|c_n| c.shape()[c_n]).unwrap_or(1);
350            let mode = self.mode_picker.pick(n)?;
351            let mmm = &*self.mmm[mode];
352            let mut cell = session.cached_mmm_scratch_space.borrow_mut();
353            if !cell.as_ref().is_some_and(|scratch| mmm.can_use_scratch_space(&**scratch)) {
354                *cell = None
355            }
356            let scratch = cell.get_or_insert_with(|| mmm.allocate_scratch_space());
357            if self.trivial_path {
358                let uops: Vec<FusedSpec> = self
359                    .micro_ops
360                    .iter()
361                    .map(|o| o.resolve_trivial(&inputs, &mut c, mmm, mode))
362                    .collect();
363                mmm.run_with_scratch_space(m, n, scratch.as_mut(), &uops)?;
364                Ok(tvec!(c.into_tvalue()))
365            } else {
366                let mut uops = vec![FusedSpec::ShiftLeft(0); self.micro_ops.len()];
367                let mut looping_shape: TVec<usize> = c_shape.to_smallvec();
368                if let Some(ax) = self.c_m_axis {
369                    looping_shape[ax] = 1;
370                }
371                if let Some(ax) = self.c_n_axis {
372                    looping_shape[ax] = 1;
373                }
374                for c_coords in indices(&*looping_shape) {
375                    for ix in 0..self.micro_ops.len() {
376                        *uops.get_unchecked_mut(ix) = self.micro_ops.get_unchecked(ix).resolve(
377                            &inputs,
378                            c_coords.slice(),
379                            &c,
380                            mmm,
381                            mode,
382                        );
383                    }
384                    mmm.run_with_scratch_space(m, n, scratch.as_mut(), &uops)
385                        .context("In mmm.run_with_scratch_space")?;
386                }
387                Ok(tvec!(c.into_tvalue()))
388            }
389        }
390    }
391}
392
393impl TypedOp for OptMatMul {
394    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
395        ensure!(self.c_m_axis.map(|ax| ax < self.c_fact.rank()).unwrap_or(true));
396        ensure!(self.c_n_axis.map(|ax| ax < self.c_fact.rank()).unwrap_or(true));
397        ensure!(self.trivial_path == self.can_use_trivial_path());
398        ensure!(self.mmm.iter().map(|mmm| mmm.internal_type()).all_equal());
399        for op in &self.micro_ops {
400            op.check_inputs(inputs)?;
401        }
402        Ok(tvec!(self.c_fact.clone()))
403    }
404
405    fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
406        let mut sums = HashMap::new();
407        for op in &self.micro_ops {
408            for (cost, count) in op.cost(self.m(), self.n(), self.mmm[0].internal_type()) {
409                *sums.entry(cost).or_default() += count;
410            }
411        }
412        let loops = self
413            .c_fact
414            .shape
415            .iter()
416            .enumerate()
417            .map(|(ix, d)| {
418                if Some(ix) == self.c_m_axis || Some(ix) == self.c_n_axis {
419                    1.to_dim()
420                } else {
421                    d.clone()
422                }
423            })
424            .product::<TDim>();
425        for s in &mut sums.values_mut() {
426            *s *= &loops;
427        }
428        Ok(sums.into_iter().collect())
429    }
430
431    fn fuse(&self, model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
432        use crate::ops;
433        if node.outputs.len() != 1
434            || node.outputs[0].successors.len() != 1
435            || model.output_outlets()?.contains(&node.id.into())
436        {
437            return Ok(None);
438        }
439        let succ = model.node(node.outputs[0].successors[0].node);
440        let mut patch = TypedModelPatch::new(format!("fusing {succ}"));
441
442        if let Some(op) = succ.op_as::<ops::binary::TypedBinOp>() {
443            let mut binop = if let Some(op) = op.0.as_linalg_binop() {
444                op
445            } else {
446                return Ok(None);
447            };
448            let flipped = succ.inputs[0].node == node.id;
449            if flipped {
450                binop = binop.flip();
451            }
452            let other_outlet = succ.inputs[flipped as usize];
453            return self.fuse_binary(model, node, patch, other_outlet, binop);
454        }
455        if let Some(op) = succ.op_as::<ops::binary::OptBinByScalar>() {
456            let mut binop = if let Some(op) = op.binop.as_linalg_binop() {
457                op
458            } else {
459                return Ok(None);
460            };
461            let flipped = succ.inputs[0].node == node.id;
462            if flipped {
463                binop = binop.flip();
464            }
465            let other_outlet = succ.inputs[flipped as usize];
466            return self.fuse_binary(model, node, patch, other_outlet, binop);
467        }
468
469        if let Some(op) = succ.op_as::<ops::element_wise::ElementWiseOp>().map(|ew| ew.0.as_ref()) {
470            if let Some(op) = op.downcast_ref::<ops::math::QScale>() {
471                return self.fuse_op(
472                    model,
473                    node,
474                    patch,
475                    vec![ProtoFusedSpec::Scaler(op.scaler)],
476                    &[],
477                );
478            }
479            if let Some(op) = op.downcast_ref::<LeakyRelu>() {
480                if !self
481                    .mmm
482                    .iter()
483                    .all(|mmm| mmm.can_fuse(&FusedSpec::LeakyRelu(&tensor0(op.alpha))))
484                {
485                    return Ok(None);
486                }
487                let alpha = patch.add_const(
488                    node.name.to_string() + ".alpha",
489                    tensor0(op.alpha).cast_to_dt(self.mmm[0].internal_type())?.into_owned(),
490                )?;
491                return self.fuse_op(
492                    model,
493                    node,
494                    patch,
495                    vec![ProtoFusedSpec::LeakyRelu(node.inputs.len())],
496                    &[alpha],
497                );
498            }
499        }
500        if let Some(cast_to) = succ.op_as::<ops::cast::Cast>().map(|cast| cast.to) {
501            if (cast_to.unquantized() == i8::datum_type()
502                || cast_to.unquantized() == u8::datum_type())
503                && self.c_fact.datum_type == i32::datum_type()
504            {
505                if let Some(ProtoFusedSpec::Store(stores)) = self.micro_ops.last() {
506                    if stores.iter().any(|s| matches!(s, OutputStoreSpec::Strides { .. })) {
507                        return Ok(None);
508                    }
509                    let c_fact = cast_to.fact(self.c_fact.shape.clone());
510                    let mut patch = TypedModelPatch::fuse_with_next(
511                        model,
512                        node,
513                        Self { c_fact, ..self.clone() },
514                    )?;
515                    patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
516                    return Ok(Some(patch));
517                }
518            }
519        }
520        if let Some(AxisOp::Rm(axis)) = succ.op_as::<ops::AxisOp>() {
521            if Some(*axis) == self.c_m_axis || Some(*axis) == self.c_n_axis {
522                return Ok(None);
523            }
524            let mut new_op = self.clone();
525            new_op.c_fact.shape.remove_axis(*axis)?;
526            if let Some(c_m_axis) = &mut new_op.c_m_axis {
527                *c_m_axis -= (*c_m_axis > *axis) as usize;
528            }
529            if let Some(c_n_axis) = &mut new_op.c_n_axis {
530                *c_n_axis -= (*c_n_axis > *axis) as usize;
531            }
532            for uop in &mut new_op.micro_ops {
533                uop.rm_c_axis(*axis);
534            }
535            let mut patch = TypedModelPatch::fuse_with_next(model, node, new_op)?;
536            patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
537            return Ok(Some(patch));
538        }
539        if succ.op_is::<AxisOp>() {
540            if let &[next] = &*succ.outputs[0].successors {
541                let bin = model.node(next.node);
542                if let Some(op) = bin.op_as::<ops::binary::TypedBinOp>() {
543                    if op.0.as_linalg_binop().is_none() {
544                        return Ok(None);
545                    };
546                    let flipped = succ.inputs[0].node == node.id;
547                    let other_outlet = bin.inputs[flipped as usize];
548                    if let Some(uni) = &model.outlet_fact(other_outlet)?.uniform {
549                        let mut patch = TypedModelPatch::default();
550                        let cst =
551                            patch.add_const(&model.node(other_outlet.node).name, uni.clone())?;
552                        let output = patch.tap_model(model, node.id.into())?;
553                        let wire = wire_with_rank_broadcast(
554                            &bin.name,
555                            &mut patch,
556                            op.clone(),
557                            &if flipped { [output, cst] } else { [cst, output] },
558                        )?;
559                        let wire = patch.wire_node(&succ.name, succ.op.clone(), &wire)?[0];
560                        patch.shunt_outside(model, bin.id.into(), wire)?;
561                        return Ok(Some(patch));
562                    }
563                }
564            }
565        }
566        if let Some(op) = succ.op_as::<ops::binary::OptBinUnicast>() {
567            let in_1_fact = model.outlet_fact(succ.inputs[0])?;
568            let in_2_fact = model.outlet_fact(succ.inputs[1])?;
569            if op.binop.is::<ops::math::Add>()
570                && self.mmm.len() == 1
571                && in_1_fact.without_value() == in_2_fact.without_value()
572            {
573                let other_slot = 1 - node.outputs[0].successors[0].slot;
574                let other_input = succ.inputs[other_slot];
575                let other_input = patch.tap_model(model, other_input)?;
576                let other_fact = patch.outlet_fact(other_input)?;
577
578                if other_fact.shape == self.c_fact.shape {
579                    let other_storage = unsafe { self.mmm[0].c_view(self.c_m_axis, self.c_n_axis) };
580                    let mapping =
581                        MapOutputAxisToInput((0..other_fact.rank()).map(|x| (x, x)).collect());
582                    return self.fuse_op(
583                        model,
584                        node,
585                        patch,
586                        vec![ProtoFusedSpec::AddUnicast(other_storage, node.inputs.len(), mapping)],
587                        &[other_input],
588                    );
589                }
590            } else {
591                let mut binop = if let Some(op) = op.binop.as_linalg_binop() {
592                    op
593                } else {
594                    return Ok(None);
595                };
596                let flipped = succ.inputs[0].node == node.id;
597                if flipped {
598                    binop = binop.flip();
599                }
600                let other_outlet = succ.inputs[flipped as usize];
601                return self.fuse_binary(model, node, patch, other_outlet, binop);
602            }
603        };
604        Ok(None)
605    }
606
607    as_op!();
608}
609
610impl OptMatMul {
611    pub fn new(
612        mmm: Vec<Box<dyn MatMatMul>>,
613        mode_picker: ModePicker,
614        c_fact: TypedFact,
615        c_m_axis: Option<usize>,
616        c_n_axis: Option<usize>,
617        micro_ops: Vec<ProtoFusedSpec>,
618        trivial_packing: bool,
619    ) -> TractResult<Self> {
620        if let Some(m) = c_m_axis {
621            ensure!(m < c_fact.rank());
622        }
623        if let Some(n) = c_n_axis {
624            ensure!(n < c_fact.rank());
625        }
626        let mut it = OptMatMul {
627            mmm,
628            mode_picker,
629            c_fact,
630            c_m_axis,
631            c_n_axis,
632            micro_ops,
633            trivial_path: false,
634            trivial_packing,
635        };
636        it.update_trivial_path();
637        Ok(it)
638    }
639
640    // for auditing only (may return None if no AddMatMul is found)
641    pub fn guess_k(&self) -> Option<TDim> {
642        self.micro_ops
643            .iter()
644            .find_map(
645                |o| {
646                    if let ProtoFusedSpec::AddMatMul { geo, .. } = o {
647                        Some(geo)
648                    } else {
649                        None
650                    }
651                },
652            )
653            .map(|geo| geo.k.clone())
654    }
655
656    #[inline]
657    pub fn m(&self) -> &TDim {
658        self.c_m_axis.map(|ax| &self.c_fact.shape[ax]).unwrap_or(&TDim::Val(1))
659    }
660
661    #[inline]
662    pub fn n(&self) -> &TDim {
663        self.c_n_axis.map(|ax| &self.c_fact.shape[ax]).unwrap_or(&TDim::Val(1))
664    }
665
666    fn update_trivial_path(&mut self) {
667        self.trivial_path = self.can_use_trivial_path();
668    }
669
670    fn can_use_trivial_path(&self) -> bool {
671        self.c_fact.shape.is_concrete()
672            && self.c_fact.shape.iter().enumerate().all(|(ax, dim)| {
673                Some(ax) == self.c_m_axis || Some(ax) == self.c_n_axis || dim.is_one()
674            })
675            && self.trivial_packing
676            && self.micro_ops.iter().all(|o| o.is_trivial())
677    }
678
679    fn fuse_op(
680        &self,
681        model: &TypedModel,
682        node: &TypedNode,
683        mut patch: TypedModelPatch,
684        fused_micro_op: Vec<ProtoFusedSpec>,
685        additional_inputs: &[OutletId],
686    ) -> TractResult<Option<TypedModelPatch>> {
687        let succ = model.node(node.outputs[0].successors[0].node);
688        let mut new_op = self.clone();
689        let before_last = new_op.micro_ops.len() - 1..new_op.micro_ops.len() - 1;
690        new_op.micro_ops.splice(before_last, fused_micro_op);
691        new_op.c_fact = succ.outputs[0].fact.clone();
692        new_op.update_trivial_path();
693        let mut inputs = patch.taps(model, &node.inputs)?;
694        inputs.extend(additional_inputs.iter().cloned());
695        let output = patch.wire_node(&succ.name, new_op, &inputs)?;
696        patch.shunt_outside(model, succ.id.into(), output[0])?;
697        Ok(Some(patch))
698    }
699
700    fn fuse_binary(
701        &self,
702        model: &TypedModel,
703        node: &TypedNode,
704        mut patch: TypedModelPatch,
705        value: OutletId,
706        binop: BinOp,
707    ) -> TractResult<Option<TypedModelPatch>> {
708        let fact = model.outlet_fact(value)?;
709        let mut v = patch.tap_model(model, value)?;
710        if fact.datum_type != self.mmm[0].internal_type() {
711            v = patch.wire_node(
712                format!("{}.cast-input-{}", node.name, node.inputs.len()),
713                cast(self.mmm[0].internal_type()),
714                &[v],
715            )?[0];
716        }
717        let value = node.inputs.len();
718        let additional_input = tvec!(v);
719        if fact.shape.volume() == 1.to_dim() {
720            return self.fuse_op(
721                model,
722                node,
723                patch,
724                vec![ProtoFusedSpec::BinScalar(value, binop)],
725                &additional_input,
726            );
727        }
728        let other_shape = fact.shape.to_owned();
729        if self.c_m_axis.is_some_and(|ax| {
730            other_shape[ax] == self.c_fact.shape[ax] && other_shape[ax] == other_shape.volume()
731        }) {
732            return self.fuse_op(
733                model,
734                node,
735                patch,
736                vec![ProtoFusedSpec::BinPerRow(
737                    value,
738                    binop,
739                    MapOutputAxisToInput(tvec!((self.c_m_axis.unwrap(), self.c_m_axis.unwrap()))),
740                )],
741                &additional_input,
742            );
743        }
744        if self.c_n_axis.is_some_and(|ax| {
745            other_shape[ax] == self.c_fact.shape[ax] && other_shape[ax] == other_shape.volume()
746        }) {
747            return self.fuse_op(
748                model,
749                node,
750                patch,
751                vec![ProtoFusedSpec::BinPerCol(
752                    value,
753                    binop,
754                    MapOutputAxisToInput(tvec!((self.c_n_axis.unwrap(), self.c_n_axis.unwrap()))),
755                )],
756                &additional_input,
757            );
758        }
759        Ok(None)
760    }
761}