1use crate::internal::*;
2use crate::ops::cast::cast;
3use crate::ops::change_axes::wire_with_rank_broadcast;
4use crate::ops::nn::LeakyRelu;
5use ndarray::*;
6use tract_itertools::Itertools;
7
8use tract_linalg::mmm::{
9 AsInputValue, EagerPackedInput, FusedSpec, MMMInputValue, MatMatMul, OutputStoreSpec,
10 PanelExtractInput, PanelExtractor,
11};
12use tract_linalg::pack::PackedFormat;
13use tract_linalg::{BinOp, Scaler};
14use tract_smallvec::ToSmallVec;
15
16use super::ModePicker;
17
18#[derive(Clone, Debug)]
19pub enum ProtoFusedSpec {
20 AddMatMul {
21 geo: AddMatMulGeometry,
22 a: usize,
23 b: usize,
24 packings: Vec<(usize, Option<PanelExtractor>)>,
25 },
26 BinScalar(usize, BinOp),
27 LeakyRelu(usize),
28 BinPerRow(usize, BinOp, MapOutputAxisToInput),
29 BinPerCol(usize, BinOp, MapOutputAxisToInput),
30 AddRowColProducts(usize, usize),
31 AddUnicast(OutputStoreSpec, usize, MapOutputAxisToInput),
32 Scaler(Scaler),
33 Store(Vec<OutputStoreSpec>),
34}
35
36impl ProtoFusedSpec {
37 pub fn format(&self, mmm: &dyn MatMatMul, mode: usize) -> String {
38 use ProtoFusedSpec::*;
39 match self {
40 AddMatMul { geo, packings: packing, .. } => {
41 let (a, b) = &mmm.packings()[packing[mode].0];
42 format!("matmul(k={}, {a:?}•{b:?})", geo.k)
43 }
44 BinScalar(_, op) => format!("scalar{op:?}"),
45 LeakyRelu(alpha) => format!("leaky_relu({alpha:?})"),
46 BinPerRow(_, op, _) => format!("row{op:?}"),
47 BinPerCol(_, op, _) => format!("col{op:?}"),
48 AddRowColProducts(_, _) => "add_row_col_product".to_string(),
49 AddUnicast(_, _, _) => "add_to_matrix".to_string(),
50 Scaler(s) => format!("scale({})", 1f32 * *s),
51 Store(_oss) => "store".to_string(),
52 }
53 }
54
55 pub fn resolve<'t>(
56 &'t self,
57 inputs: &'t [TValue],
58 output_coords: &[usize],
59 output: &Tensor,
60 mmm: &dyn MatMatMul,
61 mode: usize,
62 ) -> FusedSpec<'t> {
63 let fs = match self {
64 ProtoFusedSpec::AddMatMul { geo, a, b, packings } => {
65 let mut a = inputs[*a].view();
66 let mut b = inputs[*b].view();
67 unsafe {
68 geo.c_to_a_axis_mapping.translate_view(output_coords, &mut a);
69 }
70 let a = a.as_slice::<Opaque>().unwrap()[0]
71 .downcast_ref::<Box<dyn MMMInputValue>>()
72 .unwrap();
73 unsafe {
74 geo.c_to_b_axis_mapping.translate_view(output_coords, &mut b);
75 }
76 let b = b.as_slice::<Opaque>().unwrap()[0]
77 .downcast_ref::<Box<dyn MMMInputValue>>()
78 .unwrap();
79 let (_a_packing, b_packing) = &mmm.packings()[packings[mode].0];
80 let pa = if let Some(extractor) = &packings[mode].1 {
81 let data = a.downcast_ref::<EagerPackedInput>().unwrap();
82 AsInputValue::Owned(Box::new(PanelExtractInput {
83 format: extractor.clone(),
84 data: data.clone(),
85 }))
86 } else {
87 AsInputValue::Borrowed(&**a)
88 };
89 assert!(
90 b_packing.same_as(b.format())
91 || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
92 );
93 debug_assert!(pa.k().to_dim().compatible_with(&geo.k.to_dim()));
94 debug_assert!(b.k().to_dim().compatible_with(&geo.k.to_dim()));
95 FusedSpec::AddMatMul {
96 a: pa,
97 b: AsInputValue::Borrowed(&**b),
98 packing: packings[mode].0,
99 }
100 }
101 ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
102 ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
103 ProtoFusedSpec::BinPerRow(v, op, map) => {
104 let mut v = inputs[*v].view();
105 unsafe { map.translate_view(output_coords, &mut v) }
106 FusedSpec::BinPerRow(v, *op)
107 }
108 ProtoFusedSpec::BinPerCol(v, op, map) => {
109 let mut v = inputs[*v].view();
110 unsafe { map.translate_view(output_coords, &mut v) }
111 FusedSpec::BinPerCol(v, *op)
112 }
113 ProtoFusedSpec::AddRowColProducts(row, col) => {
114 FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
115 }
116 ProtoFusedSpec::AddUnicast(store, v, map) => unsafe {
117 let mut view = inputs[*v].view();
118 map.translate_view(output_coords, &mut view);
119 FusedSpec::AddUnicast(store.wrap(&view))
120 },
121 ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
122 ProtoFusedSpec::Store(oss) => unsafe {
123 let view = output.view_offsetting_unchecked(output_coords);
124 FusedSpec::Store(oss[mode].wrap(&view))
125 },
126 };
127 fs
128 }
129
130 pub fn is_trivial(&self) -> bool {
131 match self {
132 ProtoFusedSpec::AddMatMul { geo, .. } => geo.k.as_i64().is_some(),
133 _ => true,
134 }
135 }
136
137 pub fn resolve_trivial<'t>(
138 &'t self,
139 inputs: &'t [TValue],
140 output: &mut Tensor,
141 _mmm: &dyn MatMatMul,
142 mode: usize,
143 ) -> FusedSpec<'t> {
144 let fs = match self {
145 ProtoFusedSpec::AddMatMul { a, b, packings, .. } => unsafe {
146 debug_assert!(inputs.get(*a).is_some());
147 debug_assert!(inputs.get(*b).is_some());
148 let a = inputs.get_unchecked(*a);
149 let b = inputs.get_unchecked(*b);
150 debug_assert!(a.datum_type().is_opaque());
151 debug_assert!(a.len() == 1);
152 debug_assert!(b.datum_type().is_opaque());
153 debug_assert!(b.len() == 1);
154 let a = a.as_slice_unchecked::<Opaque>().get_unchecked(0);
155 let b = b.as_slice_unchecked::<Opaque>().get_unchecked(0);
156 debug_assert!(a.is::<Box<dyn MMMInputValue>>());
157 debug_assert!(b.is::<Box<dyn MMMInputValue>>());
158 let a = a.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
159 let b = b.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
160 debug_assert!(packings.len() == 1);
161 debug_assert!(packings[0].1.is_none()); #[cfg(debug_assertions)]
163 {
164 let (a_packing, b_packing) = &_mmm.packings()[packings[mode].0];
165 debug_assert!(
166 a_packing.same_as(a.format())
167 || (a_packing.is::<PackedFormat>() && a_packing.r() == a.format().r())
168 );
169 debug_assert!(
170 b_packing.same_as(b.format())
171 || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
172 );
173 }
174 FusedSpec::AddMatMul {
175 a: AsInputValue::Borrowed(&**a),
176 b: AsInputValue::Borrowed(&**b),
177 packing: packings[mode].0,
178 }
179 },
180 ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
181 ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
182 ProtoFusedSpec::BinPerRow(v, op, _) => {
183 let v = inputs[*v].view();
184 FusedSpec::BinPerRow(v, *op)
185 }
186 ProtoFusedSpec::BinPerCol(v, op, _) => {
187 let v = inputs[*v].view();
188 FusedSpec::BinPerCol(v, *op)
189 }
190 ProtoFusedSpec::AddRowColProducts(row, col) => {
191 FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
192 }
193 ProtoFusedSpec::AddUnicast(store, v, _) => unsafe {
194 let view = inputs[*v].view();
195 FusedSpec::AddUnicast(store.wrap(&view))
196 },
197 ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
198 ProtoFusedSpec::Store(oss) => unsafe {
199 FusedSpec::Store(oss[mode].wrap(&output.view_mut()))
200 },
201 };
202 fs
203 }
204
205 fn check_inputs(&self, inputs: &[&TypedFact]) -> TractResult<()> {
206 use ProtoFusedSpec::*;
207 match self {
208 AddMatMul { a, b, .. } => {
209 ensure!(inputs[*a].datum_type == Opaque::datum_type());
210 ensure!(inputs[*b].datum_type == Opaque::datum_type());
211 }
212 BinScalar(v, _)
213 | LeakyRelu(v)
214 | BinPerCol(v, _, _)
215 | BinPerRow(v, _, _)
216 | AddUnicast(_, v, _) => {
217 ensure!(inputs[*v].datum_type.is_number());
218 }
219 AddRowColProducts(row, col) => {
220 ensure!(inputs[*row].datum_type.is_number());
221 ensure!(inputs[*col].datum_type.is_number());
222 }
223 _ => (),
224 };
225 Ok(())
226 }
227
228 fn cost(&self, m: &TDim, n: &TDim, idt: DatumType) -> TVec<(Cost, TDim)> {
229 match self {
230 ProtoFusedSpec::AddMatMul { geo, .. } => {
231 tvec!((Cost::FMA(idt), m.clone() * n * &geo.k))
232 }
233 _ => tvec!(), }
235 }
236
237 fn rm_c_axis(&mut self, axis: usize) {
238 use ProtoFusedSpec::*;
239 match self {
240 AddMatMul { geo, .. } => {
241 geo.c_to_a_axis_mapping.rm_c_axis(axis);
242 geo.c_to_b_axis_mapping.rm_c_axis(axis);
243 }
244 BinScalar(..) | Scaler(..) | AddRowColProducts(_, _) | LeakyRelu(_) => {}
245 BinPerRow(_, _, map) | BinPerCol(_, _, map) => map.rm_c_axis(axis),
246 AddUnicast(_, _, map) => {
247 map.rm_c_axis(axis);
248 }
249 Store(oss, ..) => {
250 for oss in oss {
251 match oss {
252 OutputStoreSpec::View { m_axis, n_axis, .. } => {
253 if let Some(m) = m_axis {
254 *m -= (*m > axis) as usize
255 };
256 if let Some(n) = n_axis {
257 *n -= (*n > axis) as usize
258 }
259 }
260 OutputStoreSpec::Strides { .. } => {}
261 }
262 }
263 }
264 }
265 }
266}
267
268#[derive(Clone, Debug)]
269pub struct MapOutputAxisToInput(pub TVec<(usize, usize)>);
270
271impl MapOutputAxisToInput {
272 #[inline]
273 unsafe fn translate_view(&self, output_coords: &[usize], v: &mut TensorView) {
274 for &(out_axis, in_axis) in &self.0 {
275 v.offset_axis(in_axis, output_coords[out_axis] as isize)
276 }
277 }
278
279 #[inline]
280 fn rm_c_axis(&mut self, axis: usize) {
281 for (c, _) in &mut self.0 {
282 *c -= (*c > axis) as usize;
283 }
284 }
285}
286
287#[derive(Clone, Debug)]
288pub struct AddMatMulGeometry {
289 pub k: TDim,
290 pub c_to_a_axis_mapping: MapOutputAxisToInput,
291 pub c_to_b_axis_mapping: MapOutputAxisToInput,
292}
293
294#[derive(Clone, Debug)]
295pub struct OptMatMul {
296 pub c_fact: TypedFact,
297 pub micro_ops: Vec<ProtoFusedSpec>,
298 pub mmm: Vec<Box<dyn MatMatMul>>,
299 pub mode_picker: ModePicker,
300 pub c_m_axis: Option<usize>,
301 pub c_n_axis: Option<usize>,
302 pub trivial_packing: bool,
303 pub trivial_path: bool,
304}
305
306impl Op for OptMatMul {
307 fn name(&self) -> Cow<str> {
308 "OptMatMul".into()
309 }
310
311 fn info(&self) -> TractResult<Vec<String>> {
312 let m = self.c_m_axis.map(|ix| &self.c_fact.shape[ix]).unwrap_or(&TDim::Val(1));
313 let n = self.c_n_axis.map(|ix| &self.c_fact.shape[ix]).unwrap_or(&TDim::Val(1));
314 let mut infos = vec![format!(
315 "c_shape:{:?}, c_m_axis:{:?} c_n_axis:{:?} m:{} n:{}",
316 self.c_fact, self.c_m_axis, self.c_n_axis, m, n,
317 )];
318 if let Some(k) = self.guess_k() {
319 infos.push(format!("Mult: m:{} k:{} n:{} with {:?}", m, k, n, self.mmm));
320 } else {
321 infos.push(format!("Mult: {:?}", self.mmm));
322 }
323 for (mode, mmm) in self.mmm.iter().enumerate() {
324 infos.push(format!(
325 "Ops: {}",
326 self.micro_ops.iter().map(|o| o.format(&**mmm, mode)).join(" >>> ")
327 ));
328 }
329 Ok(infos)
330 }
331
332 op_as_typed_op!();
333}
334
335impl EvalOp for OptMatMul {
336 fn is_stateless(&self) -> bool {
337 true
338 }
339
340 fn eval_with_session(
341 &self,
342 session: &SessionState,
343 inputs: TVec<TValue>,
344 ) -> TractResult<TVec<TValue>> {
345 unsafe {
346 let c_shape = self.c_fact.shape.eval_to_usize(&session.resolved_symbols)?;
347 let mut c = Tensor::uninitialized_dt(self.c_fact.datum_type, &c_shape)?;
348 let m = self.c_m_axis.map(|c_m| c.shape()[c_m]).unwrap_or(1);
349 let n = self.c_n_axis.map(|c_n| c.shape()[c_n]).unwrap_or(1);
350 let mode = self.mode_picker.pick(n)?;
351 let mmm = &*self.mmm[mode];
352 let mut cell = session.cached_mmm_scratch_space.borrow_mut();
353 if !cell.as_ref().is_some_and(|scratch| mmm.can_use_scratch_space(&**scratch)) {
354 *cell = None
355 }
356 let scratch = cell.get_or_insert_with(|| mmm.allocate_scratch_space());
357 if self.trivial_path {
358 let uops: Vec<FusedSpec> = self
359 .micro_ops
360 .iter()
361 .map(|o| o.resolve_trivial(&inputs, &mut c, mmm, mode))
362 .collect();
363 mmm.run_with_scratch_space(m, n, scratch.as_mut(), &uops)?;
364 Ok(tvec!(c.into_tvalue()))
365 } else {
366 let mut uops = vec![FusedSpec::ShiftLeft(0); self.micro_ops.len()];
367 let mut looping_shape: TVec<usize> = c_shape.to_smallvec();
368 if let Some(ax) = self.c_m_axis {
369 looping_shape[ax] = 1;
370 }
371 if let Some(ax) = self.c_n_axis {
372 looping_shape[ax] = 1;
373 }
374 for c_coords in indices(&*looping_shape) {
375 for ix in 0..self.micro_ops.len() {
376 *uops.get_unchecked_mut(ix) = self.micro_ops.get_unchecked(ix).resolve(
377 &inputs,
378 c_coords.slice(),
379 &c,
380 mmm,
381 mode,
382 );
383 }
384 mmm.run_with_scratch_space(m, n, scratch.as_mut(), &uops)
385 .context("In mmm.run_with_scratch_space")?;
386 }
387 Ok(tvec!(c.into_tvalue()))
388 }
389 }
390 }
391}
392
393impl TypedOp for OptMatMul {
394 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
395 ensure!(self.c_m_axis.map(|ax| ax < self.c_fact.rank()).unwrap_or(true));
396 ensure!(self.c_n_axis.map(|ax| ax < self.c_fact.rank()).unwrap_or(true));
397 ensure!(self.trivial_path == self.can_use_trivial_path());
398 ensure!(self.mmm.iter().map(|mmm| mmm.internal_type()).all_equal());
399 for op in &self.micro_ops {
400 op.check_inputs(inputs)?;
401 }
402 Ok(tvec!(self.c_fact.clone()))
403 }
404
405 fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
406 let mut sums = HashMap::new();
407 for op in &self.micro_ops {
408 for (cost, count) in op.cost(self.m(), self.n(), self.mmm[0].internal_type()) {
409 *sums.entry(cost).or_default() += count;
410 }
411 }
412 let loops = self
413 .c_fact
414 .shape
415 .iter()
416 .enumerate()
417 .map(|(ix, d)| {
418 if Some(ix) == self.c_m_axis || Some(ix) == self.c_n_axis {
419 1.to_dim()
420 } else {
421 d.clone()
422 }
423 })
424 .product::<TDim>();
425 for s in &mut sums.values_mut() {
426 *s *= &loops;
427 }
428 Ok(sums.into_iter().collect())
429 }
430
431 fn fuse(&self, model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
432 use crate::ops;
433 if node.outputs.len() != 1
434 || node.outputs[0].successors.len() != 1
435 || model.output_outlets()?.contains(&node.id.into())
436 {
437 return Ok(None);
438 }
439 let succ = model.node(node.outputs[0].successors[0].node);
440 let mut patch = TypedModelPatch::new(format!("fusing {succ}"));
441
442 if let Some(op) = succ.op_as::<ops::binary::TypedBinOp>() {
443 let mut binop = if let Some(op) = op.0.as_linalg_binop() {
444 op
445 } else {
446 return Ok(None);
447 };
448 let flipped = succ.inputs[0].node == node.id;
449 if flipped {
450 binop = binop.flip();
451 }
452 let other_outlet = succ.inputs[flipped as usize];
453 return self.fuse_binary(model, node, patch, other_outlet, binop);
454 }
455 if let Some(op) = succ.op_as::<ops::binary::OptBinByScalar>() {
456 let mut binop = if let Some(op) = op.binop.as_linalg_binop() {
457 op
458 } else {
459 return Ok(None);
460 };
461 let flipped = succ.inputs[0].node == node.id;
462 if flipped {
463 binop = binop.flip();
464 }
465 let other_outlet = succ.inputs[flipped as usize];
466 return self.fuse_binary(model, node, patch, other_outlet, binop);
467 }
468
469 if let Some(op) = succ.op_as::<ops::element_wise::ElementWiseOp>().map(|ew| ew.0.as_ref()) {
470 if let Some(op) = op.downcast_ref::<ops::math::QScale>() {
471 return self.fuse_op(
472 model,
473 node,
474 patch,
475 vec![ProtoFusedSpec::Scaler(op.scaler)],
476 &[],
477 );
478 }
479 if let Some(op) = op.downcast_ref::<LeakyRelu>() {
480 if !self
481 .mmm
482 .iter()
483 .all(|mmm| mmm.can_fuse(&FusedSpec::LeakyRelu(&tensor0(op.alpha))))
484 {
485 return Ok(None);
486 }
487 let alpha = patch.add_const(
488 node.name.to_string() + ".alpha",
489 tensor0(op.alpha).cast_to_dt(self.mmm[0].internal_type())?.into_owned(),
490 )?;
491 return self.fuse_op(
492 model,
493 node,
494 patch,
495 vec![ProtoFusedSpec::LeakyRelu(node.inputs.len())],
496 &[alpha],
497 );
498 }
499 }
500 if let Some(cast_to) = succ.op_as::<ops::cast::Cast>().map(|cast| cast.to) {
501 if (cast_to.unquantized() == i8::datum_type()
502 || cast_to.unquantized() == u8::datum_type())
503 && self.c_fact.datum_type == i32::datum_type()
504 {
505 if let Some(ProtoFusedSpec::Store(stores)) = self.micro_ops.last() {
506 if stores.iter().any(|s| matches!(s, OutputStoreSpec::Strides { .. })) {
507 return Ok(None);
508 }
509 let c_fact = cast_to.fact(self.c_fact.shape.clone());
510 let mut patch = TypedModelPatch::fuse_with_next(
511 model,
512 node,
513 Self { c_fact, ..self.clone() },
514 )?;
515 patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
516 return Ok(Some(patch));
517 }
518 }
519 }
520 if let Some(AxisOp::Rm(axis)) = succ.op_as::<ops::AxisOp>() {
521 if Some(*axis) == self.c_m_axis || Some(*axis) == self.c_n_axis {
522 return Ok(None);
523 }
524 let mut new_op = self.clone();
525 new_op.c_fact.shape.remove_axis(*axis)?;
526 if let Some(c_m_axis) = &mut new_op.c_m_axis {
527 *c_m_axis -= (*c_m_axis > *axis) as usize;
528 }
529 if let Some(c_n_axis) = &mut new_op.c_n_axis {
530 *c_n_axis -= (*c_n_axis > *axis) as usize;
531 }
532 for uop in &mut new_op.micro_ops {
533 uop.rm_c_axis(*axis);
534 }
535 let mut patch = TypedModelPatch::fuse_with_next(model, node, new_op)?;
536 patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
537 return Ok(Some(patch));
538 }
539 if succ.op_is::<AxisOp>() {
540 if let &[next] = &*succ.outputs[0].successors {
541 let bin = model.node(next.node);
542 if let Some(op) = bin.op_as::<ops::binary::TypedBinOp>() {
543 if op.0.as_linalg_binop().is_none() {
544 return Ok(None);
545 };
546 let flipped = succ.inputs[0].node == node.id;
547 let other_outlet = bin.inputs[flipped as usize];
548 if let Some(uni) = &model.outlet_fact(other_outlet)?.uniform {
549 let mut patch = TypedModelPatch::default();
550 let cst =
551 patch.add_const(&model.node(other_outlet.node).name, uni.clone())?;
552 let output = patch.tap_model(model, node.id.into())?;
553 let wire = wire_with_rank_broadcast(
554 &bin.name,
555 &mut patch,
556 op.clone(),
557 &if flipped { [output, cst] } else { [cst, output] },
558 )?;
559 let wire = patch.wire_node(&succ.name, succ.op.clone(), &wire)?[0];
560 patch.shunt_outside(model, bin.id.into(), wire)?;
561 return Ok(Some(patch));
562 }
563 }
564 }
565 }
566 if let Some(op) = succ.op_as::<ops::binary::OptBinUnicast>() {
567 let in_1_fact = model.outlet_fact(succ.inputs[0])?;
568 let in_2_fact = model.outlet_fact(succ.inputs[1])?;
569 if op.binop.is::<ops::math::Add>()
570 && self.mmm.len() == 1
571 && in_1_fact.without_value() == in_2_fact.without_value()
572 {
573 let other_slot = 1 - node.outputs[0].successors[0].slot;
574 let other_input = succ.inputs[other_slot];
575 let other_input = patch.tap_model(model, other_input)?;
576 let other_fact = patch.outlet_fact(other_input)?;
577
578 if other_fact.shape == self.c_fact.shape {
579 let other_storage = unsafe { self.mmm[0].c_view(self.c_m_axis, self.c_n_axis) };
580 let mapping =
581 MapOutputAxisToInput((0..other_fact.rank()).map(|x| (x, x)).collect());
582 return self.fuse_op(
583 model,
584 node,
585 patch,
586 vec![ProtoFusedSpec::AddUnicast(other_storage, node.inputs.len(), mapping)],
587 &[other_input],
588 );
589 }
590 } else {
591 let mut binop = if let Some(op) = op.binop.as_linalg_binop() {
592 op
593 } else {
594 return Ok(None);
595 };
596 let flipped = succ.inputs[0].node == node.id;
597 if flipped {
598 binop = binop.flip();
599 }
600 let other_outlet = succ.inputs[flipped as usize];
601 return self.fuse_binary(model, node, patch, other_outlet, binop);
602 }
603 };
604 Ok(None)
605 }
606
607 as_op!();
608}
609
610impl OptMatMul {
611 pub fn new(
612 mmm: Vec<Box<dyn MatMatMul>>,
613 mode_picker: ModePicker,
614 c_fact: TypedFact,
615 c_m_axis: Option<usize>,
616 c_n_axis: Option<usize>,
617 micro_ops: Vec<ProtoFusedSpec>,
618 trivial_packing: bool,
619 ) -> TractResult<Self> {
620 if let Some(m) = c_m_axis {
621 ensure!(m < c_fact.rank());
622 }
623 if let Some(n) = c_n_axis {
624 ensure!(n < c_fact.rank());
625 }
626 let mut it = OptMatMul {
627 mmm,
628 mode_picker,
629 c_fact,
630 c_m_axis,
631 c_n_axis,
632 micro_ops,
633 trivial_path: false,
634 trivial_packing,
635 };
636 it.update_trivial_path();
637 Ok(it)
638 }
639
640 pub fn guess_k(&self) -> Option<TDim> {
642 self.micro_ops
643 .iter()
644 .find_map(
645 |o| {
646 if let ProtoFusedSpec::AddMatMul { geo, .. } = o {
647 Some(geo)
648 } else {
649 None
650 }
651 },
652 )
653 .map(|geo| geo.k.clone())
654 }
655
656 #[inline]
657 pub fn m(&self) -> &TDim {
658 self.c_m_axis.map(|ax| &self.c_fact.shape[ax]).unwrap_or(&TDim::Val(1))
659 }
660
661 #[inline]
662 pub fn n(&self) -> &TDim {
663 self.c_n_axis.map(|ax| &self.c_fact.shape[ax]).unwrap_or(&TDim::Val(1))
664 }
665
666 fn update_trivial_path(&mut self) {
667 self.trivial_path = self.can_use_trivial_path();
668 }
669
670 fn can_use_trivial_path(&self) -> bool {
671 self.c_fact.shape.is_concrete()
672 && self.c_fact.shape.iter().enumerate().all(|(ax, dim)| {
673 Some(ax) == self.c_m_axis || Some(ax) == self.c_n_axis || dim.is_one()
674 })
675 && self.trivial_packing
676 && self.micro_ops.iter().all(|o| o.is_trivial())
677 }
678
679 fn fuse_op(
680 &self,
681 model: &TypedModel,
682 node: &TypedNode,
683 mut patch: TypedModelPatch,
684 fused_micro_op: Vec<ProtoFusedSpec>,
685 additional_inputs: &[OutletId],
686 ) -> TractResult<Option<TypedModelPatch>> {
687 let succ = model.node(node.outputs[0].successors[0].node);
688 let mut new_op = self.clone();
689 let before_last = new_op.micro_ops.len() - 1..new_op.micro_ops.len() - 1;
690 new_op.micro_ops.splice(before_last, fused_micro_op);
691 new_op.c_fact = succ.outputs[0].fact.clone();
692 new_op.update_trivial_path();
693 let mut inputs = patch.taps(model, &node.inputs)?;
694 inputs.extend(additional_inputs.iter().cloned());
695 let output = patch.wire_node(&succ.name, new_op, &inputs)?;
696 patch.shunt_outside(model, succ.id.into(), output[0])?;
697 Ok(Some(patch))
698 }
699
700 fn fuse_binary(
701 &self,
702 model: &TypedModel,
703 node: &TypedNode,
704 mut patch: TypedModelPatch,
705 value: OutletId,
706 binop: BinOp,
707 ) -> TractResult<Option<TypedModelPatch>> {
708 let fact = model.outlet_fact(value)?;
709 let mut v = patch.tap_model(model, value)?;
710 if fact.datum_type != self.mmm[0].internal_type() {
711 v = patch.wire_node(
712 format!("{}.cast-input-{}", node.name, node.inputs.len()),
713 cast(self.mmm[0].internal_type()),
714 &[v],
715 )?[0];
716 }
717 let value = node.inputs.len();
718 let additional_input = tvec!(v);
719 if fact.shape.volume() == 1.to_dim() {
720 return self.fuse_op(
721 model,
722 node,
723 patch,
724 vec![ProtoFusedSpec::BinScalar(value, binop)],
725 &additional_input,
726 );
727 }
728 let other_shape = fact.shape.to_owned();
729 if self.c_m_axis.is_some_and(|ax| {
730 other_shape[ax] == self.c_fact.shape[ax] && other_shape[ax] == other_shape.volume()
731 }) {
732 return self.fuse_op(
733 model,
734 node,
735 patch,
736 vec![ProtoFusedSpec::BinPerRow(
737 value,
738 binop,
739 MapOutputAxisToInput(tvec!((self.c_m_axis.unwrap(), self.c_m_axis.unwrap()))),
740 )],
741 &additional_input,
742 );
743 }
744 if self.c_n_axis.is_some_and(|ax| {
745 other_shape[ax] == self.c_fact.shape[ax] && other_shape[ax] == other_shape.volume()
746 }) {
747 return self.fuse_op(
748 model,
749 node,
750 patch,
751 vec![ProtoFusedSpec::BinPerCol(
752 value,
753 binop,
754 MapOutputAxisToInput(tvec!((self.c_n_axis.unwrap(), self.c_n_axis.unwrap()))),
755 )],
756 &additional_input,
757 );
758 }
759 Ok(None)
760 }
761}