1use crate::internal::*;
2use crate::ops::cast::{Cast, cast};
3use crate::ops::change_axes::wire_with_rank_broadcast;
4use crate::ops::nn::LeakyRelu;
5use ndarray::*;
6use tract_itertools::Itertools;
7
8use tract_linalg::mmm::{
9 AsInputValue, EagerPackedInput, FusedSpec, MatMatMul, OutputStoreSpec, PackedMatrixStorage,
10 PanelExtractInput, PanelExtractor,
11};
12use tract_linalg::pack::PackedFormat;
13use tract_linalg::{BinOp, Scaler};
14use tract_smallvec::ToSmallVec;
15
16use super::ModePicker;
17
18#[derive(Clone, Debug, PartialEq, Eq)]
19pub enum ProtoFusedSpec {
20 AddMatMul {
21 geo: AddMatMulGeometry,
22 a: usize,
23 b: usize,
24 packings: Vec<(usize, Option<PanelExtractor>)>,
25 },
26 BinScalar(usize, BinOp),
27 LeakyRelu(usize),
28 BinPerRow(usize, BinOp, MapOutputAxisToInput),
29 BinPerCol(usize, BinOp, MapOutputAxisToInput),
30 AddRowColProducts(usize, usize),
31 AddUnicast(OutputStoreSpec, usize, MapOutputAxisToInput),
32 Scaler(Scaler),
33 Store(Vec<OutputStoreSpec>),
34}
35
36impl ProtoFusedSpec {
37 pub fn format(&self, mmm: &dyn MatMatMul, mode: usize) -> String {
38 use ProtoFusedSpec::*;
39 match self {
40 AddMatMul { geo, packings: packing, .. } => {
41 let (a, b) = &mmm.packings()[packing[mode].0];
42 format!("matmul(k={}, {a:?}•{b:?})", geo.k)
43 }
44 BinScalar(_, op) => format!("scalar{op:?}"),
45 LeakyRelu(alpha) => format!("leaky_relu({alpha:?})"),
46 BinPerRow(_, op, _) => format!("row{op:?}"),
47 BinPerCol(_, op, _) => format!("col{op:?}"),
48 AddRowColProducts(_, _) => "add_row_col_product".to_string(),
49 AddUnicast(_, _, _) => "add_to_matrix".to_string(),
50 Scaler(s) => format!("scale({})", 1f32 * *s),
51 Store(_oss) => "store".to_string(),
52 }
53 }
54
55 pub fn resolve<'t>(
56 &'t self,
57 inputs: &'t [TValue],
58 output_coords: &[usize],
59 output: &Tensor,
60 mmm: &dyn MatMatMul,
61 mode: usize,
62 ) -> FusedSpec<'t> {
63 #[allow(clippy::let_and_return)]
64 let fs = match self {
65 ProtoFusedSpec::AddMatMul { geo, a, b, packings } => {
66 let a_tensor = &inputs[*a];
67 let a_storage = a_tensor.try_storage_as::<PackedMatrixStorage>().unwrap();
68 let a_idx =
69 geo.c_to_a_axis_mapping.flat_index(output_coords, a_storage.batch_strides());
70 let a = a_storage.value_at_flat(a_idx);
71
72 let b_tensor = &inputs[*b];
73 let b_storage = b_tensor.try_storage_as::<PackedMatrixStorage>().unwrap();
74 let b_idx =
75 geo.c_to_b_axis_mapping.flat_index(output_coords, b_storage.batch_strides());
76 let b = b_storage.value_at_flat(b_idx);
77
78 let (_a_packing, b_packing) = &mmm.packings()[packings[mode].0];
79 let pa = if let Some(extractor) = &packings[mode].1 {
80 let data = a.downcast_ref::<EagerPackedInput>().unwrap();
81 AsInputValue::Owned(Box::new(PanelExtractInput {
82 format: extractor.clone(),
83 data: data.clone(),
84 }))
85 } else {
86 AsInputValue::Borrowed(a)
87 };
88 assert!(
89 b_packing.dyn_eq(b.format())
90 || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
91 );
92 debug_assert!(pa.k().to_dim().compatible_with(&geo.k.to_dim()));
93 debug_assert!(b.k().to_dim().compatible_with(&geo.k.to_dim()));
94 FusedSpec::AddMatMul {
95 a: pa,
96 b: AsInputValue::Borrowed(b),
97 packing: packings[mode].0,
98 }
99 }
100 ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
101 ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
102 ProtoFusedSpec::BinPerRow(v, op, map) => {
103 let mut v = inputs[*v].view();
104 unsafe { map.translate_view(output_coords, &mut v) }
105 FusedSpec::BinPerRow(v, *op)
106 }
107 ProtoFusedSpec::BinPerCol(v, op, map) => {
108 let mut v = inputs[*v].view();
109 unsafe { map.translate_view(output_coords, &mut v) }
110 FusedSpec::BinPerCol(v, *op)
111 }
112 ProtoFusedSpec::AddRowColProducts(row, col) => {
113 FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
114 }
115 ProtoFusedSpec::AddUnicast(store, v, map) => unsafe {
116 let mut view = inputs[*v].view();
117 map.translate_view(output_coords, &mut view);
118 FusedSpec::AddUnicast(store.wrap(&view))
119 },
120 ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
121 ProtoFusedSpec::Store(oss) => unsafe {
122 let view = output.view_offsetting_unchecked(output_coords);
123 FusedSpec::Store(oss[mode].wrap(&view))
124 },
125 };
126 fs
127 }
128
129 pub fn is_trivial(&self) -> bool {
130 match self {
131 ProtoFusedSpec::AddMatMul { geo, .. } => geo.k.as_i64().is_some(),
132 _ => true,
133 }
134 }
135
136 pub fn resolve_trivial<'t>(
137 &'t self,
138 inputs: &'t [TValue],
139 output: &mut Tensor,
140 _mmm: &dyn MatMatMul,
141 mode: usize,
142 ) -> FusedSpec<'t> {
143 #[allow(clippy::let_and_return)]
144 let fs = match self {
145 ProtoFusedSpec::AddMatMul { a, b, packings, .. } => unsafe {
146 debug_assert!(inputs.get(*a).is_some());
147 debug_assert!(inputs.get(*b).is_some());
148 let a = inputs.get_unchecked(*a);
149 let b = inputs.get_unchecked(*b);
150 debug_assert!(a.is_exotic());
151 debug_assert!(a.len() == 1);
152 debug_assert!(b.is_exotic());
153 debug_assert!(b.len() == 1);
154 let a_storage = a.try_storage_as::<PackedMatrixStorage>().unwrap_unchecked();
155 let b_storage = b.try_storage_as::<PackedMatrixStorage>().unwrap_unchecked();
156 let a = a_storage.value();
157 let b = b_storage.value();
158 debug_assert!(packings.len() == 1);
159 debug_assert!(packings[0].1.is_none()); #[cfg(debug_assertions)]
161 {
162 let (a_packing, b_packing) = &_mmm.packings()[packings[mode].0];
163 debug_assert!(
164 a_packing.dyn_eq(a.format())
165 || (a_packing.is::<PackedFormat>() && a_packing.r() == a.format().r())
166 );
167 debug_assert!(
168 b_packing.dyn_eq(b.format())
169 || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
170 );
171 }
172 FusedSpec::AddMatMul {
173 a: AsInputValue::Borrowed(a),
174 b: AsInputValue::Borrowed(b),
175 packing: packings[mode].0,
176 }
177 },
178 ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
179 ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
180 ProtoFusedSpec::BinPerRow(v, op, _) => {
181 let v = inputs[*v].view();
182 FusedSpec::BinPerRow(v, *op)
183 }
184 ProtoFusedSpec::BinPerCol(v, op, _) => {
185 let v = inputs[*v].view();
186 FusedSpec::BinPerCol(v, *op)
187 }
188 ProtoFusedSpec::AddRowColProducts(row, col) => {
189 FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
190 }
191 ProtoFusedSpec::AddUnicast(store, v, _) => unsafe {
192 let view = inputs[*v].view();
193 FusedSpec::AddUnicast(store.wrap(&view))
194 },
195 ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
196 ProtoFusedSpec::Store(oss) => unsafe {
197 FusedSpec::Store(oss[mode].wrap(&output.view_mut()))
198 },
199 };
200 fs
201 }
202
203 fn check_inputs(&self, inputs: &[&TypedFact]) -> TractResult<()> {
204 use ProtoFusedSpec::*;
205 match self {
206 AddMatMul { a, b, .. } => {
207 ensure!(inputs[*a].is_exotic());
208 ensure!(inputs[*b].is_exotic());
209 }
210 BinScalar(v, _)
211 | LeakyRelu(v)
212 | BinPerCol(v, _, _)
213 | BinPerRow(v, _, _)
214 | AddUnicast(_, v, _) => {
215 ensure!(inputs[*v].datum_type.is_number());
216 }
217 AddRowColProducts(row, col) => {
218 ensure!(inputs[*row].datum_type.is_number());
219 ensure!(inputs[*col].datum_type.is_number());
220 }
221 _ => (),
222 };
223 Ok(())
224 }
225
226 fn cost(&self, m: &TDim, n: &TDim, idt: DatumType) -> TVec<(Cost, TDim)> {
227 match self {
228 ProtoFusedSpec::AddMatMul { geo, .. } => {
229 tvec!((Cost::FMA(idt), m.clone() * n * &geo.k))
230 }
231 _ => tvec!(), }
233 }
234
235 fn rm_c_axis(&mut self, axis: usize) {
236 use ProtoFusedSpec::*;
237 match self {
238 AddMatMul { geo, .. } => {
239 geo.c_to_a_axis_mapping.rm_c_axis(axis);
240 geo.c_to_b_axis_mapping.rm_c_axis(axis);
241 }
242 BinScalar(..) | Scaler(..) | AddRowColProducts(_, _) | LeakyRelu(_) => {}
243 BinPerRow(_, _, map) | BinPerCol(_, _, map) => map.rm_c_axis(axis),
244 AddUnicast(_, _, map) => {
245 map.rm_c_axis(axis);
246 }
247 Store(oss, ..) => {
248 for oss in oss {
249 match oss {
250 OutputStoreSpec::View { m_axis, n_axis, .. } => {
251 if let Some(m) = m_axis {
252 *m -= (*m > axis) as usize
253 };
254 if let Some(n) = n_axis {
255 *n -= (*n > axis) as usize
256 }
257 }
258 OutputStoreSpec::Strides { .. } => {}
259 }
260 }
261 }
262 }
263 }
264}
265
266#[derive(Clone, Debug, PartialEq, Eq)]
267pub struct MapOutputAxisToInput(pub TVec<(usize, usize)>);
268
269impl MapOutputAxisToInput {
270 #[inline]
271 unsafe fn translate_view(&self, output_coords: &[usize], v: &mut TensorView) {
272 for &(out_axis, in_axis) in &self.0 {
273 unsafe { v.offset_axis(in_axis, output_coords[out_axis] as isize) }
274 }
275 }
276
277 #[inline]
278 fn rm_c_axis(&mut self, axis: usize) {
279 for (c, _) in &mut self.0 {
280 *c -= (*c > axis) as usize;
281 }
282 }
283
284 #[inline]
286 pub fn flat_index(&self, output_coords: &[usize], batch_strides: &[isize]) -> usize {
287 self.0
288 .iter()
289 .map(|&(out_axis, in_axis)| output_coords[out_axis] * batch_strides[in_axis] as usize)
290 .sum()
291 }
292}
293
294#[derive(Clone, Debug, PartialEq, Eq)]
295pub struct AddMatMulGeometry {
296 pub k: TDim,
297 pub c_to_a_axis_mapping: MapOutputAxisToInput,
298 pub c_to_b_axis_mapping: MapOutputAxisToInput,
299}
300
301#[derive(Clone, Debug, PartialEq, Eq)]
302pub struct OptMatMul {
303 pub c_fact: TypedFact,
304 pub micro_ops: Vec<ProtoFusedSpec>,
305 pub mmm: Vec<Box<dyn MatMatMul>>,
306 pub mode_picker: ModePicker,
307 pub c_m_axis: Option<usize>,
308 pub c_n_axis: Option<usize>,
309 pub trivial_packing: bool,
310 pub trivial_path: bool,
311}
312
313impl Op for OptMatMul {
314 fn name(&self) -> StaticName {
315 "OptMatMul".into()
316 }
317
318 fn info(&self) -> TractResult<Vec<String>> {
319 let m = self.c_m_axis.map(|ix| &self.c_fact.shape[ix]).unwrap_or(&TDim::Val(1));
320 let n = self.c_n_axis.map(|ix| &self.c_fact.shape[ix]).unwrap_or(&TDim::Val(1));
321 let mut infos = vec![format!(
322 "c_shape:{:?}, c_m_axis:{:?} c_n_axis:{:?} m:{} n:{}",
323 self.c_fact, self.c_m_axis, self.c_n_axis, m, n,
324 )];
325 if let Some(k) = self.guess_k() {
326 infos.push(format!("Mult: m:{} k:{} n:{} with {:?}", m, k, n, self.mmm));
327 } else {
328 infos.push(format!("Mult: {:?}", self.mmm));
329 }
330 for (mode, mmm) in self.mmm.iter().enumerate() {
331 infos.push(format!(
332 "Ops: {}",
333 self.micro_ops.iter().map(|o| o.format(&**mmm, mode)).join(" >>> ")
334 ));
335 }
336 Ok(infos)
337 }
338
339 op_as_typed_op!();
340}
341
342impl EvalOp for OptMatMul {
343 fn is_stateless(&self) -> bool {
344 true
345 }
346
347 fn eval_with_session(
348 &self,
349 _node_id: usize,
350 session: &TurnState,
351 inputs: TVec<TValue>,
352 ) -> TractResult<TVec<TValue>> {
353 unsafe {
354 let c_shape = self.c_fact.shape.eval_to_usize(&session.resolved_symbols)?;
355 let mut c = Tensor::uninitialized_dt(self.c_fact.datum_type, &c_shape)?;
356 let m = self.c_m_axis.map(|c_m| c.shape()[c_m]).unwrap_or(1);
357 let n = self.c_n_axis.map(|c_n| c.shape()[c_n]).unwrap_or(1);
358 let mode = self.mode_picker.pick(n)?;
359 let mmm = &*self.mmm[mode];
360 let mut cell = session.cached_mmm_scratch_space.borrow_mut();
361 if !cell.as_ref().is_some_and(|scratch| mmm.can_use_scratch_space(&**scratch)) {
362 *cell = None
363 }
364 let scratch = cell.get_or_insert_with(|| mmm.allocate_scratch_space());
365 if self.trivial_path {
366 let uops: Vec<FusedSpec> = self
367 .micro_ops
368 .iter()
369 .map(|o| o.resolve_trivial(&inputs, &mut c, mmm, mode))
370 .collect();
371 mmm.run_with_scratch_space(m, n, scratch.as_mut(), &uops)?;
372 Ok(tvec!(c.into_tvalue()))
373 } else {
374 let mut uops = vec![FusedSpec::ShiftLeft(0); self.micro_ops.len()];
375 let mut looping_shape: TVec<usize> = c_shape.to_smallvec();
376 if let Some(ax) = self.c_m_axis {
377 looping_shape[ax] = 1;
378 }
379 if let Some(ax) = self.c_n_axis {
380 looping_shape[ax] = 1;
381 }
382 for c_coords in indices(&*looping_shape) {
383 for ix in 0..self.micro_ops.len() {
384 *uops.get_unchecked_mut(ix) = self.micro_ops.get_unchecked(ix).resolve(
385 &inputs,
386 c_coords.slice(),
387 &c,
388 mmm,
389 mode,
390 );
391 }
392 mmm.run_with_scratch_space(m, n, scratch.as_mut(), &uops)
393 .context("In mmm.run_with_scratch_space")?;
394 }
395 Ok(tvec!(c.into_tvalue()))
396 }
397 }
398 }
399}
400
401impl TypedOp for OptMatMul {
402 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
403 ensure!(self.c_m_axis.map(|ax| ax < self.c_fact.rank()).unwrap_or(true));
404 ensure!(self.c_n_axis.map(|ax| ax < self.c_fact.rank()).unwrap_or(true));
405 ensure!(self.trivial_path == self.can_use_trivial_path());
406 ensure!(self.mmm.iter().map(|mmm| mmm.internal_type()).all_equal());
407 for op in &self.micro_ops {
408 op.check_inputs(inputs)?;
409 }
410 Ok(tvec!(self.c_fact.clone()))
411 }
412
413 fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
414 let mut sums = HashMap::new();
415 for op in &self.micro_ops {
416 for (cost, count) in op.cost(self.m(), self.n(), self.mmm[0].internal_type()) {
417 *sums.entry(cost).or_default() += count;
418 }
419 }
420 let loops = self
421 .c_fact
422 .shape
423 .iter()
424 .enumerate()
425 .map(|(ix, d)| {
426 if Some(ix) == self.c_m_axis || Some(ix) == self.c_n_axis {
427 1.to_dim()
428 } else {
429 d.clone()
430 }
431 })
432 .product::<TDim>();
433 for s in &mut sums.values_mut() {
434 *s *= &loops;
435 }
436 Ok(sums.into_iter().collect())
437 }
438
439 fn fuse(&self, model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
440 use crate::ops;
441 rule_if!(node.outputs.len() == 1);
442 rule_if!(node.outputs[0].successors.len() == 1);
443 rule_if!(!model.output_outlets()?.contains(&node.id.into()));
444 let succ = model.node(node.outputs[0].successors[0].node);
445 let mut patch = TypedModelPatch::new(format!("fusing {succ}"));
446
447 if let Some(op) = succ.op_as::<ops::binary::TypedBinOp>() {
448 rule_if_some!(mut binop = op.0.as_linalg_binop());
449 let flipped = succ.inputs[0].node == node.id;
450 if flipped {
451 binop = binop.flip();
452 }
453 let other_outlet = succ.inputs[flipped as usize];
454 return self.fuse_binary(model, node, patch, other_outlet, binop);
455 }
456 if let Some(op) = succ.op_as::<ops::binary::OptBinByScalar>() {
457 rule_if_some!(mut binop = op.binop.as_linalg_binop());
458 let flipped = succ.inputs[0].node == node.id;
459 if flipped {
460 binop = binop.flip();
461 }
462 let other_outlet = succ.inputs[flipped as usize];
463 return self.fuse_binary(model, node, patch, other_outlet, binop);
464 }
465
466 if let Some(op) = succ.op_as::<ops::element_wise::ElementWiseOp>().map(|ew| ew.0.as_ref()) {
467 if let Some(op) = op.downcast_ref::<ops::math::QScale>() {
468 return self.fuse_op(
469 model,
470 node,
471 patch,
472 vec![ProtoFusedSpec::Scaler(op.scaler)],
473 &[],
474 );
475 }
476 if let Some(op) = op.downcast_ref::<LeakyRelu>() {
477 rule_if!(
478 self.mmm
479 .iter()
480 .all(|mmm| mmm.can_fuse(&FusedSpec::LeakyRelu(&tensor0(op.alpha))))
481 );
482 let alpha = patch.add_const(
483 node.name.to_string() + ".alpha",
484 tensor0(op.alpha).cast_to_dt(self.mmm[0].internal_type())?.into_owned(),
485 )?;
486 return self.fuse_op(
487 model,
488 node,
489 patch,
490 vec![ProtoFusedSpec::LeakyRelu(node.inputs.len())],
491 &[alpha],
492 );
493 }
494 }
495 if let Some(cast_to) = succ.op_as::<ops::cast::Cast>().map(|cast| cast.to)
496 && (((cast_to.unquantized() == i8::datum_type()
497 || cast_to.unquantized() == u8::datum_type())
498 && self.c_fact.datum_type == i32::datum_type())
499 || self.mmm.iter().all(|m| m.stores().contains(&cast_to)))
500 && let Some(ProtoFusedSpec::Store(stores)) = self.micro_ops.last()
501 {
502 if stores.iter().any(|s| matches!(s, OutputStoreSpec::Strides { .. })) {
503 return Ok(None);
504 }
505 let c_fact = cast_to.fact(self.c_fact.shape.clone());
506 let mut patch =
507 TypedModelPatch::fuse_with_next(model, node, Self { c_fact, ..self.clone() })?;
508 patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
509 return Ok(Some(patch));
510 }
511 if let Some(AxisOp::Rm(axis)) = succ.op_as::<ops::AxisOp>() {
512 rule_if!(Some(*axis) != self.c_m_axis);
513 rule_if!(Some(*axis) != self.c_n_axis);
514 let mut new_op = self.clone();
515 new_op.c_fact.shape.remove_axis(*axis)?;
516 if let Some(c_m_axis) = &mut new_op.c_m_axis {
517 *c_m_axis -= (*c_m_axis > *axis) as usize;
518 }
519 if let Some(c_n_axis) = &mut new_op.c_n_axis {
520 *c_n_axis -= (*c_n_axis > *axis) as usize;
521 }
522 for uop in &mut new_op.micro_ops {
523 uop.rm_c_axis(*axis);
524 }
525 let mut patch = TypedModelPatch::fuse_with_next(model, node, new_op)?;
526 patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
527 return Ok(Some(patch));
528 }
529 if (succ.op_is::<AxisOp>() || succ.op_is::<IntoShape>())
530 && let &[next] = &*succ.outputs[0].successors
531 {
532 let next_node = model.node(next.node);
533 if let Some(cast) = next_node.op_as::<Cast>() {
534 let mut patch = TypedModelPatch::default();
535 let mut wire = patch.tap_model(model, node.id.into())?;
536 wire = patch.wire_node(&next_node.name, cast.clone(), &[wire])?[0];
537 wire = patch.wire_node(&succ.name, succ.op.clone(), &[wire])?[0];
538 patch.shunt_outside(model, next_node.id.into(), wire)?;
539 return Ok(Some(patch));
540 } else if let Some(op) = next_node.op_as::<ops::binary::TypedBinOp>() {
541 rule_if!(op.0.as_linalg_binop().is_some());
542 let flipped = succ.inputs[0].node == node.id;
543 let other_outlet = next_node.inputs[flipped as usize];
544 if let Some(uni) = &model.outlet_fact(other_outlet)?.uniform {
545 let mut patch = TypedModelPatch::default();
546 let cst = patch.add_const(&model.node(other_outlet.node).name, uni.clone())?;
547 let output = patch.tap_model(model, node.id.into())?;
548 let wire = wire_with_rank_broadcast(
549 &next_node.name,
550 &mut patch,
551 op.clone(),
552 &if flipped { [output, cst] } else { [cst, output] },
553 )?;
554 let wire = patch.wire_node(&succ.name, succ.op.clone(), &wire)?[0];
555 patch.shunt_outside(model, next_node.id.into(), wire)?;
556 return Ok(Some(patch));
557 }
558 }
559 }
560 if let Some(op) = succ.op_as::<ops::binary::OptBinUnicast>() {
561 let in_1_fact = model.outlet_fact(succ.inputs[0])?;
562 let in_2_fact = model.outlet_fact(succ.inputs[1])?;
563 if op.binop.is::<ops::math::Add>()
564 && self.mmm.len() == 1
565 && in_1_fact.without_value() == in_2_fact.without_value()
566 {
567 let other_slot = 1 - node.outputs[0].successors[0].slot;
568 let other_input = succ.inputs[other_slot];
569 let other_input = patch.tap_model(model, other_input)?;
570 let other_fact = patch.outlet_fact(other_input)?;
571
572 if other_fact.shape == self.c_fact.shape {
573 let other_storage = unsafe { self.mmm[0].c_view(self.c_m_axis, self.c_n_axis) };
574 let mapping =
575 MapOutputAxisToInput((0..other_fact.rank()).map(|x| (x, x)).collect());
576 return self.fuse_op(
577 model,
578 node,
579 patch,
580 vec![ProtoFusedSpec::AddUnicast(other_storage, node.inputs.len(), mapping)],
581 &[other_input],
582 );
583 }
584 } else {
585 rule_if_some!(mut binop = op.binop.as_linalg_binop());
586 let flipped = succ.inputs[0].node == node.id;
587 if flipped {
588 binop = binop.flip();
589 }
590 let other_outlet = succ.inputs[flipped as usize];
591 return self.fuse_binary(model, node, patch, other_outlet, binop);
592 }
593 };
594 Ok(None)
595 }
596
597 as_op!();
598}
599
600impl OptMatMul {
601 pub fn new(
602 mmm: Vec<Box<dyn MatMatMul>>,
603 mode_picker: ModePicker,
604 c_fact: TypedFact,
605 c_m_axis: Option<usize>,
606 c_n_axis: Option<usize>,
607 micro_ops: Vec<ProtoFusedSpec>,
608 trivial_packing: bool,
609 ) -> TractResult<Self> {
610 if let Some(m) = c_m_axis {
611 ensure!(m < c_fact.rank());
612 }
613 if let Some(n) = c_n_axis {
614 ensure!(n < c_fact.rank());
615 }
616 let mut it = OptMatMul {
617 mmm,
618 mode_picker,
619 c_fact,
620 c_m_axis,
621 c_n_axis,
622 micro_ops,
623 trivial_path: false,
624 trivial_packing,
625 };
626 it.update_trivial_path();
627 Ok(it)
628 }
629
630 pub fn guess_k(&self) -> Option<TDim> {
632 self.micro_ops
633 .iter()
634 .find_map(
635 |o| {
636 if let ProtoFusedSpec::AddMatMul { geo, .. } = o { Some(geo) } else { None }
637 },
638 )
639 .map(|geo| geo.k.clone())
640 }
641
642 #[inline]
643 pub fn m(&self) -> &TDim {
644 self.c_m_axis.map(|ax| &self.c_fact.shape[ax]).unwrap_or(&TDim::Val(1))
645 }
646
647 #[inline]
648 pub fn n(&self) -> &TDim {
649 self.c_n_axis.map(|ax| &self.c_fact.shape[ax]).unwrap_or(&TDim::Val(1))
650 }
651
652 fn update_trivial_path(&mut self) {
653 self.trivial_path = self.can_use_trivial_path();
654 }
655
656 fn can_use_trivial_path(&self) -> bool {
657 self.c_fact.shape.is_concrete()
658 && self.c_fact.shape.iter().enumerate().all(|(ax, dim)| {
659 Some(ax) == self.c_m_axis || Some(ax) == self.c_n_axis || dim.is_one()
660 })
661 && self.trivial_packing
662 && self.micro_ops.iter().all(|o| o.is_trivial())
663 }
664
665 fn fuse_op(
666 &self,
667 model: &TypedModel,
668 node: &TypedNode,
669 mut patch: TypedModelPatch,
670 fused_micro_op: Vec<ProtoFusedSpec>,
671 additional_inputs: &[OutletId],
672 ) -> TractResult<Option<TypedModelPatch>> {
673 let succ = model.node(node.outputs[0].successors[0].node);
674 let mut new_op = self.clone();
675 let before_last = new_op.micro_ops.len() - 1..new_op.micro_ops.len() - 1;
676 new_op.micro_ops.splice(before_last, fused_micro_op);
677 new_op.c_fact = succ.outputs[0].fact.clone();
678 new_op.update_trivial_path();
679 let mut inputs = patch.taps(model, &node.inputs)?;
680 inputs.extend(additional_inputs.iter().cloned());
681 let output = patch.wire_node(&succ.name, new_op, &inputs)?;
682 patch.shunt_outside(model, succ.id.into(), output[0])?;
683 Ok(Some(patch))
684 }
685
686 fn fuse_binary(
687 &self,
688 model: &TypedModel,
689 node: &TypedNode,
690 mut patch: TypedModelPatch,
691 value: OutletId,
692 binop: BinOp,
693 ) -> TractResult<Option<TypedModelPatch>> {
694 let fact = model.outlet_fact(value)?;
695 let mut v = patch.tap_model(model, value)?;
696 if fact.datum_type != self.mmm[0].internal_type() {
697 v = patch.wire_node(
698 format!("{}.cast-input-{}", node.name, node.inputs.len()),
699 cast(self.mmm[0].internal_type()),
700 &[v],
701 )?[0];
702 }
703 let value = node.inputs.len();
704 let additional_input = tvec!(v);
705 if fact.shape.volume() == 1.to_dim() {
706 return self.fuse_op(
707 model,
708 node,
709 patch,
710 vec![ProtoFusedSpec::BinScalar(value, binop)],
711 &additional_input,
712 );
713 }
714 let other_shape = fact.shape.to_owned();
715 if self.c_m_axis.is_some_and(|ax| {
716 other_shape[ax] == self.c_fact.shape[ax] && other_shape[ax] == other_shape.volume()
717 }) {
718 return self.fuse_op(
719 model,
720 node,
721 patch,
722 vec![ProtoFusedSpec::BinPerRow(
723 value,
724 binop,
725 MapOutputAxisToInput(tvec!((self.c_m_axis.unwrap(), self.c_m_axis.unwrap()))),
726 )],
727 &additional_input,
728 );
729 }
730 if self.c_n_axis.is_some_and(|ax| {
731 other_shape[ax] == self.c_fact.shape[ax] && other_shape[ax] == other_shape.volume()
732 }) {
733 return self.fuse_op(
734 model,
735 node,
736 patch,
737 vec![ProtoFusedSpec::BinPerCol(
738 value,
739 binop,
740 MapOutputAxisToInput(tvec!((self.c_n_axis.unwrap(), self.c_n_axis.unwrap()))),
741 )],
742 &additional_input,
743 );
744 }
745 Ok(None)
746 }
747}