Skip to main content

tract_core/ops/einsum/
prefix_matmul.rs

1use tract_data::itertools::Itertools;
2use tract_linalg::Scaler;
3use tract_ndarray::Ix2;
4use tract_num_traits::One;
5
6use super::einsum_matmul::EinSumMatMul;
7use super::eval::dequant_inputs;
8use crate::internal::*;
9use crate::ops::einsum::block_quant_aware_input_shape;
10use crate::ops::konst::Const;
11
12#[derive(Debug, Default)]
13pub struct EinSumToPrefixMatmulCtx {
14    pub ensure_strict_matmul_semantic: bool,
15}
16
17pub fn rewrite_einsum_to_prefix_matmul(
18    model: &mut TypedModel,
19    ensure_strict_matmul_semantic: bool,
20) -> TractResult<()> {
21    super::einsum_matmul::detect_all(model)?;
22    let ctx = EinSumToPrefixMatmulCtx { ensure_strict_matmul_semantic };
23    Rewriter::default().with_rule_for("einsum-to-prefix-matmul", rule).rewrite(&ctx, model)
24}
25
26fn rule(
27    ctx: &EinSumToPrefixMatmulCtx,
28    model: &TypedModel,
29    node: &TypedNode,
30    node_name: &str,
31    op: &EinSumMatMul,
32) -> TractResult<Option<TypedModelPatch>> {
33    // F: 2 inputs
34    // Q: 9 inputs
35    let is_fp_mm = op.q_params.is_none() && node.inputs.len() == 2;
36    let is_q_mm = op.q_params.is_some() && node.inputs.len() == 9;
37    rule_if!(is_fp_mm || is_q_mm);
38    rule_if!(
39        op.q_params.is_none()
40            || model.node_input_facts(node.id)?.iter().skip(3).all(|i| i.konst.is_some())
41    );
42    let prefix: String = op
43        .axes
44        .iter_all_axes()
45        .filter(|a| ![op.m_axis, op.k_axis, op.n_axis].contains(&a.repr))
46        .map(|a| a.repr)
47        .collect();
48    let mut patch = TypedModelPatch::default();
49    let inputs = patch.taps(model, &node.inputs)?;
50    let mut wire = tvec!(inputs[0], inputs[1]);
51
52    let (m, k, n) = (op.m_axis, op.k_axis, op.n_axis);
53    let a_order_es: String = op.axes.axes(InOut::In(0)).map(|a| a.repr).collect();
54    let a_order_mm = format!("{prefix}{m}{k}");
55    let a_order_mm_t = format!("{prefix}{k}{m}");
56    let a_transform =
57        format!("{a_order_es}->{a_order_mm}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
58    let a_transform_t =
59        format!("{a_order_es}->{a_order_mm_t}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
60    let transpose_a = a_transform.len() > a_transform_t.len();
61    let a_transform = if transpose_a { a_transform_t } else { a_transform };
62    let name = format!("{node_name}.fix_a");
63    for op in a_transform {
64        wire[0] = patch.wire_node(&name, op, &[wire[0]])?[0];
65    }
66    // terrible hack to maintain opaque fact through eager propatagation of constant through the
67    // axes transformation
68    if let Some(op) = patch.node_mut(wire[0].node).op_as_mut::<Const>() {
69        *op = Const::new_with_opt_opaque_fact(
70            op.val().clone(),
71            model.outlet_fact(node.inputs[0])?.opaque_fact.clone(),
72        )?;
73    }
74    patch
75        .outlet_fact_mut(wire[0])?
76        .opaque_fact
77        .clone_from(&model.outlet_fact(node.inputs[0])?.opaque_fact);
78    // end of hack
79
80    let b_order_es: String = op.axes.axes(InOut::In(1)).map(|a| a.repr).collect();
81    let b_order_mm = format!("{prefix}{k}{n}");
82    let b_order_mm_t = format!("{prefix}{n}{k}");
83    let b_transform =
84        format!("{b_order_es}->{b_order_mm}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
85    let b_transform_t =
86        format!("{b_order_es}->{b_order_mm_t}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
87    let transpose_b = b_transform.len() > b_transform_t.len();
88    let b_transform = if transpose_b { b_transform_t } else { b_transform };
89    let name = format!("{node_name}.fix_b");
90    for op in b_transform {
91        wire[1] = patch.wire_node(&name, op, &[wire[1]])?[0];
92    }
93
94    let c_order_es: String = op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
95    let c_order_mm = format!("{prefix}{m}{n}");
96    let c_order_mm_t = format!("{prefix}{n}{m}");
97    let c_transform =
98        format!("{c_order_mm}->{c_order_es}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
99    let c_transform_t =
100        format!("{c_order_mm_t}->{c_order_es}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
101    let transpose_c = c_transform.len() > c_transform_t.len();
102    let c_transform = if transpose_c { c_transform_t } else { c_transform };
103    let quantize_output = if let Some(qp) = op.q_params {
104        let qparams: Vec<&Tensor> = inputs[3..9]
105            .iter()
106            .map(|f| {
107                patch
108                    .outlet_fact(*f)?
109                    .konst
110                    .as_deref()
111                    .context("Can only translate fixed scalar quantization")
112            })
113            .try_collect()?;
114        Some(qp.with_qparams(QParams::ZpScale {
115            zero_point: qparams[4].cast_to_scalar::<i32>()?,
116            scale: qparams[5].cast_to_scalar::<f32>()?,
117        }))
118    } else {
119        None
120    };
121
122    let operating_dt = if ctx.ensure_strict_matmul_semantic {
123        let input_facts = model.node_input_facts(node.id)?;
124        let a_dt = input_facts[0].datum_type;
125        let b_dt = input_facts[1].datum_type;
126        let operating_dt = quantize_output.unwrap_or(op.operating_dt);
127        let allowed_dt = matmul_semantic_output_dt(&a_dt, &b_dt);
128
129        ensure!(
130            operating_dt == allowed_dt,
131            format!(
132                "Strict matmul semantic require operating_dt to be {allowed_dt:?} \
133                for (a: {a_dt:?}, b:{b_dt:?}) but got {:?}.",
134                op.operating_dt
135            )
136        );
137
138        None
139    } else {
140        Some(op.operating_dt)
141    };
142
143    wire = patch.wire_node(
144        node_name,
145        PrefixMatMul { transpose_a, transpose_b, transpose_c, quantize_output, operating_dt },
146        &wire,
147    )?;
148
149    for (ix, op) in c_transform.into_iter().enumerate() {
150        wire = patch.wire_node(format!("{node_name}.fix_c.{ix}"), op, &wire)?;
151    }
152    patch.shunt_outside(model, node.id.into(), wire[0])?;
153    Ok(Some(patch))
154}
155
156fn matmul_semantic_output_dt(a_dt: &DatumType, b_dt: &DatumType) -> DatumType {
157    if a_dt.is_number() {
158        *a_dt
159    } else if b_dt.is_number() {
160        *b_dt
161    } else {
162        f32::datum_type()
163    }
164}
165
166#[derive(Clone, Debug, Copy)]
167pub struct PrefixMatMul {
168    pub transpose_a: bool,
169    pub transpose_b: bool,
170    pub transpose_c: bool,
171    pub quantize_output: Option<DatumType>,
172    pub operating_dt: Option<DatumType>,
173}
174
175impl PrefixMatMul {
176    fn output_shape<D: DimLike + One>(&self, a: &[D], b: &[D]) -> TVec<D> {
177        let rank = a.len();
178        let mut output: TVec<D> = (0..rank - 2)
179            .map(|ix| if a[ix].is_one() { b[ix].clone() } else { a[ix].clone() })
180            .collect();
181        output.push(a[rank - 2 + self.transpose_a as usize].clone());
182        output.push(b[rank - 2 + !self.transpose_b as usize].clone());
183        if self.transpose_c {
184            output.swap(rank - 2, rank - 1);
185        }
186        output
187    }
188
189    fn mm<Acc: Datum + tract_ndarray::LinalgScalar>(
190        &self,
191        acc: &mut Tensor,
192        a: &Tensor,
193        b: &Tensor,
194    ) -> TractResult<()> {
195        use crate::ndarray::Dimension;
196        let casted_a = a.cast_to::<Acc>()?;
197        let a = casted_a.to_dense_array_view::<Acc>()?;
198        let casted_b = b.cast_to::<Acc>()?;
199        let b = casted_b.to_dense_array_view::<Acc>()?;
200        let mut c_dense = acc.try_as_dense_mut()?;
201        let mut c = c_dense.to_array_view_mut::<Acc>()?;
202        for prefix in tract_ndarray::indices(&c.shape()[..c.ndim() - 2]) {
203            let mut a = a.view();
204            let mut b = b.view();
205            let mut c = c.view_mut();
206            for &d in prefix.slice().iter() {
207                a.index_axis_inplace(tract_ndarray::Axis(0), d.min(a.shape()[0] - 1));
208                b.index_axis_inplace(tract_ndarray::Axis(0), d.min(b.shape()[0] - 1));
209                c.index_axis_inplace(tract_ndarray::Axis(0), d);
210            }
211            let a = a.into_dimensionality::<Ix2>().unwrap();
212            let b = b.into_dimensionality::<Ix2>().unwrap();
213            let mut c = c.into_dimensionality::<Ix2>().unwrap();
214            let a = if self.transpose_a { a.t() } else { a };
215            let b = if self.transpose_b { b.t() } else { b };
216            if self.transpose_c { c.assign(&b.t().dot(&a.t())) } else { c.assign(&a.dot(&b)) }
217        }
218        Ok(())
219    }
220}
221
222impl Op for PrefixMatMul {
223    fn name(&self) -> StaticName {
224        "PrefixMatMul".into()
225    }
226
227    fn info(&self) -> TractResult<Vec<String>> {
228        Ok(vec![format!(
229            "transpose_a: {} transpose_b: {} transpose_c: {} q: {:?}",
230            self.transpose_a, self.transpose_b, self.transpose_c, self.quantize_output
231        )])
232    }
233
234    op_as_typed_op!();
235}
236
237impl EvalOp for PrefixMatMul {
238    fn is_stateless(&self) -> bool {
239        true
240    }
241
242    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
243        let c_dt = self.operating_dt.unwrap_or_else(|| {
244            let a_dt = inputs[0].datum_type();
245            let b_dt = inputs[1].datum_type();
246            matmul_semantic_output_dt(&a_dt, &b_dt)
247        });
248
249        let inputs = dequant_inputs(c_dt, inputs)?;
250
251        let output_shape = self.output_shape(inputs[0].shape(), inputs[1].shape());
252
253        if let Some(qp) = self.quantize_output {
254            let mut acc = Tensor::zero_dt(i32::datum_type(), &output_shape)?;
255            let mut a_i32 = inputs[0].cast_to::<i32>()?.into_owned();
256            a_i32
257                .try_as_dense_mut()?
258                .as_slice_mut::<i32>()?
259                .iter_mut()
260                .for_each(|x| *x -= inputs[0].datum_type().zp_scale().0);
261            let mut b_i32 = inputs[1].cast_to::<i32>()?.into_owned();
262            b_i32
263                .try_as_dense_mut()?
264                .as_slice_mut::<i32>()?
265                .iter_mut()
266                .for_each(|x| *x -= inputs[1].datum_type().zp_scale().0);
267            self.mm::<i32>(&mut acc, &a_i32, &b_i32)?;
268            let scale = inputs[0].datum_type().zp_scale().1 * inputs[1].datum_type().zp_scale().1
269                / qp.zp_scale().1;
270            let scaler = Scaler::new(scale, tract_linalg::mmm::RoundingPolicy::Even);
271            acc.to_dense_array_view_mut::<i32>()?.iter_mut().for_each(|x| *x = *x * scaler);
272            let mut c: Tensor = acc.cast_to_dt(qp.unquantized())?.into_owned();
273            unsafe { c.set_datum_type(qp) };
274            Ok(tvec!(c.into_tvalue()))
275        } else {
276            let mut c = Tensor::zero_dt(c_dt, &output_shape)?;
277            dispatch_floatlike!(Self::mm(c_dt)(self, &mut c, &inputs[0], &inputs[1]))?;
278            Ok(tvec!(c.into_tvalue()))
279        }
280    }
281}
282
283impl TypedOp for PrefixMatMul {
284    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
285        let [a, b] = inputs else {
286            bail!("Expects 2 inputs");
287        };
288        let a_shape = block_quant_aware_input_shape(a)?;
289        let b_shape = block_quant_aware_input_shape(b)?;
290        let dt = self
291            .quantize_output
292            .or(self.operating_dt)
293            .unwrap_or(matmul_semantic_output_dt(&a.datum_type, &b.datum_type));
294        Ok(tvec!(dt.fact(self.output_shape(&a_shape, &b_shape))))
295    }
296
297    as_op!();
298}
299
300#[cfg(test)]
301mod test {
302    use crate::ops::einsum::EinSum;
303
304    use super::*;
305    use proptest::collection::vec;
306    use proptest::prelude::*;
307    use proptest::test_runner::{TestCaseResult, TestRunner};
308    use tract_data::itertools::Itertools;
309
310    pub fn tensor(shape: &[usize]) -> BoxedStrategy<Tensor> {
311        let shape = shape.to_vec();
312        let len = shape.iter().product::<usize>();
313        vec((-10i8..=10i8).prop_map(|i| i as f32), len..=len)
314            .prop_map(move |vec| tensor1(&vec).into_shape(&shape).unwrap())
315            .boxed()
316    }
317
318    fn full_shapes(e: &AxesMapping) -> BoxedStrategy<(Vec<usize>, Vec<usize>)> {
319        let e = e.clone();
320        let inputs_axes = e
321            .iter_all_axes()
322            .filter(|axis| axis.inputs[0].len() + axis.inputs[1].len() > 0)
323            .cloned()
324            .collect_vec();
325        let dims = vec![2usize..6; inputs_axes.len()];
326        dims.prop_map(move |dims| {
327            let a: Vec<usize> = e
328                .axes(InOut::In(0))
329                .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
330                .collect_vec();
331            let b: Vec<usize> = e
332                .axes(InOut::In(1))
333                .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
334                .collect_vec();
335            (a, b)
336        })
337        .boxed()
338    }
339
340    fn test_expr(expr: &str) -> TestCaseResult {
341        let expr = expr.to_string();
342        let mut runner = TestRunner::default();
343        let axes: AxesMapping = expr.parse().unwrap();
344        fn is_k(axes: &AxesMapping, input: usize, position: usize) -> bool {
345            let axis = axes.axis((InOut::In(input), position)).unwrap();
346            axis.inputs[1 - input].len() == 1 && axis.outputs[0].len() == 0
347        }
348        fn is_disapearing_axis(axes: &AxesMapping, input: usize, position: usize) -> bool {
349            let axis = axes.axis((InOut::In(input), position)).unwrap();
350            axis.outputs[0].len() == 0
351        }
352        let cases = full_shapes(&axes)
353            .prop_flat_map(|(a, b)| {
354                (
355                    a.iter()
356                        .enumerate()
357                        .map(|(ix, d)| {
358                            if is_k(&axes, 0, ix) {
359                                prop_oneof![Just(*d)].boxed()
360                            } else if is_disapearing_axis(&axes, 0, ix) {
361                                Just(1).boxed()
362                            } else {
363                                prop_oneof![Just(1usize), Just(*d)].boxed()
364                            }
365                        })
366                        .collect_vec(),
367                    b.iter()
368                        .enumerate()
369                        .map(|(ix, d)| {
370                            if is_k(&axes, 1, ix) {
371                                prop_oneof![Just(*d)].boxed()
372                            } else if is_disapearing_axis(&axes, 1, ix) {
373                                Just(1).boxed()
374                            } else {
375                                prop_oneof![Just(1usize), Just(*d)].boxed()
376                            }
377                        })
378                        .collect_vec(),
379                )
380            })
381            .prop_flat_map(|(a_shape, b_shape)| (tensor(&a_shape), tensor(&b_shape)))
382            .prop_map(|(a, b)| EinSumProblem { expr: expr.clone(), a, b });
383        runner.run(&cases, |pb| pb.check().map_err(|e| TestCaseError::fail(e.to_string())))?;
384        Ok(())
385    }
386
387    #[derive(Debug, Clone, PartialEq)]
388    struct EinSumProblem {
389        expr: String,
390        a: Tensor,
391        b: Tensor,
392    }
393
394    impl EinSumProblem {
395        fn check(&self) -> TractResult<()> {
396            let mut model = TypedModel::default();
397            let sa = model.add_source("a", f32::fact(self.a.shape()))?;
398            let sb = model.add_source("b", f32::fact(self.b.shape()))?;
399            let einsum = model.wire_node(
400                "einsum",
401                EinSum::new(self.expr.parse().unwrap(), f32::datum_type()),
402                &[sa, sb],
403            )?;
404            model.set_output_outlets(&einsum)?;
405            let a = self.a.clone().into_tvalue();
406            let b = self.b.clone().into_tvalue();
407            let inputs = tvec!(a, b);
408            let reference = TypedRunnableModel::new(model.clone())?.run(inputs.clone())?.remove(0);
409            rewrite_einsum_to_prefix_matmul(&mut model, true)?;
410            assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
411            let test = TypedRunnableModel::new(model)?.run(inputs)?.remove(0);
412            reference.close_enough(&test, true)
413        }
414    }
415
416    #[rustfmt::skip] #[test] fn prop_mk_kn_mn() -> TestCaseResult { test_expr("mk,kn->mn") }
417    #[rustfmt::skip] #[test] fn prop_km_kn_mn() -> TestCaseResult { test_expr("km,kn->mn") }
418    #[rustfmt::skip] #[test] fn prop_mk_nk_mn() -> TestCaseResult { test_expr("mk,nk->mn") }
419    #[rustfmt::skip] #[test] fn prop_mk_kn_nm() -> TestCaseResult { test_expr("mk,kn->nm") }
420    #[rustfmt::skip] #[test] fn prop_k_kn_mn() -> TestCaseResult { test_expr("k,kn->mn") }
421    #[rustfmt::skip] #[test] fn prop_mk_k_mn() -> TestCaseResult { test_expr("mk,k->mn") }
422    #[rustfmt::skip] #[test] fn prop_m_n_mn() -> TestCaseResult { test_expr("m,n->mn") }
423    #[rustfmt::skip] #[test] fn prop_amk_akn_amn() -> TestCaseResult { test_expr("amk,akn->amn") }
424    #[rustfmt::skip] #[test] fn prop_mk_akn_amn() -> TestCaseResult { test_expr("mk,akn->amn") }
425    #[rustfmt::skip] #[test] fn prop_btgi_gih_tgh() -> TestCaseResult { test_expr("btgi,gih->tgh") }
426    #[rustfmt::skip] #[test] fn prop_tgi_gih_btgh() -> TestCaseResult { test_expr("tgi,gih->btgh") }
427
428    #[test]
429    fn k_kn_mn_0() -> TractResult<()> {
430        EinSumProblem {
431            expr: "k,kn->mn".to_string(),
432            a: tensor1(&[0f32, 0f32]),
433            b: tensor2(&[[0f32, 0.], [0., 0.]]),
434        }
435        .check()
436    }
437
438    #[test]
439    fn mk_k_mn_0() -> TractResult<()> {
440        EinSumProblem {
441            expr: "mk,k->mn".to_string(),
442            a: Tensor::zero::<f32>(&[2, 2]).unwrap(),
443            b: Tensor::zero::<f32>(&[2]).unwrap(),
444        }
445        .check()
446    }
447
448    #[test]
449    fn mk_k_mn_1() -> TractResult<()> {
450        EinSumProblem {
451            expr: "mk,k->mn".to_string(),
452            a: Tensor::zero::<f32>(&[1, 2]).unwrap(),
453            b: Tensor::zero::<f32>(&[2]).unwrap(),
454        }
455        .check()
456    }
457
458    #[test]
459    fn mk_kn_nm_0() -> TractResult<()> {
460        EinSumProblem {
461            expr: "mk,kn->mn".to_string(),
462            a: Tensor::zero::<f32>(&[3, 2]).unwrap(),
463            b: Tensor::zero::<f32>(&[2, 2]).unwrap(),
464        }
465        .check()
466    }
467
468    #[test]
469    fn amk_akn_amn_0() -> TractResult<()> {
470        EinSumProblem {
471            expr: "amk,akn->amn".to_string(),
472            a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
473            b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
474        }
475        .check()
476    }
477
478    #[test]
479    fn amk_akn_amn_1() -> TractResult<()> {
480        EinSumProblem {
481            expr: "amk,akn->amn".to_string(),
482            a: Tensor::zero::<f32>(&[2, 1, 2]).unwrap(),
483            b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
484        }
485        .check()
486    }
487
488    #[test]
489    fn amk_akn_amn_2() -> TractResult<()> {
490        EinSumProblem {
491            expr: "amk,akn->amn".to_string(),
492            a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
493            b: Tensor::zero::<f32>(&[2, 2, 2]).unwrap(),
494        }
495        .check()
496    }
497
498    #[test]
499    fn amk_akn_amn_3() -> TractResult<()> {
500        EinSumProblem {
501            expr: "amk,akn->amn".to_string(),
502            a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
503            b: Tensor::zero::<f32>(&[2, 2, 1]).unwrap(),
504        }
505        .check()
506    }
507
508    #[test]
509    fn km_anbck_bmn_0() -> TractResult<()> {
510        EinSumProblem {
511            expr: "km,anbck->bmn".to_string(),
512            a: Tensor::zero::<f32>(&[2, 1]).unwrap(),
513            b: Tensor::zero::<f32>(&[1, 1, 1, 1, 2]).unwrap(),
514        }
515        .check()
516    }
517
518    #[test]
519    fn q() -> TractResult<()> {
520        let qp = QParams::ZpScale { zero_point: 0, scale: 0.1 };
521        let op = EinSum {
522            axes: "mk,kn,m,,,,,,->mn".parse()?,
523            operating_dt: i32::datum_type(),
524            q_params: Some(DatumType::QI8(qp)),
525        };
526        let mut model = TypedModelPatch::default();
527        let inputs = [
528            model.add_source("a", DatumType::QI8(qp).fact([3, 2]))?,
529            model.add_source("b", DatumType::QI8(qp).fact([2, 4]))?,
530            model.add_source("bias", i32::datum_type().fact([3]))?,
531            model.add_const("a0", tensor0(qp.zp_scale().0))?,
532            model.add_const("a_scale", tensor0(qp.zp_scale().1))?,
533            model.add_const("b0", tensor0(qp.zp_scale().0))?,
534            model.add_const("b_scale", tensor0(qp.zp_scale().1))?,
535            model.add_const("c0", tensor0(qp.zp_scale().0))?,
536            model.add_const("c_scale", tensor0(qp.zp_scale().1))?,
537        ];
538        let wire = model.wire_node("einsum", op.clone(), &inputs)?;
539        model.set_output_outlets(&wire)?;
540        rewrite_einsum_to_prefix_matmul(&mut model, true)?;
541        assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
542        Ok(())
543    }
544}