tract_core/ops/einsum/
optimize.rs

1use std::fmt::Formatter;
2use std::ops::Deref;
3
4use dyn_clone::clone_box;
5use kernel_selection::wire_packing;
6use tract_itertools::{izip, multiunzip};
7use tract_linalg::block_quant::BlockQuantValue;
8use tract_linalg::mmm::MMMInputFormat;
9use tract_linalg::WeightType;
10
11use super::*;
12use crate::ops::cast::cast;
13use crate::ops::math::add;
14use crate::ops::matmul::optimized::{
15    AddMatMulGeometry, MapOutputAxisToInput, OptMatMul, ProtoFusedSpec,
16};
17use crate::ops::matmul::quant::{
18    combine_scales, compensate_zero_points, requant, wire_ensure_q8_flavour,
19};
20use crate::ops::nn::{Reduce, Reducer};
21
22#[derive(Debug)]
23#[allow(clippy::large_enum_variant)]
24pub enum AxesOrPatch<'a> {
25    Annotated(EinSumAnnotatedAsMatMul<'a>),
26    Patch(TypedModelPatch),
27    NotAMatMul(&'static str, Vec<&'a Axis>),
28}
29
30pub struct EinSumAnnotatedAsMatMul<'a> {
31    pub op: &'a EinSum,
32    pub m_axis: &'a Axis,
33    pub k_axis: &'a Axis,
34    pub n_axis: &'a Axis,
35    pub m: TDim,
36    pub k: TDim,
37    pub n: TDim,
38}
39
40impl EinSumAnnotatedAsMatMul<'_> {
41    pub fn a_m(&self) -> usize {
42        self.m_axis.inputs[0][0]
43    }
44    pub fn a_k(&self) -> usize {
45        self.k_axis.inputs[0][0]
46    }
47    pub fn b_k(&self) -> usize {
48        self.k_axis.inputs[1][0]
49    }
50    pub fn b_n(&self) -> usize {
51        self.n_axis.inputs[1][0]
52    }
53    pub fn c_m(&self) -> usize {
54        *self.m_axis.outputs[0].first().unwrap_or(&self.a_m())
55    }
56    pub fn c_n(&self) -> usize {
57        *self.n_axis.outputs[0].first().unwrap_or(&self.b_n())
58    }
59}
60
61impl Debug for EinSumAnnotatedAsMatMul<'_> {
62    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
63        write!(
64            f,
65            "EinsumAsMatMul: {} {:?} m: {}={}; k: {}={}; n: {}={}",
66            self.op.axes,
67            self.op.operating_dt,
68            self.m_axis.repr,
69            self.m,
70            self.k_axis.repr,
71            self.k,
72            self.n_axis.repr,
73            self.n
74        )
75    }
76}
77
78impl Deref for EinSumAnnotatedAsMatMul<'_> {
79    type Target = EinSum;
80    fn deref(&self) -> &Self::Target {
81        self.op
82    }
83}
84
85pub struct EinSumAnnotatedAsLinear<'a> {
86    pub op: &'a EinSum,
87    pub m_axis: &'a Axis,
88    pub k_axis: &'a Axis,
89    pub n_axes: Vec<&'a Axis>,
90    pub m: usize,
91    pub k: usize,
92    pub ns: Vec<&'a TDim>,
93    pub act_dt: DatumType,
94    pub weight_type: WeightType,
95}
96
97impl Debug for EinSumAnnotatedAsLinear<'_> {
98    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
99        write!(
100            f,
101            "EinsumAsLinear: {} w:{:?} acc:{:?} m: {}={}; k: {}={}; n: {}={}",
102            self.op.axes,
103            self.weight_type,
104            self.op.operating_dt,
105            self.m_axis.repr,
106            self.m,
107            self.k_axis.repr,
108            self.k,
109            self.n_axes.iter().map(|ax| ax.repr).join(","),
110            self.ns.iter().map(|d| d.to_string()).join("•"),
111        )
112    }
113}
114
115impl<'a> EinSumAnnotatedAsLinear<'a> {
116    pub fn from(
117        model: &'a TypedModel,
118        node: &'a TypedNode,
119        op: &'a EinSum,
120    ) -> TractResult<Option<Self>> {
121        if node.inputs.len() != 2 {
122            return Ok(None);
123        }
124        let input_facts = model.node_input_facts(node.id)?;
125        if input_facts[0].konst.is_none() {
126            return Ok(None);
127        }
128        let mut n_axes = vec![];
129        let mut ns = Vec::<&'a TDim>::new();
130
131        let Some(m_axis) = op.axes.iter_all_axes().find(|axis| {
132            axis.inputs[0].len() == 1 && axis.inputs[1].len() == 0 && axis.outputs[0].len() == 1
133        }) else {
134            return Ok(None);
135        };
136        let Some(k_axis) = op.axes.iter_all_axes().find(|axis| {
137            axis.inputs[0].len() == 1 && axis.inputs[1].len() == 1 && axis.outputs[0].len() == 0
138        }) else {
139            return Ok(None);
140        };
141        for axis in op.axes.iter_all_axes() {
142            if axis != k_axis
143                && axis != m_axis
144                && axis.inputs[0].len() == 0
145                && axis.inputs[1].len() == 1
146                && axis.outputs[0].len() == 1
147            {
148                n_axes.push(axis);
149                ns.push(&node.outputs[0].fact.shape[axis.outputs[0][0]]);
150            }
151        }
152        let act_dt = input_facts[1].datum_type;
153        let bqv = input_facts[0]
154            .konst
155            .as_ref()
156            .unwrap()
157            .to_scalar::<Opaque>()
158            .ok()
159            .and_then(|a| a.downcast_ref::<BlockQuantValue>());
160        let weight_type = if let Some(a_payload) = bqv {
161            WeightType::BlockQuant(a_payload.fact.format.clone())
162        } else {
163            input_facts[0].datum_type.into()
164        };
165        let weight_shape = block_quant_aware_input_shape(input_facts[0])?;
166        let m = weight_shape[m_axis.inputs[0][0]].to_usize()?;
167        let k = weight_shape[k_axis.inputs[0][0]].to_usize()?;
168        Ok(Some(EinSumAnnotatedAsLinear {
169            op,
170            m_axis,
171            k_axis,
172            n_axes,
173            m,
174            k,
175            ns,
176            act_dt,
177            weight_type,
178        }))
179    }
180
181    pub fn weight_m_axis(&self) -> usize {
182        self.m_axis.inputs[0][0]
183    }
184
185    pub fn weight_k_axis(&self) -> usize {
186        self.k_axis.inputs[0][0]
187    }
188
189    pub fn input_k_axis(&self) -> usize {
190        self.k_axis.inputs[1][0]
191    }
192
193    pub fn output_m_axis(&self) -> usize {
194        self.m_axis.outputs[0][0]
195    }
196
197    pub fn need_mmv(&self) -> bool {
198        self.ns.is_empty() || self.ns.iter().any(|n| n.as_i64().map(|n| n == 1).unwrap_or(true))
199    }
200
201    pub fn need_mmm(&self) -> bool {
202        self.ns.iter().any(|n| n.as_i64().map(|n| n > 1).unwrap_or(true))
203    }
204
205    pub fn cost_for_weights(&self, format: &dyn MMMInputFormat) -> Option<usize> {
206        let acc = self.op.acceptable_accumulators();
207        let able = tract_linalg::ops()
208            .filter_impls(format, &acc, self.act_dt, self.op.operating_dt)
209            .collect_vec();
210        if able.len() == 0 {
211            return None;
212        }
213        let mut cost = 0;
214        if self.need_mmv() {
215            cost += able
216                .iter()
217                .map(|(mmm, _, _, pe, _)| {
218                    1_000_000 + mmm.quality().cost() * 1000 + mmm.nr() * 10 - mmm.mr() * 10
219                        + pe.is_some() as usize
220                })
221                .min()
222                .unwrap();
223        };
224        if self.need_mmm() {
225            cost += able
226                .iter()
227                .map(|(mmm, _, _, pe, _)| {
228                    1_000_000 + mmm.quality().cost() * 1000 - mmm.nr() * 10 - mmm.mr() * 10
229                        + pe.is_some() as usize
230                })
231                .min()
232                .unwrap();
233        };
234        Some(cost)
235    }
236
237    pub fn preferred_packing(&self) -> Box<dyn MMMInputFormat> {
238        if self.act_dt == self.acceptable_accumulators()[0]
239            && self.weight_type == self.act_dt.into()
240        {
241            if let Ok(n) = self.ns.iter().cloned().product::<TDim>().to_usize() {
242                let mmm = tract_linalg::ops()
243                    .mmm(self.acceptable_accumulators()[0], Some(self.m), Some(self.k), Some(n))
244                    .unwrap();
245                return mmm.packings()[0].0.clone();
246            }
247        }
248        if self.act_dt.is_integer() && self.weight_type == self.act_dt.into() {
249            if let Ok(n) = self.ns.iter().cloned().product::<TDim>().to_usize() {
250                let mmm = tract_linalg::ops()
251                    .mmm(i32::datum_type(), Some(self.m), Some(self.k), Some(n))
252                    .unwrap();
253                if let Some(packing) =
254                    mmm.packings().iter().find(|(a, _)| a.precursor() == self.weight_type)
255                {
256                    return packing.0.clone();
257                }
258            }
259        }
260        clone_box(
261            tract_linalg::ops()
262                .all_possible_packing(self.weight_type.clone())
263                .filter_map(|p| self.cost_for_weights(p).map(|cost| (p, cost)))
264                .min_by_key(|(_p, cost)| *cost)
265                .unwrap()
266                .0,
267        )
268    }
269}
270
271impl Deref for EinSumAnnotatedAsLinear<'_> {
272    type Target = EinSum;
273    fn deref(&self) -> &Self::Target {
274        self.op
275    }
276}
277
278pub(crate) fn optimize(
279    op: &EinSum,
280    model: &TypedModel,
281    node: &TypedNode,
282) -> TractResult<Option<TypedModelPatch>> {
283    if (op.q_params.is_none() && node.inputs.len() != 2)
284        || (op.q_params.is_some() && node.inputs.len() != 9)
285    {
286        return Ok(None);
287    }
288
289    let input_facts = model.node_input_facts(node.id)?;
290    if node.inputs.len() == 2 && input_facts[1].konst.is_some() {
291        return Ok(Some(transpose(op, model, node)?));
292    }
293
294    let annotated = match ensure_mkn_axes(op, model, node)? {
295        AxesOrPatch::Annotated(op) => op,
296        AxesOrPatch::Patch(p) => return Ok(Some(p)),
297        AxesOrPatch::NotAMatMul(_, _) => return Ok(None),
298    };
299    if op.q_params.is_none() {
300        optimized_mat_mul(model, node, &annotated).context("Translating to OptMatMul")
301    } else {
302        dequant(model, node, annotated).context("Dequantize")
303    }
304}
305
306fn transpose(op: &EinSum, model: &TypedModel, node: &TypedNode) -> TractResult<TypedModelPatch> {
307    let mut patch = TypedModelPatch::default();
308    let mut taps = patch.taps(model, &node.inputs)?;
309    taps.swap(0, 1);
310    let mut op = op.clone();
311    op.axes.iter_all_axes_mut().for_each(|axis| axis.inputs.swap(0, 1));
312    let wire = patch.wire_node(&node.name, op, &taps)?[0];
313    patch.shunt_outside(model, node.id.into(), wire)?;
314    Ok(patch)
315}
316
317pub(crate) fn ensure_mkn_axes<'a>(
318    op: &'a EinSum,
319    model: &TypedModel,
320    node: &TypedNode,
321) -> TractResult<AxesOrPatch<'a>> {
322    let input_facts = model.node_input_facts(node.id)?;
323    let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
324    let output_shape = super::eval::output_shape(&op.axes, &input_shapes)?;
325    let k_axes: TVec<&Axis> = op
326        .axes
327        .iter_all_axes()
328        // Filter possible candidates (should be one time in each inputs but not in output)
329        .filter(|a| a.inputs[0].len() == 1 && a.inputs[1].len() == 1 && a.outputs[0].is_empty())
330        .collect();
331
332    let non_trivial_k_axis = k_axes
333        .iter()
334        .filter(|a| {
335            !input_shapes[0][a.inputs[0][0]].is_one() || !input_shapes[1][a.inputs[1][0]].is_one()
336        })
337        .collect::<TVec<_>>();
338
339    let k_axis = if non_trivial_k_axis.len() > 1 {
340        // TODO: handle case where multiple consecutive k in the same order in both input.
341        return Ok(AxesOrPatch::NotAMatMul(
342            "multiple k-axis candidate found",
343            non_trivial_k_axis.into_iter().cloned().collect_vec(),
344        ));
345    } else {
346        non_trivial_k_axis.first().copied().or_else(|| k_axes.first()).copied()
347    };
348    let Some(k_axis) = k_axis else {
349        return Ok(AxesOrPatch::Patch(inject_k_axis(op, model, node)?));
350    };
351
352    let mut possible_m_axes: Vec<_> = op
353        .axes
354        .iter_all_axes()
355        .filter(|a| {
356            a.inputs[0].len() == 1
357                && (a.inputs[1].is_empty() || input_shapes[1][a.inputs[1][0]].is_one())
358                && (a.outputs[0].len() == 1
359                    || (input_shapes[0][a.inputs[0][0]].is_one() && a.inputs[1].is_empty()))
360        })
361        .collect();
362
363    // Prioritize obvious m-axes
364    if possible_m_axes.iter().any(|a| !a.outputs[0].is_empty()) {
365        possible_m_axes.retain(|a| !a.outputs[0].is_empty());
366    }
367
368    let m_axis = possible_m_axes
369        .into_iter()
370        .max_by_key(|a| input_shapes[0][a.inputs[0][0]].as_i64().unwrap_or(i64::MAX));
371
372    let Some(m_axis) = m_axis else {
373        return Ok(AxesOrPatch::Patch(inject_m_or_n_axis(op, model, node, false)?));
374    };
375
376    let n_axis = op
377        .axes
378        .iter_all_axes()
379        .filter(|a| {
380            (a.inputs[0].is_empty() || input_shapes[0][a.inputs[0][0]].is_one())
381                && a.inputs[1].len() == 1
382                && a.outputs[0].len() == 1
383                && *a != m_axis
384        })
385        .max_by_key(|a| input_shapes[1][a.inputs[1][0]].as_i64().unwrap_or(i64::MAX));
386    let Some(n_axis) = n_axis else {
387        return Ok(AxesOrPatch::Patch(inject_m_or_n_axis(op, model, node, true)?));
388    };
389    for axis in op.axes.iter_all_axes() {
390        let one = TDim::one();
391        let in_left =
392            axis.inputs[0].first().map(|pos| &input_shapes[0][*pos]).unwrap_or(&one) != &one;
393        let in_right =
394            axis.inputs[1].first().map(|pos| &input_shapes[1][*pos]).unwrap_or(&one) != &one;
395        let in_out = axis.outputs[0].first().map(|pos| &output_shape[*pos]).unwrap_or(&one) != &one;
396        if (in_left ^ in_right) && !in_out {
397            return Ok(AxesOrPatch::NotAMatMul(
398                "non trivial single-side disappearing axis",
399                vec![axis],
400            ));
401        }
402    }
403    let m = input_shapes[0][m_axis.inputs[0][0]].clone();
404    let k = input_shapes[0][k_axis.inputs[0][0]].clone();
405    let n = input_shapes[1][n_axis.inputs[1][0]].clone();
406    Ok(AxesOrPatch::Annotated(EinSumAnnotatedAsMatMul { op, m_axis, k_axis, n_axis, m, k, n }))
407}
408
409pub(super) fn inject_k_axis(
410    op: &EinSum,
411    model: &TypedModel,
412    node: &TypedNode,
413) -> TractResult<TypedModelPatch> {
414    let mut new_axes = op.axes.clone();
415    let name = &node.name;
416    let mut patch = TypedModelPatch::new("inject k axis");
417    let mut wire = patch.taps(model, &node.inputs)?;
418    let repr = new_axes.available_label();
419    new_axes = new_axes.with_extra_axis(repr, InOut::In(0), 0)?.with_extra_axis_occurency(
420        repr,
421        InOut::In(1),
422        0,
423    )?;
424    wire[0] = patch.wire_node(format!("{name}.add_k.0"), AxisOp::Add(0), &[wire[0]])?[0];
425    wire[1] = patch.wire_node(format!("{name}.add_k.1"), AxisOp::Add(0), &[wire[1]])?[0];
426    wire = patch.wire_node(&node.name, EinSum { axes: new_axes, ..op.clone() }, &wire)?;
427    patch.shunt_outside(model, node.id.into(), wire[0])?;
428    Ok(patch)
429}
430
431pub(super) fn inject_m_or_n_axis(
432    op: &EinSum,
433    model: &TypedModel,
434    node: &TypedNode,
435    is_n: bool,
436) -> TractResult<TypedModelPatch> {
437    let input_to_fix = is_n as usize;
438    let label = if is_n { "n" } else { "m" };
439    let name = &node.name;
440    let mut patch = TypedModelPatch::new("Injecting m or n axis");
441    let mut wire = patch.taps(model, &node.inputs)?;
442    let repr = op.axes.available_label();
443    let new_axes = op
444        .axes
445        .clone()
446        .with_extra_axis(repr, InOut::In(input_to_fix), 0)?
447        .with_extra_axis_occurency(repr, InOut::Out(0), 0)?;
448    wire[input_to_fix] =
449        patch.wire_node(format!("{name}.add_{label}"), AxisOp::Add(0), &[wire[input_to_fix]])?[0];
450    wire = patch.wire_node(name, EinSum { axes: new_axes, ..op.clone() }, &wire)?;
451    wire = patch.wire_node(&node.name, AxisOp::Rm(0), &wire)?;
452    patch.shunt_outside(model, node.id.into(), wire[0])?;
453    Ok(patch)
454}
455
456fn wire_axes_fix(
457    patch: &mut TypedModelPatch,
458    name: &str,
459    var: &str,
460    mapping: &AxesMapping,
461    mut outlet: TVec<OutletId>,
462) -> TractResult<TVec<OutletId>> {
463    for (ix, axis_op) in mapping.translate_to_axis_ops()?.into_iter().enumerate() {
464        outlet = patch.wire_node(format!("{name}.fix_{var}.{ix})"), axis_op, &outlet)?;
465    }
466    Ok(outlet)
467}
468
469fn dequant(
470    model: &TypedModel,
471    node: &TypedNode,
472    op: EinSumAnnotatedAsMatMul,
473) -> TractResult<Option<TypedModelPatch>> {
474    let name = &node.name;
475    let mut patch = TypedModelPatch::new("Dequantizing einsum");
476
477    let mut taps = patch.taps(model, &node.inputs)?;
478    for ab in [0, 1] {
479        let scale_input = 4 + ab * 2;
480        if !patch.outlet_fact(taps[scale_input])?.shape.volume().is_one() {
481            let q_axis_in_output = op.axes.axis((InOut::In(scale_input), 0))?.outputs[0][0];
482            let output_rank = node.outputs[0].fact.rank();
483            for i in 1..(output_rank - q_axis_in_output) {
484                taps[scale_input] = patch.wire_node(
485                    format!("{name}.scale_input{ab}_axis_fix_{i}"),
486                    AxisOp::Add(i),
487                    &[taps[scale_input]],
488                )?[0];
489            }
490        }
491    }
492
493    let [mut a, mut b, bias, mut a0, a_scale, mut b0, b_scale, c0, c_scale] = *taps else {
494        bail!("Expect exactly 9 inputs")
495    };
496
497    wire_ensure_q8_flavour(&mut patch, &node.name, &mut a, "a", &mut a0, i8::datum_type())?;
498    wire_ensure_q8_flavour(&mut patch, &node.name, &mut b, "b", &mut b0, i8::datum_type())?;
499
500    let mut output = patch.wire_node(
501        &node.name,
502        EinSum {
503            q_params: None,
504            axes: op.axes.extract_sub_mapping(&[0, 1], &[0])?,
505            operating_dt: op.operating_dt,
506        },
507        &[a, b],
508    )?;
509
510    let a_i32 = patch.wire_node(format!("{name}.a_as_i32"), cast(i32::datum_type()), &[a])?[0];
511    let b_i32 = patch.wire_node(format!("{name}.b_as_i32"), cast(i32::datum_type()), &[b])?[0];
512    let sum_a = patch.wire_node(
513        format!("{name}.sum_a"),
514        Reduce::new(tvec!(op.k_axis.inputs[0][0]), Reducer::Sum),
515        &[a_i32],
516    )?;
517    let sum_b = patch.wire_node(
518        format!("{name}.sum_b"),
519        Reduce::new(tvec!(op.k_axis.inputs[1][0]), Reducer::Sum),
520        &[b_i32],
521    )?;
522
523    let sum_a =
524        wire_axes_fix(&mut patch, name, "sum_a", &op.axes.extract_sub_mapping(&[0], &[0])?, sum_a)?;
525    let sum_b =
526        wire_axes_fix(&mut patch, name, "sum_b", &op.axes.extract_sub_mapping(&[1], &[0])?, sum_b)?;
527    let bias = tvec!(bias);
528    let bias =
529        wire_axes_fix(&mut patch, name, "bias", &op.axes.extract_sub_mapping(&[2], &[0])?, bias)?;
530
531    let abc_scale = combine_scales(&mut patch, name, a_scale, b_scale, c_scale)?;
532
533    output = patch.wire_node(format!("{name}.add_bias"), add(), &[output[0], bias[0]])?;
534
535    let k = model.outlet_fact(node.inputs[0])?.shape[op.k_axis.inputs[0][0]].clone();
536    let output = compensate_zero_points(&mut patch, name, output[0], k, a0, b0, sum_a[0], sum_b[0])
537        .context("Zero point compensation")?;
538    let output = requant(&mut patch, name, output, op.q_params.unwrap(), abc_scale, c0)?;
539    patch.shunt_outside(model, node.id.into(), output)?;
540    Ok(Some(patch))
541}
542
543fn optimized_mat_mul(
544    model: &TypedModel,
545    node: &TypedNode,
546    op: &EinSumAnnotatedAsMatMul,
547) -> TractResult<Option<TypedModelPatch>> {
548    let input_facts = model.node_input_facts(node.id)?;
549    let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
550    let must_transpose = input_facts[0].konst.is_none()
551        && match (op.m.as_i64(), op.n.as_i64()) {
552            (Some(m), Some(n)) => m < n,
553            (None, Some(n)) => n >= 8,
554            _ => false,
555        };
556    if must_transpose {
557        return Ok(Some(transpose(op, model, node)?));
558    }
559
560    if input_facts[0].konst.is_some()
561        && (input_facts[0].datum_type.is_number()
562            || input_facts[0].opaque_fact().is_some_and(|of| of.is::<BlockQuantFact>()))
563    {
564        return Ok(None);
565    }
566
567    let mut patch = TypedModelPatch::new("Einsum to OptMatMul");
568    let name = &node.name;
569    let taps = patch.taps(model, &node.inputs)?;
570    let (a, b, mmms, mode_picker) =
571        wire_packing(&mut patch, name, &taps[0..2], op).context("Wiring packing")?;
572
573    let mut c_to_a_axis_mapping = tvec!();
574    let mut c_to_b_axis_mapping = tvec!();
575    for axis in op
576        .op
577        .axes
578        .iter_all_axes()
579        .filter(|&axis| ![op.m_axis, op.k_axis, op.n_axis].contains(&axis))
580    {
581        if let (&[c], &[a]) = (&*axis.outputs[0], &*axis.inputs[0]) {
582            if input_shapes[0][a] != 1.to_dim() {
583                let a = a - (a > op.a_m()) as usize - (a > op.a_k()) as usize;
584                c_to_a_axis_mapping.push((c, a));
585            }
586        }
587        if let (&[c], &[b]) = (&*axis.outputs[0], &*axis.inputs[1]) {
588            if input_shapes[1][b] != 1.to_dim() {
589                let b = b - (b > op.b_n()) as usize - (b > op.b_k()) as usize;
590                c_to_b_axis_mapping.push((c, b));
591            }
592        }
593    }
594
595    let c_fact = op.output_facts(&input_facts)?.remove(0);
596    let geo = AddMatMulGeometry {
597        k: op.k.clone(),
598        c_to_a_axis_mapping: MapOutputAxisToInput(c_to_a_axis_mapping),
599        c_to_b_axis_mapping: MapOutputAxisToInput(c_to_b_axis_mapping),
600    };
601    let (mmms, packings, extractor): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(mmms);
602    let outputs = mmms.iter().map(|mmm| unsafe { mmm.c_view(op.c_m(), op.c_n()) }).collect();
603    let trivial_packing =
604        mmms.len() == 1 && packings[0] == 0 && patch.outlet_fact(a)?.opaque_fact.is_none();
605    let opt = OptMatMul::new(
606        mmms,
607        mode_picker,
608        c_fact,
609        op.c_m(),
610        op.c_n(),
611        vec![
612            ProtoFusedSpec::AddMatMul {
613                geo,
614                a: 0,
615                b: 1,
616                packings: izip!(packings, extractor).collect_vec(),
617            },
618            ProtoFusedSpec::Store(outputs),
619        ],
620        trivial_packing,
621    )
622    .context("Creating OptMatMul")?;
623    let output = patch.wire_node(name, opt, &[a, b])?[0];
624    patch.shunt_outside(model, node.id.into(), output)?;
625    Ok(Some(patch))
626}