Skip to main content

tract_core/ops/einsum/
einsum_matmul.rs

1use std::fmt::Formatter;
2use std::ops::Deref;
3
4use tract_itertools::{izip, multiunzip};
5use tract_linalg::block_quant::PackedBlockQuantFormat;
6use tract_linalg::pack::PackedFormat;
7
8use super::*;
9use crate::ops::cast::cast;
10use crate::ops::math::add;
11use crate::ops::matmul::optimized::{
12    AddMatMulGeometry, MapOutputAxisToInput, OptMatMul, ProtoFusedSpec,
13};
14use crate::ops::matmul::pack::{OptMatMulPack, OptSimpleMatMulPack};
15use crate::ops::matmul::quant::{
16    combine_scales, compensate_zero_points, requant, wire_ensure_q8_flavour,
17};
18use crate::ops::matmul::ModePicker;
19use crate::ops::nn::{Reduce, Reducer};
20
21pub fn detect_all(model: &mut TypedModel) -> TractResult<()> {
22    Rewriter::default().with_rule_for("detect-matmul-einsum", detect_rule).rewrite(&(), model)
23}
24
25pub fn flatten_all(model: &mut TypedModel) -> TractResult<()> {
26    Rewriter::default().with_rule_for("flatten-matmul-einsum", flatten_rule).rewrite(&(), model)
27}
28
29#[derive(Clone, Hash, PartialEq)]
30pub struct EinSumMatMul {
31    pub op: EinSum,
32    pub m_axis: char,
33    pub k_axis: char,
34    pub n_axis: char,
35    pub m: TDim,
36    pub k: TDim,
37    pub n: TDim,
38}
39
40impl EinSumMatMul {
41    pub fn m_axis(&self) -> &Axis {
42        self.op.axes.axis(self.m_axis).unwrap()
43    }
44    pub fn k_axis(&self) -> &Axis {
45        self.op.axes.axis(self.k_axis).unwrap()
46    }
47    pub fn n_axis(&self) -> &Axis {
48        self.op.axes.axis(self.n_axis).unwrap()
49    }
50    pub fn a_m(&self) -> usize {
51        self.m_axis().inputs[0][0]
52    }
53    pub fn a_k(&self) -> usize {
54        self.k_axis().inputs[0][0]
55    }
56    pub fn b_k(&self) -> usize {
57        self.k_axis().inputs[1][0]
58    }
59    pub fn b_n(&self) -> usize {
60        self.n_axis().inputs[1][0]
61    }
62    pub fn c_m(&self) -> Option<usize> {
63        self.m_axis().outputs[0].first().cloned()
64    }
65    pub fn c_n(&self) -> Option<usize> {
66        self.n_axis().outputs[0].first().cloned()
67    }
68
69    fn new(
70        op: EinSum,
71        m_axis: char,
72        k_axis: char,
73        n_axis: char,
74        m: TDim,
75        k: TDim,
76        n: TDim,
77    ) -> Self {
78        Self { op, m_axis, k_axis, n_axis, m, k, n }
79    }
80}
81
82impl Debug for EinSumMatMul {
83    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
84        write!(
85            f,
86            "EinsumMatMul: {} {:?} m: {}={}; k: {}={}; n: {}={}",
87            self.op.axes,
88            self.op.operating_dt,
89            self.m_axis,
90            self.m,
91            self.k_axis,
92            self.k,
93            self.n_axis,
94            self.n
95        )
96    }
97}
98
99impl Deref for EinSumMatMul {
100    type Target = EinSum;
101    fn deref(&self) -> &Self::Target {
102        &self.op
103    }
104}
105
106impl Op for EinSumMatMul {
107    fn name(&self) -> Cow<str> {
108        "EinSumMatMul".into()
109    }
110
111    op_as_typed_op!();
112    impl_op_same_as!();
113}
114
115impl EvalOp for EinSumMatMul {
116    fn is_stateless(&self) -> bool {
117        true
118    }
119    fn eval_with_session(
120        &self,
121        session: &SessionState,
122        inputs: TVec<TValue>,
123    ) -> TractResult<TVec<TValue>> {
124        self.op.eval_with_session(session, inputs)
125    }
126}
127
128impl TypedOp for EinSumMatMul {
129    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
130        self.op.output_facts(inputs)
131    }
132
133    fn codegen(
134        &self,
135        model: &TypedModel,
136        node: &TypedNode,
137    ) -> TractResult<Option<TypedModelPatch>> {
138        // deal with parametric quantization extra inputs
139        if node.inputs.len() == 9 {
140            ensure!(self.op.q_params.is_some());
141            return dequant(model, node, self).map(Some);
142        }
143        ensure!(node.inputs.len() == 2);
144        let (a, b) = model.node_input_facts(node.id)?.into_iter().collect_tuple().unwrap();
145        // at this stage a and b must NOT be packed yet. if they are Opaque, we can assume it's just compression
146        let must_transpose = if let Some(of) = a.opaque_fact() {
147            ensure!(of.is::<BlockQuantFact>());
148            false
149        } else if let Some(of) = b.opaque_fact() {
150            ensure!(of.is::<BlockQuantFact>());
151            true
152        } else {
153            match (self.m.as_i64(), self.n.as_i64()) {
154                (Some(m), Some(n)) => m < n,
155                (None, Some(n)) => n >= 8,
156                (Some(_), _) => false,
157                _ => (self.n.clone() - &self.m).prove_positive_or_zero(),
158            }
159        };
160        if must_transpose {
161            let mut op = self.clone();
162            op.op.axes.iter_all_axes_mut().for_each(|axis| axis.inputs.swap(0, 1));
163            std::mem::swap(&mut op.m_axis, &mut op.n_axis);
164            std::mem::swap(&mut op.m, &mut op.n);
165            return TypedModelPatch::replace_single_op(
166                model,
167                node,
168                &[node.inputs[1], node.inputs[0]],
169                op,
170            )
171            .map(|p| Some(p.with_context("transposing")));
172        }
173        // opt mat mul assumes we have at least one m or n
174        if self.c_m().is_some() || self.c_n().is_some() {
175            return optimized_mat_mul(model, node, self)
176                .map(|opt| opt.map(|p| p.with_context("optimizing")));
177        }
178        Ok(None)
179    }
180
181    as_op!();
182}
183
184pub(crate) fn detect_rule(
185    _ctx: &(),
186    model: &TypedModel,
187    node: &TypedNode,
188    _name: &str,
189    op: &EinSum,
190) -> TractResult<Option<TypedModelPatch>> {
191    if node.inputs.len() != (2 + op.q_params.is_some() as usize * 7) {
192        return Ok(None);
193    }
194    let input_facts = model.node_input_facts(node.id)?;
195    let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
196    let output_shape = super::eval::output_shape(&op.axes, &input_shapes)?;
197    let k_axes: TVec<&Axis> = op
198        .axes
199        .iter_all_axes()
200        // Filter possible candidates (should be one time in each inputs but not in output)
201        .filter(|a| a.inputs[0].len() == 1 && a.inputs[1].len() == 1 && a.outputs[0].is_empty())
202        .collect();
203
204    let non_trivial_k_axis = k_axes
205        .iter()
206        .filter(|a| {
207            !input_shapes[0][a.inputs[0][0]].is_one() || !input_shapes[1][a.inputs[1][0]].is_one()
208        })
209        .copied()
210        .collect::<TVec<_>>();
211
212    let k_axis = if non_trivial_k_axis.len() > 1 {
213        return regroup_k_axes(op, model, node, non_trivial_k_axis);
214    } else {
215        non_trivial_k_axis.first().or_else(|| k_axes.first()).copied()
216    };
217    let Some(k_axis) = k_axis else { return inject_k_axis(op, model, node).map(Some) };
218
219    let mut possible_m_axes: Vec<_> = op
220        .axes
221        .iter_all_axes()
222        .filter(|a| {
223            a.inputs[0].len() == 1
224                && (a.inputs[1].is_empty() || input_shapes[1][a.inputs[1][0]].is_one())
225                && (a.outputs[0].len() == 1
226                    || (input_shapes[0][a.inputs[0][0]].is_one() && a.inputs[1].is_empty()))
227        })
228        .collect();
229
230    // Prioritize obvious m-axes
231    if possible_m_axes.iter().any(|a| !a.outputs[0].is_empty()) {
232        possible_m_axes.retain(|a| !a.outputs[0].is_empty());
233    }
234
235    let m_axis = possible_m_axes
236        .into_iter()
237        .max_by_key(|a| input_shapes[0][a.inputs[0][0]].as_i64().unwrap_or(i64::MAX));
238
239    let Some(m_axis) = m_axis else {
240        return inject_m_or_n_axis(op, model, node, false).map(Some);
241    };
242
243    let n_axis = op
244        .axes
245        .iter_all_axes()
246        .filter(|a| {
247            (a.inputs[0].is_empty() || input_shapes[0][a.inputs[0][0]].is_one())
248                && a.inputs[1].len() == 1
249                && a.outputs[0].len() == 1
250                && *a != m_axis
251        })
252        .max_by_key(|a| input_shapes[1][a.inputs[1][0]].as_i64().unwrap_or(i64::MAX));
253    let Some(n_axis) = n_axis else {
254        return inject_m_or_n_axis(op, model, node, true).map(Some);
255    };
256    for axis in op.axes.iter_all_axes() {
257        let one = TDim::one();
258        let in_left =
259            axis.inputs[0].first().map(|pos| &input_shapes[0][*pos]).unwrap_or(&one) != &one;
260        let in_right =
261            axis.inputs[1].first().map(|pos| &input_shapes[1][*pos]).unwrap_or(&one) != &one;
262        let in_out = axis.outputs[0].first().map(|pos| &output_shape[*pos]).unwrap_or(&one) != &one;
263        if (in_left ^ in_right) && !in_out {
264            return Ok(None);
265            // return Ok(AxesOrPatch::NotAMatMul(
266            //     "non trivial single-side disappearing axis",
267            //     vec![axis],
268            // ));
269        }
270    }
271    let m = input_shapes[0][m_axis.inputs[0][0]].clone();
272    let k = input_shapes[0][k_axis.inputs[0][0]].clone();
273    let n = input_shapes[1][n_axis.inputs[1][0]].clone();
274    TypedModelPatch::replace_single_op(
275        model,
276        node,
277        &node.inputs,
278        EinSumMatMul::new(op.clone(), m_axis.repr, k_axis.repr, n_axis.repr, m, k, n),
279    )
280    .map(Some)
281}
282
283pub(super) fn inject_k_axis(
284    op: &EinSum,
285    model: &TypedModel,
286    node: &TypedNode,
287) -> TractResult<TypedModelPatch> {
288    let mut new_axes = op.axes.clone();
289    let name = &node.name;
290    let mut patch = TypedModelPatch::new("inject k axis");
291    let mut wire = patch.taps(model, &node.inputs)?;
292    let repr = new_axes.available_label();
293    new_axes = new_axes.with_extra_axis(repr, InOut::In(0), 0)?.with_extra_axis_occurency(
294        repr,
295        InOut::In(1),
296        0,
297    )?;
298    wire[0] = patch.wire_node(format!("{name}.add_k.0"), AxisOp::Add(0), &[wire[0]])?[0];
299    wire[1] = patch.wire_node(format!("{name}.add_k.1"), AxisOp::Add(0), &[wire[1]])?[0];
300    wire = patch.wire_node(&node.name, EinSum { axes: new_axes, ..op.clone() }, &wire)?;
301    patch.shunt_outside(model, node.id.into(), wire[0])?;
302    Ok(patch)
303}
304
305pub(super) fn regroup_k_axes(
306    op: &EinSum,
307    model: &TypedModel,
308    node: &TypedNode,
309    mut k_axes: TVec<&Axis>,
310) -> TractResult<Option<TypedModelPatch>> {
311    let input_facts = model.node_input_facts(node.id)?;
312    let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
313    let contig_in_a = k_axes
314        .iter()
315        .map(|axis| axis.inputs[0][0])
316        .sorted()
317        .tuple_windows()
318        .all(|(a, b)| a + 1 == b);
319    if contig_in_a {
320        k_axes.sort_by_key(|ax| ax.inputs[0][0]);
321    } else {
322        k_axes.sort_by_key(|ax| ax.inputs[1][0]);
323    }
324    let k_dims: TVec<_> =
325        k_axes.iter().map(|ax| input_shapes[0][ax.inputs[0][0]].clone()).collect();
326    let k: TDim = k_dims.iter().product();
327    let mut patch = TypedModelPatch::default();
328    let mut wires = patch.taps(model, &node.inputs)?;
329    let mut exprs: Vec<String> =
330        (0..2).map(|slot| op.axes.axes(InOut::In(slot)).map(|ax| ax.repr).join("")).collect();
331    for slot in 0..2 {
332        if k_axes.iter().map(|ax| ax.inputs[slot][0]).tuple_windows().any(|(a, b)| a + 1 != b) {
333            let after = op
334                .axes
335                .axes(InOut::In(slot))
336                .filter(|ax| !k_axes.contains(ax))
337                .chain(k_axes.iter().copied())
338                .map(|ax| ax.repr)
339                .join("");
340            let transpose =
341                AxesMapping::from_strs(&[&exprs[slot]], &[&after])?.translate_to_axis_ops()?;
342            for (ix, op) in transpose.into_iter().enumerate() {
343                wires[slot] = patch.wire_node(
344                    format!("{}.transpose_input_{}.{}", &node.name, slot, ix),
345                    op,
346                    &[wires[slot]],
347                )?[0];
348            }
349            exprs[slot] = after;
350        }
351        let pos = exprs[slot].chars().position(|c| k_axes[0].repr == c).unwrap();
352        wires[slot] = patch.wire_node(
353            format!("{}.fold_k_in_input_{}", &node.name, slot),
354            AxisOp::Reshape(pos, k_dims.clone(), tvec!(k.clone())),
355            &[wires[slot]],
356        )?[0];
357        exprs[slot] =
358            exprs[slot].chars().filter(|c| !k_axes.iter().any(|k| k.repr == *c)).collect();
359        exprs[slot].insert(pos, k_axes[0].repr);
360    }
361    let old = op.axes.to_string();
362    let (iexpr, oexpr) = old.split_once("->").unwrap();
363    let mut expr: String = exprs.iter().join(",");
364    if node.inputs.len() > 2 {
365        expr = expr + "," + &iexpr.split(",").skip(2).join(",");
366    }
367    expr = expr + "->" + oexpr;
368    let wire = patch.wire_node(
369        &node.name,
370        EinSum { axes: expr.parse().unwrap(), ..op.clone() },
371        &wires,
372    )?[0];
373    patch.shunt_outside(model, node.id.into(), wire)?;
374    Ok(Some(patch))
375}
376
377pub(super) fn inject_m_or_n_axis(
378    op: &EinSum,
379    model: &TypedModel,
380    node: &TypedNode,
381    is_n: bool,
382) -> TractResult<TypedModelPatch> {
383    let input_to_fix = is_n as usize;
384    let label = if is_n { "n" } else { "m" };
385    let name = &node.name;
386    let mut patch = TypedModelPatch::new("Injecting m or n axis");
387    let mut wire = patch.taps(model, &node.inputs)?;
388    let repr = op.axes.available_label();
389    let new_axes = op
390        .axes
391        .clone()
392        .with_extra_axis(repr, InOut::In(input_to_fix), 0)?
393        .with_extra_axis_occurency(repr, InOut::Out(0), 0)?;
394    wire[input_to_fix] =
395        patch.wire_node(format!("{name}.add_{label}"), AxisOp::Add(0), &[wire[input_to_fix]])?[0];
396    wire = patch.wire_node(name, EinSum { axes: new_axes, ..op.clone() }, &wire)?;
397    wire = patch.wire_node(&node.name, AxisOp::Rm(0), &wire)?;
398    patch.shunt_outside(model, node.id.into(), wire[0])?;
399    Ok(patch)
400}
401
402fn wire_axes_fix(
403    patch: &mut TypedModelPatch,
404    name: &str,
405    var: &str,
406    mapping: &AxesMapping,
407    mut outlet: TVec<OutletId>,
408) -> TractResult<TVec<OutletId>> {
409    for (ix, axis_op) in mapping.translate_to_axis_ops()?.into_iter().enumerate() {
410        outlet = patch.wire_node(format!("{name}.fix_{var}.{ix})"), axis_op, &outlet)?;
411    }
412    Ok(outlet)
413}
414
415fn dequant(
416    model: &TypedModel,
417    node: &TypedNode,
418    op: &EinSumMatMul,
419) -> TractResult<TypedModelPatch> {
420    let name = &node.name;
421    let mut patch = TypedModelPatch::new("Dequantizing einsum");
422
423    let k_axis = op.k_axis();
424
425    let mut taps = patch.taps(model, &node.inputs)?;
426    for ab in [0, 1] {
427        let scale_input = 4 + ab * 2;
428        if !patch.outlet_fact(taps[scale_input])?.shape.volume().is_one() {
429            let q_axis_in_output = op.axes.axis((InOut::In(scale_input), 0))?.outputs[0][0];
430            let output_rank = node.outputs[0].fact.rank();
431            for i in 1..(output_rank - q_axis_in_output) {
432                taps[scale_input] = patch.wire_node(
433                    format!("{name}.scale_input{ab}_axis_fix_{i}"),
434                    AxisOp::Add(i),
435                    &[taps[scale_input]],
436                )?[0];
437            }
438        }
439    }
440
441    let [mut a, mut b, bias, mut a0, a_scale, mut b0, b_scale, c0, c_scale] = *taps else {
442        bail!("Expect exactly 9 inputs")
443    };
444
445    wire_ensure_q8_flavour(&mut patch, &node.name, &mut a, "a", &mut a0, i8::datum_type())?;
446    wire_ensure_q8_flavour(&mut patch, &node.name, &mut b, "b", &mut b0, i8::datum_type())?;
447
448    let mut output = patch.wire_node(
449        &node.name,
450        EinSum {
451            q_params: None,
452            axes: op.axes.extract_sub_mapping(&[0, 1], &[0])?,
453            operating_dt: op.operating_dt,
454        },
455        &[a, b],
456    )?;
457
458    let a_i32 = patch.wire_node(format!("{name}.a_as_i32"), cast(i32::datum_type()), &[a])?[0];
459    let b_i32 = patch.wire_node(format!("{name}.b_as_i32"), cast(i32::datum_type()), &[b])?[0];
460    let sum_a = patch.wire_node(
461        format!("{name}.sum_a"),
462        Reduce::new(tvec!(k_axis.inputs[0][0]), Reducer::Sum),
463        &[a_i32],
464    )?;
465    let sum_b = patch.wire_node(
466        format!("{name}.sum_b"),
467        Reduce::new(tvec!(k_axis.inputs[1][0]), Reducer::Sum),
468        &[b_i32],
469    )?;
470
471    let sum_a =
472        wire_axes_fix(&mut patch, name, "sum_a", &op.axes.extract_sub_mapping(&[0], &[0])?, sum_a)?;
473    let sum_b =
474        wire_axes_fix(&mut patch, name, "sum_b", &op.axes.extract_sub_mapping(&[1], &[0])?, sum_b)?;
475    let bias = tvec!(bias);
476    let bias =
477        wire_axes_fix(&mut patch, name, "bias", &op.axes.extract_sub_mapping(&[2], &[0])?, bias)?;
478
479    let abc_scale = combine_scales(&mut patch, name, a_scale, b_scale, c_scale)?;
480
481    output = patch.wire_node(format!("{name}.add_bias"), add(), &[output[0], bias[0]])?;
482
483    let k = model.outlet_fact(node.inputs[0])?.shape[k_axis.inputs[0][0]].clone();
484    let output = compensate_zero_points(&mut patch, name, output[0], k, a0, b0, sum_a[0], sum_b[0])
485        .context("Zero point compensation")?;
486    let output = requant(&mut patch, name, output, op.q_params.unwrap(), abc_scale, c0)?;
487    patch.shunt_outside(model, node.id.into(), output)?;
488    Ok(patch)
489}
490
491fn flatten_rule(
492    _ctx: &(),
493    model: &TypedModel,
494    node: &TypedNode,
495    _name: &str,
496    op: &EinSumMatMul,
497) -> TractResult<Option<TypedModelPatch>> {
498    TypedModelPatch::replace_single_op(model, node, &node.inputs, op.op.clone()).map(Some)
499}
500
501fn optimized_mat_mul(
502    model: &TypedModel,
503    node: &TypedNode,
504    op: &EinSumMatMul,
505) -> TractResult<Option<TypedModelPatch>> {
506    let (mode_picker, impls) = kernel_selection::strategize(model, node, op)?;
507    let input_facts = model.node_input_facts(node.id)?;
508    let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
509    let prefix = &node.name;
510
511    let mut patch = TypedModelPatch::new("Einsum to OptMatMul");
512    let taps = patch.taps(model, &node.inputs)?;
513    let name = &node.name;
514
515    // Strategy is either one impl, or two impl with the same packing for A
516    let (mmm, pack, pe) = &impls[0];
517    let a_static_pack = if let Some(pe) = pe { &pe.from } else { &mmm.packings()[*pack].0 };
518    let pack_a: Box<dyn TypedOp> = if input_facts[0].konst.is_some() {
519        if let Some(pf) = a_static_pack.downcast_ref::<PackedFormat>() {
520            Box::new(OptMatMulPack {
521                packers: vec![pf.clone()],
522                mode_picker: ModePicker::Single,
523                k_axis: op.a_k(),
524                mn_axis: op.a_m(),
525            })
526        } else if let Some(packed_format) =
527            a_static_pack.downcast_ref::<PackedBlockQuantFormat>().cloned()
528        {
529            Box::new(OptSimpleMatMulPack {
530                packed_format,
531                k: input_shapes[0][op.a_k()].to_usize().unwrap(),
532                m: input_shapes[0][op.a_m()].to_usize().unwrap(),
533            })
534        } else {
535            bail!("Unexpected static input format {a_static_pack:?}");
536        }
537    } else {
538        Box::new(OptMatMulPack {
539            packers: impls
540                .iter()
541                .map(|(mmm, p, pe)| {
542                    pe.as_ref()
543                        .map(|pe| &pe.from)
544                        .unwrap_or(&mmm.packings()[*p].0)
545                        .downcast_ref::<PackedFormat>()
546                        .unwrap()
547                        .clone()
548                })
549                .collect(),
550            mode_picker: mode_picker.clone(),
551            k_axis: op.a_k(),
552            mn_axis: op.a_m(),
553        })
554    };
555    let pa = patch.wire_node(format!("{prefix}.pack_a"), pack_a, &[taps[0]])?[0];
556
557    let pb = patch.wire_node(
558        format!("{prefix}.pack_b"),
559        OptMatMulPack {
560            k_axis: op.b_k(),
561            mn_axis: op.b_n(),
562            packers: impls
563                .iter()
564                .map(|(mmm, p, _)| {
565                    mmm.packings()[*p].1.downcast_ref::<PackedFormat>().unwrap().clone()
566                })
567                .collect(),
568            mode_picker: mode_picker.clone(),
569        },
570        &[taps[1]],
571    )?[0];
572
573    let mut c_to_a_axis_mapping = tvec!();
574    let mut c_to_b_axis_mapping = tvec!();
575    for axis in op
576        .op
577        .axes
578        .iter_all_axes()
579        .filter(|&axis| ![op.m_axis, op.k_axis, op.n_axis].contains(&axis.repr))
580    {
581        if let (&[c], &[a]) = (&*axis.outputs[0], &*axis.inputs[0]) {
582            if input_shapes[0][a] != 1.to_dim() {
583                let a = a - (a > op.a_m()) as usize - (a > op.a_k()) as usize;
584                c_to_a_axis_mapping.push((c, a));
585            }
586        }
587        if let (&[c], &[b]) = (&*axis.outputs[0], &*axis.inputs[1]) {
588            if input_shapes[1][b] != 1.to_dim() {
589                let b = b - (b > op.b_n()) as usize - (b > op.b_k()) as usize;
590                c_to_b_axis_mapping.push((c, b));
591            }
592        }
593    }
594
595    let c_fact = op.output_facts(&input_facts)?.remove(0);
596    let geo = AddMatMulGeometry {
597        k: op.k.clone(),
598        c_to_a_axis_mapping: MapOutputAxisToInput(c_to_a_axis_mapping),
599        c_to_b_axis_mapping: MapOutputAxisToInput(c_to_b_axis_mapping),
600    };
601    let (mmms, packings, extractor): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(impls);
602    let outputs = mmms.iter().map(|mmm| unsafe { mmm.c_view(op.c_m(), op.c_n()) }).collect();
603    let trivial_packing = mmms.len() == 1
604        && packings[0] == 0
605        && extractor[0].is_none()
606        && input_facts[0].opaque_fact.is_none();
607    let opt = OptMatMul::new(
608        mmms,
609        mode_picker,
610        c_fact,
611        op.c_m(),
612        op.c_n(),
613        vec![
614            ProtoFusedSpec::AddMatMul {
615                geo,
616                a: 0,
617                b: 1,
618                packings: izip!(packings, extractor).collect_vec(),
619            },
620            ProtoFusedSpec::Store(outputs),
621        ],
622        trivial_packing,
623    )
624    .context("Creating OptMatMul")?;
625    let output = patch.wire_node(name, opt, &[pa, pb])?[0];
626    patch.shunt_outside(model, node.id.into(), output)?;
627    Ok(Some(patch))
628}