1use std::fmt::Formatter;
2use std::ops::Deref;
3
4use dyn_clone::clone_box;
5use kernel_selection::wire_packing;
6use tract_itertools::{izip, multiunzip};
7use tract_linalg::block_quant::BlockQuantValue;
8use tract_linalg::mmm::MMMInputFormat;
9use tract_linalg::WeightType;
10
11use super::*;
12use crate::ops::cast::cast;
13use crate::ops::math::add;
14use crate::ops::matmul::optimized::{
15 AddMatMulGeometry, MapOutputAxisToInput, OptMatMul, ProtoFusedSpec,
16};
17use crate::ops::matmul::quant::{
18 combine_scales, compensate_zero_points, requant, wire_ensure_q8_flavour,
19};
20use crate::ops::nn::{Reduce, Reducer};
21
22#[derive(Debug)]
23#[allow(clippy::large_enum_variant)]
24pub enum AxesOrPatch<'a> {
25 Annotated(EinSumAnnotatedAsMatMul<'a>),
26 Patch(TypedModelPatch),
27 NotAMatMul(&'static str, Vec<&'a Axis>),
28}
29
30pub struct EinSumAnnotatedAsMatMul<'a> {
31 pub op: &'a EinSum,
32 pub m_axis: &'a Axis,
33 pub k_axis: &'a Axis,
34 pub n_axis: &'a Axis,
35 pub m: TDim,
36 pub k: TDim,
37 pub n: TDim,
38}
39
40impl EinSumAnnotatedAsMatMul<'_> {
41 pub fn a_m(&self) -> usize {
42 self.m_axis.inputs[0][0]
43 }
44 pub fn a_k(&self) -> usize {
45 self.k_axis.inputs[0][0]
46 }
47 pub fn b_k(&self) -> usize {
48 self.k_axis.inputs[1][0]
49 }
50 pub fn b_n(&self) -> usize {
51 self.n_axis.inputs[1][0]
52 }
53 pub fn c_m(&self) -> usize {
54 *self.m_axis.outputs[0].first().unwrap_or(&self.a_m())
55 }
56 pub fn c_n(&self) -> usize {
57 *self.n_axis.outputs[0].first().unwrap_or(&self.b_n())
58 }
59}
60
61impl Debug for EinSumAnnotatedAsMatMul<'_> {
62 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
63 write!(
64 f,
65 "EinsumAsMatMul: {} {:?} m: {}={}; k: {}={}; n: {}={}",
66 self.op.axes,
67 self.op.operating_dt,
68 self.m_axis.repr,
69 self.m,
70 self.k_axis.repr,
71 self.k,
72 self.n_axis.repr,
73 self.n
74 )
75 }
76}
77
78impl Deref for EinSumAnnotatedAsMatMul<'_> {
79 type Target = EinSum;
80 fn deref(&self) -> &Self::Target {
81 self.op
82 }
83}
84
85pub struct EinSumAnnotatedAsLinear<'a> {
86 pub op: &'a EinSum,
87 pub m_axis: &'a Axis,
88 pub k_axis: &'a Axis,
89 pub n_axes: Vec<&'a Axis>,
90 pub m: usize,
91 pub k: usize,
92 pub ns: Vec<&'a TDim>,
93 pub act_dt: DatumType,
94 pub weight_type: WeightType,
95}
96
97impl Debug for EinSumAnnotatedAsLinear<'_> {
98 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
99 write!(
100 f,
101 "EinsumAsLinear: {} w:{:?} acc:{:?} m: {}={}; k: {}={}; n: {}={}",
102 self.op.axes,
103 self.weight_type,
104 self.op.operating_dt,
105 self.m_axis.repr,
106 self.m,
107 self.k_axis.repr,
108 self.k,
109 self.n_axes.iter().map(|ax| ax.repr).join(","),
110 self.ns.iter().map(|d| d.to_string()).join("•"),
111 )
112 }
113}
114
115impl<'a> EinSumAnnotatedAsLinear<'a> {
116 pub fn from(
117 model: &'a TypedModel,
118 node: &'a TypedNode,
119 op: &'a EinSum,
120 ) -> TractResult<Option<Self>> {
121 if node.inputs.len() != 2 {
122 return Ok(None);
123 }
124 let input_facts = model.node_input_facts(node.id)?;
125 if input_facts[0].konst.is_none() {
126 return Ok(None);
127 }
128 let mut n_axes = vec![];
129 let mut ns = Vec::<&'a TDim>::new();
130
131 let Some(m_axis) = op.axes.iter_all_axes().find(|axis| {
132 axis.inputs[0].len() == 1 && axis.inputs[1].len() == 0 && axis.outputs[0].len() == 1
133 }) else {
134 return Ok(None);
135 };
136 let Some(k_axis) = op.axes.iter_all_axes().find(|axis| {
137 axis.inputs[0].len() == 1 && axis.inputs[1].len() == 1 && axis.outputs[0].len() == 0
138 }) else {
139 return Ok(None);
140 };
141 for axis in op.axes.iter_all_axes() {
142 if axis != k_axis
143 && axis != m_axis
144 && axis.inputs[0].len() == 0
145 && axis.inputs[1].len() == 1
146 && axis.outputs[0].len() == 1
147 {
148 n_axes.push(axis);
149 ns.push(&node.outputs[0].fact.shape[axis.outputs[0][0]]);
150 }
151 }
152 let act_dt = input_facts[1].datum_type;
153 let bqv = input_facts[0]
154 .konst
155 .as_ref()
156 .unwrap()
157 .to_scalar::<Opaque>()
158 .ok()
159 .and_then(|a| a.downcast_ref::<BlockQuantValue>());
160 let weight_type = if let Some(a_payload) = bqv {
161 WeightType::BlockQuant(a_payload.fact.format.clone())
162 } else {
163 input_facts[0].datum_type.into()
164 };
165 let weight_shape = block_quant_aware_input_shape(input_facts[0])?;
166 let m = weight_shape[m_axis.inputs[0][0]].to_usize()?;
167 let k = weight_shape[k_axis.inputs[0][0]].to_usize()?;
168 Ok(Some(EinSumAnnotatedAsLinear {
169 op,
170 m_axis,
171 k_axis,
172 n_axes,
173 m,
174 k,
175 ns,
176 act_dt,
177 weight_type,
178 }))
179 }
180
181 pub fn weight_m_axis(&self) -> usize {
182 self.m_axis.inputs[0][0]
183 }
184
185 pub fn weight_k_axis(&self) -> usize {
186 self.k_axis.inputs[0][0]
187 }
188
189 pub fn input_k_axis(&self) -> usize {
190 self.k_axis.inputs[1][0]
191 }
192
193 pub fn output_m_axis(&self) -> usize {
194 self.m_axis.outputs[0][0]
195 }
196
197 pub fn need_mmv(&self) -> bool {
198 self.ns.is_empty() || self.ns.iter().any(|n| n.as_i64().map(|n| n == 1).unwrap_or(true))
199 }
200
201 pub fn need_mmm(&self) -> bool {
202 self.ns.iter().any(|n| n.as_i64().map(|n| n > 1).unwrap_or(true))
203 }
204
205 pub fn cost_for_weights(&self, format: &dyn MMMInputFormat) -> Option<usize> {
206 let acc = self.op.acceptable_accumulators();
207 let able = tract_linalg::ops()
208 .filter_impls(format, &acc, self.act_dt, self.op.operating_dt)
209 .collect_vec();
210 if able.len() == 0 {
211 return None;
212 }
213 let mut cost = 0;
214 if self.need_mmv() {
215 cost += able
216 .iter()
217 .map(|(mmm, _, _, pe, _)| {
218 1_000_000 + mmm.quality().cost() * 1000 + mmm.nr() * 10 - mmm.mr() * 10
219 + pe.is_some() as usize
220 })
221 .min()
222 .unwrap();
223 };
224 if self.need_mmm() {
225 cost += able
226 .iter()
227 .map(|(mmm, _, _, pe, _)| {
228 1_000_000 + mmm.quality().cost() * 1000 - mmm.nr() * 10 - mmm.mr() * 10
229 + pe.is_some() as usize
230 })
231 .min()
232 .unwrap();
233 };
234 Some(cost)
235 }
236
237 pub fn preferred_packing(&self) -> Box<dyn MMMInputFormat> {
238 if self.act_dt == self.acceptable_accumulators()[0]
239 && self.weight_type == self.act_dt.into()
240 {
241 if let Ok(n) = self.ns.iter().cloned().product::<TDim>().to_usize() {
242 let mmm = tract_linalg::ops()
243 .mmm(self.acceptable_accumulators()[0], Some(self.m), Some(self.k), Some(n))
244 .unwrap();
245 return mmm.packings()[0].0.clone();
246 }
247 }
248 if self.act_dt.is_integer() && self.weight_type == self.act_dt.into() {
249 if let Ok(n) = self.ns.iter().cloned().product::<TDim>().to_usize() {
250 let mmm = tract_linalg::ops()
251 .mmm(i32::datum_type(), Some(self.m), Some(self.k), Some(n))
252 .unwrap();
253 if let Some(packing) =
254 mmm.packings().iter().find(|(a, _)| a.precursor() == self.weight_type)
255 {
256 return packing.0.clone();
257 }
258 }
259 }
260 clone_box(
261 tract_linalg::ops()
262 .all_possible_packing(self.weight_type.clone())
263 .filter_map(|p| self.cost_for_weights(p).map(|cost| (p, cost)))
264 .min_by_key(|(_p, cost)| *cost)
265 .unwrap()
266 .0,
267 )
268 }
269}
270
271impl Deref for EinSumAnnotatedAsLinear<'_> {
272 type Target = EinSum;
273 fn deref(&self) -> &Self::Target {
274 self.op
275 }
276}
277
278pub(crate) fn optimize(
279 op: &EinSum,
280 model: &TypedModel,
281 node: &TypedNode,
282) -> TractResult<Option<TypedModelPatch>> {
283 if (op.q_params.is_none() && node.inputs.len() != 2)
284 || (op.q_params.is_some() && node.inputs.len() != 9)
285 {
286 return Ok(None);
287 }
288
289 let input_facts = model.node_input_facts(node.id)?;
290 if node.inputs.len() == 2 && input_facts[1].konst.is_some() {
291 return Ok(Some(transpose(op, model, node)?));
292 }
293
294 let annotated = match ensure_mkn_axes(op, model, node)? {
295 AxesOrPatch::Annotated(op) => op,
296 AxesOrPatch::Patch(p) => return Ok(Some(p)),
297 AxesOrPatch::NotAMatMul(_, _) => return Ok(None),
298 };
299 if op.q_params.is_none() {
300 optimized_mat_mul(model, node, &annotated).context("Translating to OptMatMul")
301 } else {
302 dequant(model, node, annotated).context("Dequantize")
303 }
304}
305
306fn transpose(op: &EinSum, model: &TypedModel, node: &TypedNode) -> TractResult<TypedModelPatch> {
307 let mut patch = TypedModelPatch::default();
308 let mut taps = patch.taps(model, &node.inputs)?;
309 taps.swap(0, 1);
310 let mut op = op.clone();
311 op.axes.iter_all_axes_mut().for_each(|axis| axis.inputs.swap(0, 1));
312 let wire = patch.wire_node(&node.name, op, &taps)?[0];
313 patch.shunt_outside(model, node.id.into(), wire)?;
314 Ok(patch)
315}
316
317pub(crate) fn ensure_mkn_axes<'a>(
318 op: &'a EinSum,
319 model: &TypedModel,
320 node: &TypedNode,
321) -> TractResult<AxesOrPatch<'a>> {
322 let input_facts = model.node_input_facts(node.id)?;
323 let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
324 let output_shape = super::eval::output_shape(&op.axes, &input_shapes)?;
325 let k_axes: TVec<&Axis> = op
326 .axes
327 .iter_all_axes()
328 .filter(|a| a.inputs[0].len() == 1 && a.inputs[1].len() == 1 && a.outputs[0].is_empty())
330 .collect();
331
332 let non_trivial_k_axis = k_axes
333 .iter()
334 .filter(|a| {
335 !input_shapes[0][a.inputs[0][0]].is_one() || !input_shapes[1][a.inputs[1][0]].is_one()
336 })
337 .collect::<TVec<_>>();
338
339 let k_axis = if non_trivial_k_axis.len() > 1 {
340 return Ok(AxesOrPatch::NotAMatMul(
342 "multiple k-axis candidate found",
343 non_trivial_k_axis.into_iter().cloned().collect_vec(),
344 ));
345 } else {
346 non_trivial_k_axis.first().copied().or_else(|| k_axes.first()).copied()
347 };
348 let Some(k_axis) = k_axis else {
349 return Ok(AxesOrPatch::Patch(inject_k_axis(op, model, node)?));
350 };
351
352 let mut possible_m_axes: Vec<_> = op
353 .axes
354 .iter_all_axes()
355 .filter(|a| {
356 a.inputs[0].len() == 1
357 && (a.inputs[1].is_empty() || input_shapes[1][a.inputs[1][0]].is_one())
358 && (a.outputs[0].len() == 1
359 || (input_shapes[0][a.inputs[0][0]].is_one() && a.inputs[1].is_empty()))
360 })
361 .collect();
362
363 if possible_m_axes.iter().any(|a| !a.outputs[0].is_empty()) {
365 possible_m_axes.retain(|a| !a.outputs[0].is_empty());
366 }
367
368 let m_axis = possible_m_axes
369 .into_iter()
370 .max_by_key(|a| input_shapes[0][a.inputs[0][0]].as_i64().unwrap_or(i64::MAX));
371
372 let Some(m_axis) = m_axis else {
373 return Ok(AxesOrPatch::Patch(inject_m_or_n_axis(op, model, node, false)?));
374 };
375
376 let n_axis = op
377 .axes
378 .iter_all_axes()
379 .filter(|a| {
380 (a.inputs[0].is_empty() || input_shapes[0][a.inputs[0][0]].is_one())
381 && a.inputs[1].len() == 1
382 && a.outputs[0].len() == 1
383 && *a != m_axis
384 })
385 .max_by_key(|a| input_shapes[1][a.inputs[1][0]].as_i64().unwrap_or(i64::MAX));
386 let Some(n_axis) = n_axis else {
387 return Ok(AxesOrPatch::Patch(inject_m_or_n_axis(op, model, node, true)?));
388 };
389 for axis in op.axes.iter_all_axes() {
390 let one = TDim::one();
391 let in_left =
392 axis.inputs[0].first().map(|pos| &input_shapes[0][*pos]).unwrap_or(&one) != &one;
393 let in_right =
394 axis.inputs[1].first().map(|pos| &input_shapes[1][*pos]).unwrap_or(&one) != &one;
395 let in_out = axis.outputs[0].first().map(|pos| &output_shape[*pos]).unwrap_or(&one) != &one;
396 if (in_left ^ in_right) && !in_out {
397 return Ok(AxesOrPatch::NotAMatMul(
398 "non trivial single-side disappearing axis",
399 vec![axis],
400 ));
401 }
402 }
403 let m = input_shapes[0][m_axis.inputs[0][0]].clone();
404 let k = input_shapes[0][k_axis.inputs[0][0]].clone();
405 let n = input_shapes[1][n_axis.inputs[1][0]].clone();
406 Ok(AxesOrPatch::Annotated(EinSumAnnotatedAsMatMul { op, m_axis, k_axis, n_axis, m, k, n }))
407}
408
409pub(super) fn inject_k_axis(
410 op: &EinSum,
411 model: &TypedModel,
412 node: &TypedNode,
413) -> TractResult<TypedModelPatch> {
414 let mut new_axes = op.axes.clone();
415 let name = &node.name;
416 let mut patch = TypedModelPatch::new("inject k axis");
417 let mut wire = patch.taps(model, &node.inputs)?;
418 let repr = new_axes.available_label();
419 new_axes = new_axes.with_extra_axis(repr, InOut::In(0), 0)?.with_extra_axis_occurency(
420 repr,
421 InOut::In(1),
422 0,
423 )?;
424 wire[0] = patch.wire_node(format!("{name}.add_k.0"), AxisOp::Add(0), &[wire[0]])?[0];
425 wire[1] = patch.wire_node(format!("{name}.add_k.1"), AxisOp::Add(0), &[wire[1]])?[0];
426 wire = patch.wire_node(&node.name, EinSum { axes: new_axes, ..op.clone() }, &wire)?;
427 patch.shunt_outside(model, node.id.into(), wire[0])?;
428 Ok(patch)
429}
430
431pub(super) fn inject_m_or_n_axis(
432 op: &EinSum,
433 model: &TypedModel,
434 node: &TypedNode,
435 is_n: bool,
436) -> TractResult<TypedModelPatch> {
437 let input_to_fix = is_n as usize;
438 let label = if is_n { "n" } else { "m" };
439 let name = &node.name;
440 let mut patch = TypedModelPatch::new("Injecting m or n axis");
441 let mut wire = patch.taps(model, &node.inputs)?;
442 let repr = op.axes.available_label();
443 let new_axes = op
444 .axes
445 .clone()
446 .with_extra_axis(repr, InOut::In(input_to_fix), 0)?
447 .with_extra_axis_occurency(repr, InOut::Out(0), 0)?;
448 wire[input_to_fix] =
449 patch.wire_node(format!("{name}.add_{label}"), AxisOp::Add(0), &[wire[input_to_fix]])?[0];
450 wire = patch.wire_node(name, EinSum { axes: new_axes, ..op.clone() }, &wire)?;
451 wire = patch.wire_node(&node.name, AxisOp::Rm(0), &wire)?;
452 patch.shunt_outside(model, node.id.into(), wire[0])?;
453 Ok(patch)
454}
455
456fn wire_axes_fix(
457 patch: &mut TypedModelPatch,
458 name: &str,
459 var: &str,
460 mapping: &AxesMapping,
461 mut outlet: TVec<OutletId>,
462) -> TractResult<TVec<OutletId>> {
463 for (ix, axis_op) in mapping.translate_to_axis_ops()?.into_iter().enumerate() {
464 outlet = patch.wire_node(format!("{name}.fix_{var}.{ix})"), axis_op, &outlet)?;
465 }
466 Ok(outlet)
467}
468
469fn dequant(
470 model: &TypedModel,
471 node: &TypedNode,
472 op: EinSumAnnotatedAsMatMul,
473) -> TractResult<Option<TypedModelPatch>> {
474 let name = &node.name;
475 let mut patch = TypedModelPatch::new("Dequantizing einsum");
476
477 let mut taps = patch.taps(model, &node.inputs)?;
478 for ab in [0, 1] {
479 let scale_input = 4 + ab * 2;
480 if !patch.outlet_fact(taps[scale_input])?.shape.volume().is_one() {
481 let q_axis_in_output = op.axes.axis((InOut::In(scale_input), 0))?.outputs[0][0];
482 let output_rank = node.outputs[0].fact.rank();
483 for i in 1..(output_rank - q_axis_in_output) {
484 taps[scale_input] = patch.wire_node(
485 format!("{name}.scale_input{ab}_axis_fix_{i}"),
486 AxisOp::Add(i),
487 &[taps[scale_input]],
488 )?[0];
489 }
490 }
491 }
492
493 let [mut a, mut b, bias, mut a0, a_scale, mut b0, b_scale, c0, c_scale] = *taps else {
494 bail!("Expect exactly 9 inputs")
495 };
496
497 wire_ensure_q8_flavour(&mut patch, &node.name, &mut a, "a", &mut a0, i8::datum_type())?;
498 wire_ensure_q8_flavour(&mut patch, &node.name, &mut b, "b", &mut b0, i8::datum_type())?;
499
500 let mut output = patch.wire_node(
501 &node.name,
502 EinSum {
503 q_params: None,
504 axes: op.axes.extract_sub_mapping(&[0, 1], &[0])?,
505 operating_dt: op.operating_dt,
506 },
507 &[a, b],
508 )?;
509
510 let a_i32 = patch.wire_node(format!("{name}.a_as_i32"), cast(i32::datum_type()), &[a])?[0];
511 let b_i32 = patch.wire_node(format!("{name}.b_as_i32"), cast(i32::datum_type()), &[b])?[0];
512 let sum_a = patch.wire_node(
513 format!("{name}.sum_a"),
514 Reduce::new(tvec!(op.k_axis.inputs[0][0]), Reducer::Sum),
515 &[a_i32],
516 )?;
517 let sum_b = patch.wire_node(
518 format!("{name}.sum_b"),
519 Reduce::new(tvec!(op.k_axis.inputs[1][0]), Reducer::Sum),
520 &[b_i32],
521 )?;
522
523 let sum_a =
524 wire_axes_fix(&mut patch, name, "sum_a", &op.axes.extract_sub_mapping(&[0], &[0])?, sum_a)?;
525 let sum_b =
526 wire_axes_fix(&mut patch, name, "sum_b", &op.axes.extract_sub_mapping(&[1], &[0])?, sum_b)?;
527 let bias = tvec!(bias);
528 let bias =
529 wire_axes_fix(&mut patch, name, "bias", &op.axes.extract_sub_mapping(&[2], &[0])?, bias)?;
530
531 let abc_scale = combine_scales(&mut patch, name, a_scale, b_scale, c_scale)?;
532
533 output = patch.wire_node(format!("{name}.add_bias"), add(), &[output[0], bias[0]])?;
534
535 let k = model.outlet_fact(node.inputs[0])?.shape[op.k_axis.inputs[0][0]].clone();
536 let output = compensate_zero_points(&mut patch, name, output[0], k, a0, b0, sum_a[0], sum_b[0])
537 .context("Zero point compensation")?;
538 let output = requant(&mut patch, name, output, op.q_params.unwrap(), abc_scale, c0)?;
539 patch.shunt_outside(model, node.id.into(), output)?;
540 Ok(Some(patch))
541}
542
543fn optimized_mat_mul(
544 model: &TypedModel,
545 node: &TypedNode,
546 op: &EinSumAnnotatedAsMatMul,
547) -> TractResult<Option<TypedModelPatch>> {
548 let input_facts = model.node_input_facts(node.id)?;
549 let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
550 let must_transpose = input_facts[0].konst.is_none()
551 && match (op.m.as_i64(), op.n.as_i64()) {
552 (Some(m), Some(n)) => m < n,
553 (None, Some(n)) => n >= 8,
554 _ => false,
555 };
556 if must_transpose {
557 return Ok(Some(transpose(op, model, node)?));
558 }
559
560 if input_facts[0].konst.is_some()
561 && (input_facts[0].datum_type.is_number()
562 || input_facts[0].opaque_fact().is_some_and(|of| of.is::<BlockQuantFact>()))
563 {
564 return Ok(None);
565 }
566
567 let mut patch = TypedModelPatch::new("Einsum to OptMatMul");
568 let name = &node.name;
569 let taps = patch.taps(model, &node.inputs)?;
570 let (a, b, mmms, mode_picker) =
571 wire_packing(&mut patch, name, &taps[0..2], op).context("Wiring packing")?;
572
573 let mut c_to_a_axis_mapping = tvec!();
574 let mut c_to_b_axis_mapping = tvec!();
575 for axis in op
576 .op
577 .axes
578 .iter_all_axes()
579 .filter(|&axis| ![op.m_axis, op.k_axis, op.n_axis].contains(&axis))
580 {
581 if let (&[c], &[a]) = (&*axis.outputs[0], &*axis.inputs[0]) {
582 if input_shapes[0][a] != 1.to_dim() {
583 let a = a - (a > op.a_m()) as usize - (a > op.a_k()) as usize;
584 c_to_a_axis_mapping.push((c, a));
585 }
586 }
587 if let (&[c], &[b]) = (&*axis.outputs[0], &*axis.inputs[1]) {
588 if input_shapes[1][b] != 1.to_dim() {
589 let b = b - (b > op.b_n()) as usize - (b > op.b_k()) as usize;
590 c_to_b_axis_mapping.push((c, b));
591 }
592 }
593 }
594
595 let c_fact = op.output_facts(&input_facts)?.remove(0);
596 let geo = AddMatMulGeometry {
597 k: op.k.clone(),
598 c_to_a_axis_mapping: MapOutputAxisToInput(c_to_a_axis_mapping),
599 c_to_b_axis_mapping: MapOutputAxisToInput(c_to_b_axis_mapping),
600 };
601 let (mmms, packings, extractor): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(mmms);
602 let outputs = mmms.iter().map(|mmm| unsafe { mmm.c_view(op.c_m(), op.c_n()) }).collect();
603 let trivial_packing =
604 mmms.len() == 1 && packings[0] == 0 && patch.outlet_fact(a)?.opaque_fact.is_none();
605 let opt = OptMatMul::new(
606 mmms,
607 mode_picker,
608 c_fact,
609 op.c_m(),
610 op.c_n(),
611 vec![
612 ProtoFusedSpec::AddMatMul {
613 geo,
614 a: 0,
615 b: 1,
616 packings: izip!(packings, extractor).collect_vec(),
617 },
618 ProtoFusedSpec::Store(outputs),
619 ],
620 trivial_packing,
621 )
622 .context("Creating OptMatMul")?;
623 let output = patch.wire_node(name, opt, &[a, b])?[0];
624 patch.shunt_outside(model, node.id.into(), output)?;
625 Ok(Some(patch))
626}