Skip to main content

tract_core/ops/einsum/
prefix_matmul.rs

1use tract_data::itertools::Itertools;
2use tract_linalg::Scaler;
3use tract_ndarray::Ix2;
4use tract_num_traits::One;
5
6use super::einsum_matmul::EinSumMatMul;
7use super::eval::dequant_inputs;
8use crate::internal::*;
9use crate::ops::einsum::block_quant_aware_input_shape;
10use crate::ops::konst::Const;
11
12#[derive(Debug, Default)]
13pub struct EinSumToPrefixMatmulCtx {
14    pub ensure_strict_matmul_semantic: bool,
15}
16
17pub fn rewrite_einsum_to_prefix_matmul(
18    model: &mut TypedModel,
19    ensure_strict_matmul_semantic: bool,
20) -> TractResult<()> {
21    super::einsum_matmul::merge_consecutive_same_role_axes(model)?;
22    super::einsum_matmul::detect_all(model)?;
23    let ctx = EinSumToPrefixMatmulCtx { ensure_strict_matmul_semantic };
24    Rewriter::default().with_rule_for("einsum-to-prefix-matmul", rule).rewrite(&ctx, model)
25}
26
27fn rule(
28    ctx: &EinSumToPrefixMatmulCtx,
29    model: &TypedModel,
30    node: &TypedNode,
31    node_name: &str,
32    op: &EinSumMatMul,
33) -> TractResult<Option<TypedModelPatch>> {
34    // F: 2 inputs
35    // Q: 9 inputs
36    let is_fp_mm = op.q_params.is_none() && node.inputs.len() == 2;
37    let is_q_mm = op.q_params.is_some() && node.inputs.len() == 9;
38    rule_if!(is_fp_mm || is_q_mm);
39    rule_if!(
40        op.q_params.is_none()
41            || model.node_input_facts(node.id)?.iter().skip(3).all(|i| i.konst.is_some())
42    );
43    let prefix: String = op
44        .axes
45        .iter_all_axes()
46        .filter(|a| ![op.m_axis, op.k_axis, op.n_axis].contains(&a.repr))
47        .map(|a| a.repr)
48        .collect();
49    let mut patch = TypedModelPatch::default();
50    let inputs = patch.taps(model, &node.inputs)?;
51    let mut wire = tvec!(inputs[0], inputs[1]);
52
53    let (m, k, n) = (op.m_axis, op.k_axis, op.n_axis);
54    let a_order_es: String = op.axes.axes(InOut::In(0)).map(|a| a.repr).collect();
55    let a_order_mm = format!("{prefix}{m}{k}");
56    let a_order_mm_t = format!("{prefix}{k}{m}");
57    let a_transform =
58        format!("{a_order_es}->{a_order_mm}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
59    let a_transform_t =
60        format!("{a_order_es}->{a_order_mm_t}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
61    let transpose_a = a_transform.len() > a_transform_t.len();
62    let a_transform = if transpose_a { a_transform_t } else { a_transform };
63    let name = format!("{node_name}.fix_a");
64    for op in a_transform {
65        wire[0] = patch.wire_node(&name, op, &[wire[0]])?[0];
66    }
67    // terrible hack to maintain exotic fact through eager propagation of constant through the
68    // axes transformation
69    if let Some(op) = patch.node_mut(wire[0].node).op_as_mut::<Const>() {
70        *op = Const::new_with_opt_exotic_fact(
71            op.val().clone(),
72            model.outlet_fact(node.inputs[0])?.exotic_fact.clone(),
73        )?;
74    }
75    patch
76        .outlet_fact_mut(wire[0])?
77        .exotic_fact
78        .clone_from(&model.outlet_fact(node.inputs[0])?.exotic_fact);
79    // end of hack
80
81    let b_order_es: String = op.axes.axes(InOut::In(1)).map(|a| a.repr).collect();
82    let b_order_mm = format!("{prefix}{k}{n}");
83    let b_order_mm_t = format!("{prefix}{n}{k}");
84    let b_transform =
85        format!("{b_order_es}->{b_order_mm}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
86    let b_transform_t =
87        format!("{b_order_es}->{b_order_mm_t}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
88    let transpose_b = b_transform.len() > b_transform_t.len();
89    let b_transform = if transpose_b { b_transform_t } else { b_transform };
90    let name = format!("{node_name}.fix_b");
91    for op in b_transform {
92        wire[1] = patch.wire_node(&name, op, &[wire[1]])?[0];
93    }
94
95    let c_order_es: String = op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
96    let c_order_mm = format!("{prefix}{m}{n}");
97    let c_order_mm_t = format!("{prefix}{n}{m}");
98    let c_transform =
99        format!("{c_order_mm}->{c_order_es}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
100    let c_transform_t =
101        format!("{c_order_mm_t}->{c_order_es}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
102    let transpose_c = c_transform.len() > c_transform_t.len();
103    let c_transform = if transpose_c { c_transform_t } else { c_transform };
104    let quantize_output = if let Some(qp) = op.q_params {
105        let qparams: Vec<&Tensor> = inputs[3..9]
106            .iter()
107            .map(|f| {
108                patch
109                    .outlet_fact(*f)?
110                    .konst
111                    .as_deref()
112                    .context("Can only translate fixed scalar quantization")
113            })
114            .try_collect()?;
115        Some(qp.with_qparams(QParams::ZpScale {
116            zero_point: qparams[4].cast_to_scalar::<i32>()?,
117            scale: qparams[5].cast_to_scalar::<f32>()?,
118        }))
119    } else {
120        None
121    };
122
123    let operating_dt = if ctx.ensure_strict_matmul_semantic {
124        let input_facts = model.node_input_facts(node.id)?;
125        let a_dt = input_facts[0].datum_type;
126        let b_dt = input_facts[1].datum_type;
127        let operating_dt = quantize_output.unwrap_or(op.operating_dt);
128        let a_plain = input_facts[0].is_plain();
129        let b_plain = input_facts[1].is_plain();
130        let allowed_dt = matmul_semantic_output_dt(&a_dt, a_plain, &b_dt, b_plain);
131
132        ensure!(
133            operating_dt == allowed_dt,
134            format!(
135                "Strict matmul semantic require operating_dt to be {allowed_dt:?} \
136                for (a: {a_dt:?}, b:{b_dt:?}) but got {:?}.",
137                op.operating_dt
138            )
139        );
140
141        None
142    } else {
143        Some(op.operating_dt)
144    };
145
146    wire = patch.wire_node(
147        node_name,
148        PrefixMatMul { transpose_a, transpose_b, transpose_c, quantize_output, operating_dt },
149        &wire,
150    )?;
151
152    for (ix, op) in c_transform.into_iter().enumerate() {
153        wire = patch.wire_node(format!("{node_name}.fix_c.{ix}"), op, &wire)?;
154    }
155    patch.shunt_outside(model, node.id.into(), wire[0])?;
156    Ok(Some(patch))
157}
158
159fn matmul_semantic_output_dt(
160    a_dt: &DatumType,
161    a_plain: bool,
162    b_dt: &DatumType,
163    b_plain: bool,
164) -> DatumType {
165    if a_plain && a_dt.is_number() {
166        *a_dt
167    } else if b_plain && b_dt.is_number() {
168        *b_dt
169    } else if a_dt.is_number() {
170        *a_dt
171    } else if b_dt.is_number() {
172        *b_dt
173    } else {
174        f32::datum_type()
175    }
176}
177
178#[derive(Clone, Debug, Copy, PartialEq, Eq)]
179pub struct PrefixMatMul {
180    pub transpose_a: bool,
181    pub transpose_b: bool,
182    pub transpose_c: bool,
183    pub quantize_output: Option<DatumType>,
184    pub operating_dt: Option<DatumType>,
185}
186
187impl PrefixMatMul {
188    fn output_shape<D: DimLike + One>(&self, a: &[D], b: &[D]) -> TVec<D> {
189        let rank = a.len();
190        let mut output: TVec<D> = (0..rank - 2)
191            .map(|ix| if a[ix].is_one() { b[ix].clone() } else { a[ix].clone() })
192            .collect();
193        output.push(a[rank - 2 + self.transpose_a as usize].clone());
194        output.push(b[rank - 2 + !self.transpose_b as usize].clone());
195        if self.transpose_c {
196            output.swap(rank - 2, rank - 1);
197        }
198        output
199    }
200
201    fn mm<Acc: Datum + tract_ndarray::LinalgScalar>(
202        &self,
203        acc: &mut Tensor,
204        a: &Tensor,
205        b: &Tensor,
206    ) -> TractResult<()> {
207        use crate::ndarray::Dimension;
208        let casted_a = a.cast_to::<Acc>()?;
209        let a = casted_a.to_plain_array_view::<Acc>()?;
210        let casted_b = b.cast_to::<Acc>()?;
211        let b = casted_b.to_plain_array_view::<Acc>()?;
212        let mut c_plain = acc.try_as_plain_mut()?;
213        let mut c = c_plain.to_array_view_mut::<Acc>()?;
214        for prefix in tract_ndarray::indices(&c.shape()[..c.ndim() - 2]) {
215            let mut a = a.view();
216            let mut b = b.view();
217            let mut c = c.view_mut();
218            for &d in prefix.slice().iter() {
219                a.index_axis_inplace(tract_ndarray::Axis(0), d.min(a.shape()[0] - 1));
220                b.index_axis_inplace(tract_ndarray::Axis(0), d.min(b.shape()[0] - 1));
221                c.index_axis_inplace(tract_ndarray::Axis(0), d);
222            }
223            let a = a.into_dimensionality::<Ix2>().unwrap();
224            let b = b.into_dimensionality::<Ix2>().unwrap();
225            let mut c = c.into_dimensionality::<Ix2>().unwrap();
226            let a = if self.transpose_a { a.t() } else { a };
227            let b = if self.transpose_b { b.t() } else { b };
228            if self.transpose_c { c.assign(&b.t().dot(&a.t())) } else { c.assign(&a.dot(&b)) }
229        }
230        Ok(())
231    }
232}
233
234impl Op for PrefixMatMul {
235    fn name(&self) -> StaticName {
236        "PrefixMatMul".into()
237    }
238
239    fn info(&self) -> TractResult<Vec<String>> {
240        Ok(vec![format!(
241            "transpose_a: {} transpose_b: {} transpose_c: {} q: {:?}",
242            self.transpose_a, self.transpose_b, self.transpose_c, self.quantize_output
243        )])
244    }
245
246    op_as_typed_op!();
247}
248
249impl EvalOp for PrefixMatMul {
250    fn is_stateless(&self) -> bool {
251        true
252    }
253
254    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
255        let c_dt = self.operating_dt.unwrap_or_else(|| {
256            let a_dt = inputs[0].datum_type();
257            let b_dt = inputs[1].datum_type();
258            matmul_semantic_output_dt(&a_dt, inputs[0].is_plain(), &b_dt, inputs[1].is_plain())
259        });
260
261        let inputs = dequant_inputs(c_dt, inputs)?;
262
263        let output_shape = self.output_shape(inputs[0].shape(), inputs[1].shape());
264
265        if let Some(qp) = self.quantize_output {
266            let mut acc = Tensor::zero_dt(i32::datum_type(), &output_shape)?;
267            let mut a_i32 = inputs[0].cast_to::<i32>()?.into_owned();
268            a_i32
269                .try_as_plain_mut()?
270                .as_slice_mut::<i32>()?
271                .iter_mut()
272                .for_each(|x| *x -= inputs[0].datum_type().zp_scale().0);
273            let mut b_i32 = inputs[1].cast_to::<i32>()?.into_owned();
274            b_i32
275                .try_as_plain_mut()?
276                .as_slice_mut::<i32>()?
277                .iter_mut()
278                .for_each(|x| *x -= inputs[1].datum_type().zp_scale().0);
279            self.mm::<i32>(&mut acc, &a_i32, &b_i32)?;
280            let scale = inputs[0].datum_type().zp_scale().1 * inputs[1].datum_type().zp_scale().1
281                / qp.zp_scale().1;
282            let scaler = Scaler::new(scale, tract_linalg::mmm::RoundingPolicy::Even);
283            acc.to_plain_array_view_mut::<i32>()?.iter_mut().for_each(|x| *x = *x * scaler);
284            let mut c: Tensor = acc.cast_to_dt(qp.unquantized())?.into_owned();
285            unsafe { c.set_datum_type(qp) };
286            Ok(tvec!(c.into_tvalue()))
287        } else {
288            let mut c = Tensor::zero_dt(c_dt, &output_shape)?;
289            dispatch_floatlike!(Self::mm(c_dt)(self, &mut c, &inputs[0], &inputs[1]))?;
290            Ok(tvec!(c.into_tvalue()))
291        }
292    }
293}
294
295impl TypedOp for PrefixMatMul {
296    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
297        let [a, b] = inputs else {
298            bail!("Expects 2 inputs");
299        };
300        let a_shape = block_quant_aware_input_shape(a)?;
301        let b_shape = block_quant_aware_input_shape(b)?;
302        let dt = self.quantize_output.or(self.operating_dt).unwrap_or(matmul_semantic_output_dt(
303            &a.datum_type,
304            a.is_plain(),
305            &b.datum_type,
306            b.is_plain(),
307        ));
308        Ok(tvec!(dt.fact(self.output_shape(&a_shape, &b_shape))))
309    }
310
311    as_op!();
312}
313
314#[cfg(test)]
315mod test {
316    use crate::ops::einsum::EinSum;
317
318    use super::*;
319    use proptest::collection::vec;
320    use proptest::prelude::*;
321    use proptest::test_runner::{TestCaseResult, TestRunner};
322    use tract_data::itertools::Itertools;
323
324    pub fn tensor(shape: &[usize]) -> BoxedStrategy<Tensor> {
325        let shape = shape.to_vec();
326        let len = shape.iter().product::<usize>();
327        vec((-10i8..=10i8).prop_map(|i| i as f32), len..=len)
328            .prop_map(move |vec| tensor1(&vec).into_shape(&shape).unwrap())
329            .boxed()
330    }
331
332    fn full_shapes(e: &AxesMapping) -> BoxedStrategy<(Vec<usize>, Vec<usize>)> {
333        let e = e.clone();
334        let inputs_axes = e
335            .iter_all_axes()
336            .filter(|axis| axis.inputs[0].len() + axis.inputs[1].len() > 0)
337            .cloned()
338            .collect_vec();
339        let dims = vec![2usize..6; inputs_axes.len()];
340        dims.prop_map(move |dims| {
341            let a: Vec<usize> = e
342                .axes(InOut::In(0))
343                .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
344                .collect_vec();
345            let b: Vec<usize> = e
346                .axes(InOut::In(1))
347                .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
348                .collect_vec();
349            (a, b)
350        })
351        .boxed()
352    }
353
354    fn test_expr(expr: &str) -> TestCaseResult {
355        let expr = expr.to_string();
356        let mut runner = TestRunner::default();
357        let axes: AxesMapping = expr.parse().unwrap();
358        fn is_k(axes: &AxesMapping, input: usize, position: usize) -> bool {
359            let axis = axes.axis((InOut::In(input), position)).unwrap();
360            axis.inputs[1 - input].len() == 1 && axis.outputs[0].len() == 0
361        }
362        fn is_disapearing_axis(axes: &AxesMapping, input: usize, position: usize) -> bool {
363            let axis = axes.axis((InOut::In(input), position)).unwrap();
364            axis.outputs[0].len() == 0
365        }
366        let cases = full_shapes(&axes)
367            .prop_flat_map(|(a, b)| {
368                (
369                    a.iter()
370                        .enumerate()
371                        .map(|(ix, d)| {
372                            if is_k(&axes, 0, ix) {
373                                prop_oneof![Just(*d)].boxed()
374                            } else if is_disapearing_axis(&axes, 0, ix) {
375                                Just(1).boxed()
376                            } else {
377                                prop_oneof![Just(1usize), Just(*d)].boxed()
378                            }
379                        })
380                        .collect_vec(),
381                    b.iter()
382                        .enumerate()
383                        .map(|(ix, d)| {
384                            if is_k(&axes, 1, ix) {
385                                prop_oneof![Just(*d)].boxed()
386                            } else if is_disapearing_axis(&axes, 1, ix) {
387                                Just(1).boxed()
388                            } else {
389                                prop_oneof![Just(1usize), Just(*d)].boxed()
390                            }
391                        })
392                        .collect_vec(),
393                )
394            })
395            .prop_flat_map(|(a_shape, b_shape)| (tensor(&a_shape), tensor(&b_shape)))
396            .prop_map(|(a, b)| EinSumProblem { expr: expr.clone(), a, b });
397        runner.run(&cases, |pb| pb.check().map_err(|e| TestCaseError::fail(e.to_string())))?;
398        Ok(())
399    }
400
401    #[derive(Debug, Clone, PartialEq, Eq)]
402    struct EinSumProblem {
403        expr: String,
404        a: Tensor,
405        b: Tensor,
406    }
407
408    impl EinSumProblem {
409        fn check(&self) -> TractResult<()> {
410            let mut model = TypedModel::default();
411            let sa = model.add_source("a", f32::fact(self.a.shape()))?;
412            let sb = model.add_source("b", f32::fact(self.b.shape()))?;
413            let einsum = model.wire_node(
414                "einsum",
415                EinSum::new(self.expr.parse().unwrap(), f32::datum_type()),
416                &[sa, sb],
417            )?;
418            model.select_output_outlets(&einsum)?;
419            let a = self.a.clone().into_tvalue();
420            let b = self.b.clone().into_tvalue();
421            let inputs = tvec!(a, b);
422            let reference = TypedRunnableModel::new(model.clone())?.run(inputs.clone())?.remove(0);
423            rewrite_einsum_to_prefix_matmul(&mut model, true)?;
424            assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
425            let test = TypedRunnableModel::new(model)?.run(inputs)?.remove(0);
426            reference.close_enough(&test, true)
427        }
428    }
429
430    #[rustfmt::skip] #[test] fn prop_mk_kn_mn() -> TestCaseResult { test_expr("mk,kn->mn") }
431    #[rustfmt::skip] #[test] fn prop_km_kn_mn() -> TestCaseResult { test_expr("km,kn->mn") }
432    #[rustfmt::skip] #[test] fn prop_mk_nk_mn() -> TestCaseResult { test_expr("mk,nk->mn") }
433    #[rustfmt::skip] #[test] fn prop_mk_kn_nm() -> TestCaseResult { test_expr("mk,kn->nm") }
434    #[rustfmt::skip] #[test] fn prop_k_kn_mn() -> TestCaseResult { test_expr("k,kn->mn") }
435    #[rustfmt::skip] #[test] fn prop_mk_k_mn() -> TestCaseResult { test_expr("mk,k->mn") }
436    #[rustfmt::skip] #[test] fn prop_m_n_mn() -> TestCaseResult { test_expr("m,n->mn") }
437    #[rustfmt::skip] #[test] fn prop_amk_akn_amn() -> TestCaseResult { test_expr("amk,akn->amn") }
438    #[rustfmt::skip] #[test] fn prop_mk_akn_amn() -> TestCaseResult { test_expr("mk,akn->amn") }
439    #[rustfmt::skip] #[test] fn prop_btgi_gih_tgh() -> TestCaseResult { test_expr("btgi,gih->tgh") }
440    #[rustfmt::skip] #[test] fn prop_tgi_gih_btgh() -> TestCaseResult { test_expr("tgi,gih->btgh") }
441
442    #[test]
443    fn k_kn_mn_0() -> TractResult<()> {
444        EinSumProblem {
445            expr: "k,kn->mn".to_string(),
446            a: tensor1(&[0f32, 0f32]),
447            b: tensor2(&[[0f32, 0.], [0., 0.]]),
448        }
449        .check()
450    }
451
452    #[test]
453    fn mk_k_mn_0() -> TractResult<()> {
454        EinSumProblem {
455            expr: "mk,k->mn".to_string(),
456            a: Tensor::zero::<f32>(&[2, 2]).unwrap(),
457            b: Tensor::zero::<f32>(&[2]).unwrap(),
458        }
459        .check()
460    }
461
462    #[test]
463    fn mk_k_mn_1() -> TractResult<()> {
464        EinSumProblem {
465            expr: "mk,k->mn".to_string(),
466            a: Tensor::zero::<f32>(&[1, 2]).unwrap(),
467            b: Tensor::zero::<f32>(&[2]).unwrap(),
468        }
469        .check()
470    }
471
472    #[test]
473    fn mk_kn_nm_0() -> TractResult<()> {
474        EinSumProblem {
475            expr: "mk,kn->mn".to_string(),
476            a: Tensor::zero::<f32>(&[3, 2]).unwrap(),
477            b: Tensor::zero::<f32>(&[2, 2]).unwrap(),
478        }
479        .check()
480    }
481
482    #[test]
483    fn amk_akn_amn_0() -> TractResult<()> {
484        EinSumProblem {
485            expr: "amk,akn->amn".to_string(),
486            a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
487            b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
488        }
489        .check()
490    }
491
492    #[test]
493    fn amk_akn_amn_1() -> TractResult<()> {
494        EinSumProblem {
495            expr: "amk,akn->amn".to_string(),
496            a: Tensor::zero::<f32>(&[2, 1, 2]).unwrap(),
497            b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
498        }
499        .check()
500    }
501
502    #[test]
503    fn amk_akn_amn_2() -> TractResult<()> {
504        EinSumProblem {
505            expr: "amk,akn->amn".to_string(),
506            a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
507            b: Tensor::zero::<f32>(&[2, 2, 2]).unwrap(),
508        }
509        .check()
510    }
511
512    #[test]
513    fn amk_akn_amn_3() -> TractResult<()> {
514        EinSumProblem {
515            expr: "amk,akn->amn".to_string(),
516            a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
517            b: Tensor::zero::<f32>(&[2, 2, 1]).unwrap(),
518        }
519        .check()
520    }
521
522    #[test]
523    fn km_anbck_bmn_0() -> TractResult<()> {
524        EinSumProblem {
525            expr: "km,anbck->bmn".to_string(),
526            a: Tensor::zero::<f32>(&[2, 1]).unwrap(),
527            b: Tensor::zero::<f32>(&[1, 1, 1, 1, 2]).unwrap(),
528        }
529        .check()
530    }
531
532    #[test]
533    fn q() -> TractResult<()> {
534        let qp = QParams::ZpScale { zero_point: 0, scale: 0.1 };
535        let op = EinSum {
536            axes: "mk,kn,m,,,,,,->mn".parse()?,
537            operating_dt: i32::datum_type(),
538            q_params: Some(DatumType::QI8(qp)),
539        };
540        let mut model = TypedModelPatch::default();
541        let inputs = [
542            model.add_source("a", DatumType::QI8(qp).fact([3, 2]))?,
543            model.add_source("b", DatumType::QI8(qp).fact([2, 4]))?,
544            model.add_source("bias", i32::datum_type().fact([3]))?,
545            model.add_const("a0", tensor0(qp.zp_scale().0))?,
546            model.add_const("a_scale", tensor0(qp.zp_scale().1))?,
547            model.add_const("b0", tensor0(qp.zp_scale().0))?,
548            model.add_const("b_scale", tensor0(qp.zp_scale().1))?,
549            model.add_const("c0", tensor0(qp.zp_scale().0))?,
550            model.add_const("c_scale", tensor0(qp.zp_scale().1))?,
551        ];
552        let wire = model.wire_node("einsum", op.clone(), &inputs)?;
553        model.select_output_outlets(&wire)?;
554        rewrite_einsum_to_prefix_matmul(&mut model, true)?;
555        assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
556        Ok(())
557    }
558}