1use tract_data::itertools::Itertools;
2use tract_linalg::Scaler;
3use tract_ndarray::Ix2;
4use tract_num_traits::One;
5
6use super::eval::dequant_inputs;
7use super::optimize::*;
8use super::EinSum;
9use crate::internal::*;
10use crate::ops::einsum::block_quant_aware_input_shape;
11use crate::ops::konst::Const;
12
13pub fn rewrite_einsums_as_matmul(model: &mut TypedModel) -> TractResult<()> {
14 let rules = Rewriter::default().with_rule_for::<EinSum>("einsum-to-matmul", einsum_rules);
15 rules.rewrite(&(), model)
16}
17
18fn einsum_rules(
19 _ctx: &(),
20 model: &TypedModel,
21 node: &TypedNode,
22 node_name: &str,
23 op: &EinSum,
24) -> TractResult<Option<TypedModelPatch>> {
25 if !((op.q_params.is_none() && node.inputs.len() == 2)
28 || (op.q_params.is_some() && node.inputs.len() == 9))
29 {
30 return Ok(None);
31 }
32 if op.q_params.is_some()
33 && model.node_input_facts(node.id)?.iter().skip(3).any(|i| i.konst.is_none())
34 {
35 return Ok(None);
36 }
37 let op = match ensure_mkn_axes(op, model, node).context("Figuring out m, k and n axes")? {
38 AxesOrPatch::Annotated(op) => op,
39 AxesOrPatch::Patch(p) => return Ok(Some(p)),
40 AxesOrPatch::NotAMatMul(s, axes) => {
41 bail!("{} is not a matmul: {s} ({})", op.axes, axes.iter().map(|a| a.repr).join(", "))
42 }
43 };
44 let prefix: String = op
45 .axes
46 .iter_all_axes()
47 .filter(|a| ![op.m_axis, op.k_axis, op.n_axis].contains(a))
48 .map(|a| a.repr)
49 .collect();
50 let mut patch = TypedModelPatch::default();
51 let inputs = patch.taps(model, &node.inputs)?;
52 let mut wire = tvec!(inputs[0], inputs[1]);
53
54 let (m, k, n) = (op.m_axis.repr, op.k_axis.repr, op.n_axis.repr);
55 let a_order_es: String = op.axes.axes(InOut::In(0)).map(|a| a.repr).collect();
56 let a_order_mm = format!("{prefix}{m}{k}");
57 let a_order_mm_t = format!("{prefix}{k}{m}");
58 let a_transform = format!("{}->{}", a_order_es, a_order_mm)
59 .parse::<AxesMapping>()?
60 .translate_to_axis_ops()?;
61 let a_transform_t = format!("{}->{}", a_order_es, a_order_mm_t)
62 .parse::<AxesMapping>()?
63 .translate_to_axis_ops()?;
64 let transpose_a = a_transform.len() > a_transform_t.len();
65 let a_transform = if transpose_a { a_transform_t } else { a_transform };
66 let name = format!("{node_name}.fix_a");
67 for op in a_transform {
68 wire[0] = patch.wire_node(&name, op, &[wire[0]])?[0];
69 }
70 if let Some(op) = patch.node_mut(wire[0].node).op_as_mut::<Const>() {
73 *op = Const::new_with_opt_opaque_fact(
74 op.val().clone(),
75 model.outlet_fact(node.inputs[0])?.opaque_fact.clone(),
76 )?;
77 }
78 patch
79 .outlet_fact_mut(wire[0])?
80 .opaque_fact
81 .clone_from(&model.outlet_fact(node.inputs[0])?.opaque_fact);
82 let b_order_es: String = op.axes.axes(InOut::In(1)).map(|a| a.repr).collect();
85 let b_order_mm = format!("{prefix}{k}{n}");
86 let b_order_mm_t = format!("{prefix}{n}{k}");
87 let b_transform = format!("{}->{}", b_order_es, b_order_mm)
88 .parse::<AxesMapping>()?
89 .translate_to_axis_ops()?;
90 let b_transform_t = format!("{}->{}", b_order_es, b_order_mm_t)
91 .parse::<AxesMapping>()?
92 .translate_to_axis_ops()?;
93 let transpose_b = b_transform.len() > b_transform_t.len();
94 let b_transform = if transpose_b { b_transform_t } else { b_transform };
95 let name = format!("{node_name}.fix_b");
96 for op in b_transform {
97 wire[1] = patch.wire_node(&name, op, &[wire[1]])?[0];
98 }
99
100 let c_order_es: String = op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
101 let c_order_mm = format!("{prefix}{m}{n}");
102 let c_order_mm_t = format!("{prefix}{n}{m}");
103 let c_transform = format!("{}->{}", c_order_mm, c_order_es)
104 .parse::<AxesMapping>()?
105 .translate_to_axis_ops()?;
106 let c_transform_t = format!("{}->{}", c_order_mm_t, c_order_es)
107 .parse::<AxesMapping>()?
108 .translate_to_axis_ops()?;
109 let transpose_c = c_transform.len() > c_transform_t.len();
110 let c_transform = if transpose_c { c_transform_t } else { c_transform };
111 let quantize_output = if let Some(qp) = op.q_params {
112 let qparams: Vec<&Tensor> = inputs[3..9]
113 .iter()
114 .map(|f| {
115 patch
116 .outlet_fact(*f)?
117 .konst
118 .as_deref()
119 .context("Can only translate fixed scalar quantization")
120 })
121 .try_collect()?;
122 Some(qp.with_qparams(QParams::ZpScale {
123 zero_point: qparams[4].cast_to_scalar::<i32>()?,
124 scale: qparams[5].cast_to_scalar::<f32>()?,
125 }))
126 } else {
127 None
128 };
129 wire = patch.wire_node(
130 node_name,
131 BasicMatMul { transpose_a, transpose_b, transpose_c, quantize_output },
132 &wire,
133 )?;
134
135 for (ix, op) in c_transform.into_iter().enumerate() {
136 wire = patch.wire_node(format!("{node_name}.fix_c.{ix}"), op, &wire)?;
137 }
138 patch.shunt_outside(model, node.id.into(), wire[0])?;
139 Ok(Some(patch))
140}
141
142#[derive(Clone, Debug, Copy, Default)]
143pub struct BasicMatMul {
144 pub transpose_a: bool,
145 pub transpose_b: bool,
146 pub transpose_c: bool,
147 pub quantize_output: Option<DatumType>,
148}
149
150impl BasicMatMul {
151 fn output_shape<D: DimLike + One>(&self, a: &[D], b: &[D]) -> TVec<D> {
152 let rank = a.len();
153 let mut output: TVec<D> = (0..rank - 2)
154 .map(|ix| if a[ix].is_one() { b[ix].clone() } else { a[ix].clone() })
155 .collect();
156 output.push(a[rank - 2 + self.transpose_a as usize].clone());
157 output.push(b[rank - 2 + !self.transpose_b as usize].clone());
158 if self.transpose_c {
159 output.swap(rank - 2, rank - 1);
160 }
161 output
162 }
163
164 fn mm<Acc: Datum + tract_ndarray::LinalgScalar>(
165 &self,
166 acc: &mut Tensor,
167 a: &Tensor,
168 b: &Tensor,
169 ) -> TractResult<()> {
170 use crate::ndarray::Dimension;
171 let a = a.to_array_view::<Acc>()?;
172 let b = b.to_array_view::<Acc>()?;
173 let mut c = acc.to_array_view_mut::<Acc>()?;
174 for prefix in tract_ndarray::indices(&c.shape()[..c.ndim() - 2]) {
175 let mut a = a.view();
176 let mut b = b.view();
177 let mut c = c.view_mut();
178 for &d in prefix.slice().iter() {
179 a.index_axis_inplace(tract_ndarray::Axis(0), d.min(a.shape()[0] - 1));
180 b.index_axis_inplace(tract_ndarray::Axis(0), d.min(b.shape()[0] - 1));
181 c.index_axis_inplace(tract_ndarray::Axis(0), d);
182 }
183 let a = a.into_dimensionality::<Ix2>().unwrap();
184 let b = b.into_dimensionality::<Ix2>().unwrap();
185 let mut c = c.into_dimensionality::<Ix2>().unwrap();
186 let a = if self.transpose_a { a.t() } else { a };
187 let b = if self.transpose_b { b.t() } else { b };
188 if self.transpose_c {
189 c.assign(&b.t().dot(&a.t()))
190 } else {
191 c.assign(&a.dot(&b))
192 }
193 }
194 Ok(())
195 }
196}
197
198impl Op for BasicMatMul {
199 fn name(&self) -> Cow<str> {
200 "MatMul".into()
201 }
202
203 fn info(&self) -> TractResult<Vec<String>> {
204 Ok(vec![format!(
205 "transpose_a: {} transpose_b: {} transpose_c: {} q: {:?}",
206 self.transpose_a, self.transpose_b, self.transpose_c, self.quantize_output
207 )])
208 }
209
210 op_as_typed_op!();
211}
212
213impl EvalOp for BasicMatMul {
214 fn is_stateless(&self) -> bool {
215 true
216 }
217
218 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
219 let c_dt = if inputs[0].datum_type().is_number() {
220 inputs[0].datum_type()
221 } else if inputs[1].datum_type().is_number() {
222 inputs[1].datum_type()
223 } else {
224 f32::datum_type()
225 };
226 let inputs = dequant_inputs(c_dt, inputs)?;
227
228 let output_shape = self.output_shape(inputs[0].shape(), inputs[1].shape());
229
230 if let Some(qp) = self.quantize_output {
231 let mut acc = Tensor::zero_dt(i32::datum_type(), &output_shape)?;
232 let mut a_i32 = inputs[0].cast_to::<i32>()?.into_owned();
233 a_i32
234 .as_slice_mut::<i32>()?
235 .iter_mut()
236 .for_each(|x| *x -= inputs[0].datum_type().zp_scale().0);
237 let mut b_i32 = inputs[1].cast_to::<i32>()?.into_owned();
238 b_i32
239 .as_slice_mut::<i32>()?
240 .iter_mut()
241 .for_each(|x| *x -= inputs[1].datum_type().zp_scale().0);
242 self.mm::<i32>(&mut acc, &a_i32, &b_i32)?;
243 let scale = inputs[0].datum_type().zp_scale().1 * inputs[1].datum_type().zp_scale().1
244 / qp.zp_scale().1;
245 let scaler = Scaler::new(scale, tract_linalg::mmm::RoundingPolicy::Even);
246 acc.to_array_view_mut::<i32>()?.iter_mut().for_each(|x| *x = *x * scaler);
247 let mut c: Tensor = acc.cast_to_dt(qp.unquantized())?.into_owned();
248 unsafe { c.set_datum_type(qp) };
249 Ok(tvec!(c.into_tvalue()))
250 } else {
251 let mut c = Tensor::zero_dt(c_dt, &output_shape)?;
252 dispatch_floatlike!(Self::mm(c_dt)(self, &mut c, &inputs[0], &inputs[1]))?;
253 Ok(tvec!(c.into_tvalue()))
254 }
255 }
256}
257
258impl TypedOp for BasicMatMul {
259 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
260 let [a, b] = inputs else {
261 bail!("Expects 2 inputs");
262 };
263 let a_shape = block_quant_aware_input_shape(inputs[0])?;
264 let b_shape = block_quant_aware_input_shape(inputs[1])?;
265 let dt = self.quantize_output.unwrap_or(if a.datum_type.is_number() {
266 a.datum_type
267 } else {
268 b.datum_type
269 });
270 Ok(tvec!(dt.fact(self.output_shape(&a_shape, &b_shape))))
271 }
272
273 as_op!();
274}
275
276#[cfg(test)]
277mod test {
278 use super::*;
279 use proptest::collection::vec;
280 use proptest::prelude::*;
281 use proptest::test_runner::{TestCaseResult, TestRunner};
282 use tract_data::itertools::Itertools;
283
284 pub fn tensor(shape: &[usize]) -> BoxedStrategy<Tensor> {
285 let shape = shape.to_vec();
286 let len = shape.iter().product::<usize>();
287 vec((-10i8..=10i8).prop_map(|i| i as f32), len..=len)
288 .prop_map(move |vec| tensor1(&vec).into_shape(&shape).unwrap())
289 .boxed()
290 }
291
292 fn full_shapes(e: &AxesMapping) -> BoxedStrategy<(Vec<usize>, Vec<usize>)> {
293 let e = e.clone();
294 let inputs_axes = e
295 .iter_all_axes()
296 .filter(|axis| axis.inputs[0].len() + axis.inputs[1].len() > 0)
297 .cloned()
298 .collect_vec();
299 let dims = vec![2usize..6; inputs_axes.len()];
300 dims.prop_map(move |dims| {
301 let a: Vec<usize> = e
302 .axes(InOut::In(0))
303 .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
304 .collect_vec();
305 let b: Vec<usize> = e
306 .axes(InOut::In(1))
307 .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
308 .collect_vec();
309 (a, b)
310 })
311 .boxed()
312 }
313
314 fn test_expr(expr: &str) -> TestCaseResult {
315 let expr = expr.to_string();
316 let mut runner = TestRunner::default();
317 let axes: AxesMapping = expr.parse().unwrap();
318 fn is_k(axes: &AxesMapping, input: usize, position: usize) -> bool {
319 let axis = axes.axis((InOut::In(input), position)).unwrap();
320 axis.inputs[1 - input].len() == 1 && axis.outputs[0].len() == 0
321 }
322 fn is_disapearing_axis(axes: &AxesMapping, input: usize, position: usize) -> bool {
323 let axis = axes.axis((InOut::In(input), position)).unwrap();
324 axis.outputs[0].len() == 0
325 }
326 let cases = full_shapes(&axes)
327 .prop_flat_map(|(a, b)| {
328 (
329 a.iter()
330 .enumerate()
331 .map(|(ix, d)| {
332 if is_k(&axes, 0, ix) {
333 prop_oneof![Just(*d)].boxed()
334 } else if is_disapearing_axis(&axes, 0, ix) {
335 Just(1).boxed()
336 } else {
337 prop_oneof![Just(1usize), Just(*d)].boxed()
338 }
339 })
340 .collect_vec(),
341 b.iter()
342 .enumerate()
343 .map(|(ix, d)| {
344 if is_k(&axes, 1, ix) {
345 prop_oneof![Just(*d)].boxed()
346 } else if is_disapearing_axis(&axes, 1, ix) {
347 Just(1).boxed()
348 } else {
349 prop_oneof![Just(1usize), Just(*d)].boxed()
350 }
351 })
352 .collect_vec(),
353 )
354 })
355 .prop_flat_map(|(a_shape, b_shape)| (tensor(&a_shape), tensor(&b_shape)))
356 .prop_map(|(a, b)| EinSumProblem { expr: expr.clone(), a, b });
357 runner.run(&cases, |pb| pb.check().map_err(|e| TestCaseError::fail(e.to_string())))?;
358 Ok(())
359 }
360
361 #[derive(Debug, Clone, PartialEq)]
362 struct EinSumProblem {
363 expr: String,
364 a: Tensor,
365 b: Tensor,
366 }
367
368 impl EinSumProblem {
369 fn check(&self) -> TractResult<()> {
370 let mut model = TypedModel::default();
371 let sa = model.add_source("a", f32::fact(self.a.shape())).unwrap();
372 let sb = model.add_source("b", f32::fact(self.b.shape())).unwrap();
373 let einsum = model
374 .wire_node(
375 "einsum",
376 EinSum::new(self.expr.parse().unwrap(), f32::datum_type()),
377 &[sa, sb],
378 )
379 .unwrap();
380 model.set_output_outlets(&einsum).unwrap();
381 let a = self.a.clone().into_tvalue();
382 let b = self.b.clone().into_tvalue();
383 let inputs = tvec!(a, b);
384 let reference =
385 TypedRunnableModel::new(&model).unwrap().run(inputs.clone()).unwrap().remove(0);
386 rewrite_einsums_as_matmul(&mut model)?;
387 assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
388 let test = TypedRunnableModel::new(&model).unwrap().run(inputs).unwrap().remove(0);
389 reference.close_enough(&test, true).unwrap();
390 Ok(())
391 }
392 }
393
394 #[rustfmt::skip] #[test] fn prop_mk_kn_mn() -> TestCaseResult { test_expr("mk,kn->mn") }
395 #[rustfmt::skip] #[test] fn prop_km_kn_mn() -> TestCaseResult { test_expr("km,kn->mn") }
396 #[rustfmt::skip] #[test] fn prop_mk_nk_mn() -> TestCaseResult { test_expr("mk,nk->mn") }
397 #[rustfmt::skip] #[test] fn prop_mk_kn_nm() -> TestCaseResult { test_expr("mk,kn->nm") }
398 #[rustfmt::skip] #[test] fn prop_k_kn_mn() -> TestCaseResult { test_expr("k,kn->mn") }
399 #[rustfmt::skip] #[test] fn prop_mk_k_mn() -> TestCaseResult { test_expr("mk,k->mn") }
400 #[rustfmt::skip] #[test] fn prop_m_n_mn() -> TestCaseResult { test_expr("m,n->mn") }
401 #[rustfmt::skip] #[test] fn prop_amk_akn_amn() -> TestCaseResult { test_expr("amk,akn->amn") }
402 #[rustfmt::skip] #[test] fn prop_mk_akn_amn() -> TestCaseResult { test_expr("mk,akn->amn") }
403 #[rustfmt::skip] #[test] fn prop_btgi_gih_tgh() -> TestCaseResult { test_expr("btgi,gih->tgh") }
404 #[rustfmt::skip] #[test] fn prop_tgi_gih_btgh() -> TestCaseResult { test_expr("tgi,gih->btgh") }
405
406 #[test]
407 fn k_kn_mn_0() -> TractResult<()> {
408 EinSumProblem {
409 expr: "k,kn->mn".to_string(),
410 a: tensor1(&[0f32, 0f32]),
411 b: tensor2(&[[0f32, 0.], [0., 0.]]),
412 }
413 .check()
414 }
415
416 #[test]
417 fn mk_k_mn_0() -> TractResult<()> {
418 EinSumProblem {
419 expr: "mk,k->mn".to_string(),
420 a: Tensor::zero::<f32>(&[2, 2]).unwrap(),
421 b: Tensor::zero::<f32>(&[2]).unwrap(),
422 }
423 .check()
424 }
425
426 #[test]
427 fn mk_k_mn_1() -> TractResult<()> {
428 EinSumProblem {
429 expr: "mk,k->mn".to_string(),
430 a: Tensor::zero::<f32>(&[1, 2]).unwrap(),
431 b: Tensor::zero::<f32>(&[2]).unwrap(),
432 }
433 .check()
434 }
435
436 #[test]
437 fn mk_kn_nm_0() -> TractResult<()> {
438 EinSumProblem {
439 expr: "mk,kn->mn".to_string(),
440 a: Tensor::zero::<f32>(&[3, 2]).unwrap(),
441 b: Tensor::zero::<f32>(&[2, 2]).unwrap(),
442 }
443 .check()
444 }
445
446 #[test]
447 fn amk_akn_amn_0() -> TractResult<()> {
448 EinSumProblem {
449 expr: "amk,akn->amn".to_string(),
450 a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
451 b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
452 }
453 .check()
454 }
455
456 #[test]
457 fn amk_akn_amn_1() -> TractResult<()> {
458 EinSumProblem {
459 expr: "amk,akn->amn".to_string(),
460 a: Tensor::zero::<f32>(&[2, 1, 2]).unwrap(),
461 b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
462 }
463 .check()
464 }
465
466 #[test]
467 fn amk_akn_amn_2() -> TractResult<()> {
468 EinSumProblem {
469 expr: "amk,akn->amn".to_string(),
470 a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
471 b: Tensor::zero::<f32>(&[2, 2, 2]).unwrap(),
472 }
473 .check()
474 }
475
476 #[test]
477 fn amk_akn_amn_3() -> TractResult<()> {
478 EinSumProblem {
479 expr: "amk,akn->amn".to_string(),
480 a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
481 b: Tensor::zero::<f32>(&[2, 2, 1]).unwrap(),
482 }
483 .check()
484 }
485
486 #[test]
487 fn km_anbck_bmn_0() -> TractResult<()> {
488 EinSumProblem {
489 expr: "km,anbck->bmn".to_string(),
490 a: Tensor::zero::<f32>(&[2, 1]).unwrap(),
491 b: Tensor::zero::<f32>(&[1, 1, 1, 1, 2]).unwrap(),
492 }
493 .check()
494 }
495
496 #[test]
497 fn q() -> TractResult<()> {
498 let qp = QParams::ZpScale { zero_point: 0, scale: 0.1 };
499 let op = EinSum {
500 axes: "mk,kn,m,,,,,,->mn".parse()?,
501 operating_dt: i32::datum_type(),
502 q_params: Some(DatumType::QI8(qp)),
503 };
504 let mut model = TypedModelPatch::default();
505 let inputs = [
506 model.add_source("a", DatumType::QI8(qp).fact([3, 2]))?,
507 model.add_source("b", DatumType::QI8(qp).fact([2, 4]))?,
508 model.add_source("bias", i32::datum_type().fact([3]))?,
509 model.add_const("a0", tensor0(qp.zp_scale().0))?,
510 model.add_const("a_scale", tensor0(qp.zp_scale().1))?,
511 model.add_const("b0", tensor0(qp.zp_scale().0))?,
512 model.add_const("b_scale", tensor0(qp.zp_scale().1))?,
513 model.add_const("c0", tensor0(qp.zp_scale().0))?,
514 model.add_const("c_scale", tensor0(qp.zp_scale().1))?,
515 ];
516 let wire = model.wire_node("einsum", op.clone(), &inputs)?;
517 model.set_output_outlets(&wire)?;
518 rewrite_einsums_as_matmul(&mut model)?;
519 assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
520 Ok(())
521 }
522}