1use crate::internal::*;
2use crate::ops::cast::cast;
3use crate::ops::change_axes::wire_with_rank_broadcast;
4use crate::ops::nn::LeakyRelu;
5use ndarray::*;
6use tract_itertools::Itertools;
7
8use tract_linalg::frame::PackedFormat;
9use tract_linalg::mmm::panel_extract::{PanelExtractInput, PanelExtractor};
10use tract_linalg::mmm::{
11 AsInputValue, EagerPackedInput, FusedSpec, MMMInputValue, MatMatMul, OutputStoreSpec,
12};
13use tract_linalg::{BinOp, Scaler};
14use tract_smallvec::ToSmallVec;
15
16use super::ModePicker;
17
18#[derive(Clone, Debug)]
19pub enum ProtoFusedSpec {
20 AddMatMul {
21 geo: AddMatMulGeometry,
22 a: usize,
23 b: usize,
24 packings: Vec<(usize, Option<PanelExtractor>)>,
25 },
26 BinScalar(usize, BinOp),
27 LeakyRelu(usize),
28 BinPerRow(usize, BinOp, MapOutputAxisToInput),
29 BinPerCol(usize, BinOp, MapOutputAxisToInput),
30 AddRowColProducts(usize, usize),
31 AddUnicast(OutputStoreSpec, usize, MapOutputAxisToInput),
32 Scaler(Scaler),
33 Store(Vec<OutputStoreSpec>),
34}
35
36impl ProtoFusedSpec {
37 pub fn format(&self, mmm: &dyn MatMatMul, mode: usize) -> String {
38 use ProtoFusedSpec::*;
39 match self {
40 AddMatMul { geo, packings: packing, .. } => {
41 let (a, b) = &mmm.packings()[packing[mode].0];
42 format!("matmul(k={}, {a:?}•{b:?})", geo.k)
43 }
44 BinScalar(_, op) => format!("scalar{op:?}"),
45 LeakyRelu(alpha) => format!("leaky_relu({alpha:?})"),
46 BinPerRow(_, op, _) => format!("row{op:?}"),
47 BinPerCol(_, op, _) => format!("col{op:?}"),
48 AddRowColProducts(_, _) => "add_row_col_product".to_string(),
49 AddUnicast(_, _, _) => "add_to_matrix".to_string(),
50 Scaler(s) => format!("scale({})", 1f32 * *s),
51 Store(_oss) => "store".to_string(),
52 }
53 }
54
55 pub fn resolve<'t>(
56 &'t self,
57 inputs: &'t [TValue],
58 output_coords: &[usize],
59 output: &Tensor,
60 mmm: &dyn MatMatMul,
61 mode: usize,
62 ) -> FusedSpec<'t> {
63 let fs = match self {
64 ProtoFusedSpec::AddMatMul { geo, a, b, packings } => {
65 let mut a = inputs[*a].view();
66 let mut b = inputs[*b].view();
67 unsafe {
68 geo.c_to_a_axis_mapping.translate_view(output_coords, &mut a);
69 }
70 let a = a.as_slice::<Opaque>().unwrap()[0]
71 .downcast_ref::<Box<dyn MMMInputValue>>()
72 .unwrap();
73 unsafe {
74 geo.c_to_b_axis_mapping.translate_view(output_coords, &mut b);
75 }
76 let b = b.as_slice::<Opaque>().unwrap()[0]
77 .downcast_ref::<Box<dyn MMMInputValue>>()
78 .unwrap();
79 let (_a_packing, b_packing) = &mmm.packings()[packings[mode].0];
80 let pa = if let Some(extractor) = &packings[mode].1 {
81 let data = a.downcast_ref::<EagerPackedInput>().unwrap();
82 AsInputValue::Owned(Box::new(PanelExtractInput {
83 format: extractor.clone(),
84 data: data.clone(),
85 }))
86 } else {
87 AsInputValue::Borrowed(&**a)
88 };
89 assert!(
90 b_packing.same_as(b.format())
91 || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
92 );
93 FusedSpec::AddMatMul {
94 a: pa,
95 b: AsInputValue::Borrowed(&**b),
96 packing: packings[mode].0,
97 }
98 }
99 ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
100 ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
101 ProtoFusedSpec::BinPerRow(v, op, map) => {
102 let mut v = inputs[*v].view();
103 unsafe { map.translate_view(output_coords, &mut v) }
104 FusedSpec::BinPerRow(v, *op)
105 }
106 ProtoFusedSpec::BinPerCol(v, op, map) => {
107 let mut v = inputs[*v].view();
108 unsafe { map.translate_view(output_coords, &mut v) }
109 FusedSpec::BinPerCol(v, *op)
110 }
111 ProtoFusedSpec::AddRowColProducts(row, col) => {
112 FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
113 }
114 ProtoFusedSpec::AddUnicast(store, v, map) => unsafe {
115 let mut view = inputs[*v].view();
116 map.translate_view(output_coords, &mut view);
117 FusedSpec::AddUnicast(store.wrap(&view))
118 },
119 ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
120 ProtoFusedSpec::Store(oss) => unsafe {
121 let view = output.view_offsetting_unchecked(output_coords);
122 FusedSpec::Store(oss[mode].wrap(&view))
123 },
124 };
125 fs
126 }
127
128 pub fn is_trivial(&self) -> bool {
129 match self {
130 ProtoFusedSpec::AddMatMul { geo, .. } => geo.k.as_i64().is_some(),
131 _ => true,
132 }
133 }
134
135 pub fn resolve_trivial<'t>(
136 &'t self,
137 inputs: &'t [TValue],
138 output: &mut Tensor,
139 _mmm: &dyn MatMatMul,
140 mode: usize,
141 ) -> FusedSpec<'t> {
142 let fs = match self {
143 ProtoFusedSpec::AddMatMul { a, b, packings, .. } => unsafe {
144 debug_assert!(inputs.get(*a).is_some());
145 debug_assert!(inputs.get(*b).is_some());
146 let a = inputs.get_unchecked(*a);
147 let b = inputs.get_unchecked(*b);
148 debug_assert!(a.datum_type().is_opaque());
149 debug_assert!(a.len() == 1);
150 debug_assert!(b.datum_type().is_opaque());
151 debug_assert!(b.len() == 1);
152 let a = a.as_slice_unchecked::<Opaque>().get_unchecked(0);
153 let b = b.as_slice_unchecked::<Opaque>().get_unchecked(0);
154 debug_assert!(a.is::<Box<dyn MMMInputValue>>());
155 debug_assert!(b.is::<Box<dyn MMMInputValue>>());
156 let a = a.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
157 let b = b.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
158 #[cfg(debug_assertions)]
159 {
160 let (a_packing, b_packing) = &_mmm.packings()[packings[mode].0];
161 debug_assert!(
162 a_packing.same_as(a.format())
163 || (a_packing.is::<PackedFormat>() && a_packing.r() == a.format().r())
164 );
165 debug_assert!(
166 b_packing.same_as(b.format())
167 || (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
168 );
169 }
170 FusedSpec::AddMatMul {
171 a: AsInputValue::Borrowed(&**a),
172 b: AsInputValue::Borrowed(&**b),
173 packing: packings[mode].0,
174 }
175 },
176 ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
177 ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
178 ProtoFusedSpec::BinPerRow(v, op, _) => {
179 let v = inputs[*v].view();
180 FusedSpec::BinPerRow(v, *op)
181 }
182 ProtoFusedSpec::BinPerCol(v, op, _) => {
183 let v = inputs[*v].view();
184 FusedSpec::BinPerCol(v, *op)
185 }
186 ProtoFusedSpec::AddRowColProducts(row, col) => {
187 FusedSpec::AddRowColProducts(&inputs[*row], &inputs[*col])
188 }
189 ProtoFusedSpec::AddUnicast(store, v, _) => unsafe {
190 let view = inputs[*v].view();
191 FusedSpec::AddUnicast(store.wrap(&view))
192 },
193 ProtoFusedSpec::Scaler(scaler) => scaler.as_fused_spec(),
194 ProtoFusedSpec::Store(oss) => unsafe {
195 FusedSpec::Store(oss[mode].wrap(&output.view_mut()))
196 },
197 };
198 fs
199 }
200
201 fn check_inputs(&self, inputs: &[&TypedFact]) -> TractResult<()> {
202 use ProtoFusedSpec::*;
203 match self {
204 AddMatMul { a, b, .. } => {
205 ensure!(inputs[*a].datum_type == Opaque::datum_type());
206 ensure!(inputs[*b].datum_type == Opaque::datum_type());
207 }
208 BinScalar(v, _)
209 | LeakyRelu(v)
210 | BinPerCol(v, _, _)
211 | BinPerRow(v, _, _)
212 | AddUnicast(_, v, _) => {
213 ensure!(inputs[*v].datum_type.is_number());
214 }
215 AddRowColProducts(row, col) => {
216 ensure!(inputs[*row].datum_type.is_number());
217 ensure!(inputs[*col].datum_type.is_number());
218 }
219 _ => (),
220 };
221 Ok(())
222 }
223
224 fn cost(&self, m: &TDim, n: &TDim, idt: DatumType) -> TVec<(Cost, TDim)> {
225 match self {
226 ProtoFusedSpec::AddMatMul { geo, .. } => {
227 tvec!((Cost::FMA(idt), m.clone() * n * &geo.k))
228 }
229 _ => tvec!(), }
231 }
232
233 fn rm_c_axis(&mut self, axis: usize) {
234 use ProtoFusedSpec::*;
235 match self {
236 AddMatMul { geo, .. } => {
237 geo.c_to_a_axis_mapping.rm_c_axis(axis);
238 geo.c_to_b_axis_mapping.rm_c_axis(axis);
239 }
240 BinScalar(..) | Scaler(..) | AddRowColProducts(_, _) | LeakyRelu(_) => {}
241 BinPerRow(_, _, map) | BinPerCol(_, _, map) => map.rm_c_axis(axis),
242 AddUnicast(_, _, map) => {
243 map.rm_c_axis(axis);
244 }
245 Store(oss, ..) => {
246 for oss in oss {
247 match oss {
248 OutputStoreSpec::View { m_axis, n_axis, .. } => {
249 *m_axis -= (*m_axis > axis) as usize;
250 *n_axis -= (*n_axis > axis) as usize;
251 }
252 OutputStoreSpec::Strides { .. } => {}
253 }
254 }
255 }
256 }
257 }
258}
259
260#[derive(Clone, Debug)]
261pub struct MapOutputAxisToInput(pub TVec<(usize, usize)>);
262
263impl MapOutputAxisToInput {
264 #[inline]
265 unsafe fn translate_view(&self, output_coords: &[usize], v: &mut TensorView) {
266 for &(out_axis, in_axis) in &self.0 {
267 v.offset_axis(in_axis, output_coords[out_axis] as isize)
268 }
269 }
270
271 #[inline]
272 fn rm_c_axis(&mut self, axis: usize) {
273 for (c, _) in &mut self.0 {
274 *c -= (*c > axis) as usize;
275 }
276 }
277}
278
279#[derive(Clone, Debug)]
280pub struct AddMatMulGeometry {
281 pub k: TDim,
282 pub c_to_a_axis_mapping: MapOutputAxisToInput,
283 pub c_to_b_axis_mapping: MapOutputAxisToInput,
284}
285
286#[derive(Clone, Debug)]
287pub struct OptMatMul {
288 pub c_fact: TypedFact,
289 pub micro_ops: Vec<ProtoFusedSpec>,
290 pub mmm: Vec<Box<dyn MatMatMul>>,
291 pub mode_picker: ModePicker,
292 pub c_m_axis: usize,
293 pub c_n_axis: usize,
294 pub trivial_packing: bool,
295 pub trivial_path: bool,
296}
297
298impl Op for OptMatMul {
299 fn name(&self) -> Cow<str> {
300 "OptMatMul".into()
301 }
302
303 fn info(&self) -> TractResult<Vec<String>> {
304 let m = &self.c_fact.shape[self.c_m_axis];
305 let n = &self.c_fact.shape[self.c_n_axis];
306 let mut infos = vec![format!(
307 "c_shape:{:?}, c_m_axis:{} c_n_axis:{} m:{} n:{}",
308 self.c_fact, self.c_m_axis, self.c_n_axis, m, n,
309 )];
310 if let Some(k) = self.guess_k() {
311 infos.push(format!("Mult: m:{} k:{} n:{} with {:?}", m, k, n, self.mmm));
312 } else {
313 infos.push(format!("Mult: {:?}", self.mmm));
314 }
315 for (mode, mmm) in self.mmm.iter().enumerate() {
316 infos.push(format!(
317 "Ops: {}",
318 self.micro_ops.iter().map(|o| o.format(&**mmm, mode)).join(" >>> ")
319 ));
320 }
321 Ok(infos)
322 }
323
324 op_as_typed_op!();
325}
326
327impl EvalOp for OptMatMul {
328 fn is_stateless(&self) -> bool {
329 true
330 }
331
332 fn eval_with_session(
333 &self,
334 session: &SessionState,
335 inputs: TVec<TValue>,
336 ) -> TractResult<TVec<TValue>> {
337 unsafe {
338 let c_shape = self.c_fact.shape.eval_to_usize(&session.resolved_symbols)?;
339 let mut c = Tensor::uninitialized_dt(self.c_fact.datum_type, &c_shape)?;
340 let mode = self.mode_picker.pick(c_shape[self.c_n_axis])?;
341 let mmm = &*self.mmm[mode];
342 let mut cell = session.cached_mmm_scratch_space.borrow_mut();
343 if !cell.as_ref().is_some_and(|scratch| mmm.can_use_scratch_space(&**scratch)) {
344 *cell = None
345 }
346 let scratch = cell.get_or_insert_with(|| mmm.allocate_scratch_space());
347
348 if self.trivial_path {
349 let uops: Vec<FusedSpec> = self
350 .micro_ops
351 .iter()
352 .map(|o| o.resolve_trivial(&inputs, &mut c, mmm, mode))
353 .collect();
354 mmm.run_with_scratch_space(
355 *c_shape.get_unchecked(self.c_m_axis),
356 *c_shape.get_unchecked(self.c_n_axis),
357 scratch.as_mut(),
358 &uops,
359 )?;
360 Ok(tvec!(c.into_tvalue()))
361 } else {
362 let mut uops = vec![FusedSpec::ShiftLeft(0); self.micro_ops.len()];
363 let mut looping_shape: TVec<usize> = c_shape.to_smallvec();
364 looping_shape[self.c_m_axis] = 1;
365 looping_shape[self.c_n_axis] = 1;
366 let m = c_shape[self.c_m_axis];
367 let n = c_shape[self.c_n_axis];
368 for c_coords in indices(&*looping_shape) {
369 for ix in 0..self.micro_ops.len() {
370 *uops.get_unchecked_mut(ix) = self.micro_ops.get_unchecked(ix).resolve(
371 &inputs,
372 c_coords.slice(),
373 &c,
374 mmm,
375 mode,
376 );
377 }
378 mmm.run_with_scratch_space(m, n, scratch.as_mut(), &uops)
379 .context("In mmm.run_with_scratch_space")?;
380 }
381 Ok(tvec!(c.into_tvalue()))
382 }
383 }
384 }
385}
386
387impl TypedOp for OptMatMul {
388 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
389 ensure!(self.c_m_axis < self.c_fact.rank());
390 ensure!(self.c_n_axis < self.c_fact.rank());
391 ensure!(self.trivial_path == self.can_use_trivial_path());
392 ensure!(self.mmm.iter().map(|mmm| mmm.internal_type()).all_equal());
393 for op in &self.micro_ops {
394 op.check_inputs(inputs)?;
395 }
396 Ok(tvec!(self.c_fact.clone()))
397 }
398
399 fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
400 let mut sums = HashMap::new();
401 let m = &self.c_fact.shape[self.c_m_axis];
402 let n = &self.c_fact.shape[self.c_n_axis];
403 for op in &self.micro_ops {
404 for (cost, count) in op.cost(m, n, self.mmm[0].internal_type()) {
405 *sums.entry(cost).or_default() += count;
406 }
407 }
408 let loops =
409 self.c_fact
410 .shape
411 .iter()
412 .enumerate()
413 .map(|(ix, d)| {
414 if ix == self.c_m_axis || ix == self.c_n_axis {
415 1.to_dim()
416 } else {
417 d.clone()
418 }
419 })
420 .product::<TDim>();
421 for s in &mut sums.values_mut() {
422 *s *= &loops;
423 }
424 Ok(sums.into_iter().collect())
425 }
426
427 fn fuse(&self, model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
428 use crate::ops;
429 if node.outputs.len() != 1
430 || node.outputs[0].successors.len() != 1
431 || model.output_outlets()?.contains(&node.id.into())
432 {
433 return Ok(None);
434 }
435 let succ = model.node(node.outputs[0].successors[0].node);
436 let mut patch = TypedModelPatch::new(format!("fusing {succ}"));
437
438 if let Some(op) = succ.op_as::<ops::binary::TypedBinOp>() {
439 let mut binop =
440 if let Some(op) = op.0.as_linalg_binop() { op } else { return Ok(None) };
441 let flipped = succ.inputs[0].node == node.id;
442 if flipped {
443 binop = binop.flip();
444 }
445 let other_outlet = succ.inputs[flipped as usize];
446 return self.fuse_binary(model, node, patch, other_outlet, binop);
447 }
448 if let Some(op) = succ.op_as::<ops::binary::OptBinByScalar>() {
449 let mut binop =
450 if let Some(op) = op.binop.as_linalg_binop() { op } else { return Ok(None) };
451 let flipped = succ.inputs[0].node == node.id;
452 if flipped {
453 binop = binop.flip();
454 }
455 let other_outlet = succ.inputs[flipped as usize];
456 return self.fuse_binary(model, node, patch, other_outlet, binop);
457 }
458
459 if let Some(op) = succ.op_as::<ops::element_wise::ElementWiseOp>().map(|ew| ew.0.as_ref()) {
460 if let Some(op) = op.downcast_ref::<ops::math::QScale>() {
461 return self.fuse_op(
462 model,
463 node,
464 patch,
465 vec![ProtoFusedSpec::Scaler(op.scaler)],
466 &[],
467 );
468 }
469 if let Some(op) = op.downcast_ref::<LeakyRelu>() {
470 if !self
471 .mmm
472 .iter()
473 .all(|mmm| mmm.can_fuse(&FusedSpec::LeakyRelu(&tensor0(op.alpha))))
474 {
475 return Ok(None);
476 }
477 let alpha = patch.add_const(
478 node.name.to_string() + ".alpha",
479 tensor0(op.alpha).cast_to_dt(self.mmm[0].internal_type())?.into_owned(),
480 )?;
481 return self.fuse_op(
482 model,
483 node,
484 patch,
485 vec![ProtoFusedSpec::LeakyRelu(node.inputs.len())],
486 &[alpha],
487 );
488 }
489 }
490 if let Some(cast_to) = succ.op_as::<ops::cast::Cast>().map(|cast| cast.to) {
491 if (cast_to.unquantized() == i8::datum_type()
492 || cast_to.unquantized() == u8::datum_type())
493 && self.c_fact.datum_type == i32::datum_type()
494 {
495 if let Some(ProtoFusedSpec::Store(stores)) = self.micro_ops.last() {
496 if stores.iter().any(|s| matches!(s, OutputStoreSpec::Strides { .. })) {
497 return Ok(None);
498 }
499 let c_fact = cast_to.fact(self.c_fact.shape.clone());
500 let mut patch = TypedModelPatch::fuse_with_next(
501 model,
502 node,
503 Self { c_fact, ..self.clone() },
504 )?;
505 patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
506 return Ok(Some(patch));
507 }
508 }
509 }
510 if let Some(AxisOp::Rm(axis)) = succ.op_as::<ops::AxisOp>() {
511 if *axis == self.c_m_axis || *axis == self.c_n_axis {
512 return Ok(None);
513 }
514 let mut new_op = self.clone();
515 new_op.c_fact.shape.remove_axis(*axis)?;
516 new_op.c_m_axis -= (new_op.c_m_axis > *axis) as usize;
517 new_op.c_n_axis -= (new_op.c_n_axis > *axis) as usize;
518 for uop in &mut new_op.micro_ops {
519 uop.rm_c_axis(*axis);
520 }
521 let mut patch = TypedModelPatch::fuse_with_next(model, node, new_op)?;
522 patch.dont_apply_twice = Some(format!("Fuse {succ} into {node}"));
523 return Ok(Some(patch));
524 }
525 if succ.op_is::<AxisOp>() {
526 if let &[next] = &*succ.outputs[0].successors {
527 let bin = model.node(next.node);
528 if let Some(op) = bin.op_as::<ops::binary::TypedBinOp>() {
529 if op.0.as_linalg_binop().is_none() {
530 return Ok(None);
531 };
532 let flipped = succ.inputs[0].node == node.id;
533 let other_outlet = bin.inputs[flipped as usize];
534 if let Some(uni) = &model.outlet_fact(other_outlet)?.uniform {
535 let mut patch = TypedModelPatch::default();
536 let cst =
537 patch.add_const(&model.node(other_outlet.node).name, uni.clone())?;
538 let output = patch.tap_model(model, node.id.into())?;
539 let wire = wire_with_rank_broadcast(
540 &bin.name,
541 &mut patch,
542 op.clone(),
543 &if flipped { [output, cst] } else { [cst, output] },
544 )?;
545 let wire = patch.wire_node(&succ.name, succ.op.clone(), &wire)?[0];
546 patch.shunt_outside(model, bin.id.into(), wire)?;
547 return Ok(Some(patch));
548 }
549 }
550 }
551 }
552 if let Some(op) = succ.op_as::<ops::binary::OptBinUnicast>() {
553 let in_1_fact = model.outlet_fact(succ.inputs[0])?;
554 let in_2_fact = model.outlet_fact(succ.inputs[1])?;
555 if op.binop.is::<ops::math::Add>()
556 && self.mmm.len() == 1
557 && in_1_fact.without_value() == in_2_fact.without_value()
558 {
559 let other_slot = 1 - node.outputs[0].successors[0].slot;
560 let other_input = succ.inputs[other_slot];
561 let other_input = patch.tap_model(model, other_input)?;
562 let other_fact = patch.outlet_fact(other_input)?;
563
564 if other_fact.shape == self.c_fact.shape {
565 let other_storage = unsafe { self.mmm[0].c_view(self.c_m_axis, self.c_n_axis) };
566 let mapping =
567 MapOutputAxisToInput((0..other_fact.rank()).map(|x| (x, x)).collect());
568 return self.fuse_op(
569 model,
570 node,
571 patch,
572 vec![ProtoFusedSpec::AddUnicast(other_storage, node.inputs.len(), mapping)],
573 &[other_input],
574 );
575 }
576 } else {
577 let mut binop =
578 if let Some(op) = op.binop.as_linalg_binop() { op } else { return Ok(None) };
579 let flipped = succ.inputs[0].node == node.id;
580 if flipped {
581 binop = binop.flip();
582 }
583 let other_outlet = succ.inputs[flipped as usize];
584 return self.fuse_binary(model, node, patch, other_outlet, binop);
585 }
586 };
587 Ok(None)
588 }
589
590 as_op!();
591}
592
593impl OptMatMul {
594 pub fn new(
595 mmm: Vec<Box<dyn MatMatMul>>,
596 mode_picker: ModePicker,
597 c_fact: TypedFact,
598 c_m_axis: usize,
599 c_n_axis: usize,
600 micro_ops: Vec<ProtoFusedSpec>,
601 trivial_packing: bool,
602 ) -> TractResult<Self> {
603 ensure!(c_m_axis < c_fact.rank());
604 ensure!(c_n_axis < c_fact.rank());
605 let mut it = OptMatMul {
606 mmm,
607 mode_picker,
608 c_fact,
609 c_m_axis,
610 c_n_axis,
611 micro_ops,
612 trivial_path: false,
613 trivial_packing,
614 };
615 it.update_trivial_path();
616 Ok(it)
617 }
618 fn guess_k(&self) -> Option<TDim> {
620 self.micro_ops
621 .iter()
622 .find_map(
623 |o| {
624 if let ProtoFusedSpec::AddMatMul { geo, .. } = o {
625 Some(geo)
626 } else {
627 None
628 }
629 },
630 )
631 .map(|geo| geo.k.clone())
632 }
633
634 fn update_trivial_path(&mut self) {
635 self.trivial_path = self.can_use_trivial_path();
636 }
637
638 fn can_use_trivial_path(&self) -> bool {
639 self.c_fact.shape.is_concrete()
640 && self
641 .c_fact
642 .shape
643 .iter()
644 .enumerate()
645 .all(|(ax, dim)| ax == self.c_m_axis || ax == self.c_n_axis || dim.is_one())
646 && self.trivial_packing
647 && self.micro_ops.iter().all(|o| o.is_trivial())
648 }
649
650 fn fuse_op(
651 &self,
652 model: &TypedModel,
653 node: &TypedNode,
654 mut patch: TypedModelPatch,
655 fused_micro_op: Vec<ProtoFusedSpec>,
656 additional_inputs: &[OutletId],
657 ) -> TractResult<Option<TypedModelPatch>> {
658 let succ = model.node(node.outputs[0].successors[0].node);
659 let mut new_op = self.clone();
660 let before_last = new_op.micro_ops.len() - 1..new_op.micro_ops.len() - 1;
661 new_op.micro_ops.splice(before_last, fused_micro_op);
662 new_op.c_fact = succ.outputs[0].fact.clone();
663 new_op.update_trivial_path();
664 let mut inputs = patch.taps(model, &node.inputs)?;
665 inputs.extend(additional_inputs.iter().cloned());
666 let output = patch.wire_node(&succ.name, new_op, &inputs)?;
667 patch.shunt_outside(model, succ.id.into(), output[0])?;
668 Ok(Some(patch))
669 }
670
671 fn fuse_binary(
672 &self,
673 model: &TypedModel,
674 node: &TypedNode,
675 mut patch: TypedModelPatch,
676 value: OutletId,
677 binop: BinOp,
678 ) -> TractResult<Option<TypedModelPatch>> {
679 let fact = model.outlet_fact(value)?;
680 let mut v = patch.tap_model(model, value)?;
681 if fact.datum_type != self.mmm[0].internal_type() {
682 v = patch.wire_node(
683 format!("{}.cast-input-{}", node.name, node.inputs.len()),
684 cast(self.mmm[0].internal_type()),
685 &[v],
686 )?[0];
687 }
688 let value = node.inputs.len();
689 let additional_input = tvec!(v);
690 if fact.shape.volume() == 1.to_dim() {
691 return self.fuse_op(
692 model,
693 node,
694 patch,
695 vec![ProtoFusedSpec::BinScalar(value, binop)],
696 &additional_input,
697 );
698 }
699 let other_shape = fact.shape.to_owned();
700 if other_shape[self.c_m_axis] == self.c_fact.shape[self.c_m_axis]
701 && other_shape[self.c_m_axis] == other_shape.volume()
702 {
703 return self.fuse_op(
704 model,
705 node,
706 patch,
707 vec![ProtoFusedSpec::BinPerRow(
708 value,
709 binop,
710 MapOutputAxisToInput(tvec!((self.c_m_axis, self.c_m_axis))),
711 )],
712 &additional_input,
713 );
714 }
715 if other_shape[self.c_n_axis] == self.c_fact.shape[self.c_n_axis]
716 && other_shape[self.c_n_axis] == other_shape.volume()
717 {
718 return self.fuse_op(
719 model,
720 node,
721 patch,
722 vec![ProtoFusedSpec::BinPerCol(
723 value,
724 binop,
725 MapOutputAxisToInput(tvec!((self.c_n_axis, self.c_n_axis))),
726 )],
727 &additional_input,
728 );
729 }
730 Ok(None)
731 }
732}