1use crate::infer::*;
2use crate::internal::*;
3
4use tract_core::ops::einsum::EinSum;
5use tract_core::tract_data::itertools::Itertools;
6
7#[derive(Debug, Clone, Default, Hash)]
8pub struct MatMulInference {
9 pub a_trans: bool,
10 pub b_trans: bool,
11 pub c_trans: bool,
12}
13
14impl MatMulInference {
15 pub fn with_a_trans(self, a_trans: bool) -> MatMulInference {
16 MatMulInference { a_trans, ..self }
17 }
18
19 pub fn with_b_trans(self, b_trans: bool) -> MatMulInference {
20 MatMulInference { b_trans, ..self }
21 }
22
23 pub fn with_c_trans(self, c_trans: bool) -> MatMulInference {
24 MatMulInference { c_trans, ..self }
25 }
26}
27
28impl Expansion for MatMulInference {
29 fn name(&self) -> StaticName {
30 "MatMulInference".into()
31 }
32
33 fn rules<'r, 'p: 'r, 's: 'r>(
34 &'s self,
35 s: &mut Solver<'r>,
36 inputs: &'p [TensorProxy],
37 outputs: &'p [TensorProxy],
38 ) -> InferenceResult {
39 check_input_arity(inputs, 2)?;
40 check_output_arity(outputs, 1)?;
41 s.equals(&inputs[0].datum_type, &inputs[1].datum_type)?;
42 s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
43 s.given_2(&inputs[0].shape, &inputs[1].shape, move |s, ashape, bshape| {
44 let (_, _, _, cshape) =
45 compute_shapes(ashape, bshape, self.a_trans, self.b_trans, self.c_trans)?;
46 s.equals(&outputs[0].shape, cshape)
47 })?;
48 Ok(())
49 }
50
51 fn wire(
52 &self,
53 prefix: &str,
54 target: &mut TypedModel,
55 inputs: &[OutletId],
56 ) -> TractResult<TVec<OutletId>> {
57 let a_rank = target.outlet_fact(inputs[0])?.rank();
58 let b_rank = target.outlet_fact(inputs[1])?.rank();
59 ensure!(a_rank > 1 || b_rank > 1);
60 let mk = if self.a_trans { "km" } else { "mk" };
61 let kn = if self.b_trans { "nk" } else { "kn" };
62 let mn = if self.c_trans { "nm" } else { "mn" };
63 let axes: AxesMapping = if a_rank == 1 {
64 let prefix: String = ('a'..).take(b_rank - 2).collect();
65 format!("k,{prefix}{kn}->{prefix}n").parse()?
66 } else if b_rank == 1 {
67 let prefix: String = ('a'..).take(a_rank - 2).collect();
68 format!("{prefix}{mk},k->{prefix}m").parse()?
69 } else {
70 let c_rank = b_rank.max(a_rank);
71 let a_prefix: String =
72 ('a'..).take(c_rank - 2).skip(b_rank.saturating_sub(a_rank)).collect();
73 let b_prefix: String =
74 ('a'..).take(c_rank - 2).skip(a_rank.saturating_sub(b_rank)).collect();
75 let c_prefix: String = ('a'..).take(c_rank - 2).collect();
76 format!("{a_prefix}{mk},{b_prefix}{kn}->{c_prefix}{mn}").parse()?
77 };
78 let dt = target.outlet_fact(inputs[0])?.datum_type;
79 target.wire_node(prefix, EinSum { axes, operating_dt: dt, q_params: None }, inputs)
80 }
81}
82
83#[allow(clippy::type_complexity)]
84pub fn compute_shapes<D: DimLike>(
85 mut ashape: TVec<D>,
86 mut bshape: TVec<D>,
87 a_trans: bool,
88 b_trans: bool,
89 c_trans: bool,
90) -> TractResult<(TVec<D>, TVec<D>, TVec<D>, TVec<D>)> {
91 let mut implicit_m = false;
92 let mut implicit_n = false;
93 if ashape.len() < 2 {
94 implicit_m = true;
95 ashape.insert(a_trans as usize, D::one());
96 }
97 if bshape.len() < 2 {
98 implicit_n = true;
99 bshape.insert(!b_trans as usize, D::one());
100 }
101 while ashape.len() < bshape.len() {
102 ashape.insert(0, D::one());
103 }
104 while bshape.len() < ashape.len() {
105 bshape.insert(0, D::one());
106 }
107 let c_bc_shape_prefix = tract_core::broadcast::multi_broadcast(&[
108 &ashape[..(ashape.len() - 2)],
109 &bshape[..(bshape.len() - 2)],
110 ])?;
111 let mut c_bc_shape: TVec<D> = c_bc_shape_prefix;
112 let (mut m, mut ka) = (ashape[ashape.len() - 2].clone(), ashape[ashape.len() - 1].clone());
113 let (mut kb, mut n) = (bshape[bshape.len() - 2].clone(), bshape[bshape.len() - 1].clone());
114 if a_trans {
115 std::mem::swap(&mut m, &mut ka);
116 }
117 if b_trans {
118 std::mem::swap(&mut kb, &mut n);
119 }
120 if !ka.compatible_with(&kb) {
121 bail!(
122 "Inconsistent matmul: a: {} b: {}, a_trans: {} b_trans: {} c_trans: {}",
123 ashape.iter().join(","),
124 bshape.iter().join(","),
125 a_trans,
126 b_trans,
127 c_trans
128 );
129 }
130 let mut c_shape_final = c_bc_shape.clone();
131 if c_trans {
132 c_bc_shape.push(n.clone());
133 c_bc_shape.push(m.clone());
134 if !implicit_n {
135 c_shape_final.push(n.clone());
136 }
137 if !implicit_m {
138 c_shape_final.push(m.clone());
139 }
140 } else {
141 c_bc_shape.push(m.clone());
142 c_bc_shape.push(n.clone());
143 if !implicit_m {
144 c_shape_final.push(m.clone());
145 }
146 if !implicit_n {
147 c_shape_final.push(n.clone());
148 }
149 }
150 Ok((ashape, bshape, c_bc_shape, c_shape_final))
151}