1use tract_data::itertools::Itertools;
2use tract_linalg::Scaler;
3use tract_ndarray::Ix2;
4use tract_num_traits::One;
5
6use super::einsum_matmul::EinSumMatMul;
7use super::eval::dequant_inputs;
8use crate::internal::*;
9use crate::ops::einsum::block_quant_aware_input_shape;
10use crate::ops::konst::Const;
11
12#[derive(Debug, Default)]
13pub struct EinSumToPrefixMatmulCtx {
14 pub ensure_strict_matmul_semantic: bool,
15}
16
17pub fn rewrite_einsum_to_prefix_matmul(
18 model: &mut TypedModel,
19 ensure_strict_matmul_semantic: bool,
20) -> TractResult<()> {
21 super::einsum_matmul::merge_consecutive_same_role_axes(model)?;
22 super::einsum_matmul::detect_all(model)?;
23 let ctx = EinSumToPrefixMatmulCtx { ensure_strict_matmul_semantic };
24 Rewriter::default().with_rule_for("einsum-to-prefix-matmul", rule).rewrite(&ctx, model)
25}
26
27fn rule(
28 ctx: &EinSumToPrefixMatmulCtx,
29 model: &TypedModel,
30 node: &TypedNode,
31 node_name: &str,
32 op: &EinSumMatMul,
33) -> TractResult<Option<TypedModelPatch>> {
34 let is_fp_mm = op.q_params.is_none() && node.inputs.len() == 2;
37 let is_q_mm = op.q_params.is_some() && node.inputs.len() == 9;
38 rule_if!(is_fp_mm || is_q_mm);
39 rule_if!(
40 op.q_params.is_none()
41 || model.node_input_facts(node.id)?.iter().skip(3).all(|i| i.konst.is_some())
42 );
43 let prefix: String = op
44 .axes
45 .iter_all_axes()
46 .filter(|a| ![op.m_axis, op.k_axis, op.n_axis].contains(&a.repr))
47 .map(|a| a.repr)
48 .collect();
49 let mut patch = TypedModelPatch::default();
50 let inputs = patch.taps(model, &node.inputs)?;
51 let mut wire = tvec!(inputs[0], inputs[1]);
52
53 let (m, k, n) = (op.m_axis, op.k_axis, op.n_axis);
54 let a_order_es: String = op.axes.axes(InOut::In(0)).map(|a| a.repr).collect();
55 let a_order_mm = format!("{prefix}{m}{k}");
56 let a_order_mm_t = format!("{prefix}{k}{m}");
57 let a_transform =
58 format!("{a_order_es}->{a_order_mm}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
59 let a_transform_t =
60 format!("{a_order_es}->{a_order_mm_t}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
61 let transpose_a = a_transform.len() > a_transform_t.len();
62 let a_transform = if transpose_a { a_transform_t } else { a_transform };
63 let name = format!("{node_name}.fix_a");
64 for op in a_transform {
65 wire[0] = patch.wire_node(&name, op, &[wire[0]])?[0];
66 }
67 if let Some(op) = patch.node_mut(wire[0].node).op_as_mut::<Const>() {
70 *op = Const::new_with_opt_exotic_fact(
71 op.val().clone(),
72 model.outlet_fact(node.inputs[0])?.exotic_fact.clone(),
73 )?;
74 }
75 patch
76 .outlet_fact_mut(wire[0])?
77 .exotic_fact
78 .clone_from(&model.outlet_fact(node.inputs[0])?.exotic_fact);
79 let b_order_es: String = op.axes.axes(InOut::In(1)).map(|a| a.repr).collect();
82 let b_order_mm = format!("{prefix}{k}{n}");
83 let b_order_mm_t = format!("{prefix}{n}{k}");
84 let b_transform =
85 format!("{b_order_es}->{b_order_mm}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
86 let b_transform_t =
87 format!("{b_order_es}->{b_order_mm_t}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
88 let transpose_b = b_transform.len() > b_transform_t.len();
89 let b_transform = if transpose_b { b_transform_t } else { b_transform };
90 let name = format!("{node_name}.fix_b");
91 for op in b_transform {
92 wire[1] = patch.wire_node(&name, op, &[wire[1]])?[0];
93 }
94
95 let c_order_es: String = op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
96 let c_order_mm = format!("{prefix}{m}{n}");
97 let c_order_mm_t = format!("{prefix}{n}{m}");
98 let c_transform =
99 format!("{c_order_mm}->{c_order_es}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
100 let c_transform_t =
101 format!("{c_order_mm_t}->{c_order_es}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
102 let transpose_c = c_transform.len() > c_transform_t.len();
103 let c_transform = if transpose_c { c_transform_t } else { c_transform };
104 let quantize_output = if let Some(qp) = op.q_params {
105 let qparams: Vec<&Tensor> = inputs[3..9]
106 .iter()
107 .map(|f| {
108 patch
109 .outlet_fact(*f)?
110 .konst
111 .as_deref()
112 .context("Can only translate fixed scalar quantization")
113 })
114 .try_collect()?;
115 Some(qp.with_qparams(QParams::ZpScale {
116 zero_point: qparams[4].cast_to_scalar::<i32>()?,
117 scale: qparams[5].cast_to_scalar::<f32>()?,
118 }))
119 } else {
120 None
121 };
122
123 let operating_dt = if ctx.ensure_strict_matmul_semantic {
124 let input_facts = model.node_input_facts(node.id)?;
125 let a_dt = input_facts[0].datum_type;
126 let b_dt = input_facts[1].datum_type;
127 let operating_dt = quantize_output.unwrap_or(op.operating_dt);
128 let a_plain = input_facts[0].is_plain();
129 let b_plain = input_facts[1].is_plain();
130 let allowed_dt = matmul_semantic_output_dt(&a_dt, a_plain, &b_dt, b_plain);
131
132 ensure!(
133 operating_dt == allowed_dt,
134 format!(
135 "Strict matmul semantic require operating_dt to be {allowed_dt:?} \
136 for (a: {a_dt:?}, b:{b_dt:?}) but got {:?}.",
137 op.operating_dt
138 )
139 );
140
141 None
142 } else {
143 Some(op.operating_dt)
144 };
145
146 wire = patch.wire_node(
147 node_name,
148 PrefixMatMul { transpose_a, transpose_b, transpose_c, quantize_output, operating_dt },
149 &wire,
150 )?;
151
152 for (ix, op) in c_transform.into_iter().enumerate() {
153 wire = patch.wire_node(format!("{node_name}.fix_c.{ix}"), op, &wire)?;
154 }
155 patch.shunt_outside(model, node.id.into(), wire[0])?;
156 Ok(Some(patch))
157}
158
159fn matmul_semantic_output_dt(
160 a_dt: &DatumType,
161 a_plain: bool,
162 b_dt: &DatumType,
163 b_plain: bool,
164) -> DatumType {
165 if a_plain && a_dt.is_number() {
166 *a_dt
167 } else if b_plain && b_dt.is_number() {
168 *b_dt
169 } else if a_dt.is_number() {
170 *a_dt
171 } else if b_dt.is_number() {
172 *b_dt
173 } else {
174 f32::datum_type()
175 }
176}
177
178#[derive(Clone, Debug, Copy, PartialEq, Eq)]
179pub struct PrefixMatMul {
180 pub transpose_a: bool,
181 pub transpose_b: bool,
182 pub transpose_c: bool,
183 pub quantize_output: Option<DatumType>,
184 pub operating_dt: Option<DatumType>,
185}
186
187impl PrefixMatMul {
188 fn output_shape<D: DimLike + One>(&self, a: &[D], b: &[D]) -> TVec<D> {
189 let rank = a.len();
190 let mut output: TVec<D> = (0..rank - 2)
191 .map(|ix| if a[ix].is_one() { b[ix].clone() } else { a[ix].clone() })
192 .collect();
193 output.push(a[rank - 2 + self.transpose_a as usize].clone());
194 output.push(b[rank - 2 + !self.transpose_b as usize].clone());
195 if self.transpose_c {
196 output.swap(rank - 2, rank - 1);
197 }
198 output
199 }
200
201 fn mm<Acc: Datum + tract_ndarray::LinalgScalar>(
202 &self,
203 acc: &mut Tensor,
204 a: &Tensor,
205 b: &Tensor,
206 ) -> TractResult<()> {
207 use crate::ndarray::Dimension;
208 let casted_a = a.cast_to::<Acc>()?;
209 let a = casted_a.to_plain_array_view::<Acc>()?;
210 let casted_b = b.cast_to::<Acc>()?;
211 let b = casted_b.to_plain_array_view::<Acc>()?;
212 let mut c_plain = acc.try_as_plain_mut()?;
213 let mut c = c_plain.to_array_view_mut::<Acc>()?;
214 for prefix in tract_ndarray::indices(&c.shape()[..c.ndim() - 2]) {
215 let mut a = a.view();
216 let mut b = b.view();
217 let mut c = c.view_mut();
218 for &d in prefix.slice().iter() {
219 a.index_axis_inplace(tract_ndarray::Axis(0), d.min(a.shape()[0] - 1));
220 b.index_axis_inplace(tract_ndarray::Axis(0), d.min(b.shape()[0] - 1));
221 c.index_axis_inplace(tract_ndarray::Axis(0), d);
222 }
223 let a = a.into_dimensionality::<Ix2>().unwrap();
224 let b = b.into_dimensionality::<Ix2>().unwrap();
225 let mut c = c.into_dimensionality::<Ix2>().unwrap();
226 let a = if self.transpose_a { a.t() } else { a };
227 let b = if self.transpose_b { b.t() } else { b };
228 if self.transpose_c { c.assign(&b.t().dot(&a.t())) } else { c.assign(&a.dot(&b)) }
229 }
230 Ok(())
231 }
232}
233
234impl Op for PrefixMatMul {
235 fn name(&self) -> StaticName {
236 "PrefixMatMul".into()
237 }
238
239 fn info(&self) -> TractResult<Vec<String>> {
240 Ok(vec![format!(
241 "transpose_a: {} transpose_b: {} transpose_c: {} q: {:?}",
242 self.transpose_a, self.transpose_b, self.transpose_c, self.quantize_output
243 )])
244 }
245
246 op_as_typed_op!();
247}
248
249impl EvalOp for PrefixMatMul {
250 fn is_stateless(&self) -> bool {
251 true
252 }
253
254 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
255 let c_dt = self.operating_dt.unwrap_or_else(|| {
256 let a_dt = inputs[0].datum_type();
257 let b_dt = inputs[1].datum_type();
258 matmul_semantic_output_dt(&a_dt, inputs[0].is_plain(), &b_dt, inputs[1].is_plain())
259 });
260
261 let inputs = dequant_inputs(c_dt, inputs)?;
262
263 let output_shape = self.output_shape(inputs[0].shape(), inputs[1].shape());
264
265 if let Some(qp) = self.quantize_output {
266 let mut acc = Tensor::zero_dt(i32::datum_type(), &output_shape)?;
267 let mut a_i32 = inputs[0].cast_to::<i32>()?.into_owned();
268 a_i32
269 .try_as_plain_mut()?
270 .as_slice_mut::<i32>()?
271 .iter_mut()
272 .for_each(|x| *x -= inputs[0].datum_type().zp_scale().0);
273 let mut b_i32 = inputs[1].cast_to::<i32>()?.into_owned();
274 b_i32
275 .try_as_plain_mut()?
276 .as_slice_mut::<i32>()?
277 .iter_mut()
278 .for_each(|x| *x -= inputs[1].datum_type().zp_scale().0);
279 self.mm::<i32>(&mut acc, &a_i32, &b_i32)?;
280 let scale = inputs[0].datum_type().zp_scale().1 * inputs[1].datum_type().zp_scale().1
281 / qp.zp_scale().1;
282 let scaler = Scaler::new(scale, tract_linalg::mmm::RoundingPolicy::Even);
283 acc.to_plain_array_view_mut::<i32>()?.iter_mut().for_each(|x| *x = *x * scaler);
284 let mut c: Tensor = acc.cast_to_dt(qp.unquantized())?.into_owned();
285 unsafe { c.set_datum_type(qp) };
286 Ok(tvec!(c.into_tvalue()))
287 } else {
288 let mut c = Tensor::zero_dt(c_dt, &output_shape)?;
289 dispatch_floatlike!(Self::mm(c_dt)(self, &mut c, &inputs[0], &inputs[1]))?;
290 Ok(tvec!(c.into_tvalue()))
291 }
292 }
293}
294
295impl TypedOp for PrefixMatMul {
296 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
297 let [a, b] = inputs else {
298 bail!("Expects 2 inputs");
299 };
300 let a_shape = block_quant_aware_input_shape(a)?;
301 let b_shape = block_quant_aware_input_shape(b)?;
302 let dt = self.quantize_output.or(self.operating_dt).unwrap_or(matmul_semantic_output_dt(
303 &a.datum_type,
304 a.is_plain(),
305 &b.datum_type,
306 b.is_plain(),
307 ));
308 Ok(tvec!(dt.fact(self.output_shape(&a_shape, &b_shape))))
309 }
310
311 as_op!();
312}
313
314#[cfg(test)]
315mod test {
316 use crate::ops::einsum::EinSum;
317
318 use super::*;
319 use proptest::collection::vec;
320 use proptest::prelude::*;
321 use proptest::test_runner::{TestCaseResult, TestRunner};
322 use tract_data::itertools::Itertools;
323
324 pub fn tensor(shape: &[usize]) -> BoxedStrategy<Tensor> {
325 let shape = shape.to_vec();
326 let len = shape.iter().product::<usize>();
327 vec((-10i8..=10i8).prop_map(|i| i as f32), len..=len)
328 .prop_map(move |vec| tensor1(&vec).into_shape(&shape).unwrap())
329 .boxed()
330 }
331
332 fn full_shapes(e: &AxesMapping) -> BoxedStrategy<(Vec<usize>, Vec<usize>)> {
333 let e = e.clone();
334 let inputs_axes = e
335 .iter_all_axes()
336 .filter(|axis| axis.inputs[0].len() + axis.inputs[1].len() > 0)
337 .cloned()
338 .collect_vec();
339 let dims = vec![2usize..6; inputs_axes.len()];
340 dims.prop_map(move |dims| {
341 let a: Vec<usize> = e
342 .axes(InOut::In(0))
343 .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
344 .collect_vec();
345 let b: Vec<usize> = e
346 .axes(InOut::In(1))
347 .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
348 .collect_vec();
349 (a, b)
350 })
351 .boxed()
352 }
353
354 fn test_expr(expr: &str) -> TestCaseResult {
355 let expr = expr.to_string();
356 let mut runner = TestRunner::default();
357 let axes: AxesMapping = expr.parse().unwrap();
358 fn is_k(axes: &AxesMapping, input: usize, position: usize) -> bool {
359 let axis = axes.axis((InOut::In(input), position)).unwrap();
360 axis.inputs[1 - input].len() == 1 && axis.outputs[0].len() == 0
361 }
362 fn is_disapearing_axis(axes: &AxesMapping, input: usize, position: usize) -> bool {
363 let axis = axes.axis((InOut::In(input), position)).unwrap();
364 axis.outputs[0].len() == 0
365 }
366 let cases = full_shapes(&axes)
367 .prop_flat_map(|(a, b)| {
368 (
369 a.iter()
370 .enumerate()
371 .map(|(ix, d)| {
372 if is_k(&axes, 0, ix) {
373 prop_oneof![Just(*d)].boxed()
374 } else if is_disapearing_axis(&axes, 0, ix) {
375 Just(1).boxed()
376 } else {
377 prop_oneof![Just(1usize), Just(*d)].boxed()
378 }
379 })
380 .collect_vec(),
381 b.iter()
382 .enumerate()
383 .map(|(ix, d)| {
384 if is_k(&axes, 1, ix) {
385 prop_oneof![Just(*d)].boxed()
386 } else if is_disapearing_axis(&axes, 1, ix) {
387 Just(1).boxed()
388 } else {
389 prop_oneof![Just(1usize), Just(*d)].boxed()
390 }
391 })
392 .collect_vec(),
393 )
394 })
395 .prop_flat_map(|(a_shape, b_shape)| (tensor(&a_shape), tensor(&b_shape)))
396 .prop_map(|(a, b)| EinSumProblem { expr: expr.clone(), a, b });
397 runner.run(&cases, |pb| pb.check().map_err(|e| TestCaseError::fail(e.to_string())))?;
398 Ok(())
399 }
400
401 #[derive(Debug, Clone, PartialEq, Eq)]
402 struct EinSumProblem {
403 expr: String,
404 a: Tensor,
405 b: Tensor,
406 }
407
408 impl EinSumProblem {
409 fn check(&self) -> TractResult<()> {
410 let mut model = TypedModel::default();
411 let sa = model.add_source("a", f32::fact(self.a.shape()))?;
412 let sb = model.add_source("b", f32::fact(self.b.shape()))?;
413 let einsum = model.wire_node(
414 "einsum",
415 EinSum::new(self.expr.parse().unwrap(), f32::datum_type()),
416 &[sa, sb],
417 )?;
418 model.select_output_outlets(&einsum)?;
419 let a = self.a.clone().into_tvalue();
420 let b = self.b.clone().into_tvalue();
421 let inputs = tvec!(a, b);
422 let reference = TypedRunnableModel::new(model.clone())?.run(inputs.clone())?.remove(0);
423 rewrite_einsum_to_prefix_matmul(&mut model, true)?;
424 assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
425 let test = TypedRunnableModel::new(model)?.run(inputs)?.remove(0);
426 reference.close_enough(&test, true)
427 }
428 }
429
430 #[rustfmt::skip] #[test] fn prop_mk_kn_mn() -> TestCaseResult { test_expr("mk,kn->mn") }
431 #[rustfmt::skip] #[test] fn prop_km_kn_mn() -> TestCaseResult { test_expr("km,kn->mn") }
432 #[rustfmt::skip] #[test] fn prop_mk_nk_mn() -> TestCaseResult { test_expr("mk,nk->mn") }
433 #[rustfmt::skip] #[test] fn prop_mk_kn_nm() -> TestCaseResult { test_expr("mk,kn->nm") }
434 #[rustfmt::skip] #[test] fn prop_k_kn_mn() -> TestCaseResult { test_expr("k,kn->mn") }
435 #[rustfmt::skip] #[test] fn prop_mk_k_mn() -> TestCaseResult { test_expr("mk,k->mn") }
436 #[rustfmt::skip] #[test] fn prop_m_n_mn() -> TestCaseResult { test_expr("m,n->mn") }
437 #[rustfmt::skip] #[test] fn prop_amk_akn_amn() -> TestCaseResult { test_expr("amk,akn->amn") }
438 #[rustfmt::skip] #[test] fn prop_mk_akn_amn() -> TestCaseResult { test_expr("mk,akn->amn") }
439 #[rustfmt::skip] #[test] fn prop_btgi_gih_tgh() -> TestCaseResult { test_expr("btgi,gih->tgh") }
440 #[rustfmt::skip] #[test] fn prop_tgi_gih_btgh() -> TestCaseResult { test_expr("tgi,gih->btgh") }
441
442 #[test]
443 fn k_kn_mn_0() -> TractResult<()> {
444 EinSumProblem {
445 expr: "k,kn->mn".to_string(),
446 a: tensor1(&[0f32, 0f32]),
447 b: tensor2(&[[0f32, 0.], [0., 0.]]),
448 }
449 .check()
450 }
451
452 #[test]
453 fn mk_k_mn_0() -> TractResult<()> {
454 EinSumProblem {
455 expr: "mk,k->mn".to_string(),
456 a: Tensor::zero::<f32>(&[2, 2]).unwrap(),
457 b: Tensor::zero::<f32>(&[2]).unwrap(),
458 }
459 .check()
460 }
461
462 #[test]
463 fn mk_k_mn_1() -> TractResult<()> {
464 EinSumProblem {
465 expr: "mk,k->mn".to_string(),
466 a: Tensor::zero::<f32>(&[1, 2]).unwrap(),
467 b: Tensor::zero::<f32>(&[2]).unwrap(),
468 }
469 .check()
470 }
471
472 #[test]
473 fn mk_kn_nm_0() -> TractResult<()> {
474 EinSumProblem {
475 expr: "mk,kn->mn".to_string(),
476 a: Tensor::zero::<f32>(&[3, 2]).unwrap(),
477 b: Tensor::zero::<f32>(&[2, 2]).unwrap(),
478 }
479 .check()
480 }
481
482 #[test]
483 fn amk_akn_amn_0() -> TractResult<()> {
484 EinSumProblem {
485 expr: "amk,akn->amn".to_string(),
486 a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
487 b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
488 }
489 .check()
490 }
491
492 #[test]
493 fn amk_akn_amn_1() -> TractResult<()> {
494 EinSumProblem {
495 expr: "amk,akn->amn".to_string(),
496 a: Tensor::zero::<f32>(&[2, 1, 2]).unwrap(),
497 b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
498 }
499 .check()
500 }
501
502 #[test]
503 fn amk_akn_amn_2() -> TractResult<()> {
504 EinSumProblem {
505 expr: "amk,akn->amn".to_string(),
506 a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
507 b: Tensor::zero::<f32>(&[2, 2, 2]).unwrap(),
508 }
509 .check()
510 }
511
512 #[test]
513 fn amk_akn_amn_3() -> TractResult<()> {
514 EinSumProblem {
515 expr: "amk,akn->amn".to_string(),
516 a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
517 b: Tensor::zero::<f32>(&[2, 2, 1]).unwrap(),
518 }
519 .check()
520 }
521
522 #[test]
523 fn km_anbck_bmn_0() -> TractResult<()> {
524 EinSumProblem {
525 expr: "km,anbck->bmn".to_string(),
526 a: Tensor::zero::<f32>(&[2, 1]).unwrap(),
527 b: Tensor::zero::<f32>(&[1, 1, 1, 1, 2]).unwrap(),
528 }
529 .check()
530 }
531
532 #[test]
533 fn q() -> TractResult<()> {
534 let qp = QParams::ZpScale { zero_point: 0, scale: 0.1 };
535 let op = EinSum {
536 axes: "mk,kn,m,,,,,,->mn".parse()?,
537 operating_dt: i32::datum_type(),
538 q_params: Some(DatumType::QI8(qp)),
539 };
540 let mut model = TypedModelPatch::default();
541 let inputs = [
542 model.add_source("a", DatumType::QI8(qp).fact([3, 2]))?,
543 model.add_source("b", DatumType::QI8(qp).fact([2, 4]))?,
544 model.add_source("bias", i32::datum_type().fact([3]))?,
545 model.add_const("a0", tensor0(qp.zp_scale().0))?,
546 model.add_const("a_scale", tensor0(qp.zp_scale().1))?,
547 model.add_const("b0", tensor0(qp.zp_scale().0))?,
548 model.add_const("b_scale", tensor0(qp.zp_scale().1))?,
549 model.add_const("c0", tensor0(qp.zp_scale().0))?,
550 model.add_const("c_scale", tensor0(qp.zp_scale().1))?,
551 ];
552 let wire = model.wire_node("einsum", op.clone(), &inputs)?;
553 model.select_output_outlets(&wire)?;
554 rewrite_einsum_to_prefix_matmul(&mut model, true)?;
555 assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
556 Ok(())
557 }
558}