1use crate::internal::*;
2use crate::ops::cast::cast;
3use crate::ops::change_axes::wire_with_rank_broadcast;
4use crate::ops::nn::LeakyRelu;
5use ndarray::*;
6use tract_itertools::Itertools;
7
8use tract_linalg::mmm::{
9 AsInputValue, EagerPackedInput, FusedSpec, MMMInputValue, MatMatMul, OutputStoreSpec,
10 PanelExtractInput, PanelExtractor,
11};
12use tract_linalg::pack::PackedFormat;
13use tract_linalg::{BinOp, Scaler};
14use tract_smallvec::ToSmallVec;
15
16use super::ModePicker;
17
18#[derive(Clone, Debug)]
19pub enum ProtoFusedSpec {
20 AddMatMul {
21 geo: AddMatMulGeometry,
22 a: usize,
23 b: usize,
24 packings: Vec<(usize, Option<PanelExtractor>)>,
25 },
26 BinScalar(usize, BinOp),
27 LeakyRelu(usize),
28 BinPerRow(usize, BinOp, MapOutputAxisToInput),
29 BinPerCol(usize, BinOp, MapOutputAxisToInput),
30 AddRowColProducts(usize, usize),
31 AddUnicast(OutputStoreSpec, usize, MapOutputAxisToInput),
32 Scaler(Scaler),
33 Store(Vec<OutputStoreSpec>),
34}
35
36impl ProtoFusedSpec {
37 pub fn format(&self, mmm: &dyn MatMatMul, mode: usize) -> String {
38 use ProtoFusedSpec::*;
39 match self {
40 AddMatMul { geo, packings: packing, .. } => {
41 let (a, b) = &mmm.packings()[packing[mode].0];
42 format!("matmul(k={}, {a:?}•{b:?})", geo.k)
43 }
44 BinScalar(_, op) => format!("scalar{op:?}"),
45 LeakyRelu(alpha) => format!("leaky_relu({alpha:?})"),
46 BinPerRow(_, op, _) => format!("row{op:?}"),
47 BinPerCol(_, op, _) => format!("col{op:?}"),
48 AddRowColProducts(_, _) => "add_row_col_product".to_string(),
49 AddUnicast(_, _, _) => "add_to_matrix".to_string(),
50 Scaler(s) => format!("scale({})", 1f32 * *s),
51 Store(_oss) => "store".to_string(),
52 }
53 }
54
55 pub fn resolve<'t>(
56 &'t self,
57 inputs: &'t [TValue],
58 output_coords: &[usize],
59 output: &Tensor,
60 mmm: &dyn MatMatMul,
61 mode: usize,
62 ) -> FusedSpec<'t> {
63 let fs = match self {
64 ProtoFusedSpec::AddMatMul { geo, a, b, packings } => {
65 let mut a = inputs[*a].view();
66 let mut b = inputs[*b].view();
67 unsafe {
68 geo.c_to_a_axis_mapping.translate_view(output_coords, &mut a);
69 }
70 let a = a.as_slice::<Opaque>().unwrap()[0]
71 .downcast_ref::<Box<dyn MMMInputValue>>()
72 .unwrap();
73 unsafe {
74 geo.c_to_b_axis_mapping.translate_view(output_coords, &mut b);
75 }
76 let b = b.as_slice::<Opaque>().unwrap()[0]
77 .downcast_ref::<Box<dyn MMMInputValue>>()
78 .unwrap();
79 let (_a_packing, b_packing) = &mmm.packings()[packings[mode].0];
80 let pa = if let Some(extractor) = &packings[mode].1 {
81 let data = a.downcast_ref::<EagerPackedInput>().unwrap();
82 AsInputValue::Owned(Box::new(PanelExtractInput {
83 format: extractor.clone(),
84 data: data.clone(),
85 }))
86 } else {
87 AsInputValue::Borrowed(&**a)
88 };
89 assert!(
90 b_packing.same_as(b.format())
91 || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
92 );
93 debug_assert!(pa.k().to_dim().compatible_with(&geo.k.to_dim()));
94 debug_assert!(b.k().to_dim().compatible_with(&geo.k.to_dim()));
95 FusedSpec::AddMatMul {
96 a: pa,
97 b: AsInputValue::Borrowed(&**b),
98 packing: packings[mode].0,
99 }
100 }
101 ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
102 ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
103 ProtoFusedSpec::BinPerRow(v, op, map) => {
104 let mut v = inputs[*v].view();
105 unsafe { map.translate_view(output_coords, &mut v) }
106 FusedSpec::BinPerRow(v, *op)
107 }
108 ProtoFusedSpec::BinPerCol(v, op, map) => {
109 let mut v = inputs[*v].view();
110 unsafe { map.translate_view(output_coords, &mut v) }
111 FusedSpec::BinPerCol(v, *op)
112 }
113 ProtoFusedSpec::AddRowColProducts(row, col) => {
114 FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
115 }
116 ProtoFusedSpec::AddUnicast(store, v, map) => unsafe {
117 let mut view = inputs[*v].view();
118 map.translate_view(output_coords, &mut view);
119 FusedSpec::AddUnicast(store.wrap(&view))
120 },
121 ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
122 ProtoFusedSpec::Store(oss) => unsafe {
123 let view = output.view_offsetting_unchecked(output_coords);
124 FusedSpec::Store(oss[mode].wrap(&view))
125 },
126 };
127 fs
128 }
129
130 pub fn is_trivial(&self) -> bool {
131 match self {
132 ProtoFusedSpec::AddMatMul { geo, .. } => geo.k.as_i64().is_some(),
133 _ => true,
134 }
135 }
136
137 pub fn resolve_trivial<'t>(
138 &'t self,
139 inputs: &'t [TValue],
140 output: &mut Tensor,
141 _mmm: &dyn MatMatMul,
142 mode: usize,
143 ) -> FusedSpec<'t> {
144 let fs = match self {
145 ProtoFusedSpec::AddMatMul { a, b, packings, .. } => unsafe {
146 debug_assert!(inputs.get(*a).is_some());
147 debug_assert!(inputs.get(*b).is_some());
148 let a = inputs.get_unchecked(*a);
149 let b = inputs.get_unchecked(*b);
150 debug_assert!(a.datum_type().is_opaque());
151 debug_assert!(a.len() == 1);
152 debug_assert!(b.datum_type().is_opaque());
153 debug_assert!(b.len() == 1);
154 let a = a.as_slice_unchecked::<Opaque>().get_unchecked(0);
155 let b = b.as_slice_unchecked::<Opaque>().get_unchecked(0);
156 debug_assert!(a.is::<Box<dyn MMMInputValue>>());
157 debug_assert!(b.is::<Box<dyn MMMInputValue>>());
158 let a = a.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
159 let b = b.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
160 #[cfg(debug_assertions)]
161 {
162 let (a_packing, b_packing) = &_mmm.packings()[packings[mode].0];
163 debug_assert!(
164 a_packing.same_as(a.format())
165 || (a_packing.is::<PackedFormat>() && a_packing.r() == a.format().r())
166 );
167 debug_assert!(
168 b_packing.same_as(b.format())
169 || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
170 );
171 }
172 FusedSpec::AddMatMul {
173 a: AsInputValue::Borrowed(&**a),
174 b: AsInputValue::Borrowed(&**b),
175 packing: packings[mode].0,
176 }
177 },
178 ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
179 ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
180 ProtoFusedSpec::BinPerRow(v, op, _) => {
181 let v = inputs[*v].view();
182 FusedSpec::BinPerRow(v, *op)
183 }
184 ProtoFusedSpec::BinPerCol(v, op, _) => {
185 let v = inputs[*v].view();
186 FusedSpec::BinPerCol(v, *op)
187 }
188 ProtoFusedSpec::AddRowColProducts(row, col) => {
189 FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
190 }
191 ProtoFusedSpec::AddUnicast(store, v, _) => unsafe {
192 let view = inputs[*v].view();
193 FusedSpec::AddUnicast(store.wrap(&view))
194 },
195 ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
196 ProtoFusedSpec::Store(oss) => unsafe {
197 FusedSpec::Store(oss[mode].wrap(&output.view_mut()))
198 },
199 };
200 fs
201 }
202
203 fn check_inputs(&self, inputs: &[&TypedFact]) -> TractResult<()> {
204 use ProtoFusedSpec::*;
205 match self {
206 AddMatMul { a, b, .. } => {
207 ensure!(inputs[*a].datum_type == Opaque::datum_type());
208 ensure!(inputs[*b].datum_type == Opaque::datum_type());
209 }
210 BinScalar(v, _)
211 | LeakyRelu(v)
212 | BinPerCol(v, _, _)
213 | BinPerRow(v, _, _)
214 | AddUnicast(_, v, _) => {
215 ensure!(inputs[*v].datum_type.is_number());
216 }
217 AddRowColProducts(row, col) => {
218 ensure!(inputs[*row].datum_type.is_number());
219 ensure!(inputs[*col].datum_type.is_number());
220 }
221 _ => (),
222 };
223 Ok(())
224 }
225
226 fn cost(&self, m: &TDim, n: &TDim, idt: DatumType) -> TVec<(Cost, TDim)> {
227 match self {
228 ProtoFusedSpec::AddMatMul { geo, .. } => {
229 tvec!((Cost::FMA(idt), m.clone() * n * &geo.k))
230 }
231 _ => tvec!(), }
233 }
234
235 fn rm_c_axis(&mut self, axis: usize) {
236 use ProtoFusedSpec::*;
237 match self {
238 AddMatMul { geo, .. } => {
239 geo.c_to_a_axis_mapping.rm_c_axis(axis);
240 geo.c_to_b_axis_mapping.rm_c_axis(axis);
241 }
242 BinScalar(..) | Scaler(..) | AddRowColProducts(_, _) | LeakyRelu(_) => {}
243 BinPerRow(_, _, map) | BinPerCol(_, _, map) => map.rm_c_axis(axis),
244 AddUnicast(_, _, map) => {
245 map.rm_c_axis(axis);
246 }
247 Store(oss, ..) => {
248 for oss in oss {
249 match oss {
250 OutputStoreSpec::View { m_axis, n_axis, .. } => {
251 *m_axis -= (*m_axis > axis) as usize;
252 *n_axis -= (*n_axis > axis) as usize;
253 }
254 OutputStoreSpec::Strides { .. } => {}
255 }
256 }
257 }
258 }
259 }
260}
261
262#[derive(Clone, Debug)]
263pub struct MapOutputAxisToInput(pub TVec<(usize, usize)>);
264
265impl MapOutputAxisToInput {
266 #[inline]
267 unsafe fn translate_view(&self, output_coords: &[usize], v: &mut TensorView) {
268 for &(out_axis, in_axis) in &self.0 {
269 v.offset_axis(in_axis, output_coords[out_axis] as isize)
270 }
271 }
272
273 #[inline]
274 fn rm_c_axis(&mut self, axis: usize) {
275 for (c, _) in &mut self.0 {
276 *c -= (*c > axis) as usize;
277 }
278 }
279}
280
281#[derive(Clone, Debug)]
282pub struct AddMatMulGeometry {
283 pub k: TDim,
284 pub c_to_a_axis_mapping: MapOutputAxisToInput,
285 pub c_to_b_axis_mapping: MapOutputAxisToInput,
286}
287
288#[derive(Clone, Debug)]
289pub struct OptMatMul {
290 pub c_fact: TypedFact,
291 pub micro_ops: Vec<ProtoFusedSpec>,
292 pub mmm: Vec<Box<dyn MatMatMul>>,
293 pub mode_picker: ModePicker,
294 pub c_m_axis: usize,
295 pub c_n_axis: usize,
296 pub trivial_packing: bool,
297 pub trivial_path: bool,
298}
299
300impl Op for OptMatMul {
301 fn name(&self) -> Cow<str> {
302 "OptMatMul".into()
303 }
304
305 fn info(&self) -> TractResult<Vec<String>> {
306 let m = &self.c_fact.shape[self.c_m_axis];
307 let n = &self.c_fact.shape[self.c_n_axis];
308 let mut infos = vec![format!(
309 "c_shape:{:?}, c_m_axis:{} c_n_axis:{} m:{} n:{}",
310 self.c_fact, self.c_m_axis, self.c_n_axis, m, n,
311 )];
312 if let Some(k) = self.guess_k() {
313 infos.push(format!("Mult: m:{} k:{} n:{} with {:?}", m, k, n, self.mmm));
314 } else {
315 infos.push(format!("Mult: {:?}", self.mmm));
316 }
317 for (mode, mmm) in self.mmm.iter().enumerate() {
318 infos.push(format!(
319 "Ops: {}",
320 self.micro_ops.iter().map(|o| o.format(&**mmm, mode)).join(" >>> ")
321 ));
322 }
323 Ok(infos)
324 }
325
326 op_as_typed_op!();
327}
328
329impl EvalOp for OptMatMul {
330 fn is_stateless(&self) -> bool {
331 true
332 }
333
334 fn eval_with_session(
335 &self,
336 session: &SessionState,
337 inputs: TVec<TValue>,
338 ) -> TractResult<TVec<TValue>> {
339 unsafe {
340 let c_shape = self.c_fact.shape.eval_to_usize(&session.resolved_symbols)?;
341 let mut c = Tensor::uninitialized_dt(self.c_fact.datum_type, &c_shape)?;
342 let mode = self.mode_picker.pick(c_shape[self.c_n_axis])?;
343 let mmm = &*self.mmm[mode];
344 let mut cell = session.cached_mmm_scratch_space.borrow_mut();
345 if !cell.as_ref().is_some_and(|scratch| mmm.can_use_scratch_space(&**scratch)) {
346 *cell = None
347 }
348 let scratch = cell.get_or_insert_with(|| mmm.allocate_scratch_space());
349
350 if self.trivial_path {
351 let uops: Vec<FusedSpec> = self
352 .micro_ops
353 .iter()
354 .map(|o| o.resolve_trivial(&inputs, &mut c, mmm, mode))
355 .collect();
356 mmm.run_with_scratch_space(
357 *c_shape.get_unchecked(self.c_m_axis),
358 *c_shape.get_unchecked(self.c_n_axis),
359 scratch.as_mut(),
360 &uops,
361 )?;
362 Ok(tvec!(c.into_tvalue()))
363 } else {
364 let mut uops = vec![FusedSpec::ShiftLeft(0); self.micro_ops.len()];
365 let mut looping_shape: TVec<usize> = c_shape.to_smallvec();
366 looping_shape[self.c_m_axis] = 1;
367 looping_shape[self.c_n_axis] = 1;
368 let m = c_shape[self.c_m_axis];
369 let n = c_shape[self.c_n_axis];
370 for c_coords in indices(&*looping_shape) {
371 for ix in 0..self.micro_ops.len() {
372 *uops.get_unchecked_mut(ix) = self.micro_ops.get_unchecked(ix).resolve(
373 &inputs,
374 c_coords.slice(),
375 &c,
376 mmm,
377 mode,
378 );
379 }
380 mmm.run_with_scratch_space(m, n, scratch.as_mut(), &uops)
381 .context("In mmm.run_with_scratch_space")?;
382 }
383 Ok(tvec!(c.into_tvalue()))
384 }
385 }
386 }
387}
388
389impl TypedOp for OptMatMul {
390 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
391 ensure!(self.c_m_axis < self.c_fact.rank());
392 ensure!(self.c_n_axis < self.c_fact.rank());
393 ensure!(self.trivial_path == self.can_use_trivial_path());
394 ensure!(self.mmm.iter().map(|mmm| mmm.internal_type()).all_equal());
395 for op in &self.micro_ops {
396 op.check_inputs(inputs)?;
397 }
398 Ok(tvec!(self.c_fact.clone()))
399 }
400
401 fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
402 let mut sums = HashMap::new();
403 let m = &self.c_fact.shape[self.c_m_axis];
404 let n = &self.c_fact.shape[self.c_n_axis];
405 for op in &self.micro_ops {
406 for (cost, count) in op.cost(m, n, self.mmm[0].internal_type()) {
407 *sums.entry(cost).or_default() += count;
408 }
409 }
410 let loops =
411 self.c_fact
412 .shape
413 .iter()
414 .enumerate()
415 .map(|(ix, d)| {
416 if ix == self.c_m_axis || ix == self.c_n_axis {
417 1.to_dim()
418 } else {
419 d.clone()
420 }
421 })
422 .product::<TDim>();
423 for s in &mut sums.values_mut() {
424 *s *= &loops;
425 }
426 Ok(sums.into_iter().collect())
427 }
428
429 fn fuse(&self, model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
430 use crate::ops;
431 if node.outputs.len() != 1
432 || node.outputs[0].successors.len() != 1
433 || model.output_outlets()?.contains(&node.id.into())
434 {
435 return Ok(None);
436 }
437 let succ = model.node(node.outputs[0].successors[0].node);
438 let mut patch = TypedModelPatch::new(format!("fusing {succ}"));
439
440 if let Some(op) = succ.op_as::<ops::binary::TypedBinOp>() {
441 let mut binop = if let Some(op) = op.0.as_linalg_binop() {
442 op
443 } else {
444 return Ok(None);
445 };
446 let flipped = succ.inputs[0].node == node.id;
447 if flipped {
448 binop = binop.flip();
449 }
450 let other_outlet = succ.inputs[flipped as usize];
451 return self.fuse_binary(model, node, patch, other_outlet, binop);
452 }
453 if let Some(op) = succ.op_as::<ops::binary::OptBinByScalar>() {
454 let mut binop = if let Some(op) = op.binop.as_linalg_binop() {
455 op
456 } else {
457 return Ok(None);
458 };
459 let flipped = succ.inputs[0].node == node.id;
460 if flipped {
461 binop = binop.flip();
462 }
463 let other_outlet = succ.inputs[flipped as usize];
464 return self.fuse_binary(model, node, patch, other_outlet, binop);
465 }
466
467 if let Some(op) = succ.op_as::<ops::element_wise::ElementWiseOp>().map(|ew| ew.0.as_ref()) {
468 if let Some(op) = op.downcast_ref::<ops::math::QScale>() {
469 return self.fuse_op(
470 model,
471 node,
472 patch,
473 vec![ProtoFusedSpec::Scaler(op.scaler)],
474 &[],
475 );
476 }
477 if let Some(op) = op.downcast_ref::<LeakyRelu>() {
478 if !self
479 .mmm
480 .iter()
481 .all(|mmm| mmm.can_fuse(&FusedSpec::LeakyRelu(&tensor0(op.alpha))))
482 {
483 return Ok(None);
484 }
485 let alpha = patch.add_const(
486 node.name.to_string() + ".alpha",
487 tensor0(op.alpha).cast_to_dt(self.mmm[0].internal_type())?.into_owned(),
488 )?;
489 return self.fuse_op(
490 model,
491 node,
492 patch,
493 vec![ProtoFusedSpec::LeakyRelu(node.inputs.len())],
494 &[alpha],
495 );
496 }
497 }
498 if let Some(cast_to) = succ.op_as::<ops::cast::Cast>().map(|cast| cast.to) {
499 if (cast_to.unquantized() == i8::datum_type()
500 || cast_to.unquantized() == u8::datum_type())
501 && self.c_fact.datum_type == i32::datum_type()
502 {
503 if let Some(ProtoFusedSpec::Store(stores)) = self.micro_ops.last() {
504 if stores.iter().any(|s| matches!(s, OutputStoreSpec::Strides { .. })) {
505 return Ok(None);
506 }
507 let c_fact = cast_to.fact(self.c_fact.shape.clone());
508 let mut patch = TypedModelPatch::fuse_with_next(
509 model,
510 node,
511 Self { c_fact, ..self.clone() },
512 )?;
513 patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
514 return Ok(Some(patch));
515 }
516 }
517 }
518 if let Some(AxisOp::Rm(axis)) = succ.op_as::<ops::AxisOp>() {
519 if *axis == self.c_m_axis || *axis == self.c_n_axis {
520 return Ok(None);
521 }
522 let mut new_op = self.clone();
523 new_op.c_fact.shape.remove_axis(*axis)?;
524 new_op.c_m_axis -= (new_op.c_m_axis > *axis) as usize;
525 new_op.c_n_axis -= (new_op.c_n_axis > *axis) as usize;
526 for uop in &mut new_op.micro_ops {
527 uop.rm_c_axis(*axis);
528 }
529 let mut patch = TypedModelPatch::fuse_with_next(model, node, new_op)?;
530 patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
531 return Ok(Some(patch));
532 }
533 if succ.op_is::<AxisOp>() {
534 if let &[next] = &*succ.outputs[0].successors {
535 let bin = model.node(next.node);
536 if let Some(op) = bin.op_as::<ops::binary::TypedBinOp>() {
537 if op.0.as_linalg_binop().is_none() {
538 return Ok(None);
539 };
540 let flipped = succ.inputs[0].node == node.id;
541 let other_outlet = bin.inputs[flipped as usize];
542 if let Some(uni) = &model.outlet_fact(other_outlet)?.uniform {
543 let mut patch = TypedModelPatch::default();
544 let cst =
545 patch.add_const(&model.node(other_outlet.node).name, uni.clone())?;
546 let output = patch.tap_model(model, node.id.into())?;
547 let wire = wire_with_rank_broadcast(
548 &bin.name,
549 &mut patch,
550 op.clone(),
551 &if flipped { [output, cst] } else { [cst, output] },
552 )?;
553 let wire = patch.wire_node(&succ.name, succ.op.clone(), &wire)?[0];
554 patch.shunt_outside(model, bin.id.into(), wire)?;
555 return Ok(Some(patch));
556 }
557 }
558 }
559 }
560 if let Some(op) = succ.op_as::<ops::binary::OptBinUnicast>() {
561 let in_1_fact = model.outlet_fact(succ.inputs[0])?;
562 let in_2_fact = model.outlet_fact(succ.inputs[1])?;
563 if op.binop.is::<ops::math::Add>()
564 && self.mmm.len() == 1
565 && in_1_fact.without_value() == in_2_fact.without_value()
566 {
567 let other_slot = 1 - node.outputs[0].successors[0].slot;
568 let other_input = succ.inputs[other_slot];
569 let other_input = patch.tap_model(model, other_input)?;
570 let other_fact = patch.outlet_fact(other_input)?;
571
572 if other_fact.shape == self.c_fact.shape {
573 let other_storage = unsafe { self.mmm[0].c_view(self.c_m_axis, self.c_n_axis) };
574 let mapping =
575 MapOutputAxisToInput((0..other_fact.rank()).map(|x| (x, x)).collect());
576 return self.fuse_op(
577 model,
578 node,
579 patch,
580 vec![ProtoFusedSpec::AddUnicast(other_storage, node.inputs.len(), mapping)],
581 &[other_input],
582 );
583 }
584 } else {
585 let mut binop = if let Some(op) = op.binop.as_linalg_binop() {
586 op
587 } else {
588 return Ok(None);
589 };
590 let flipped = succ.inputs[0].node == node.id;
591 if flipped {
592 binop = binop.flip();
593 }
594 let other_outlet = succ.inputs[flipped as usize];
595 return self.fuse_binary(model, node, patch, other_outlet, binop);
596 }
597 };
598 Ok(None)
599 }
600
601 as_op!();
602}
603
604impl OptMatMul {
605 pub fn new(
606 mmm: Vec<Box<dyn MatMatMul>>,
607 mode_picker: ModePicker,
608 c_fact: TypedFact,
609 c_m_axis: usize,
610 c_n_axis: usize,
611 micro_ops: Vec<ProtoFusedSpec>,
612 trivial_packing: bool,
613 ) -> TractResult<Self> {
614 ensure!(c_m_axis < c_fact.rank());
615 ensure!(c_n_axis < c_fact.rank());
616 let mut it = OptMatMul {
617 mmm,
618 mode_picker,
619 c_fact,
620 c_m_axis,
621 c_n_axis,
622 micro_ops,
623 trivial_path: false,
624 trivial_packing,
625 };
626 it.update_trivial_path();
627 Ok(it)
628 }
629
630 pub fn guess_k(&self) -> Option<TDim> {
632 self.micro_ops
633 .iter()
634 .find_map(
635 |o| {
636 if let ProtoFusedSpec::AddMatMul { geo, .. } = o {
637 Some(geo)
638 } else {
639 None
640 }
641 },
642 )
643 .map(|geo| geo.k.clone())
644 }
645
646 pub fn m(&self) -> &TDim {
647 &self.c_fact.shape[self.c_m_axis]
648 }
649
650 pub fn n(&self) -> &TDim {
651 &self.c_fact.shape[self.c_n_axis]
652 }
653
654 fn update_trivial_path(&mut self) {
655 self.trivial_path = self.can_use_trivial_path();
656 }
657
658 fn can_use_trivial_path(&self) -> bool {
659 self.c_fact.shape.is_concrete()
660 && self
661 .c_fact
662 .shape
663 .iter()
664 .enumerate()
665 .all(|(ax, dim)| ax == self.c_m_axis || ax == self.c_n_axis || dim.is_one())
666 && self.trivial_packing
667 && self.micro_ops.iter().all(|o| o.is_trivial())
668 }
669
670 fn fuse_op(
671 &self,
672 model: &TypedModel,
673 node: &TypedNode,
674 mut patch: TypedModelPatch,
675 fused_micro_op: Vec<ProtoFusedSpec>,
676 additional_inputs: &[OutletId],
677 ) -> TractResult<Option<TypedModelPatch>> {
678 let succ = model.node(node.outputs[0].successors[0].node);
679 let mut new_op = self.clone();
680 let before_last = new_op.micro_ops.len() - 1..new_op.micro_ops.len() - 1;
681 new_op.micro_ops.splice(before_last, fused_micro_op);
682 new_op.c_fact = succ.outputs[0].fact.clone();
683 new_op.update_trivial_path();
684 let mut inputs = patch.taps(model, &node.inputs)?;
685 inputs.extend(additional_inputs.iter().cloned());
686 let output = patch.wire_node(&succ.name, new_op, &inputs)?;
687 patch.shunt_outside(model, succ.id.into(), output[0])?;
688 Ok(Some(patch))
689 }
690
691 fn fuse_binary(
692 &self,
693 model: &TypedModel,
694 node: &TypedNode,
695 mut patch: TypedModelPatch,
696 value: OutletId,
697 binop: BinOp,
698 ) -> TractResult<Option<TypedModelPatch>> {
699 let fact = model.outlet_fact(value)?;
700 let mut v = patch.tap_model(model, value)?;
701 if fact.datum_type != self.mmm[0].internal_type() {
702 v = patch.wire_node(
703 format!("{}.cast-input-{}", node.name, node.inputs.len()),
704 cast(self.mmm[0].internal_type()),
705 &[v],
706 )?[0];
707 }
708 let value = node.inputs.len();
709 let additional_input = tvec!(v);
710 if fact.shape.volume() == 1.to_dim() {
711 return self.fuse_op(
712 model,
713 node,
714 patch,
715 vec![ProtoFusedSpec::BinScalar(value, binop)],
716 &additional_input,
717 );
718 }
719 let other_shape = fact.shape.to_owned();
720 if other_shape[self.c_m_axis] == self.c_fact.shape[self.c_m_axis]
721 && other_shape[self.c_m_axis] == other_shape.volume()
722 {
723 return self.fuse_op(
724 model,
725 node,
726 patch,
727 vec![ProtoFusedSpec::BinPerRow(
728 value,
729 binop,
730 MapOutputAxisToInput(tvec!((self.c_m_axis, self.c_m_axis))),
731 )],
732 &additional_input,
733 );
734 }
735 if other_shape[self.c_n_axis] == self.c_fact.shape[self.c_n_axis]
736 && other_shape[self.c_n_axis] == other_shape.volume()
737 {
738 return self.fuse_op(
739 model,
740 node,
741 patch,
742 vec![ProtoFusedSpec::BinPerCol(
743 value,
744 binop,
745 MapOutputAxisToInput(tvec!((self.c_n_axis, self.c_n_axis))),
746 )],
747 &additional_input,
748 );
749 }
750 Ok(None)
751 }
752}