1use tract_data::itertools::Itertools;
2use tract_linalg::Scaler;
3use tract_ndarray::Ix2;
4use tract_num_traits::One;
5
6use super::einsum_matmul::EinSumMatMul;
7use super::eval::dequant_inputs;
8use crate::internal::*;
9use crate::ops::einsum::block_quant_aware_input_shape;
10use crate::ops::konst::Const;
11
12#[derive(Debug, Default)]
13pub struct EinSumToPrefixMatmulCtx {
14 pub ensure_strict_matmul_semantic: bool,
15}
16
17pub fn rewrite_einsum_to_prefix_matmul(
18 model: &mut TypedModel,
19 ensure_strict_matmul_semantic: bool,
20) -> TractResult<()> {
21 super::einsum_matmul::detect_all(model)?;
22 let ctx = EinSumToPrefixMatmulCtx { ensure_strict_matmul_semantic };
23 Rewriter::default().with_rule_for("einsum-to-prefix-matmul", rule).rewrite(&ctx, model)
24}
25
26fn rule(
27 ctx: &EinSumToPrefixMatmulCtx,
28 model: &TypedModel,
29 node: &TypedNode,
30 node_name: &str,
31 op: &EinSumMatMul,
32) -> TractResult<Option<TypedModelPatch>> {
33 let is_fp_mm = op.q_params.is_none() && node.inputs.len() == 2;
36 let is_q_mm = op.q_params.is_some() && node.inputs.len() == 9;
37 rule_if!(is_fp_mm || is_q_mm);
38 rule_if!(
39 op.q_params.is_none()
40 || model.node_input_facts(node.id)?.iter().skip(3).all(|i| i.konst.is_some())
41 );
42 let prefix: String = op
43 .axes
44 .iter_all_axes()
45 .filter(|a| ![op.m_axis, op.k_axis, op.n_axis].contains(&a.repr))
46 .map(|a| a.repr)
47 .collect();
48 let mut patch = TypedModelPatch::default();
49 let inputs = patch.taps(model, &node.inputs)?;
50 let mut wire = tvec!(inputs[0], inputs[1]);
51
52 let (m, k, n) = (op.m_axis, op.k_axis, op.n_axis);
53 let a_order_es: String = op.axes.axes(InOut::In(0)).map(|a| a.repr).collect();
54 let a_order_mm = format!("{prefix}{m}{k}");
55 let a_order_mm_t = format!("{prefix}{k}{m}");
56 let a_transform =
57 format!("{a_order_es}->{a_order_mm}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
58 let a_transform_t =
59 format!("{a_order_es}->{a_order_mm_t}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
60 let transpose_a = a_transform.len() > a_transform_t.len();
61 let a_transform = if transpose_a { a_transform_t } else { a_transform };
62 let name = format!("{node_name}.fix_a");
63 for op in a_transform {
64 wire[0] = patch.wire_node(&name, op, &[wire[0]])?[0];
65 }
66 if let Some(op) = patch.node_mut(wire[0].node).op_as_mut::<Const>() {
69 *op = Const::new_with_opt_opaque_fact(
70 op.val().clone(),
71 model.outlet_fact(node.inputs[0])?.opaque_fact.clone(),
72 )?;
73 }
74 patch
75 .outlet_fact_mut(wire[0])?
76 .opaque_fact
77 .clone_from(&model.outlet_fact(node.inputs[0])?.opaque_fact);
78 let b_order_es: String = op.axes.axes(InOut::In(1)).map(|a| a.repr).collect();
81 let b_order_mm = format!("{prefix}{k}{n}");
82 let b_order_mm_t = format!("{prefix}{n}{k}");
83 let b_transform =
84 format!("{b_order_es}->{b_order_mm}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
85 let b_transform_t =
86 format!("{b_order_es}->{b_order_mm_t}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
87 let transpose_b = b_transform.len() > b_transform_t.len();
88 let b_transform = if transpose_b { b_transform_t } else { b_transform };
89 let name = format!("{node_name}.fix_b");
90 for op in b_transform {
91 wire[1] = patch.wire_node(&name, op, &[wire[1]])?[0];
92 }
93
94 let c_order_es: String = op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
95 let c_order_mm = format!("{prefix}{m}{n}");
96 let c_order_mm_t = format!("{prefix}{n}{m}");
97 let c_transform =
98 format!("{c_order_mm}->{c_order_es}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
99 let c_transform_t =
100 format!("{c_order_mm_t}->{c_order_es}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
101 let transpose_c = c_transform.len() > c_transform_t.len();
102 let c_transform = if transpose_c { c_transform_t } else { c_transform };
103 let quantize_output = if let Some(qp) = op.q_params {
104 let qparams: Vec<&Tensor> = inputs[3..9]
105 .iter()
106 .map(|f| {
107 patch
108 .outlet_fact(*f)?
109 .konst
110 .as_deref()
111 .context("Can only translate fixed scalar quantization")
112 })
113 .try_collect()?;
114 Some(qp.with_qparams(QParams::ZpScale {
115 zero_point: qparams[4].cast_to_scalar::<i32>()?,
116 scale: qparams[5].cast_to_scalar::<f32>()?,
117 }))
118 } else {
119 None
120 };
121
122 let operating_dt = if ctx.ensure_strict_matmul_semantic {
123 let input_facts = model.node_input_facts(node.id)?;
124 let a_dt = input_facts[0].datum_type;
125 let b_dt = input_facts[1].datum_type;
126 let operating_dt = quantize_output.unwrap_or(op.operating_dt);
127 let allowed_dt = matmul_semantic_output_dt(&a_dt, &b_dt);
128
129 ensure!(
130 operating_dt == allowed_dt,
131 format!(
132 "Strict matmul semantic require operating_dt to be {allowed_dt:?} \
133 for (a: {a_dt:?}, b:{b_dt:?}) but got {:?}.",
134 op.operating_dt
135 )
136 );
137
138 None
139 } else {
140 Some(op.operating_dt)
141 };
142
143 wire = patch.wire_node(
144 node_name,
145 PrefixMatMul { transpose_a, transpose_b, transpose_c, quantize_output, operating_dt },
146 &wire,
147 )?;
148
149 for (ix, op) in c_transform.into_iter().enumerate() {
150 wire = patch.wire_node(format!("{node_name}.fix_c.{ix}"), op, &wire)?;
151 }
152 patch.shunt_outside(model, node.id.into(), wire[0])?;
153 Ok(Some(patch))
154}
155
156fn matmul_semantic_output_dt(a_dt: &DatumType, b_dt: &DatumType) -> DatumType {
157 if a_dt.is_number() {
158 *a_dt
159 } else if b_dt.is_number() {
160 *b_dt
161 } else {
162 f32::datum_type()
163 }
164}
165
166#[derive(Clone, Debug, Copy)]
167pub struct PrefixMatMul {
168 pub transpose_a: bool,
169 pub transpose_b: bool,
170 pub transpose_c: bool,
171 pub quantize_output: Option<DatumType>,
172 pub operating_dt: Option<DatumType>,
173}
174
175impl PrefixMatMul {
176 fn output_shape<D: DimLike + One>(&self, a: &[D], b: &[D]) -> TVec<D> {
177 let rank = a.len();
178 let mut output: TVec<D> = (0..rank - 2)
179 .map(|ix| if a[ix].is_one() { b[ix].clone() } else { a[ix].clone() })
180 .collect();
181 output.push(a[rank - 2 + self.transpose_a as usize].clone());
182 output.push(b[rank - 2 + !self.transpose_b as usize].clone());
183 if self.transpose_c {
184 output.swap(rank - 2, rank - 1);
185 }
186 output
187 }
188
189 fn mm<Acc: Datum + tract_ndarray::LinalgScalar>(
190 &self,
191 acc: &mut Tensor,
192 a: &Tensor,
193 b: &Tensor,
194 ) -> TractResult<()> {
195 use crate::ndarray::Dimension;
196 let casted_a = a.cast_to::<Acc>()?;
197 let a = casted_a.to_dense_array_view::<Acc>()?;
198 let casted_b = b.cast_to::<Acc>()?;
199 let b = casted_b.to_dense_array_view::<Acc>()?;
200 let mut c_dense = acc.try_as_dense_mut()?;
201 let mut c = c_dense.to_array_view_mut::<Acc>()?;
202 for prefix in tract_ndarray::indices(&c.shape()[..c.ndim() - 2]) {
203 let mut a = a.view();
204 let mut b = b.view();
205 let mut c = c.view_mut();
206 for &d in prefix.slice().iter() {
207 a.index_axis_inplace(tract_ndarray::Axis(0), d.min(a.shape()[0] - 1));
208 b.index_axis_inplace(tract_ndarray::Axis(0), d.min(b.shape()[0] - 1));
209 c.index_axis_inplace(tract_ndarray::Axis(0), d);
210 }
211 let a = a.into_dimensionality::<Ix2>().unwrap();
212 let b = b.into_dimensionality::<Ix2>().unwrap();
213 let mut c = c.into_dimensionality::<Ix2>().unwrap();
214 let a = if self.transpose_a { a.t() } else { a };
215 let b = if self.transpose_b { b.t() } else { b };
216 if self.transpose_c { c.assign(&b.t().dot(&a.t())) } else { c.assign(&a.dot(&b)) }
217 }
218 Ok(())
219 }
220}
221
222impl Op for PrefixMatMul {
223 fn name(&self) -> StaticName {
224 "PrefixMatMul".into()
225 }
226
227 fn info(&self) -> TractResult<Vec<String>> {
228 Ok(vec![format!(
229 "transpose_a: {} transpose_b: {} transpose_c: {} q: {:?}",
230 self.transpose_a, self.transpose_b, self.transpose_c, self.quantize_output
231 )])
232 }
233
234 op_as_typed_op!();
235}
236
237impl EvalOp for PrefixMatMul {
238 fn is_stateless(&self) -> bool {
239 true
240 }
241
242 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
243 let c_dt = self.operating_dt.unwrap_or_else(|| {
244 let a_dt = inputs[0].datum_type();
245 let b_dt = inputs[1].datum_type();
246 matmul_semantic_output_dt(&a_dt, &b_dt)
247 });
248
249 let inputs = dequant_inputs(c_dt, inputs)?;
250
251 let output_shape = self.output_shape(inputs[0].shape(), inputs[1].shape());
252
253 if let Some(qp) = self.quantize_output {
254 let mut acc = Tensor::zero_dt(i32::datum_type(), &output_shape)?;
255 let mut a_i32 = inputs[0].cast_to::<i32>()?.into_owned();
256 a_i32
257 .try_as_dense_mut()?
258 .as_slice_mut::<i32>()?
259 .iter_mut()
260 .for_each(|x| *x -= inputs[0].datum_type().zp_scale().0);
261 let mut b_i32 = inputs[1].cast_to::<i32>()?.into_owned();
262 b_i32
263 .try_as_dense_mut()?
264 .as_slice_mut::<i32>()?
265 .iter_mut()
266 .for_each(|x| *x -= inputs[1].datum_type().zp_scale().0);
267 self.mm::<i32>(&mut acc, &a_i32, &b_i32)?;
268 let scale = inputs[0].datum_type().zp_scale().1 * inputs[1].datum_type().zp_scale().1
269 / qp.zp_scale().1;
270 let scaler = Scaler::new(scale, tract_linalg::mmm::RoundingPolicy::Even);
271 acc.to_dense_array_view_mut::<i32>()?.iter_mut().for_each(|x| *x = *x * scaler);
272 let mut c: Tensor = acc.cast_to_dt(qp.unquantized())?.into_owned();
273 unsafe { c.set_datum_type(qp) };
274 Ok(tvec!(c.into_tvalue()))
275 } else {
276 let mut c = Tensor::zero_dt(c_dt, &output_shape)?;
277 dispatch_floatlike!(Self::mm(c_dt)(self, &mut c, &inputs[0], &inputs[1]))?;
278 Ok(tvec!(c.into_tvalue()))
279 }
280 }
281}
282
283impl TypedOp for PrefixMatMul {
284 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
285 let [a, b] = inputs else {
286 bail!("Expects 2 inputs");
287 };
288 let a_shape = block_quant_aware_input_shape(a)?;
289 let b_shape = block_quant_aware_input_shape(b)?;
290 let dt = self
291 .quantize_output
292 .or(self.operating_dt)
293 .unwrap_or(matmul_semantic_output_dt(&a.datum_type, &b.datum_type));
294 Ok(tvec!(dt.fact(self.output_shape(&a_shape, &b_shape))))
295 }
296
297 as_op!();
298}
299
300#[cfg(test)]
301mod test {
302 use crate::ops::einsum::EinSum;
303
304 use super::*;
305 use proptest::collection::vec;
306 use proptest::prelude::*;
307 use proptest::test_runner::{TestCaseResult, TestRunner};
308 use tract_data::itertools::Itertools;
309
310 pub fn tensor(shape: &[usize]) -> BoxedStrategy<Tensor> {
311 let shape = shape.to_vec();
312 let len = shape.iter().product::<usize>();
313 vec((-10i8..=10i8).prop_map(|i| i as f32), len..=len)
314 .prop_map(move |vec| tensor1(&vec).into_shape(&shape).unwrap())
315 .boxed()
316 }
317
318 fn full_shapes(e: &AxesMapping) -> BoxedStrategy<(Vec<usize>, Vec<usize>)> {
319 let e = e.clone();
320 let inputs_axes = e
321 .iter_all_axes()
322 .filter(|axis| axis.inputs[0].len() + axis.inputs[1].len() > 0)
323 .cloned()
324 .collect_vec();
325 let dims = vec![2usize..6; inputs_axes.len()];
326 dims.prop_map(move |dims| {
327 let a: Vec<usize> = e
328 .axes(InOut::In(0))
329 .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
330 .collect_vec();
331 let b: Vec<usize> = e
332 .axes(InOut::In(1))
333 .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
334 .collect_vec();
335 (a, b)
336 })
337 .boxed()
338 }
339
340 fn test_expr(expr: &str) -> TestCaseResult {
341 let expr = expr.to_string();
342 let mut runner = TestRunner::default();
343 let axes: AxesMapping = expr.parse().unwrap();
344 fn is_k(axes: &AxesMapping, input: usize, position: usize) -> bool {
345 let axis = axes.axis((InOut::In(input), position)).unwrap();
346 axis.inputs[1 - input].len() == 1 && axis.outputs[0].len() == 0
347 }
348 fn is_disapearing_axis(axes: &AxesMapping, input: usize, position: usize) -> bool {
349 let axis = axes.axis((InOut::In(input), position)).unwrap();
350 axis.outputs[0].len() == 0
351 }
352 let cases = full_shapes(&axes)
353 .prop_flat_map(|(a, b)| {
354 (
355 a.iter()
356 .enumerate()
357 .map(|(ix, d)| {
358 if is_k(&axes, 0, ix) {
359 prop_oneof![Just(*d)].boxed()
360 } else if is_disapearing_axis(&axes, 0, ix) {
361 Just(1).boxed()
362 } else {
363 prop_oneof![Just(1usize), Just(*d)].boxed()
364 }
365 })
366 .collect_vec(),
367 b.iter()
368 .enumerate()
369 .map(|(ix, d)| {
370 if is_k(&axes, 1, ix) {
371 prop_oneof![Just(*d)].boxed()
372 } else if is_disapearing_axis(&axes, 1, ix) {
373 Just(1).boxed()
374 } else {
375 prop_oneof![Just(1usize), Just(*d)].boxed()
376 }
377 })
378 .collect_vec(),
379 )
380 })
381 .prop_flat_map(|(a_shape, b_shape)| (tensor(&a_shape), tensor(&b_shape)))
382 .prop_map(|(a, b)| EinSumProblem { expr: expr.clone(), a, b });
383 runner.run(&cases, |pb| pb.check().map_err(|e| TestCaseError::fail(e.to_string())))?;
384 Ok(())
385 }
386
387 #[derive(Debug, Clone, PartialEq)]
388 struct EinSumProblem {
389 expr: String,
390 a: Tensor,
391 b: Tensor,
392 }
393
394 impl EinSumProblem {
395 fn check(&self) -> TractResult<()> {
396 let mut model = TypedModel::default();
397 let sa = model.add_source("a", f32::fact(self.a.shape()))?;
398 let sb = model.add_source("b", f32::fact(self.b.shape()))?;
399 let einsum = model.wire_node(
400 "einsum",
401 EinSum::new(self.expr.parse().unwrap(), f32::datum_type()),
402 &[sa, sb],
403 )?;
404 model.set_output_outlets(&einsum)?;
405 let a = self.a.clone().into_tvalue();
406 let b = self.b.clone().into_tvalue();
407 let inputs = tvec!(a, b);
408 let reference = TypedRunnableModel::new(model.clone())?.run(inputs.clone())?.remove(0);
409 rewrite_einsum_to_prefix_matmul(&mut model, true)?;
410 assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
411 let test = TypedRunnableModel::new(model)?.run(inputs)?.remove(0);
412 reference.close_enough(&test, true)
413 }
414 }
415
416 #[rustfmt::skip] #[test] fn prop_mk_kn_mn() -> TestCaseResult { test_expr("mk,kn->mn") }
417 #[rustfmt::skip] #[test] fn prop_km_kn_mn() -> TestCaseResult { test_expr("km,kn->mn") }
418 #[rustfmt::skip] #[test] fn prop_mk_nk_mn() -> TestCaseResult { test_expr("mk,nk->mn") }
419 #[rustfmt::skip] #[test] fn prop_mk_kn_nm() -> TestCaseResult { test_expr("mk,kn->nm") }
420 #[rustfmt::skip] #[test] fn prop_k_kn_mn() -> TestCaseResult { test_expr("k,kn->mn") }
421 #[rustfmt::skip] #[test] fn prop_mk_k_mn() -> TestCaseResult { test_expr("mk,k->mn") }
422 #[rustfmt::skip] #[test] fn prop_m_n_mn() -> TestCaseResult { test_expr("m,n->mn") }
423 #[rustfmt::skip] #[test] fn prop_amk_akn_amn() -> TestCaseResult { test_expr("amk,akn->amn") }
424 #[rustfmt::skip] #[test] fn prop_mk_akn_amn() -> TestCaseResult { test_expr("mk,akn->amn") }
425 #[rustfmt::skip] #[test] fn prop_btgi_gih_tgh() -> TestCaseResult { test_expr("btgi,gih->tgh") }
426 #[rustfmt::skip] #[test] fn prop_tgi_gih_btgh() -> TestCaseResult { test_expr("tgi,gih->btgh") }
427
428 #[test]
429 fn k_kn_mn_0() -> TractResult<()> {
430 EinSumProblem {
431 expr: "k,kn->mn".to_string(),
432 a: tensor1(&[0f32, 0f32]),
433 b: tensor2(&[[0f32, 0.], [0., 0.]]),
434 }
435 .check()
436 }
437
438 #[test]
439 fn mk_k_mn_0() -> TractResult<()> {
440 EinSumProblem {
441 expr: "mk,k->mn".to_string(),
442 a: Tensor::zero::<f32>(&[2, 2]).unwrap(),
443 b: Tensor::zero::<f32>(&[2]).unwrap(),
444 }
445 .check()
446 }
447
448 #[test]
449 fn mk_k_mn_1() -> TractResult<()> {
450 EinSumProblem {
451 expr: "mk,k->mn".to_string(),
452 a: Tensor::zero::<f32>(&[1, 2]).unwrap(),
453 b: Tensor::zero::<f32>(&[2]).unwrap(),
454 }
455 .check()
456 }
457
458 #[test]
459 fn mk_kn_nm_0() -> TractResult<()> {
460 EinSumProblem {
461 expr: "mk,kn->mn".to_string(),
462 a: Tensor::zero::<f32>(&[3, 2]).unwrap(),
463 b: Tensor::zero::<f32>(&[2, 2]).unwrap(),
464 }
465 .check()
466 }
467
468 #[test]
469 fn amk_akn_amn_0() -> TractResult<()> {
470 EinSumProblem {
471 expr: "amk,akn->amn".to_string(),
472 a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
473 b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
474 }
475 .check()
476 }
477
478 #[test]
479 fn amk_akn_amn_1() -> TractResult<()> {
480 EinSumProblem {
481 expr: "amk,akn->amn".to_string(),
482 a: Tensor::zero::<f32>(&[2, 1, 2]).unwrap(),
483 b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
484 }
485 .check()
486 }
487
488 #[test]
489 fn amk_akn_amn_2() -> TractResult<()> {
490 EinSumProblem {
491 expr: "amk,akn->amn".to_string(),
492 a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
493 b: Tensor::zero::<f32>(&[2, 2, 2]).unwrap(),
494 }
495 .check()
496 }
497
498 #[test]
499 fn amk_akn_amn_3() -> TractResult<()> {
500 EinSumProblem {
501 expr: "amk,akn->amn".to_string(),
502 a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
503 b: Tensor::zero::<f32>(&[2, 2, 1]).unwrap(),
504 }
505 .check()
506 }
507
508 #[test]
509 fn km_anbck_bmn_0() -> TractResult<()> {
510 EinSumProblem {
511 expr: "km,anbck->bmn".to_string(),
512 a: Tensor::zero::<f32>(&[2, 1]).unwrap(),
513 b: Tensor::zero::<f32>(&[1, 1, 1, 1, 2]).unwrap(),
514 }
515 .check()
516 }
517
518 #[test]
519 fn q() -> TractResult<()> {
520 let qp = QParams::ZpScale { zero_point: 0, scale: 0.1 };
521 let op = EinSum {
522 axes: "mk,kn,m,,,,,,->mn".parse()?,
523 operating_dt: i32::datum_type(),
524 q_params: Some(DatumType::QI8(qp)),
525 };
526 let mut model = TypedModelPatch::default();
527 let inputs = [
528 model.add_source("a", DatumType::QI8(qp).fact([3, 2]))?,
529 model.add_source("b", DatumType::QI8(qp).fact([2, 4]))?,
530 model.add_source("bias", i32::datum_type().fact([3]))?,
531 model.add_const("a0", tensor0(qp.zp_scale().0))?,
532 model.add_const("a_scale", tensor0(qp.zp_scale().1))?,
533 model.add_const("b0", tensor0(qp.zp_scale().0))?,
534 model.add_const("b_scale", tensor0(qp.zp_scale().1))?,
535 model.add_const("c0", tensor0(qp.zp_scale().0))?,
536 model.add_const("c_scale", tensor0(qp.zp_scale().1))?,
537 ];
538 let wire = model.wire_node("einsum", op.clone(), &inputs)?;
539 model.set_output_outlets(&wire)?;
540 rewrite_einsum_to_prefix_matmul(&mut model, true)?;
541 assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
542 Ok(())
543 }
544}