tract_core/ops/einsum/
as_matmul.rs

1use tract_data::itertools::Itertools;
2use tract_linalg::Scaler;
3use tract_ndarray::Ix2;
4use tract_num_traits::One;
5
6use super::eval::dequant_inputs;
7use super::optimize::*;
8use super::EinSum;
9use crate::internal::*;
10use crate::ops::einsum::block_quant_aware_input_shape;
11use crate::ops::konst::Const;
12
13pub fn rewrite_einsums_as_matmul(model: &mut TypedModel) -> TractResult<()> {
14    let rules = Rewriter::default().with_rule_for::<EinSum>("einsum-to-matmul", einsum_rules);
15    rules.rewrite(&(), model)
16}
17
18fn einsum_rules(
19    _ctx: &(),
20    model: &TypedModel,
21    node: &TypedNode,
22    node_name: &str,
23    op: &EinSum,
24) -> TractResult<Option<TypedModelPatch>> {
25    // F: 2 inputs
26    // Q: 9 inputs
27    if !((op.q_params.is_none() && node.inputs.len() == 2)
28        || (op.q_params.is_some() && node.inputs.len() == 9))
29    {
30        return Ok(None);
31    }
32    if op.q_params.is_some()
33        && model.node_input_facts(node.id)?.iter().skip(3).any(|i| i.konst.is_none())
34    {
35        return Ok(None);
36    }
37    let op = match ensure_mkn_axes(op, model, node).context("Figuring out m, k and n axes")? {
38        AxesOrPatch::Annotated(op) => op,
39        AxesOrPatch::Patch(p) => return Ok(Some(p)),
40        AxesOrPatch::NotAMatMul(s, axes) => {
41            bail!("{} is not a matmul: {s} ({})", op.axes, axes.iter().map(|a| a.repr).join(", "))
42        }
43    };
44    let prefix: String = op
45        .axes
46        .iter_all_axes()
47        .filter(|a| ![op.m_axis, op.k_axis, op.n_axis].contains(a))
48        .map(|a| a.repr)
49        .collect();
50    let mut patch = TypedModelPatch::default();
51    let inputs = patch.taps(model, &node.inputs)?;
52    let mut wire = tvec!(inputs[0], inputs[1]);
53
54    let (m, k, n) = (op.m_axis.repr, op.k_axis.repr, op.n_axis.repr);
55    let a_order_es: String = op.axes.axes(InOut::In(0)).map(|a| a.repr).collect();
56    let a_order_mm = format!("{prefix}{m}{k}");
57    let a_order_mm_t = format!("{prefix}{k}{m}");
58    let a_transform = format!("{}->{}", a_order_es, a_order_mm)
59        .parse::<AxesMapping>()?
60        .translate_to_axis_ops()?;
61    let a_transform_t = format!("{}->{}", a_order_es, a_order_mm_t)
62        .parse::<AxesMapping>()?
63        .translate_to_axis_ops()?;
64    let transpose_a = a_transform.len() > a_transform_t.len();
65    let a_transform = if transpose_a { a_transform_t } else { a_transform };
66    let name = format!("{node_name}.fix_a");
67    for op in a_transform {
68        wire[0] = patch.wire_node(&name, op, &[wire[0]])?[0];
69    }
70    // terrible hack to maintain opaque fact through eager propatagation of constant through the
71    // axes transformation
72    if let Some(op) = patch.node_mut(wire[0].node).op_as_mut::<Const>() {
73        *op = Const::new_with_opt_opaque_fact(
74            op.val().clone(),
75            model.outlet_fact(node.inputs[0])?.opaque_fact.clone(),
76        )?;
77    }
78    patch
79        .outlet_fact_mut(wire[0])?
80        .opaque_fact
81        .clone_from(&model.outlet_fact(node.inputs[0])?.opaque_fact);
82    // end of hack
83
84    let b_order_es: String = op.axes.axes(InOut::In(1)).map(|a| a.repr).collect();
85    let b_order_mm = format!("{prefix}{k}{n}");
86    let b_order_mm_t = format!("{prefix}{n}{k}");
87    let b_transform = format!("{}->{}", b_order_es, b_order_mm)
88        .parse::<AxesMapping>()?
89        .translate_to_axis_ops()?;
90    let b_transform_t = format!("{}->{}", b_order_es, b_order_mm_t)
91        .parse::<AxesMapping>()?
92        .translate_to_axis_ops()?;
93    let transpose_b = b_transform.len() > b_transform_t.len();
94    let b_transform = if transpose_b { b_transform_t } else { b_transform };
95    let name = format!("{node_name}.fix_b");
96    for op in b_transform {
97        wire[1] = patch.wire_node(&name, op, &[wire[1]])?[0];
98    }
99
100    let c_order_es: String = op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
101    let c_order_mm = format!("{prefix}{m}{n}");
102    let c_order_mm_t = format!("{prefix}{n}{m}");
103    let c_transform = format!("{}->{}", c_order_mm, c_order_es)
104        .parse::<AxesMapping>()?
105        .translate_to_axis_ops()?;
106    let c_transform_t = format!("{}->{}", c_order_mm_t, c_order_es)
107        .parse::<AxesMapping>()?
108        .translate_to_axis_ops()?;
109    let transpose_c = c_transform.len() > c_transform_t.len();
110    let c_transform = if transpose_c { c_transform_t } else { c_transform };
111    let quantize_output = if let Some(qp) = op.q_params {
112        let qparams: Vec<&Tensor> = inputs[3..9]
113            .iter()
114            .map(|f| {
115                patch
116                    .outlet_fact(*f)?
117                    .konst
118                    .as_deref()
119                    .context("Can only translate fixed scalar quantization")
120            })
121            .try_collect()?;
122        Some(qp.with_qparams(QParams::ZpScale {
123            zero_point: qparams[4].cast_to_scalar::<i32>()?,
124            scale: qparams[5].cast_to_scalar::<f32>()?,
125        }))
126    } else {
127        None
128    };
129    wire = patch.wire_node(
130        node_name,
131        BasicMatMul { transpose_a, transpose_b, transpose_c, quantize_output },
132        &wire,
133    )?;
134
135    for (ix, op) in c_transform.into_iter().enumerate() {
136        wire = patch.wire_node(format!("{node_name}.fix_c.{ix}"), op, &wire)?;
137    }
138    patch.shunt_outside(model, node.id.into(), wire[0])?;
139    Ok(Some(patch))
140}
141
142#[derive(Clone, Debug, Copy, Default)]
143pub struct BasicMatMul {
144    pub transpose_a: bool,
145    pub transpose_b: bool,
146    pub transpose_c: bool,
147    pub quantize_output: Option<DatumType>,
148}
149
150impl BasicMatMul {
151    fn output_shape<D: DimLike + One>(&self, a: &[D], b: &[D]) -> TVec<D> {
152        let rank = a.len();
153        let mut output: TVec<D> = (0..rank - 2)
154            .map(|ix| if a[ix].is_one() { b[ix].clone() } else { a[ix].clone() })
155            .collect();
156        output.push(a[rank - 2 + self.transpose_a as usize].clone());
157        output.push(b[rank - 2 + !self.transpose_b as usize].clone());
158        if self.transpose_c {
159            output.swap(rank - 2, rank - 1);
160        }
161        output
162    }
163
164    fn mm<Acc: Datum + tract_ndarray::LinalgScalar>(
165        &self,
166        acc: &mut Tensor,
167        a: &Tensor,
168        b: &Tensor,
169    ) -> TractResult<()> {
170        use crate::ndarray::Dimension;
171        let a = a.to_array_view::<Acc>()?;
172        let b = b.to_array_view::<Acc>()?;
173        let mut c = acc.to_array_view_mut::<Acc>()?;
174        for prefix in tract_ndarray::indices(&c.shape()[..c.ndim() - 2]) {
175            let mut a = a.view();
176            let mut b = b.view();
177            let mut c = c.view_mut();
178            for &d in prefix.slice().iter() {
179                a.index_axis_inplace(tract_ndarray::Axis(0), d.min(a.shape()[0] - 1));
180                b.index_axis_inplace(tract_ndarray::Axis(0), d.min(b.shape()[0] - 1));
181                c.index_axis_inplace(tract_ndarray::Axis(0), d);
182            }
183            let a = a.into_dimensionality::<Ix2>().unwrap();
184            let b = b.into_dimensionality::<Ix2>().unwrap();
185            let mut c = c.into_dimensionality::<Ix2>().unwrap();
186            let a = if self.transpose_a { a.t() } else { a };
187            let b = if self.transpose_b { b.t() } else { b };
188            if self.transpose_c {
189                c.assign(&b.t().dot(&a.t()))
190            } else {
191                c.assign(&a.dot(&b))
192            }
193        }
194        Ok(())
195    }
196}
197
198impl Op for BasicMatMul {
199    fn name(&self) -> Cow<str> {
200        "MatMul".into()
201    }
202
203    fn info(&self) -> TractResult<Vec<String>> {
204        Ok(vec![format!(
205            "transpose_a: {} transpose_b: {} transpose_c: {} q: {:?}",
206            self.transpose_a, self.transpose_b, self.transpose_c, self.quantize_output
207        )])
208    }
209
210    op_as_typed_op!();
211}
212
213impl EvalOp for BasicMatMul {
214    fn is_stateless(&self) -> bool {
215        true
216    }
217
218    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
219        let c_dt = if inputs[0].datum_type().is_number() {
220            inputs[0].datum_type()
221        } else if inputs[1].datum_type().is_number() {
222            inputs[1].datum_type()
223        } else {
224            f32::datum_type()
225        };
226        let inputs = dequant_inputs(c_dt, inputs)?;
227
228        let output_shape = self.output_shape(inputs[0].shape(), inputs[1].shape());
229
230        if let Some(qp) = self.quantize_output {
231            let mut acc = Tensor::zero_dt(i32::datum_type(), &output_shape)?;
232            let mut a_i32 = inputs[0].cast_to::<i32>()?.into_owned();
233            a_i32
234                .as_slice_mut::<i32>()?
235                .iter_mut()
236                .for_each(|x| *x -= inputs[0].datum_type().zp_scale().0);
237            let mut b_i32 = inputs[1].cast_to::<i32>()?.into_owned();
238            b_i32
239                .as_slice_mut::<i32>()?
240                .iter_mut()
241                .for_each(|x| *x -= inputs[1].datum_type().zp_scale().0);
242            self.mm::<i32>(&mut acc, &a_i32, &b_i32)?;
243            let scale = inputs[0].datum_type().zp_scale().1 * inputs[1].datum_type().zp_scale().1
244                / qp.zp_scale().1;
245            let scaler = Scaler::new(scale, tract_linalg::mmm::RoundingPolicy::Even);
246            acc.to_array_view_mut::<i32>()?.iter_mut().for_each(|x| *x = *x * scaler);
247            let mut c: Tensor = acc.cast_to_dt(qp.unquantized())?.into_owned();
248            unsafe { c.set_datum_type(qp) };
249            Ok(tvec!(c.into_tvalue()))
250        } else {
251            let mut c = Tensor::zero_dt(c_dt, &output_shape)?;
252            dispatch_floatlike!(Self::mm(c_dt)(self, &mut c, &inputs[0], &inputs[1]))?;
253            Ok(tvec!(c.into_tvalue()))
254        }
255    }
256}
257
258impl TypedOp for BasicMatMul {
259    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
260        let [a, b] = inputs else {
261            bail!("Expects 2 inputs");
262        };
263        let a_shape = block_quant_aware_input_shape(inputs[0])?;
264        let b_shape = block_quant_aware_input_shape(inputs[1])?;
265        let dt = self.quantize_output.unwrap_or(if a.datum_type.is_number() {
266            a.datum_type
267        } else {
268            b.datum_type
269        });
270        Ok(tvec!(dt.fact(self.output_shape(&a_shape, &b_shape))))
271    }
272
273    as_op!();
274}
275
276#[cfg(test)]
277mod test {
278    use super::*;
279    use proptest::collection::vec;
280    use proptest::prelude::*;
281    use proptest::test_runner::{TestCaseResult, TestRunner};
282    use tract_data::itertools::Itertools;
283
284    pub fn tensor(shape: &[usize]) -> BoxedStrategy<Tensor> {
285        let shape = shape.to_vec();
286        let len = shape.iter().product::<usize>();
287        vec((-10i8..=10i8).prop_map(|i| i as f32), len..=len)
288            .prop_map(move |vec| tensor1(&vec).into_shape(&shape).unwrap())
289            .boxed()
290    }
291
292    fn full_shapes(e: &AxesMapping) -> BoxedStrategy<(Vec<usize>, Vec<usize>)> {
293        let e = e.clone();
294        let inputs_axes = e
295            .iter_all_axes()
296            .filter(|axis| axis.inputs[0].len() + axis.inputs[1].len() > 0)
297            .cloned()
298            .collect_vec();
299        let dims = vec![2usize..6; inputs_axes.len()];
300        dims.prop_map(move |dims| {
301            let a: Vec<usize> = e
302                .axes(InOut::In(0))
303                .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
304                .collect_vec();
305            let b: Vec<usize> = e
306                .axes(InOut::In(1))
307                .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
308                .collect_vec();
309            (a, b)
310        })
311        .boxed()
312    }
313
314    fn test_expr(expr: &str) -> TestCaseResult {
315        let expr = expr.to_string();
316        let mut runner = TestRunner::default();
317        let axes: AxesMapping = expr.parse().unwrap();
318        fn is_k(axes: &AxesMapping, input: usize, position: usize) -> bool {
319            let axis = axes.axis((InOut::In(input), position)).unwrap();
320            axis.inputs[1 - input].len() == 1 && axis.outputs[0].len() == 0
321        }
322        fn is_disapearing_axis(axes: &AxesMapping, input: usize, position: usize) -> bool {
323            let axis = axes.axis((InOut::In(input), position)).unwrap();
324            axis.outputs[0].len() == 0
325        }
326        let cases = full_shapes(&axes)
327            .prop_flat_map(|(a, b)| {
328                (
329                    a.iter()
330                        .enumerate()
331                        .map(|(ix, d)| {
332                            if is_k(&axes, 0, ix) {
333                                prop_oneof![Just(*d)].boxed()
334                            } else if is_disapearing_axis(&axes, 0, ix) {
335                                Just(1).boxed()
336                            } else {
337                                prop_oneof![Just(1usize), Just(*d)].boxed()
338                            }
339                        })
340                        .collect_vec(),
341                    b.iter()
342                        .enumerate()
343                        .map(|(ix, d)| {
344                            if is_k(&axes, 1, ix) {
345                                prop_oneof![Just(*d)].boxed()
346                            } else if is_disapearing_axis(&axes, 1, ix) {
347                                Just(1).boxed()
348                            } else {
349                                prop_oneof![Just(1usize), Just(*d)].boxed()
350                            }
351                        })
352                        .collect_vec(),
353                )
354            })
355            .prop_flat_map(|(a_shape, b_shape)| (tensor(&a_shape), tensor(&b_shape)))
356            .prop_map(|(a, b)| EinSumProblem { expr: expr.clone(), a, b });
357        runner.run(&cases, |pb| pb.check().map_err(|e| TestCaseError::fail(e.to_string())))?;
358        Ok(())
359    }
360
361    #[derive(Debug, Clone, PartialEq)]
362    struct EinSumProblem {
363        expr: String,
364        a: Tensor,
365        b: Tensor,
366    }
367
368    impl EinSumProblem {
369        fn check(&self) -> TractResult<()> {
370            let mut model = TypedModel::default();
371            let sa = model.add_source("a", f32::fact(self.a.shape())).unwrap();
372            let sb = model.add_source("b", f32::fact(self.b.shape())).unwrap();
373            let einsum = model
374                .wire_node(
375                    "einsum",
376                    EinSum::new(self.expr.parse().unwrap(), f32::datum_type()),
377                    &[sa, sb],
378                )
379                .unwrap();
380            model.set_output_outlets(&einsum).unwrap();
381            let a = self.a.clone().into_tvalue();
382            let b = self.b.clone().into_tvalue();
383            let inputs = tvec!(a, b);
384            let reference =
385                TypedRunnableModel::new(&model).unwrap().run(inputs.clone()).unwrap().remove(0);
386            rewrite_einsums_as_matmul(&mut model)?;
387            assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
388            let test = TypedRunnableModel::new(&model).unwrap().run(inputs).unwrap().remove(0);
389            reference.close_enough(&test, true).unwrap();
390            Ok(())
391        }
392    }
393
394    #[rustfmt::skip] #[test] fn prop_mk_kn_mn() -> TestCaseResult { test_expr("mk,kn->mn") }
395    #[rustfmt::skip] #[test] fn prop_km_kn_mn() -> TestCaseResult { test_expr("km,kn->mn") }
396    #[rustfmt::skip] #[test] fn prop_mk_nk_mn() -> TestCaseResult { test_expr("mk,nk->mn") }
397    #[rustfmt::skip] #[test] fn prop_mk_kn_nm() -> TestCaseResult { test_expr("mk,kn->nm") }
398    #[rustfmt::skip] #[test] fn prop_k_kn_mn() -> TestCaseResult { test_expr("k,kn->mn") }
399    #[rustfmt::skip] #[test] fn prop_mk_k_mn() -> TestCaseResult { test_expr("mk,k->mn") }
400    #[rustfmt::skip] #[test] fn prop_m_n_mn() -> TestCaseResult { test_expr("m,n->mn") }
401    #[rustfmt::skip] #[test] fn prop_amk_akn_amn() -> TestCaseResult { test_expr("amk,akn->amn") }
402    #[rustfmt::skip] #[test] fn prop_mk_akn_amn() -> TestCaseResult { test_expr("mk,akn->amn") }
403    #[rustfmt::skip] #[test] fn prop_btgi_gih_tgh() -> TestCaseResult { test_expr("btgi,gih->tgh") }
404    #[rustfmt::skip] #[test] fn prop_tgi_gih_btgh() -> TestCaseResult { test_expr("tgi,gih->btgh") }
405
406    #[test]
407    fn k_kn_mn_0() -> TractResult<()> {
408        EinSumProblem {
409            expr: "k,kn->mn".to_string(),
410            a: tensor1(&[0f32, 0f32]),
411            b: tensor2(&[[0f32, 0.], [0., 0.]]),
412        }
413        .check()
414    }
415
416    #[test]
417    fn mk_k_mn_0() -> TractResult<()> {
418        EinSumProblem {
419            expr: "mk,k->mn".to_string(),
420            a: Tensor::zero::<f32>(&[2, 2]).unwrap(),
421            b: Tensor::zero::<f32>(&[2]).unwrap(),
422        }
423        .check()
424    }
425
426    #[test]
427    fn mk_k_mn_1() -> TractResult<()> {
428        EinSumProblem {
429            expr: "mk,k->mn".to_string(),
430            a: Tensor::zero::<f32>(&[1, 2]).unwrap(),
431            b: Tensor::zero::<f32>(&[2]).unwrap(),
432        }
433        .check()
434    }
435
436    #[test]
437    fn mk_kn_nm_0() -> TractResult<()> {
438        EinSumProblem {
439            expr: "mk,kn->mn".to_string(),
440            a: Tensor::zero::<f32>(&[3, 2]).unwrap(),
441            b: Tensor::zero::<f32>(&[2, 2]).unwrap(),
442        }
443        .check()
444    }
445
446    #[test]
447    fn amk_akn_amn_0() -> TractResult<()> {
448        EinSumProblem {
449            expr: "amk,akn->amn".to_string(),
450            a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
451            b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
452        }
453        .check()
454    }
455
456    #[test]
457    fn amk_akn_amn_1() -> TractResult<()> {
458        EinSumProblem {
459            expr: "amk,akn->amn".to_string(),
460            a: Tensor::zero::<f32>(&[2, 1, 2]).unwrap(),
461            b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
462        }
463        .check()
464    }
465
466    #[test]
467    fn amk_akn_amn_2() -> TractResult<()> {
468        EinSumProblem {
469            expr: "amk,akn->amn".to_string(),
470            a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
471            b: Tensor::zero::<f32>(&[2, 2, 2]).unwrap(),
472        }
473        .check()
474    }
475
476    #[test]
477    fn amk_akn_amn_3() -> TractResult<()> {
478        EinSumProblem {
479            expr: "amk,akn->amn".to_string(),
480            a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
481            b: Tensor::zero::<f32>(&[2, 2, 1]).unwrap(),
482        }
483        .check()
484    }
485
486    #[test]
487    fn km_anbck_bmn_0() -> TractResult<()> {
488        EinSumProblem {
489            expr: "km,anbck->bmn".to_string(),
490            a: Tensor::zero::<f32>(&[2, 1]).unwrap(),
491            b: Tensor::zero::<f32>(&[1, 1, 1, 1, 2]).unwrap(),
492        }
493        .check()
494    }
495
496    #[test]
497    fn q() -> TractResult<()> {
498        let qp = QParams::ZpScale { zero_point: 0, scale: 0.1 };
499        let op = EinSum {
500            axes: "mk,kn,m,,,,,,->mn".parse()?,
501            operating_dt: i32::datum_type(),
502            q_params: Some(DatumType::QI8(qp)),
503        };
504        let mut model = TypedModelPatch::default();
505        let inputs = [
506            model.add_source("a", DatumType::QI8(qp).fact([3, 2]))?,
507            model.add_source("b", DatumType::QI8(qp).fact([2, 4]))?,
508            model.add_source("bias", i32::datum_type().fact([3]))?,
509            model.add_const("a0", tensor0(qp.zp_scale().0))?,
510            model.add_const("a_scale", tensor0(qp.zp_scale().1))?,
511            model.add_const("b0", tensor0(qp.zp_scale().0))?,
512            model.add_const("b_scale", tensor0(qp.zp_scale().1))?,
513            model.add_const("c0", tensor0(qp.zp_scale().0))?,
514            model.add_const("c_scale", tensor0(qp.zp_scale().1))?,
515        ];
516        let wire = model.wire_node("einsum", op.clone(), &inputs)?;
517        model.set_output_outlets(&wire)?;
518        rewrite_einsums_as_matmul(&mut model)?;
519        assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
520        Ok(())
521    }
522}