1use crate::internal::*;
2use crate::ops::cast::{Cast, cast};
3use crate::ops::change_axes::wire_with_rank_broadcast;
4use crate::ops::nn::LeakyRelu;
5use ndarray::*;
6use tract_itertools::Itertools;
7
8use tract_linalg::mmm::{
9 AsInputValue, EagerPackedInput, FusedSpec, MMMInputValue, MatMatMul, OutputStoreSpec,
10 PanelExtractInput, PanelExtractor,
11};
12use tract_linalg::pack::PackedFormat;
13use tract_linalg::{BinOp, Scaler};
14use tract_smallvec::ToSmallVec;
15
16use super::ModePicker;
17
18#[derive(Clone, Debug)]
19pub enum ProtoFusedSpec {
20 AddMatMul {
21 geo: AddMatMulGeometry,
22 a: usize,
23 b: usize,
24 packings: Vec<(usize, Option<PanelExtractor>)>,
25 },
26 BinScalar(usize, BinOp),
27 LeakyRelu(usize),
28 BinPerRow(usize, BinOp, MapOutputAxisToInput),
29 BinPerCol(usize, BinOp, MapOutputAxisToInput),
30 AddRowColProducts(usize, usize),
31 AddUnicast(OutputStoreSpec, usize, MapOutputAxisToInput),
32 Scaler(Scaler),
33 Store(Vec<OutputStoreSpec>),
34}
35
36impl ProtoFusedSpec {
37 pub fn format(&self, mmm: &dyn MatMatMul, mode: usize) -> String {
38 use ProtoFusedSpec::*;
39 match self {
40 AddMatMul { geo, packings: packing, .. } => {
41 let (a, b) = &mmm.packings()[packing[mode].0];
42 format!("matmul(k={}, {a:?}•{b:?})", geo.k)
43 }
44 BinScalar(_, op) => format!("scalar{op:?}"),
45 LeakyRelu(alpha) => format!("leaky_relu({alpha:?})"),
46 BinPerRow(_, op, _) => format!("row{op:?}"),
47 BinPerCol(_, op, _) => format!("col{op:?}"),
48 AddRowColProducts(_, _) => "add_row_col_product".to_string(),
49 AddUnicast(_, _, _) => "add_to_matrix".to_string(),
50 Scaler(s) => format!("scale({})", 1f32 * *s),
51 Store(_oss) => "store".to_string(),
52 }
53 }
54
55 pub fn resolve<'t>(
56 &'t self,
57 inputs: &'t [TValue],
58 output_coords: &[usize],
59 output: &Tensor,
60 mmm: &dyn MatMatMul,
61 mode: usize,
62 ) -> FusedSpec<'t> {
63 #[allow(clippy::let_and_return)]
64 let fs = match self {
65 ProtoFusedSpec::AddMatMul { geo, a, b, packings } => {
66 let mut a = inputs[*a].view();
67 let mut b = inputs[*b].view();
68 unsafe {
69 geo.c_to_a_axis_mapping.translate_view(output_coords, &mut a);
70 }
71 let a = a.as_slice::<Opaque>().unwrap()[0]
72 .downcast_ref::<Box<dyn MMMInputValue>>()
73 .unwrap();
74 unsafe {
75 geo.c_to_b_axis_mapping.translate_view(output_coords, &mut b);
76 }
77 let b = b.as_slice::<Opaque>().unwrap()[0]
78 .downcast_ref::<Box<dyn MMMInputValue>>()
79 .unwrap();
80 let (_a_packing, b_packing) = &mmm.packings()[packings[mode].0];
81 let pa = if let Some(extractor) = &packings[mode].1 {
82 let data = a.downcast_ref::<EagerPackedInput>().unwrap();
83 AsInputValue::Owned(Box::new(PanelExtractInput {
84 format: extractor.clone(),
85 data: data.clone(),
86 }))
87 } else {
88 AsInputValue::Borrowed(&**a)
89 };
90 assert!(
91 b_packing.same_as(b.format())
92 || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
93 );
94 debug_assert!(pa.k().to_dim().compatible_with(&geo.k.to_dim()));
95 debug_assert!(b.k().to_dim().compatible_with(&geo.k.to_dim()));
96 FusedSpec::AddMatMul {
97 a: pa,
98 b: AsInputValue::Borrowed(&**b),
99 packing: packings[mode].0,
100 }
101 }
102 ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
103 ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
104 ProtoFusedSpec::BinPerRow(v, op, map) => {
105 let mut v = inputs[*v].view();
106 unsafe { map.translate_view(output_coords, &mut v) }
107 FusedSpec::BinPerRow(v, *op)
108 }
109 ProtoFusedSpec::BinPerCol(v, op, map) => {
110 let mut v = inputs[*v].view();
111 unsafe { map.translate_view(output_coords, &mut v) }
112 FusedSpec::BinPerCol(v, *op)
113 }
114 ProtoFusedSpec::AddRowColProducts(row, col) => {
115 FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
116 }
117 ProtoFusedSpec::AddUnicast(store, v, map) => unsafe {
118 let mut view = inputs[*v].view();
119 map.translate_view(output_coords, &mut view);
120 FusedSpec::AddUnicast(store.wrap(&view))
121 },
122 ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
123 ProtoFusedSpec::Store(oss) => unsafe {
124 let view = output.view_offsetting_unchecked(output_coords);
125 FusedSpec::Store(oss[mode].wrap(&view))
126 },
127 };
128 fs
129 }
130
131 pub fn is_trivial(&self) -> bool {
132 match self {
133 ProtoFusedSpec::AddMatMul { geo, .. } => geo.k.as_i64().is_some(),
134 _ => true,
135 }
136 }
137
138 pub fn resolve_trivial<'t>(
139 &'t self,
140 inputs: &'t [TValue],
141 output: &mut Tensor,
142 _mmm: &dyn MatMatMul,
143 mode: usize,
144 ) -> FusedSpec<'t> {
145 #[allow(clippy::let_and_return)]
146 let fs = match self {
147 ProtoFusedSpec::AddMatMul { a, b, packings, .. } => unsafe {
148 debug_assert!(inputs.get(*a).is_some());
149 debug_assert!(inputs.get(*b).is_some());
150 let a = inputs.get_unchecked(*a);
151 let b = inputs.get_unchecked(*b);
152 debug_assert!(a.datum_type().is_opaque());
153 debug_assert!(a.len() == 1);
154 debug_assert!(b.datum_type().is_opaque());
155 debug_assert!(b.len() == 1);
156 let a = a.as_slice_unchecked::<Opaque>().get_unchecked(0);
157 let b = b.as_slice_unchecked::<Opaque>().get_unchecked(0);
158 debug_assert!(a.is::<Box<dyn MMMInputValue>>());
159 debug_assert!(b.is::<Box<dyn MMMInputValue>>());
160 let a = a.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
161 let b = b.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
162 debug_assert!(packings.len() == 1);
163 debug_assert!(packings[0].1.is_none()); #[cfg(debug_assertions)]
165 {
166 let (a_packing, b_packing) = &_mmm.packings()[packings[mode].0];
167 debug_assert!(
168 a_packing.same_as(a.format())
169 || (a_packing.is::<PackedFormat>() && a_packing.r() == a.format().r())
170 );
171 debug_assert!(
172 b_packing.same_as(b.format())
173 || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
174 );
175 }
176 FusedSpec::AddMatMul {
177 a: AsInputValue::Borrowed(&**a),
178 b: AsInputValue::Borrowed(&**b),
179 packing: packings[mode].0,
180 }
181 },
182 ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
183 ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
184 ProtoFusedSpec::BinPerRow(v, op, _) => {
185 let v = inputs[*v].view();
186 FusedSpec::BinPerRow(v, *op)
187 }
188 ProtoFusedSpec::BinPerCol(v, op, _) => {
189 let v = inputs[*v].view();
190 FusedSpec::BinPerCol(v, *op)
191 }
192 ProtoFusedSpec::AddRowColProducts(row, col) => {
193 FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
194 }
195 ProtoFusedSpec::AddUnicast(store, v, _) => unsafe {
196 let view = inputs[*v].view();
197 FusedSpec::AddUnicast(store.wrap(&view))
198 },
199 ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
200 ProtoFusedSpec::Store(oss) => unsafe {
201 FusedSpec::Store(oss[mode].wrap(&output.view_mut()))
202 },
203 };
204 fs
205 }
206
207 fn check_inputs(&self, inputs: &[&TypedFact]) -> TractResult<()> {
208 use ProtoFusedSpec::*;
209 match self {
210 AddMatMul { a, b, .. } => {
211 ensure!(inputs[*a].datum_type == Opaque::datum_type());
212 ensure!(inputs[*b].datum_type == Opaque::datum_type());
213 }
214 BinScalar(v, _)
215 | LeakyRelu(v)
216 | BinPerCol(v, _, _)
217 | BinPerRow(v, _, _)
218 | AddUnicast(_, v, _) => {
219 ensure!(inputs[*v].datum_type.is_number());
220 }
221 AddRowColProducts(row, col) => {
222 ensure!(inputs[*row].datum_type.is_number());
223 ensure!(inputs[*col].datum_type.is_number());
224 }
225 _ => (),
226 };
227 Ok(())
228 }
229
230 fn cost(&self, m: &TDim, n: &TDim, idt: DatumType) -> TVec<(Cost, TDim)> {
231 match self {
232 ProtoFusedSpec::AddMatMul { geo, .. } => {
233 tvec!((Cost::FMA(idt), m.clone() * n * &geo.k))
234 }
235 _ => tvec!(), }
237 }
238
239 fn rm_c_axis(&mut self, axis: usize) {
240 use ProtoFusedSpec::*;
241 match self {
242 AddMatMul { geo, .. } => {
243 geo.c_to_a_axis_mapping.rm_c_axis(axis);
244 geo.c_to_b_axis_mapping.rm_c_axis(axis);
245 }
246 BinScalar(..) | Scaler(..) | AddRowColProducts(_, _) | LeakyRelu(_) => {}
247 BinPerRow(_, _, map) | BinPerCol(_, _, map) => map.rm_c_axis(axis),
248 AddUnicast(_, _, map) => {
249 map.rm_c_axis(axis);
250 }
251 Store(oss, ..) => {
252 for oss in oss {
253 match oss {
254 OutputStoreSpec::View { m_axis, n_axis, .. } => {
255 if let Some(m) = m_axis {
256 *m -= (*m > axis) as usize
257 };
258 if let Some(n) = n_axis {
259 *n -= (*n > axis) as usize
260 }
261 }
262 OutputStoreSpec::Strides { .. } => {}
263 }
264 }
265 }
266 }
267 }
268}
269
270#[derive(Clone, Debug)]
271pub struct MapOutputAxisToInput(pub TVec<(usize, usize)>);
272
273impl MapOutputAxisToInput {
274 #[inline]
275 unsafe fn translate_view(&self, output_coords: &[usize], v: &mut TensorView) {
276 for &(out_axis, in_axis) in &self.0 {
277 unsafe { v.offset_axis(in_axis, output_coords[out_axis] as isize) }
278 }
279 }
280
281 #[inline]
282 fn rm_c_axis(&mut self, axis: usize) {
283 for (c, _) in &mut self.0 {
284 *c -= (*c > axis) as usize;
285 }
286 }
287}
288
289#[derive(Clone, Debug)]
290pub struct AddMatMulGeometry {
291 pub k: TDim,
292 pub c_to_a_axis_mapping: MapOutputAxisToInput,
293 pub c_to_b_axis_mapping: MapOutputAxisToInput,
294}
295
296#[derive(Clone, Debug)]
297pub struct OptMatMul {
298 pub c_fact: TypedFact,
299 pub micro_ops: Vec<ProtoFusedSpec>,
300 pub mmm: Vec<Box<dyn MatMatMul>>,
301 pub mode_picker: ModePicker,
302 pub c_m_axis: Option<usize>,
303 pub c_n_axis: Option<usize>,
304 pub trivial_packing: bool,
305 pub trivial_path: bool,
306}
307
308impl Op for OptMatMul {
309 fn name(&self) -> StaticName {
310 "OptMatMul".into()
311 }
312
313 fn info(&self) -> TractResult<Vec<String>> {
314 let m = self.c_m_axis.map(|ix| &self.c_fact.shape[ix]).unwrap_or(&TDim::Val(1));
315 let n = self.c_n_axis.map(|ix| &self.c_fact.shape[ix]).unwrap_or(&TDim::Val(1));
316 let mut infos = vec![format!(
317 "c_shape:{:?}, c_m_axis:{:?} c_n_axis:{:?} m:{} n:{}",
318 self.c_fact, self.c_m_axis, self.c_n_axis, m, n,
319 )];
320 if let Some(k) = self.guess_k() {
321 infos.push(format!("Mult: m:{} k:{} n:{} with {:?}", m, k, n, self.mmm));
322 } else {
323 infos.push(format!("Mult: {:?}", self.mmm));
324 }
325 for (mode, mmm) in self.mmm.iter().enumerate() {
326 infos.push(format!(
327 "Ops: {}",
328 self.micro_ops.iter().map(|o| o.format(&**mmm, mode)).join(" >>> ")
329 ));
330 }
331 Ok(infos)
332 }
333
334 op_as_typed_op!();
335}
336
337impl EvalOp for OptMatMul {
338 fn is_stateless(&self) -> bool {
339 true
340 }
341
342 fn eval_with_session(
343 &self,
344 _node_id: usize,
345 session: &TurnState,
346 inputs: TVec<TValue>,
347 ) -> TractResult<TVec<TValue>> {
348 unsafe {
349 let c_shape = self.c_fact.shape.eval_to_usize(&session.resolved_symbols)?;
350 let mut c = Tensor::uninitialized_dt(self.c_fact.datum_type, &c_shape)?;
351 let m = self.c_m_axis.map(|c_m| c.shape()[c_m]).unwrap_or(1);
352 let n = self.c_n_axis.map(|c_n| c.shape()[c_n]).unwrap_or(1);
353 let mode = self.mode_picker.pick(n)?;
354 let mmm = &*self.mmm[mode];
355 let mut cell = session.cached_mmm_scratch_space.borrow_mut();
356 if !cell.as_ref().is_some_and(|scratch| mmm.can_use_scratch_space(&**scratch)) {
357 *cell = None
358 }
359 let scratch = cell.get_or_insert_with(|| mmm.allocate_scratch_space());
360 if self.trivial_path {
361 let uops: Vec<FusedSpec> = self
362 .micro_ops
363 .iter()
364 .map(|o| o.resolve_trivial(&inputs, &mut c, mmm, mode))
365 .collect();
366 mmm.run_with_scratch_space(m, n, scratch.as_mut(), &uops)?;
367 Ok(tvec!(c.into_tvalue()))
368 } else {
369 let mut uops = vec![FusedSpec::ShiftLeft(0); self.micro_ops.len()];
370 let mut looping_shape: TVec<usize> = c_shape.to_smallvec();
371 if let Some(ax) = self.c_m_axis {
372 looping_shape[ax] = 1;
373 }
374 if let Some(ax) = self.c_n_axis {
375 looping_shape[ax] = 1;
376 }
377 for c_coords in indices(&*looping_shape) {
378 for ix in 0..self.micro_ops.len() {
379 *uops.get_unchecked_mut(ix) = self.micro_ops.get_unchecked(ix).resolve(
380 &inputs,
381 c_coords.slice(),
382 &c,
383 mmm,
384 mode,
385 );
386 }
387 mmm.run_with_scratch_space(m, n, scratch.as_mut(), &uops)
388 .context("In mmm.run_with_scratch_space")?;
389 }
390 Ok(tvec!(c.into_tvalue()))
391 }
392 }
393 }
394}
395
396impl TypedOp for OptMatMul {
397 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
398 ensure!(self.c_m_axis.map(|ax| ax < self.c_fact.rank()).unwrap_or(true));
399 ensure!(self.c_n_axis.map(|ax| ax < self.c_fact.rank()).unwrap_or(true));
400 ensure!(self.trivial_path == self.can_use_trivial_path());
401 ensure!(self.mmm.iter().map(|mmm| mmm.internal_type()).all_equal());
402 for op in &self.micro_ops {
403 op.check_inputs(inputs)?;
404 }
405 Ok(tvec!(self.c_fact.clone()))
406 }
407
408 fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
409 let mut sums = HashMap::new();
410 for op in &self.micro_ops {
411 for (cost, count) in op.cost(self.m(), self.n(), self.mmm[0].internal_type()) {
412 *sums.entry(cost).or_default() += count;
413 }
414 }
415 let loops = self
416 .c_fact
417 .shape
418 .iter()
419 .enumerate()
420 .map(|(ix, d)| {
421 if Some(ix) == self.c_m_axis || Some(ix) == self.c_n_axis {
422 1.to_dim()
423 } else {
424 d.clone()
425 }
426 })
427 .product::<TDim>();
428 for s in &mut sums.values_mut() {
429 *s *= &loops;
430 }
431 Ok(sums.into_iter().collect())
432 }
433
434 fn fuse(&self, model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
435 use crate::ops;
436 rule_if!(node.outputs.len() == 1);
437 rule_if!(node.outputs[0].successors.len() == 1);
438 rule_if!(!model.output_outlets()?.contains(&node.id.into()));
439 let succ = model.node(node.outputs[0].successors[0].node);
440 let mut patch = TypedModelPatch::new(format!("fusing {succ}"));
441
442 if let Some(op) = succ.op_as::<ops::binary::TypedBinOp>() {
443 rule_if_some!(mut binop = op.0.as_linalg_binop());
444 let flipped = succ.inputs[0].node == node.id;
445 if flipped {
446 binop = binop.flip();
447 }
448 let other_outlet = succ.inputs[flipped as usize];
449 return self.fuse_binary(model, node, patch, other_outlet, binop);
450 }
451 if let Some(op) = succ.op_as::<ops::binary::OptBinByScalar>() {
452 rule_if_some!(mut binop = op.binop.as_linalg_binop());
453 let flipped = succ.inputs[0].node == node.id;
454 if flipped {
455 binop = binop.flip();
456 }
457 let other_outlet = succ.inputs[flipped as usize];
458 return self.fuse_binary(model, node, patch, other_outlet, binop);
459 }
460
461 if let Some(op) = succ.op_as::<ops::element_wise::ElementWiseOp>().map(|ew| ew.0.as_ref()) {
462 if let Some(op) = op.downcast_ref::<ops::math::QScale>() {
463 return self.fuse_op(
464 model,
465 node,
466 patch,
467 vec![ProtoFusedSpec::Scaler(op.scaler)],
468 &[],
469 );
470 }
471 if let Some(op) = op.downcast_ref::<LeakyRelu>() {
472 rule_if!(
473 self.mmm
474 .iter()
475 .all(|mmm| mmm.can_fuse(&FusedSpec::LeakyRelu(&tensor0(op.alpha))))
476 );
477 let alpha = patch.add_const(
478 node.name.to_string() + ".alpha",
479 tensor0(op.alpha).cast_to_dt(self.mmm[0].internal_type())?.into_owned(),
480 )?;
481 return self.fuse_op(
482 model,
483 node,
484 patch,
485 vec![ProtoFusedSpec::LeakyRelu(node.inputs.len())],
486 &[alpha],
487 );
488 }
489 }
490 if let Some(cast_to) = succ.op_as::<ops::cast::Cast>().map(|cast| cast.to) {
491 if ((cast_to.unquantized() == i8::datum_type()
492 || cast_to.unquantized() == u8::datum_type())
493 && self.c_fact.datum_type == i32::datum_type())
494 || self.mmm.iter().all(|m| m.stores().contains(&cast_to))
495 {
496 if let Some(ProtoFusedSpec::Store(stores)) = self.micro_ops.last() {
497 if stores.iter().any(|s| matches!(s, OutputStoreSpec::Strides { .. })) {
498 return Ok(None);
499 }
500 let c_fact = cast_to.fact(self.c_fact.shape.clone());
501 let mut patch = TypedModelPatch::fuse_with_next(
502 model,
503 node,
504 Self { c_fact, ..self.clone() },
505 )?;
506 patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
507 return Ok(Some(patch));
508 }
509 }
510 }
511 if let Some(AxisOp::Rm(axis)) = succ.op_as::<ops::AxisOp>() {
512 rule_if!(Some(*axis) != self.c_m_axis);
513 rule_if!(Some(*axis) != self.c_n_axis);
514 let mut new_op = self.clone();
515 new_op.c_fact.shape.remove_axis(*axis)?;
516 if let Some(c_m_axis) = &mut new_op.c_m_axis {
517 *c_m_axis -= (*c_m_axis > *axis) as usize;
518 }
519 if let Some(c_n_axis) = &mut new_op.c_n_axis {
520 *c_n_axis -= (*c_n_axis > *axis) as usize;
521 }
522 for uop in &mut new_op.micro_ops {
523 uop.rm_c_axis(*axis);
524 }
525 let mut patch = TypedModelPatch::fuse_with_next(model, node, new_op)?;
526 patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
527 return Ok(Some(patch));
528 }
529 if succ.op_is::<AxisOp>() || succ.op_is::<IntoShape>() {
530 if let &[next] = &*succ.outputs[0].successors {
531 let next_node = model.node(next.node);
532 if let Some(cast) = next_node.op_as::<Cast>() {
533 let mut patch = TypedModelPatch::default();
534 let mut wire = patch.tap_model(model, node.id.into())?;
535 wire = patch.wire_node(&next_node.name, cast.clone(), &[wire])?[0];
536 wire = patch.wire_node(&succ.name, succ.op.clone(), &[wire])?[0];
537 patch.shunt_outside(model, next_node.id.into(), wire)?;
538 return Ok(Some(patch));
539 } else if let Some(op) = next_node.op_as::<ops::binary::TypedBinOp>() {
540 rule_if!(op.0.as_linalg_binop().is_some());
541 let flipped = succ.inputs[0].node == node.id;
542 let other_outlet = next_node.inputs[flipped as usize];
543 if let Some(uni) = &model.outlet_fact(other_outlet)?.uniform {
544 let mut patch = TypedModelPatch::default();
545 let cst =
546 patch.add_const(&model.node(other_outlet.node).name, uni.clone())?;
547 let output = patch.tap_model(model, node.id.into())?;
548 let wire = wire_with_rank_broadcast(
549 &next_node.name,
550 &mut patch,
551 op.clone(),
552 &if flipped { [output, cst] } else { [cst, output] },
553 )?;
554 let wire = patch.wire_node(&succ.name, succ.op.clone(), &wire)?[0];
555 patch.shunt_outside(model, next_node.id.into(), wire)?;
556 return Ok(Some(patch));
557 }
558 }
559 }
560 }
561 if let Some(op) = succ.op_as::<ops::binary::OptBinUnicast>() {
562 let in_1_fact = model.outlet_fact(succ.inputs[0])?;
563 let in_2_fact = model.outlet_fact(succ.inputs[1])?;
564 if op.binop.is::<ops::math::Add>()
565 && self.mmm.len() == 1
566 && in_1_fact.without_value() == in_2_fact.without_value()
567 {
568 let other_slot = 1 - node.outputs[0].successors[0].slot;
569 let other_input = succ.inputs[other_slot];
570 let other_input = patch.tap_model(model, other_input)?;
571 let other_fact = patch.outlet_fact(other_input)?;
572
573 if other_fact.shape == self.c_fact.shape {
574 let other_storage = unsafe { self.mmm[0].c_view(self.c_m_axis, self.c_n_axis) };
575 let mapping =
576 MapOutputAxisToInput((0..other_fact.rank()).map(|x| (x, x)).collect());
577 return self.fuse_op(
578 model,
579 node,
580 patch,
581 vec![ProtoFusedSpec::AddUnicast(other_storage, node.inputs.len(), mapping)],
582 &[other_input],
583 );
584 }
585 } else {
586 rule_if_some!(mut binop = op.binop.as_linalg_binop());
587 let flipped = succ.inputs[0].node == node.id;
588 if flipped {
589 binop = binop.flip();
590 }
591 let other_outlet = succ.inputs[flipped as usize];
592 return self.fuse_binary(model, node, patch, other_outlet, binop);
593 }
594 };
595 Ok(None)
596 }
597
598 as_op!();
599}
600
601impl OptMatMul {
602 pub fn new(
603 mmm: Vec<Box<dyn MatMatMul>>,
604 mode_picker: ModePicker,
605 c_fact: TypedFact,
606 c_m_axis: Option<usize>,
607 c_n_axis: Option<usize>,
608 micro_ops: Vec<ProtoFusedSpec>,
609 trivial_packing: bool,
610 ) -> TractResult<Self> {
611 if let Some(m) = c_m_axis {
612 ensure!(m < c_fact.rank());
613 }
614 if let Some(n) = c_n_axis {
615 ensure!(n < c_fact.rank());
616 }
617 let mut it = OptMatMul {
618 mmm,
619 mode_picker,
620 c_fact,
621 c_m_axis,
622 c_n_axis,
623 micro_ops,
624 trivial_path: false,
625 trivial_packing,
626 };
627 it.update_trivial_path();
628 Ok(it)
629 }
630
631 pub fn guess_k(&self) -> Option<TDim> {
633 self.micro_ops
634 .iter()
635 .find_map(
636 |o| {
637 if let ProtoFusedSpec::AddMatMul { geo, .. } = o { Some(geo) } else { None }
638 },
639 )
640 .map(|geo| geo.k.clone())
641 }
642
643 #[inline]
644 pub fn m(&self) -> &TDim {
645 self.c_m_axis.map(|ax| &self.c_fact.shape[ax]).unwrap_or(&TDim::Val(1))
646 }
647
648 #[inline]
649 pub fn n(&self) -> &TDim {
650 self.c_n_axis.map(|ax| &self.c_fact.shape[ax]).unwrap_or(&TDim::Val(1))
651 }
652
653 fn update_trivial_path(&mut self) {
654 self.trivial_path = self.can_use_trivial_path();
655 }
656
657 fn can_use_trivial_path(&self) -> bool {
658 self.c_fact.shape.is_concrete()
659 && self.c_fact.shape.iter().enumerate().all(|(ax, dim)| {
660 Some(ax) == self.c_m_axis || Some(ax) == self.c_n_axis || dim.is_one()
661 })
662 && self.trivial_packing
663 && self.micro_ops.iter().all(|o| o.is_trivial())
664 }
665
666 fn fuse_op(
667 &self,
668 model: &TypedModel,
669 node: &TypedNode,
670 mut patch: TypedModelPatch,
671 fused_micro_op: Vec<ProtoFusedSpec>,
672 additional_inputs: &[OutletId],
673 ) -> TractResult<Option<TypedModelPatch>> {
674 let succ = model.node(node.outputs[0].successors[0].node);
675 let mut new_op = self.clone();
676 let before_last = new_op.micro_ops.len() - 1..new_op.micro_ops.len() - 1;
677 new_op.micro_ops.splice(before_last, fused_micro_op);
678 new_op.c_fact = succ.outputs[0].fact.clone();
679 new_op.update_trivial_path();
680 let mut inputs = patch.taps(model, &node.inputs)?;
681 inputs.extend(additional_inputs.iter().cloned());
682 let output = patch.wire_node(&succ.name, new_op, &inputs)?;
683 patch.shunt_outside(model, succ.id.into(), output[0])?;
684 Ok(Some(patch))
685 }
686
687 fn fuse_binary(
688 &self,
689 model: &TypedModel,
690 node: &TypedNode,
691 mut patch: TypedModelPatch,
692 value: OutletId,
693 binop: BinOp,
694 ) -> TractResult<Option<TypedModelPatch>> {
695 let fact = model.outlet_fact(value)?;
696 let mut v = patch.tap_model(model, value)?;
697 if fact.datum_type != self.mmm[0].internal_type() {
698 v = patch.wire_node(
699 format!("{}.cast-input-{}", node.name, node.inputs.len()),
700 cast(self.mmm[0].internal_type()),
701 &[v],
702 )?[0];
703 }
704 let value = node.inputs.len();
705 let additional_input = tvec!(v);
706 if fact.shape.volume() == 1.to_dim() {
707 return self.fuse_op(
708 model,
709 node,
710 patch,
711 vec![ProtoFusedSpec::BinScalar(value, binop)],
712 &additional_input,
713 );
714 }
715 let other_shape = fact.shape.to_owned();
716 if self.c_m_axis.is_some_and(|ax| {
717 other_shape[ax] == self.c_fact.shape[ax] && other_shape[ax] == other_shape.volume()
718 }) {
719 return self.fuse_op(
720 model,
721 node,
722 patch,
723 vec![ProtoFusedSpec::BinPerRow(
724 value,
725 binop,
726 MapOutputAxisToInput(tvec!((self.c_m_axis.unwrap(), self.c_m_axis.unwrap()))),
727 )],
728 &additional_input,
729 );
730 }
731 if self.c_n_axis.is_some_and(|ax| {
732 other_shape[ax] == self.c_fact.shape[ax] && other_shape[ax] == other_shape.volume()
733 }) {
734 return self.fuse_op(
735 model,
736 node,
737 patch,
738 vec![ProtoFusedSpec::BinPerCol(
739 value,
740 binop,
741 MapOutputAxisToInput(tvec!((self.c_n_axis.unwrap(), self.c_n_axis.unwrap()))),
742 )],
743 &additional_input,
744 );
745 }
746 Ok(None)
747 }
748}