1use tract_data::itertools::Itertools;
2use tract_linalg::Scaler;
3use tract_ndarray::Ix2;
4use tract_num_traits::One;
5
6use super::einsum_matmul::EinSumMatMul;
7use super::eval::dequant_inputs;
8use crate::internal::*;
9use crate::ops::einsum::block_quant_aware_input_shape;
10use crate::ops::konst::Const;
11
12pub fn rewrite_einsum_to_prefix_matmul(model: &mut TypedModel) -> TractResult<()> {
13 super::einsum_matmul::detect_all(model)?;
14 Rewriter::default().with_rule_for("einsum-to-prefix-matmul", rule).rewrite(&(), model)
15}
16
17fn rule(
18 _ctx: &(),
19 model: &TypedModel,
20 node: &TypedNode,
21 node_name: &str,
22 op: &EinSumMatMul,
23) -> TractResult<Option<TypedModelPatch>> {
24 if !((op.q_params.is_none() && node.inputs.len() == 2)
27 || (op.q_params.is_some() && node.inputs.len() == 9))
28 {
29 return Ok(None);
30 }
31 if op.q_params.is_some()
32 && model.node_input_facts(node.id)?.iter().skip(3).any(|i| i.konst.is_none())
33 {
34 return Ok(None);
35 }
36 let prefix: String = op
37 .axes
38 .iter_all_axes()
39 .filter(|a| ![op.m_axis, op.k_axis, op.n_axis].contains(&a.repr))
40 .map(|a| a.repr)
41 .collect();
42 let mut patch = TypedModelPatch::default();
43 let inputs = patch.taps(model, &node.inputs)?;
44 let mut wire = tvec!(inputs[0], inputs[1]);
45
46 let (m, k, n) = (op.m_axis, op.k_axis, op.n_axis);
47 let a_order_es: String = op.axes.axes(InOut::In(0)).map(|a| a.repr).collect();
48 let a_order_mm = format!("{prefix}{m}{k}");
49 let a_order_mm_t = format!("{prefix}{k}{m}");
50 let a_transform = format!("{a_order_es}->{a_order_mm}")
51 .parse::<AxesMapping>()?
52 .translate_to_axis_ops()?;
53 let a_transform_t = format!("{a_order_es}->{a_order_mm_t}")
54 .parse::<AxesMapping>()?
55 .translate_to_axis_ops()?;
56 let transpose_a = a_transform.len() > a_transform_t.len();
57 let a_transform = if transpose_a { a_transform_t } else { a_transform };
58 let name = format!("{node_name}.fix_a");
59 for op in a_transform {
60 wire[0] = patch.wire_node(&name, op, &[wire[0]])?[0];
61 }
62 if let Some(op) = patch.node_mut(wire[0].node).op_as_mut::<Const>() {
65 *op = Const::new_with_opt_opaque_fact(
66 op.val().clone(),
67 model.outlet_fact(node.inputs[0])?.opaque_fact.clone(),
68 )?;
69 }
70 patch
71 .outlet_fact_mut(wire[0])?
72 .opaque_fact
73 .clone_from(&model.outlet_fact(node.inputs[0])?.opaque_fact);
74 let b_order_es: String = op.axes.axes(InOut::In(1)).map(|a| a.repr).collect();
77 let b_order_mm = format!("{prefix}{k}{n}");
78 let b_order_mm_t = format!("{prefix}{n}{k}");
79 let b_transform = format!("{b_order_es}->{b_order_mm}")
80 .parse::<AxesMapping>()?
81 .translate_to_axis_ops()?;
82 let b_transform_t = format!("{b_order_es}->{b_order_mm_t}")
83 .parse::<AxesMapping>()?
84 .translate_to_axis_ops()?;
85 let transpose_b = b_transform.len() > b_transform_t.len();
86 let b_transform = if transpose_b { b_transform_t } else { b_transform };
87 let name = format!("{node_name}.fix_b");
88 for op in b_transform {
89 wire[1] = patch.wire_node(&name, op, &[wire[1]])?[0];
90 }
91
92 let c_order_es: String = op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
93 let c_order_mm = format!("{prefix}{m}{n}");
94 let c_order_mm_t = format!("{prefix}{n}{m}");
95 let c_transform = format!("{c_order_mm}->{c_order_es}")
96 .parse::<AxesMapping>()?
97 .translate_to_axis_ops()?;
98 let c_transform_t = format!("{c_order_mm_t}->{c_order_es}")
99 .parse::<AxesMapping>()?
100 .translate_to_axis_ops()?;
101 let transpose_c = c_transform.len() > c_transform_t.len();
102 let c_transform = if transpose_c { c_transform_t } else { c_transform };
103 let quantize_output = if let Some(qp) = op.q_params {
104 let qparams: Vec<&Tensor> = inputs[3..9]
105 .iter()
106 .map(|f| {
107 patch
108 .outlet_fact(*f)?
109 .konst
110 .as_deref()
111 .context("Can only translate fixed scalar quantization")
112 })
113 .try_collect()?;
114 Some(qp.with_qparams(QParams::ZpScale {
115 zero_point: qparams[4].cast_to_scalar::<i32>()?,
116 scale: qparams[5].cast_to_scalar::<f32>()?,
117 }))
118 } else {
119 None
120 };
121 wire = patch.wire_node(
122 node_name,
123 PrefixMatMul { transpose_a, transpose_b, transpose_c, quantize_output },
124 &wire,
125 )?;
126
127 for (ix, op) in c_transform.into_iter().enumerate() {
128 wire = patch.wire_node(format!("{node_name}.fix_c.{ix}"), op, &wire)?;
129 }
130 patch.shunt_outside(model, node.id.into(), wire[0])?;
131 Ok(Some(patch))
132}
133
134#[derive(Clone, Debug, Copy, Default)]
135pub struct PrefixMatMul {
136 pub transpose_a: bool,
137 pub transpose_b: bool,
138 pub transpose_c: bool,
139 pub quantize_output: Option<DatumType>,
140}
141
142impl PrefixMatMul {
143 fn output_shape<D: DimLike + One>(&self, a: &[D], b: &[D]) -> TVec<D> {
144 let rank = a.len();
145 let mut output: TVec<D> = (0..rank - 2)
146 .map(|ix| if a[ix].is_one() { b[ix].clone() } else { a[ix].clone() })
147 .collect();
148 output.push(a[rank - 2 + self.transpose_a as usize].clone());
149 output.push(b[rank - 2 + !self.transpose_b as usize].clone());
150 if self.transpose_c {
151 output.swap(rank - 2, rank - 1);
152 }
153 output
154 }
155
156 fn mm<Acc: Datum + tract_ndarray::LinalgScalar>(
157 &self,
158 acc: &mut Tensor,
159 a: &Tensor,
160 b: &Tensor,
161 ) -> TractResult<()> {
162 use crate::ndarray::Dimension;
163 let a = a.to_array_view::<Acc>()?;
164 let b = b.to_array_view::<Acc>()?;
165 let mut c = acc.to_array_view_mut::<Acc>()?;
166 for prefix in tract_ndarray::indices(&c.shape()[..c.ndim() - 2]) {
167 let mut a = a.view();
168 let mut b = b.view();
169 let mut c = c.view_mut();
170 for &d in prefix.slice().iter() {
171 a.index_axis_inplace(tract_ndarray::Axis(0), d.min(a.shape()[0] - 1));
172 b.index_axis_inplace(tract_ndarray::Axis(0), d.min(b.shape()[0] - 1));
173 c.index_axis_inplace(tract_ndarray::Axis(0), d);
174 }
175 let a = a.into_dimensionality::<Ix2>().unwrap();
176 let b = b.into_dimensionality::<Ix2>().unwrap();
177 let mut c = c.into_dimensionality::<Ix2>().unwrap();
178 let a = if self.transpose_a { a.t() } else { a };
179 let b = if self.transpose_b { b.t() } else { b };
180 if self.transpose_c {
181 c.assign(&b.t().dot(&a.t()))
182 } else {
183 c.assign(&a.dot(&b))
184 }
185 }
186 Ok(())
187 }
188}
189
190impl Op for PrefixMatMul {
191 fn name(&self) -> StaticName {
192 "PrefixMatMul".into()
193 }
194
195 fn info(&self) -> TractResult<Vec<String>> {
196 Ok(vec![format!(
197 "transpose_a: {} transpose_b: {} transpose_c: {} q: {:?}",
198 self.transpose_a, self.transpose_b, self.transpose_c, self.quantize_output
199 )])
200 }
201
202 op_as_typed_op!();
203}
204
205impl EvalOp for PrefixMatMul {
206 fn is_stateless(&self) -> bool {
207 true
208 }
209
210 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
211 let c_dt = if inputs[0].datum_type().is_number() {
212 inputs[0].datum_type()
213 } else if inputs[1].datum_type().is_number() {
214 inputs[1].datum_type()
215 } else {
216 f32::datum_type()
217 };
218 let inputs = dequant_inputs(c_dt, inputs)?;
219
220 let output_shape = self.output_shape(inputs[0].shape(), inputs[1].shape());
221
222 if let Some(qp) = self.quantize_output {
223 let mut acc = Tensor::zero_dt(i32::datum_type(), &output_shape)?;
224 let mut a_i32 = inputs[0].cast_to::<i32>()?.into_owned();
225 a_i32
226 .as_slice_mut::<i32>()?
227 .iter_mut()
228 .for_each(|x| *x -= inputs[0].datum_type().zp_scale().0);
229 let mut b_i32 = inputs[1].cast_to::<i32>()?.into_owned();
230 b_i32
231 .as_slice_mut::<i32>()?
232 .iter_mut()
233 .for_each(|x| *x -= inputs[1].datum_type().zp_scale().0);
234 self.mm::<i32>(&mut acc, &a_i32, &b_i32)?;
235 let scale = inputs[0].datum_type().zp_scale().1 * inputs[1].datum_type().zp_scale().1
236 / qp.zp_scale().1;
237 let scaler = Scaler::new(scale, tract_linalg::mmm::RoundingPolicy::Even);
238 acc.to_array_view_mut::<i32>()?.iter_mut().for_each(|x| *x = *x * scaler);
239 let mut c: Tensor = acc.cast_to_dt(qp.unquantized())?.into_owned();
240 unsafe { c.set_datum_type(qp) };
241 Ok(tvec!(c.into_tvalue()))
242 } else {
243 let mut c = Tensor::zero_dt(c_dt, &output_shape)?;
244 dispatch_floatlike!(Self::mm(c_dt)(self, &mut c, &inputs[0], &inputs[1]))?;
245 Ok(tvec!(c.into_tvalue()))
246 }
247 }
248}
249
250impl TypedOp for PrefixMatMul {
251 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
252 let [a, b] = inputs else {
253 bail!("Expects 2 inputs");
254 };
255 let a_shape = block_quant_aware_input_shape(inputs[0])?;
256 let b_shape = block_quant_aware_input_shape(inputs[1])?;
257 let dt = self.quantize_output.unwrap_or(if a.datum_type.is_number() {
258 a.datum_type
259 } else {
260 b.datum_type
261 });
262 Ok(tvec!(dt.fact(self.output_shape(&a_shape, &b_shape))))
263 }
264
265 as_op!();
266}
267
268#[cfg(test)]
269mod test {
270 use crate::ops::einsum::EinSum;
271
272 use super::*;
273 use proptest::collection::vec;
274 use proptest::prelude::*;
275 use proptest::test_runner::{TestCaseResult, TestRunner};
276 use tract_data::itertools::Itertools;
277
278 pub fn tensor(shape: &[usize]) -> BoxedStrategy<Tensor> {
279 let shape = shape.to_vec();
280 let len = shape.iter().product::<usize>();
281 vec((-10i8..=10i8).prop_map(|i| i as f32), len..=len)
282 .prop_map(move |vec| tensor1(&vec).into_shape(&shape).unwrap())
283 .boxed()
284 }
285
286 fn full_shapes(e: &AxesMapping) -> BoxedStrategy<(Vec<usize>, Vec<usize>)> {
287 let e = e.clone();
288 let inputs_axes = e
289 .iter_all_axes()
290 .filter(|axis| axis.inputs[0].len() + axis.inputs[1].len() > 0)
291 .cloned()
292 .collect_vec();
293 let dims = vec![2usize..6; inputs_axes.len()];
294 dims.prop_map(move |dims| {
295 let a: Vec<usize> = e
296 .axes(InOut::In(0))
297 .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
298 .collect_vec();
299 let b: Vec<usize> = e
300 .axes(InOut::In(1))
301 .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
302 .collect_vec();
303 (a, b)
304 })
305 .boxed()
306 }
307
308 fn test_expr(expr: &str) -> TestCaseResult {
309 let expr = expr.to_string();
310 let mut runner = TestRunner::default();
311 let axes: AxesMapping = expr.parse().unwrap();
312 fn is_k(axes: &AxesMapping, input: usize, position: usize) -> bool {
313 let axis = axes.axis((InOut::In(input), position)).unwrap();
314 axis.inputs[1 - input].len() == 1 && axis.outputs[0].len() == 0
315 }
316 fn is_disapearing_axis(axes: &AxesMapping, input: usize, position: usize) -> bool {
317 let axis = axes.axis((InOut::In(input), position)).unwrap();
318 axis.outputs[0].len() == 0
319 }
320 let cases = full_shapes(&axes)
321 .prop_flat_map(|(a, b)| {
322 (
323 a.iter()
324 .enumerate()
325 .map(|(ix, d)| {
326 if is_k(&axes, 0, ix) {
327 prop_oneof![Just(*d)].boxed()
328 } else if is_disapearing_axis(&axes, 0, ix) {
329 Just(1).boxed()
330 } else {
331 prop_oneof![Just(1usize), Just(*d)].boxed()
332 }
333 })
334 .collect_vec(),
335 b.iter()
336 .enumerate()
337 .map(|(ix, d)| {
338 if is_k(&axes, 1, ix) {
339 prop_oneof![Just(*d)].boxed()
340 } else if is_disapearing_axis(&axes, 1, ix) {
341 Just(1).boxed()
342 } else {
343 prop_oneof![Just(1usize), Just(*d)].boxed()
344 }
345 })
346 .collect_vec(),
347 )
348 })
349 .prop_flat_map(|(a_shape, b_shape)| (tensor(&a_shape), tensor(&b_shape)))
350 .prop_map(|(a, b)| EinSumProblem { expr: expr.clone(), a, b });
351 runner.run(&cases, |pb| pb.check().map_err(|e| TestCaseError::fail(e.to_string())))?;
352 Ok(())
353 }
354
355 #[derive(Debug, Clone, PartialEq)]
356 struct EinSumProblem {
357 expr: String,
358 a: Tensor,
359 b: Tensor,
360 }
361
362 impl EinSumProblem {
363 fn check(&self) -> TractResult<()> {
364 let mut model = TypedModel::default();
365 let sa = model.add_source("a", f32::fact(self.a.shape())).unwrap();
366 let sb = model.add_source("b", f32::fact(self.b.shape())).unwrap();
367 let einsum = model
368 .wire_node(
369 "einsum",
370 EinSum::new(self.expr.parse().unwrap(), f32::datum_type()),
371 &[sa, sb],
372 )
373 .unwrap();
374 model.set_output_outlets(&einsum).unwrap();
375 let a = self.a.clone().into_tvalue();
376 let b = self.b.clone().into_tvalue();
377 let inputs = tvec!(a, b);
378 let reference =
379 TypedRunnableModel::new(&model).unwrap().run(inputs.clone()).unwrap().remove(0);
380 rewrite_einsum_to_prefix_matmul(&mut model)?;
381 assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
382 let test = TypedRunnableModel::new(&model).unwrap().run(inputs).unwrap().remove(0);
383 reference.close_enough(&test, true).unwrap();
384 Ok(())
385 }
386 }
387
388 #[rustfmt::skip] #[test] fn prop_mk_kn_mn() -> TestCaseResult { test_expr("mk,kn->mn") }
389 #[rustfmt::skip] #[test] fn prop_km_kn_mn() -> TestCaseResult { test_expr("km,kn->mn") }
390 #[rustfmt::skip] #[test] fn prop_mk_nk_mn() -> TestCaseResult { test_expr("mk,nk->mn") }
391 #[rustfmt::skip] #[test] fn prop_mk_kn_nm() -> TestCaseResult { test_expr("mk,kn->nm") }
392 #[rustfmt::skip] #[test] fn prop_k_kn_mn() -> TestCaseResult { test_expr("k,kn->mn") }
393 #[rustfmt::skip] #[test] fn prop_mk_k_mn() -> TestCaseResult { test_expr("mk,k->mn") }
394 #[rustfmt::skip] #[test] fn prop_m_n_mn() -> TestCaseResult { test_expr("m,n->mn") }
395 #[rustfmt::skip] #[test] fn prop_amk_akn_amn() -> TestCaseResult { test_expr("amk,akn->amn") }
396 #[rustfmt::skip] #[test] fn prop_mk_akn_amn() -> TestCaseResult { test_expr("mk,akn->amn") }
397 #[rustfmt::skip] #[test] fn prop_btgi_gih_tgh() -> TestCaseResult { test_expr("btgi,gih->tgh") }
398 #[rustfmt::skip] #[test] fn prop_tgi_gih_btgh() -> TestCaseResult { test_expr("tgi,gih->btgh") }
399
400 #[test]
401 fn k_kn_mn_0() -> TractResult<()> {
402 EinSumProblem {
403 expr: "k,kn->mn".to_string(),
404 a: tensor1(&[0f32, 0f32]),
405 b: tensor2(&[[0f32, 0.], [0., 0.]]),
406 }
407 .check()
408 }
409
410 #[test]
411 fn mk_k_mn_0() -> TractResult<()> {
412 EinSumProblem {
413 expr: "mk,k->mn".to_string(),
414 a: Tensor::zero::<f32>(&[2, 2]).unwrap(),
415 b: Tensor::zero::<f32>(&[2]).unwrap(),
416 }
417 .check()
418 }
419
420 #[test]
421 fn mk_k_mn_1() -> TractResult<()> {
422 EinSumProblem {
423 expr: "mk,k->mn".to_string(),
424 a: Tensor::zero::<f32>(&[1, 2]).unwrap(),
425 b: Tensor::zero::<f32>(&[2]).unwrap(),
426 }
427 .check()
428 }
429
430 #[test]
431 fn mk_kn_nm_0() -> TractResult<()> {
432 EinSumProblem {
433 expr: "mk,kn->mn".to_string(),
434 a: Tensor::zero::<f32>(&[3, 2]).unwrap(),
435 b: Tensor::zero::<f32>(&[2, 2]).unwrap(),
436 }
437 .check()
438 }
439
440 #[test]
441 fn amk_akn_amn_0() -> TractResult<()> {
442 EinSumProblem {
443 expr: "amk,akn->amn".to_string(),
444 a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
445 b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
446 }
447 .check()
448 }
449
450 #[test]
451 fn amk_akn_amn_1() -> TractResult<()> {
452 EinSumProblem {
453 expr: "amk,akn->amn".to_string(),
454 a: Tensor::zero::<f32>(&[2, 1, 2]).unwrap(),
455 b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
456 }
457 .check()
458 }
459
460 #[test]
461 fn amk_akn_amn_2() -> TractResult<()> {
462 EinSumProblem {
463 expr: "amk,akn->amn".to_string(),
464 a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
465 b: Tensor::zero::<f32>(&[2, 2, 2]).unwrap(),
466 }
467 .check()
468 }
469
470 #[test]
471 fn amk_akn_amn_3() -> TractResult<()> {
472 EinSumProblem {
473 expr: "amk,akn->amn".to_string(),
474 a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
475 b: Tensor::zero::<f32>(&[2, 2, 1]).unwrap(),
476 }
477 .check()
478 }
479
480 #[test]
481 fn km_anbck_bmn_0() -> TractResult<()> {
482 EinSumProblem {
483 expr: "km,anbck->bmn".to_string(),
484 a: Tensor::zero::<f32>(&[2, 1]).unwrap(),
485 b: Tensor::zero::<f32>(&[1, 1, 1, 1, 2]).unwrap(),
486 }
487 .check()
488 }
489
490 #[test]
491 fn q() -> TractResult<()> {
492 let qp = QParams::ZpScale { zero_point: 0, scale: 0.1 };
493 let op = EinSum {
494 axes: "mk,kn,m,,,,,,->mn".parse()?,
495 operating_dt: i32::datum_type(),
496 q_params: Some(DatumType::QI8(qp)),
497 };
498 let mut model = TypedModelPatch::default();
499 let inputs = [
500 model.add_source("a", DatumType::QI8(qp).fact([3, 2]))?,
501 model.add_source("b", DatumType::QI8(qp).fact([2, 4]))?,
502 model.add_source("bias", i32::datum_type().fact([3]))?,
503 model.add_const("a0", tensor0(qp.zp_scale().0))?,
504 model.add_const("a_scale", tensor0(qp.zp_scale().1))?,
505 model.add_const("b0", tensor0(qp.zp_scale().0))?,
506 model.add_const("b_scale", tensor0(qp.zp_scale().1))?,
507 model.add_const("c0", tensor0(qp.zp_scale().0))?,
508 model.add_const("c_scale", tensor0(qp.zp_scale().1))?,
509 ];
510 let wire = model.wire_node("einsum", op.clone(), &inputs)?;
511 model.set_output_outlets(&wire)?;
512 rewrite_einsum_to_prefix_matmul(&mut model)?;
513 assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
514 Ok(())
515 }
516}