1use std::fmt::Formatter;
2use std::ops::Deref;
3
4use tract_itertools::{izip, multiunzip};
5use tract_linalg::block_quant::PackedBlockQuantFormat;
6use tract_linalg::pack::PackedFormat;
7
8use super::*;
9use crate::ops::cast::cast;
10use crate::ops::math::add;
11use crate::ops::matmul::optimized::{
12 AddMatMulGeometry, MapOutputAxisToInput, OptMatMul, ProtoFusedSpec,
13};
14use crate::ops::matmul::pack::{OptMatMulPack, OptSimpleMatMulPack};
15use crate::ops::matmul::quant::{
16 combine_scales, compensate_zero_points, requant, wire_ensure_q8_flavour,
17};
18use crate::ops::matmul::ModePicker;
19use crate::ops::nn::{Reduce, Reducer};
20
21pub fn detect_all(model: &mut TypedModel) -> TractResult<()> {
22 Rewriter::default().with_rule_for("detect-matmul-einsum", detect_rule).rewrite(&(), model)
23}
24
25pub fn flatten_all(model: &mut TypedModel) -> TractResult<()> {
26 Rewriter::default().with_rule_for("flatten-matmul-einsum", flatten_rule).rewrite(&(), model)
27}
28
29#[derive(Clone, Hash, PartialEq)]
30pub struct EinSumMatMul {
31 pub op: EinSum,
32 pub m_axis: char,
33 pub k_axis: char,
34 pub n_axis: char,
35 pub m: TDim,
36 pub k: TDim,
37 pub n: TDim,
38}
39
40impl EinSumMatMul {
41 pub fn m_axis(&self) -> &Axis {
42 self.op.axes.axis(self.m_axis).unwrap()
43 }
44 pub fn k_axis(&self) -> &Axis {
45 self.op.axes.axis(self.k_axis).unwrap()
46 }
47 pub fn n_axis(&self) -> &Axis {
48 self.op.axes.axis(self.n_axis).unwrap()
49 }
50 pub fn a_m(&self) -> usize {
51 self.m_axis().inputs[0][0]
52 }
53 pub fn a_k(&self) -> usize {
54 self.k_axis().inputs[0][0]
55 }
56 pub fn b_k(&self) -> usize {
57 self.k_axis().inputs[1][0]
58 }
59 pub fn b_n(&self) -> usize {
60 self.n_axis().inputs[1][0]
61 }
62 pub fn c_m(&self) -> Option<usize> {
63 self.m_axis().outputs[0].first().cloned()
64 }
65 pub fn c_n(&self) -> Option<usize> {
66 self.n_axis().outputs[0].first().cloned()
67 }
68
69 fn new(
70 op: EinSum,
71 m_axis: char,
72 k_axis: char,
73 n_axis: char,
74 m: TDim,
75 k: TDim,
76 n: TDim,
77 ) -> Self {
78 Self { op, m_axis, k_axis, n_axis, m, k, n }
79 }
80}
81
82impl Debug for EinSumMatMul {
83 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
84 write!(
85 f,
86 "EinsumMatMul: {} {:?} m: {}={}; k: {}={}; n: {}={}",
87 self.op.axes,
88 self.op.operating_dt,
89 self.m_axis,
90 self.m,
91 self.k_axis,
92 self.k,
93 self.n_axis,
94 self.n
95 )
96 }
97}
98
99impl Deref for EinSumMatMul {
100 type Target = EinSum;
101 fn deref(&self) -> &Self::Target {
102 &self.op
103 }
104}
105
106impl Op for EinSumMatMul {
107 fn name(&self) -> StaticName {
108 "EinSumMatMul".into()
109 }
110
111 op_as_typed_op!();
112 impl_op_same_as!();
113}
114
115impl EvalOp for EinSumMatMul {
116 fn is_stateless(&self) -> bool {
117 true
118 }
119 fn eval_with_session(
120 &self,
121 node_id: usize,
122 session: &SessionState,
123 inputs: TVec<TValue>,
124 ) -> TractResult<TVec<TValue>> {
125 self.op.eval_with_session(node_id, session, inputs)
126 }
127}
128
129impl TypedOp for EinSumMatMul {
130 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
131 self.op.output_facts(inputs)
132 }
133
134 fn codegen(
135 &self,
136 model: &TypedModel,
137 node: &TypedNode,
138 ) -> TractResult<Option<TypedModelPatch>> {
139 if node.inputs.len() == 9 {
141 ensure!(self.op.q_params.is_some());
142 return dequant(model, node, self).map(Some);
143 }
144 ensure!(node.inputs.len() == 2);
145 let (a, b) = model.node_input_facts(node.id)?.into_iter().collect_tuple().unwrap();
146 let must_transpose = if let Some(of) = a.opaque_fact() {
148 ensure!(of.is::<BlockQuantFact>());
149 false
150 } else if let Some(of) = b.opaque_fact() {
151 ensure!(of.is::<BlockQuantFact>());
152 true
153 } else if self.m == self.n {
154 false
155 } else {
156 match (self.m.as_i64(), self.n.as_i64()) {
157 (Some(m), Some(n)) => m < n,
158 (None, Some(n)) => n >= 8,
159 (Some(_), _) => false,
160 _ => (self.n.clone() - &self.m).prove_positive_or_zero(),
161 }
162 };
163 if must_transpose {
164 let mut op = self.clone();
165 op.op.axes.iter_all_axes_mut().for_each(|axis| axis.inputs.swap(0, 1));
166 std::mem::swap(&mut op.m_axis, &mut op.n_axis);
167 std::mem::swap(&mut op.m, &mut op.n);
168 return TypedModelPatch::replace_single_op(
169 model,
170 node,
171 &[node.inputs[1], node.inputs[0]],
172 op,
173 )
174 .map(|p| Some(p.with_context("transposing")));
175 }
176 if self.c_m().is_some() || self.c_n().is_some() {
178 return optimized_mat_mul(model, node, self)
179 .map(|opt| opt.map(|p| p.with_context("optimizing")));
180 }
181 Ok(None)
182 }
183
184 as_op!();
185}
186
187pub(crate) fn detect_rule(
188 _ctx: &(),
189 model: &TypedModel,
190 node: &TypedNode,
191 _name: &str,
192 op: &EinSum,
193) -> TractResult<Option<TypedModelPatch>> {
194 if node.inputs.len() != (2 + op.q_params.is_some() as usize * 7) {
195 return Ok(None);
196 }
197 let input_facts = model.node_input_facts(node.id)?;
198 let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
199 let output_shape = super::eval::output_shape(&op.axes, &input_shapes)?;
200 let k_axes: TVec<&Axis> = op
201 .axes
202 .iter_all_axes()
203 .filter(|a| a.inputs[0].len() == 1 && a.inputs[1].len() == 1 && a.outputs[0].is_empty())
205 .collect();
206
207 let non_trivial_k_axis = k_axes
208 .iter()
209 .filter(|a| {
210 !input_shapes[0][a.inputs[0][0]].is_one() || !input_shapes[1][a.inputs[1][0]].is_one()
211 })
212 .copied()
213 .collect::<TVec<_>>();
214
215 let k_axis = if non_trivial_k_axis.len() > 1 {
216 return regroup_k_axes(op, model, node, non_trivial_k_axis);
217 } else {
218 non_trivial_k_axis.first().or_else(|| k_axes.first()).copied()
219 };
220 let Some(k_axis) = k_axis else { return inject_k_axis(op, model, node).map(Some) };
221
222 let mut possible_m_axes: Vec<_> = op
223 .axes
224 .iter_all_axes()
225 .filter(|a| {
226 a.inputs[0].len() == 1
227 && (a.inputs[1].is_empty() || input_shapes[1][a.inputs[1][0]].is_one())
228 && (a.outputs[0].len() == 1
229 || (input_shapes[0][a.inputs[0][0]].is_one() && a.inputs[1].is_empty()))
230 })
231 .collect();
232
233 if possible_m_axes.iter().any(|a| !a.outputs[0].is_empty()) {
235 possible_m_axes.retain(|a| !a.outputs[0].is_empty());
236 }
237
238 let m_axis = possible_m_axes
239 .into_iter()
240 .max_by_key(|a| input_shapes[0][a.inputs[0][0]].as_i64().unwrap_or(i64::MAX));
241
242 let Some(m_axis) = m_axis else {
243 return inject_m_or_n_axis(op, model, node, false).map(Some);
244 };
245
246 let n_axis = op
247 .axes
248 .iter_all_axes()
249 .filter(|a| {
250 (a.inputs[0].is_empty() || input_shapes[0][a.inputs[0][0]].is_one())
251 && a.inputs[1].len() == 1
252 && a.outputs[0].len() == 1
253 && *a != m_axis
254 })
255 .max_by_key(|a| input_shapes[1][a.inputs[1][0]].as_i64().unwrap_or(i64::MAX));
256 let Some(n_axis) = n_axis else {
257 return inject_m_or_n_axis(op, model, node, true).map(Some);
258 };
259 for axis in op.axes.iter_all_axes() {
260 let one = TDim::one();
261 let in_left =
262 axis.inputs[0].first().map(|pos| &input_shapes[0][*pos]).unwrap_or(&one) != &one;
263 let in_right =
264 axis.inputs[1].first().map(|pos| &input_shapes[1][*pos]).unwrap_or(&one) != &one;
265 let in_out = axis.outputs[0].first().map(|pos| &output_shape[*pos]).unwrap_or(&one) != &one;
266 if (in_left ^ in_right) && !in_out {
267 return Ok(None);
268 }
273 }
274 let m = input_shapes[0][m_axis.inputs[0][0]].clone();
275 let k = input_shapes[0][k_axis.inputs[0][0]].clone();
276 let n = input_shapes[1][n_axis.inputs[1][0]].clone();
277 TypedModelPatch::replace_single_op(
278 model,
279 node,
280 &node.inputs,
281 EinSumMatMul::new(op.clone(), m_axis.repr, k_axis.repr, n_axis.repr, m, k, n),
282 )
283 .map(Some)
284}
285
286pub(super) fn inject_k_axis(
287 op: &EinSum,
288 model: &TypedModel,
289 node: &TypedNode,
290) -> TractResult<TypedModelPatch> {
291 let mut new_axes = op.axes.clone();
292 let name = &node.name;
293 let mut patch = TypedModelPatch::new("inject k axis");
294 let mut wire = patch.taps(model, &node.inputs)?;
295 let repr = new_axes.available_label();
296 new_axes = new_axes.with_extra_axis(repr, InOut::In(0), 0)?.with_extra_axis_occurency(
297 repr,
298 InOut::In(1),
299 0,
300 )?;
301 wire[0] = patch.wire_node(format!("{name}.add_k.0"), AxisOp::Add(0), &[wire[0]])?[0];
302 wire[1] = patch.wire_node(format!("{name}.add_k.1"), AxisOp::Add(0), &[wire[1]])?[0];
303 wire = patch.wire_node(&node.name, EinSum { axes: new_axes, ..op.clone() }, &wire)?;
304 patch.shunt_outside(model, node.id.into(), wire[0])?;
305 Ok(patch)
306}
307
308pub(super) fn regroup_k_axes(
309 op: &EinSum,
310 model: &TypedModel,
311 node: &TypedNode,
312 mut k_axes: TVec<&Axis>,
313) -> TractResult<Option<TypedModelPatch>> {
314 let input_facts = model.node_input_facts(node.id)?;
315 let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
316 let contig_in_a = k_axes
317 .iter()
318 .map(|axis| axis.inputs[0][0])
319 .sorted()
320 .tuple_windows()
321 .all(|(a, b)| a + 1 == b);
322 if contig_in_a {
323 k_axes.sort_by_key(|ax| ax.inputs[0][0]);
324 } else {
325 k_axes.sort_by_key(|ax| ax.inputs[1][0]);
326 }
327 let k_dims: TVec<_> =
328 k_axes.iter().map(|ax| input_shapes[0][ax.inputs[0][0]].clone()).collect();
329 let k: TDim = k_dims.iter().product();
330 let mut patch = TypedModelPatch::default();
331 let mut wires = patch.taps(model, &node.inputs)?;
332 let mut exprs: Vec<String> =
333 (0..2).map(|slot| op.axes.axes(InOut::In(slot)).map(|ax| ax.repr).join("")).collect();
334 for slot in 0..2 {
335 if k_axes.iter().map(|ax| ax.inputs[slot][0]).tuple_windows().any(|(a, b)| a + 1 != b) {
336 let after = op
337 .axes
338 .axes(InOut::In(slot))
339 .filter(|ax| !k_axes.contains(ax))
340 .chain(k_axes.iter().copied())
341 .map(|ax| ax.repr)
342 .join("");
343 let transpose =
344 AxesMapping::from_strs(&[&exprs[slot]], &[&after])?.translate_to_axis_ops()?;
345 for (ix, op) in transpose.into_iter().enumerate() {
346 wires[slot] = patch.wire_node(
347 format!("{}.transpose_input_{}.{}", &node.name, slot, ix),
348 op,
349 &[wires[slot]],
350 )?[0];
351 }
352 exprs[slot] = after;
353 }
354 let pos = exprs[slot].chars().position(|c| k_axes[0].repr == c).unwrap();
355 wires[slot] = patch.wire_node(
356 format!("{}.fold_k_in_input_{}", &node.name, slot),
357 AxisOp::Reshape(pos, k_dims.clone(), tvec!(k.clone())),
358 &[wires[slot]],
359 )?[0];
360 exprs[slot] =
361 exprs[slot].chars().filter(|c| !k_axes.iter().any(|k| k.repr == *c)).collect();
362 exprs[slot].insert(pos, k_axes[0].repr);
363 }
364 let old = op.axes.to_string();
365 let (iexpr, oexpr) = old.split_once("->").unwrap();
366 let mut expr: String = exprs.iter().join(",");
367 if node.inputs.len() > 2 {
368 expr = expr + "," + &iexpr.split(",").skip(2).join(",");
369 }
370 expr = expr + "->" + oexpr;
371 let wire = patch.wire_node(
372 &node.name,
373 EinSum { axes: expr.parse().unwrap(), ..op.clone() },
374 &wires,
375 )?[0];
376 patch.shunt_outside(model, node.id.into(), wire)?;
377 Ok(Some(patch))
378}
379
380pub(super) fn inject_m_or_n_axis(
381 op: &EinSum,
382 model: &TypedModel,
383 node: &TypedNode,
384 is_n: bool,
385) -> TractResult<TypedModelPatch> {
386 let input_to_fix = is_n as usize;
387 let label = if is_n { "n" } else { "m" };
388 let name = &node.name;
389 let mut patch = TypedModelPatch::new("Injecting m or n axis");
390 let mut wire = patch.taps(model, &node.inputs)?;
391 let repr = op.axes.available_label();
392 let new_axes = op
393 .axes
394 .clone()
395 .with_extra_axis(repr, InOut::In(input_to_fix), 0)?
396 .with_extra_axis_occurency(repr, InOut::Out(0), 0)?;
397 wire[input_to_fix] =
398 patch.wire_node(format!("{name}.add_{label}"), AxisOp::Add(0), &[wire[input_to_fix]])?[0];
399 wire = patch.wire_node(name, EinSum { axes: new_axes, ..op.clone() }, &wire)?;
400 wire = patch.wire_node(&node.name, AxisOp::Rm(0), &wire)?;
401 patch.shunt_outside(model, node.id.into(), wire[0])?;
402 Ok(patch)
403}
404
405fn wire_axes_fix(
406 patch: &mut TypedModelPatch,
407 name: &str,
408 var: &str,
409 mapping: &AxesMapping,
410 mut outlet: TVec<OutletId>,
411) -> TractResult<TVec<OutletId>> {
412 for (ix, axis_op) in mapping.translate_to_axis_ops()?.into_iter().enumerate() {
413 outlet = patch.wire_node(format!("{name}.fix_{var}.{ix})"), axis_op, &outlet)?;
414 }
415 Ok(outlet)
416}
417
418fn dequant(
419 model: &TypedModel,
420 node: &TypedNode,
421 op: &EinSumMatMul,
422) -> TractResult<TypedModelPatch> {
423 let name = &node.name;
424 let mut patch = TypedModelPatch::new("Dequantizing einsum");
425
426 let k_axis = op.k_axis();
427
428 let mut taps = patch.taps(model, &node.inputs)?;
429 for ab in [0, 1] {
430 let scale_input = 4 + ab * 2;
431 if !patch.outlet_fact(taps[scale_input])?.shape.volume().is_one() {
432 let q_axis_in_output = op.axes.axis((InOut::In(scale_input), 0))?.outputs[0][0];
433 let output_rank = node.outputs[0].fact.rank();
434 for i in 1..(output_rank - q_axis_in_output) {
435 taps[scale_input] = patch.wire_node(
436 format!("{name}.scale_input{ab}_axis_fix_{i}"),
437 AxisOp::Add(i),
438 &[taps[scale_input]],
439 )?[0];
440 }
441 }
442 }
443
444 let [mut a, mut b, bias, mut a0, a_scale, mut b0, b_scale, c0, c_scale] = *taps else {
445 bail!("Expect exactly 9 inputs")
446 };
447
448 wire_ensure_q8_flavour(&mut patch, &node.name, &mut a, "a", &mut a0, i8::datum_type())?;
449 wire_ensure_q8_flavour(&mut patch, &node.name, &mut b, "b", &mut b0, i8::datum_type())?;
450
451 let mut output = patch.wire_node(
452 &node.name,
453 EinSum {
454 q_params: None,
455 axes: op.axes.extract_sub_mapping(&[0, 1], &[0])?,
456 operating_dt: op.operating_dt,
457 },
458 &[a, b],
459 )?;
460
461 let a_i32 = patch.wire_node(format!("{name}.a_as_i32"), cast(i32::datum_type()), &[a])?[0];
462 let b_i32 = patch.wire_node(format!("{name}.b_as_i32"), cast(i32::datum_type()), &[b])?[0];
463 let sum_a = patch.wire_node(
464 format!("{name}.sum_a"),
465 Reduce::new(tvec!(k_axis.inputs[0][0]), Reducer::Sum),
466 &[a_i32],
467 )?;
468 let sum_b = patch.wire_node(
469 format!("{name}.sum_b"),
470 Reduce::new(tvec!(k_axis.inputs[1][0]), Reducer::Sum),
471 &[b_i32],
472 )?;
473
474 let sum_a =
475 wire_axes_fix(&mut patch, name, "sum_a", &op.axes.extract_sub_mapping(&[0], &[0])?, sum_a)?;
476 let sum_b =
477 wire_axes_fix(&mut patch, name, "sum_b", &op.axes.extract_sub_mapping(&[1], &[0])?, sum_b)?;
478 let bias = tvec!(bias);
479 let bias =
480 wire_axes_fix(&mut patch, name, "bias", &op.axes.extract_sub_mapping(&[2], &[0])?, bias)?;
481
482 let abc_scale = combine_scales(&mut patch, name, a_scale, b_scale, c_scale)?;
483
484 output = patch.wire_node(format!("{name}.add_bias"), add(), &[output[0], bias[0]])?;
485
486 let k = model.outlet_fact(node.inputs[0])?.shape[k_axis.inputs[0][0]].clone();
487 let output = compensate_zero_points(&mut patch, name, output[0], k, a0, b0, sum_a[0], sum_b[0])
488 .context("Zero point compensation")?;
489 let output = requant(&mut patch, name, output, op.q_params.unwrap(), abc_scale, c0)?;
490 patch.shunt_outside(model, node.id.into(), output)?;
491 Ok(patch)
492}
493
494fn flatten_rule(
495 _ctx: &(),
496 model: &TypedModel,
497 node: &TypedNode,
498 _name: &str,
499 op: &EinSumMatMul,
500) -> TractResult<Option<TypedModelPatch>> {
501 TypedModelPatch::replace_single_op(model, node, &node.inputs, op.op.clone()).map(Some)
502}
503
504fn optimized_mat_mul(
505 model: &TypedModel,
506 node: &TypedNode,
507 op: &EinSumMatMul,
508) -> TractResult<Option<TypedModelPatch>> {
509 let (mode_picker, left_pack, impls) = kernel_selection::strategize(model, node, op)?;
510 let input_facts = model.node_input_facts(node.id)?;
511 let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
512 let prefix = &node.name;
513
514 let mut patch = TypedModelPatch::new("Einsum to OptMatMul");
515 let taps = patch.taps(model, &node.inputs)?;
516 let name = &node.name;
517
518 let pack_a: Box<dyn TypedOp> = if input_facts[0].konst.is_some() {
519 if let Some(pf) = left_pack.downcast_ref::<PackedFormat>() {
520 Box::new(OptMatMulPack {
521 packers: vec![pf.clone()],
522 mode_picker: ModePicker::Single,
523 k_axis: op.a_k(),
524 mn_axis: op.a_m(),
525 })
526 } else if let Some(packed_format) =
527 left_pack.downcast_ref::<PackedBlockQuantFormat>().cloned()
528 {
529 Box::new(OptSimpleMatMulPack {
530 packed_format,
531 k: input_shapes[0][op.a_k()].to_usize().unwrap(),
532 m: input_shapes[0][op.a_m()].to_usize().unwrap(),
533 })
534 } else {
535 bail!("Unexpected static input format {left_pack:?}");
536 }
537 } else {
538 Box::new(OptMatMulPack {
539 packers: impls
540 .iter()
541 .map(|(mmm, p, pe)| {
542 pe.as_ref()
543 .map(|pe| &pe.from)
544 .unwrap_or(&mmm.packings()[*p].0)
545 .downcast_ref::<PackedFormat>()
546 .unwrap()
547 .clone()
548 })
549 .collect(),
550 mode_picker: mode_picker.clone(),
551 k_axis: op.a_k(),
552 mn_axis: op.a_m(),
553 })
554 };
555 let pa = patch.wire_node(format!("{prefix}.pack_a"), pack_a, &[taps[0]])?[0];
556
557 let pb = patch.wire_node(
558 format!("{prefix}.pack_b"),
559 OptMatMulPack {
560 k_axis: op.b_k(),
561 mn_axis: op.b_n(),
562 packers: impls
563 .iter()
564 .map(|(mmm, p, _)| {
565 mmm.packings()[*p].1.downcast_ref::<PackedFormat>().unwrap().clone()
566 })
567 .collect(),
568 mode_picker: mode_picker.clone(),
569 },
570 &[taps[1]],
571 )?[0];
572
573 let mut c_to_a_axis_mapping = tvec!();
574 let mut c_to_b_axis_mapping = tvec!();
575 for axis in op
576 .op
577 .axes
578 .iter_all_axes()
579 .filter(|&axis| ![op.m_axis, op.k_axis, op.n_axis].contains(&axis.repr))
580 {
581 if let (&[c], &[a]) = (&*axis.outputs[0], &*axis.inputs[0]) {
582 if input_shapes[0][a] != 1.to_dim() {
583 let a = a - (a > op.a_m()) as usize - (a > op.a_k()) as usize;
584 c_to_a_axis_mapping.push((c, a));
585 }
586 }
587 if let (&[c], &[b]) = (&*axis.outputs[0], &*axis.inputs[1]) {
588 if input_shapes[1][b] != 1.to_dim() {
589 let b = b - (b > op.b_n()) as usize - (b > op.b_k()) as usize;
590 c_to_b_axis_mapping.push((c, b));
591 }
592 }
593 }
594
595 let c_fact = op.output_facts(&input_facts)?.remove(0);
596 let geo = AddMatMulGeometry {
597 k: op.k.clone(),
598 c_to_a_axis_mapping: MapOutputAxisToInput(c_to_a_axis_mapping),
599 c_to_b_axis_mapping: MapOutputAxisToInput(c_to_b_axis_mapping),
600 };
601 let (mmms, packings, extractor): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(impls);
602 let outputs = mmms.iter().map(|mmm| unsafe { mmm.c_view(op.c_m(), op.c_n()) }).collect();
603 let trivial_packing = mmms.len() == 1
604 && packings[0] == 0
605 && extractor[0].is_none()
606 && input_facts[0].opaque_fact.is_none();
607 let opt = OptMatMul::new(
608 mmms,
609 mode_picker,
610 c_fact,
611 op.c_m(),
612 op.c_n(),
613 vec![
614 ProtoFusedSpec::AddMatMul {
615 geo,
616 a: 0,
617 b: 1,
618 packings: izip!(packings, extractor).collect_vec(),
619 },
620 ProtoFusedSpec::Store(outputs),
621 ],
622 trivial_packing,
623 )
624 .context("Creating OptMatMul")?;
625 let output = patch.wire_node(name, opt, &[pa, pb])?[0];
626 patch.shunt_outside(model, node.id.into(), output)?;
627 Ok(Some(patch))
628}