Skip to main content

tract_core/ops/einsum/
prefix_matmul.rs

1use tract_data::itertools::Itertools;
2use tract_linalg::Scaler;
3use tract_ndarray::Ix2;
4use tract_num_traits::One;
5
6use super::einsum_matmul::EinSumMatMul;
7use super::eval::dequant_inputs;
8use crate::internal::*;
9use crate::ops::einsum::block_quant_aware_input_shape;
10use crate::ops::konst::Const;
11
12#[derive(Debug, Default)]
13pub struct EinSumToPrefixMatmulCtx {
14    pub ensure_strict_matmul_semantic: bool,
15}
16
17pub fn rewrite_einsum_to_prefix_matmul(
18    model: &mut TypedModel,
19    ensure_strict_matmul_semantic: bool,
20) -> TractResult<()> {
21    super::einsum_matmul::detect_all(model)?;
22    let ctx = EinSumToPrefixMatmulCtx { ensure_strict_matmul_semantic };
23    Rewriter::default().with_rule_for("einsum-to-prefix-matmul", rule).rewrite(&ctx, model)
24}
25
26fn rule(
27    ctx: &EinSumToPrefixMatmulCtx,
28    model: &TypedModel,
29    node: &TypedNode,
30    node_name: &str,
31    op: &EinSumMatMul,
32) -> TractResult<Option<TypedModelPatch>> {
33    // F: 2 inputs
34    // Q: 9 inputs
35    let is_fp_mm = op.q_params.is_none() && node.inputs.len() == 2;
36    let is_q_mm = op.q_params.is_some() && node.inputs.len() == 9;
37    rule_if!(is_fp_mm || is_q_mm);
38    rule_if!(
39        op.q_params.is_none()
40            || model.node_input_facts(node.id)?.iter().skip(3).all(|i| i.konst.is_some())
41    );
42    let prefix: String = op
43        .axes
44        .iter_all_axes()
45        .filter(|a| ![op.m_axis, op.k_axis, op.n_axis].contains(&a.repr))
46        .map(|a| a.repr)
47        .collect();
48    let mut patch = TypedModelPatch::default();
49    let inputs = patch.taps(model, &node.inputs)?;
50    let mut wire = tvec!(inputs[0], inputs[1]);
51
52    let (m, k, n) = (op.m_axis, op.k_axis, op.n_axis);
53    let a_order_es: String = op.axes.axes(InOut::In(0)).map(|a| a.repr).collect();
54    let a_order_mm = format!("{prefix}{m}{k}");
55    let a_order_mm_t = format!("{prefix}{k}{m}");
56    let a_transform =
57        format!("{a_order_es}->{a_order_mm}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
58    let a_transform_t =
59        format!("{a_order_es}->{a_order_mm_t}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
60    let transpose_a = a_transform.len() > a_transform_t.len();
61    let a_transform = if transpose_a { a_transform_t } else { a_transform };
62    let name = format!("{node_name}.fix_a");
63    for op in a_transform {
64        wire[0] = patch.wire_node(&name, op, &[wire[0]])?[0];
65    }
66    // terrible hack to maintain opaque fact through eager propatagation of constant through the
67    // axes transformation
68    if let Some(op) = patch.node_mut(wire[0].node).op_as_mut::<Const>() {
69        *op = Const::new_with_opt_opaque_fact(
70            op.val().clone(),
71            model.outlet_fact(node.inputs[0])?.opaque_fact.clone(),
72        )?;
73    }
74    patch
75        .outlet_fact_mut(wire[0])?
76        .opaque_fact
77        .clone_from(&model.outlet_fact(node.inputs[0])?.opaque_fact);
78    // end of hack
79
80    let b_order_es: String = op.axes.axes(InOut::In(1)).map(|a| a.repr).collect();
81    let b_order_mm = format!("{prefix}{k}{n}");
82    let b_order_mm_t = format!("{prefix}{n}{k}");
83    let b_transform =
84        format!("{b_order_es}->{b_order_mm}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
85    let b_transform_t =
86        format!("{b_order_es}->{b_order_mm_t}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
87    let transpose_b = b_transform.len() > b_transform_t.len();
88    let b_transform = if transpose_b { b_transform_t } else { b_transform };
89    let name = format!("{node_name}.fix_b");
90    for op in b_transform {
91        wire[1] = patch.wire_node(&name, op, &[wire[1]])?[0];
92    }
93
94    let c_order_es: String = op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
95    let c_order_mm = format!("{prefix}{m}{n}");
96    let c_order_mm_t = format!("{prefix}{n}{m}");
97    let c_transform =
98        format!("{c_order_mm}->{c_order_es}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
99    let c_transform_t =
100        format!("{c_order_mm_t}->{c_order_es}").parse::<AxesMapping>()?.translate_to_axis_ops()?;
101    let transpose_c = c_transform.len() > c_transform_t.len();
102    let c_transform = if transpose_c { c_transform_t } else { c_transform };
103    let quantize_output = if let Some(qp) = op.q_params {
104        let qparams: Vec<&Tensor> = inputs[3..9]
105            .iter()
106            .map(|f| {
107                patch
108                    .outlet_fact(*f)?
109                    .konst
110                    .as_deref()
111                    .context("Can only translate fixed scalar quantization")
112            })
113            .try_collect()?;
114        Some(qp.with_qparams(QParams::ZpScale {
115            zero_point: qparams[4].cast_to_scalar::<i32>()?,
116            scale: qparams[5].cast_to_scalar::<f32>()?,
117        }))
118    } else {
119        None
120    };
121
122    let operating_dt = if ctx.ensure_strict_matmul_semantic {
123        let input_facts = model.node_input_facts(node.id)?;
124        let a_dt = input_facts[0].datum_type;
125        let b_dt = input_facts[1].datum_type;
126        let operating_dt = quantize_output.unwrap_or(op.operating_dt);
127        let allowed_dt = matmul_semantic_output_dt(&a_dt, &b_dt);
128
129        ensure!(
130            operating_dt == allowed_dt,
131            format!(
132                "Strict matmul semantic require operating_dt to be {allowed_dt:?} \
133                for (a: {a_dt:?}, b:{b_dt:?}) but got {:?}.",
134                op.operating_dt
135            )
136        );
137
138        None
139    } else {
140        Some(op.operating_dt)
141    };
142
143    wire = patch.wire_node(
144        node_name,
145        PrefixMatMul { transpose_a, transpose_b, transpose_c, quantize_output, operating_dt },
146        &wire,
147    )?;
148
149    for (ix, op) in c_transform.into_iter().enumerate() {
150        wire = patch.wire_node(format!("{node_name}.fix_c.{ix}"), op, &wire)?;
151    }
152    patch.shunt_outside(model, node.id.into(), wire[0])?;
153    Ok(Some(patch))
154}
155
156fn matmul_semantic_output_dt(a_dt: &DatumType, b_dt: &DatumType) -> DatumType {
157    if a_dt.is_number() {
158        *a_dt
159    } else if b_dt.is_number() {
160        *b_dt
161    } else {
162        f32::datum_type()
163    }
164}
165
166#[derive(Clone, Debug, Copy)]
167pub struct PrefixMatMul {
168    pub transpose_a: bool,
169    pub transpose_b: bool,
170    pub transpose_c: bool,
171    pub quantize_output: Option<DatumType>,
172    pub operating_dt: Option<DatumType>,
173}
174
175impl PrefixMatMul {
176    fn output_shape<D: DimLike + One>(&self, a: &[D], b: &[D]) -> TVec<D> {
177        let rank = a.len();
178        let mut output: TVec<D> = (0..rank - 2)
179            .map(|ix| if a[ix].is_one() { b[ix].clone() } else { a[ix].clone() })
180            .collect();
181        output.push(a[rank - 2 + self.transpose_a as usize].clone());
182        output.push(b[rank - 2 + !self.transpose_b as usize].clone());
183        if self.transpose_c {
184            output.swap(rank - 2, rank - 1);
185        }
186        output
187    }
188
189    fn mm<Acc: Datum + tract_ndarray::LinalgScalar>(
190        &self,
191        acc: &mut Tensor,
192        a: &Tensor,
193        b: &Tensor,
194    ) -> TractResult<()> {
195        use crate::ndarray::Dimension;
196        let casted_a = a.cast_to::<Acc>()?;
197        let a = casted_a.to_array_view::<Acc>()?;
198        let casted_b = b.cast_to::<Acc>()?;
199        let b = casted_b.to_array_view::<Acc>()?;
200        let mut c = acc.to_array_view_mut::<Acc>()?;
201        for prefix in tract_ndarray::indices(&c.shape()[..c.ndim() - 2]) {
202            let mut a = a.view();
203            let mut b = b.view();
204            let mut c = c.view_mut();
205            for &d in prefix.slice().iter() {
206                a.index_axis_inplace(tract_ndarray::Axis(0), d.min(a.shape()[0] - 1));
207                b.index_axis_inplace(tract_ndarray::Axis(0), d.min(b.shape()[0] - 1));
208                c.index_axis_inplace(tract_ndarray::Axis(0), d);
209            }
210            let a = a.into_dimensionality::<Ix2>().unwrap();
211            let b = b.into_dimensionality::<Ix2>().unwrap();
212            let mut c = c.into_dimensionality::<Ix2>().unwrap();
213            let a = if self.transpose_a { a.t() } else { a };
214            let b = if self.transpose_b { b.t() } else { b };
215            if self.transpose_c { c.assign(&b.t().dot(&a.t())) } else { c.assign(&a.dot(&b)) }
216        }
217        Ok(())
218    }
219}
220
221impl Op for PrefixMatMul {
222    fn name(&self) -> StaticName {
223        "PrefixMatMul".into()
224    }
225
226    fn info(&self) -> TractResult<Vec<String>> {
227        Ok(vec![format!(
228            "transpose_a: {} transpose_b: {} transpose_c: {} q: {:?}",
229            self.transpose_a, self.transpose_b, self.transpose_c, self.quantize_output
230        )])
231    }
232
233    op_as_typed_op!();
234}
235
236impl EvalOp for PrefixMatMul {
237    fn is_stateless(&self) -> bool {
238        true
239    }
240
241    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
242        let c_dt = self.operating_dt.unwrap_or_else(|| {
243            let a_dt = inputs[0].datum_type();
244            let b_dt = inputs[1].datum_type();
245            matmul_semantic_output_dt(&a_dt, &b_dt)
246        });
247
248        let inputs = dequant_inputs(c_dt, inputs)?;
249
250        let output_shape = self.output_shape(inputs[0].shape(), inputs[1].shape());
251
252        if let Some(qp) = self.quantize_output {
253            let mut acc = Tensor::zero_dt(i32::datum_type(), &output_shape)?;
254            let mut a_i32 = inputs[0].cast_to::<i32>()?.into_owned();
255            a_i32
256                .as_slice_mut::<i32>()?
257                .iter_mut()
258                .for_each(|x| *x -= inputs[0].datum_type().zp_scale().0);
259            let mut b_i32 = inputs[1].cast_to::<i32>()?.into_owned();
260            b_i32
261                .as_slice_mut::<i32>()?
262                .iter_mut()
263                .for_each(|x| *x -= inputs[1].datum_type().zp_scale().0);
264            self.mm::<i32>(&mut acc, &a_i32, &b_i32)?;
265            let scale = inputs[0].datum_type().zp_scale().1 * inputs[1].datum_type().zp_scale().1
266                / qp.zp_scale().1;
267            let scaler = Scaler::new(scale, tract_linalg::mmm::RoundingPolicy::Even);
268            acc.to_array_view_mut::<i32>()?.iter_mut().for_each(|x| *x = *x * scaler);
269            let mut c: Tensor = acc.cast_to_dt(qp.unquantized())?.into_owned();
270            unsafe { c.set_datum_type(qp) };
271            Ok(tvec!(c.into_tvalue()))
272        } else {
273            let mut c = Tensor::zero_dt(c_dt, &output_shape)?;
274            dispatch_floatlike!(Self::mm(c_dt)(self, &mut c, &inputs[0], &inputs[1]))?;
275            Ok(tvec!(c.into_tvalue()))
276        }
277    }
278}
279
280impl TypedOp for PrefixMatMul {
281    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
282        let [a, b] = inputs else {
283            bail!("Expects 2 inputs");
284        };
285        let a_shape = block_quant_aware_input_shape(a)?;
286        let b_shape = block_quant_aware_input_shape(b)?;
287        let dt = self
288            .quantize_output
289            .or(self.operating_dt)
290            .unwrap_or(matmul_semantic_output_dt(&a.datum_type, &b.datum_type));
291        Ok(tvec!(dt.fact(self.output_shape(&a_shape, &b_shape))))
292    }
293
294    as_op!();
295}
296
297#[cfg(test)]
298mod test {
299    use crate::ops::einsum::EinSum;
300
301    use super::*;
302    use proptest::collection::vec;
303    use proptest::prelude::*;
304    use proptest::test_runner::{TestCaseResult, TestRunner};
305    use tract_data::itertools::Itertools;
306
307    pub fn tensor(shape: &[usize]) -> BoxedStrategy<Tensor> {
308        let shape = shape.to_vec();
309        let len = shape.iter().product::<usize>();
310        vec((-10i8..=10i8).prop_map(|i| i as f32), len..=len)
311            .prop_map(move |vec| tensor1(&vec).into_shape(&shape).unwrap())
312            .boxed()
313    }
314
315    fn full_shapes(e: &AxesMapping) -> BoxedStrategy<(Vec<usize>, Vec<usize>)> {
316        let e = e.clone();
317        let inputs_axes = e
318            .iter_all_axes()
319            .filter(|axis| axis.inputs[0].len() + axis.inputs[1].len() > 0)
320            .cloned()
321            .collect_vec();
322        let dims = vec![2usize..6; inputs_axes.len()];
323        dims.prop_map(move |dims| {
324            let a: Vec<usize> = e
325                .axes(InOut::In(0))
326                .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
327                .collect_vec();
328            let b: Vec<usize> = e
329                .axes(InOut::In(1))
330                .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
331                .collect_vec();
332            (a, b)
333        })
334        .boxed()
335    }
336
337    fn test_expr(expr: &str) -> TestCaseResult {
338        let expr = expr.to_string();
339        let mut runner = TestRunner::default();
340        let axes: AxesMapping = expr.parse().unwrap();
341        fn is_k(axes: &AxesMapping, input: usize, position: usize) -> bool {
342            let axis = axes.axis((InOut::In(input), position)).unwrap();
343            axis.inputs[1 - input].len() == 1 && axis.outputs[0].len() == 0
344        }
345        fn is_disapearing_axis(axes: &AxesMapping, input: usize, position: usize) -> bool {
346            let axis = axes.axis((InOut::In(input), position)).unwrap();
347            axis.outputs[0].len() == 0
348        }
349        let cases = full_shapes(&axes)
350            .prop_flat_map(|(a, b)| {
351                (
352                    a.iter()
353                        .enumerate()
354                        .map(|(ix, d)| {
355                            if is_k(&axes, 0, ix) {
356                                prop_oneof![Just(*d)].boxed()
357                            } else if is_disapearing_axis(&axes, 0, ix) {
358                                Just(1).boxed()
359                            } else {
360                                prop_oneof![Just(1usize), Just(*d)].boxed()
361                            }
362                        })
363                        .collect_vec(),
364                    b.iter()
365                        .enumerate()
366                        .map(|(ix, d)| {
367                            if is_k(&axes, 1, ix) {
368                                prop_oneof![Just(*d)].boxed()
369                            } else if is_disapearing_axis(&axes, 1, ix) {
370                                Just(1).boxed()
371                            } else {
372                                prop_oneof![Just(1usize), Just(*d)].boxed()
373                            }
374                        })
375                        .collect_vec(),
376                )
377            })
378            .prop_flat_map(|(a_shape, b_shape)| (tensor(&a_shape), tensor(&b_shape)))
379            .prop_map(|(a, b)| EinSumProblem { expr: expr.clone(), a, b });
380        runner.run(&cases, |pb| pb.check().map_err(|e| TestCaseError::fail(e.to_string())))?;
381        Ok(())
382    }
383
384    #[derive(Debug, Clone, PartialEq)]
385    struct EinSumProblem {
386        expr: String,
387        a: Tensor,
388        b: Tensor,
389    }
390
391    impl EinSumProblem {
392        fn check(&self) -> TractResult<()> {
393            let mut model = TypedModel::default();
394            let sa = model.add_source("a", f32::fact(self.a.shape()))?;
395            let sb = model.add_source("b", f32::fact(self.b.shape()))?;
396            let einsum = model.wire_node(
397                "einsum",
398                EinSum::new(self.expr.parse().unwrap(), f32::datum_type()),
399                &[sa, sb],
400            )?;
401            model.set_output_outlets(&einsum)?;
402            let a = self.a.clone().into_tvalue();
403            let b = self.b.clone().into_tvalue();
404            let inputs = tvec!(a, b);
405            let reference = TypedRunnableModel::new(model.clone())?.run(inputs.clone())?.remove(0);
406            rewrite_einsum_to_prefix_matmul(&mut model, true)?;
407            assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
408            let test = TypedRunnableModel::new(model)?.run(inputs)?.remove(0);
409            reference.close_enough(&test, true)
410        }
411    }
412
413    #[rustfmt::skip] #[test] fn prop_mk_kn_mn() -> TestCaseResult { test_expr("mk,kn->mn") }
414    #[rustfmt::skip] #[test] fn prop_km_kn_mn() -> TestCaseResult { test_expr("km,kn->mn") }
415    #[rustfmt::skip] #[test] fn prop_mk_nk_mn() -> TestCaseResult { test_expr("mk,nk->mn") }
416    #[rustfmt::skip] #[test] fn prop_mk_kn_nm() -> TestCaseResult { test_expr("mk,kn->nm") }
417    #[rustfmt::skip] #[test] fn prop_k_kn_mn() -> TestCaseResult { test_expr("k,kn->mn") }
418    #[rustfmt::skip] #[test] fn prop_mk_k_mn() -> TestCaseResult { test_expr("mk,k->mn") }
419    #[rustfmt::skip] #[test] fn prop_m_n_mn() -> TestCaseResult { test_expr("m,n->mn") }
420    #[rustfmt::skip] #[test] fn prop_amk_akn_amn() -> TestCaseResult { test_expr("amk,akn->amn") }
421    #[rustfmt::skip] #[test] fn prop_mk_akn_amn() -> TestCaseResult { test_expr("mk,akn->amn") }
422    #[rustfmt::skip] #[test] fn prop_btgi_gih_tgh() -> TestCaseResult { test_expr("btgi,gih->tgh") }
423    #[rustfmt::skip] #[test] fn prop_tgi_gih_btgh() -> TestCaseResult { test_expr("tgi,gih->btgh") }
424
425    #[test]
426    fn k_kn_mn_0() -> TractResult<()> {
427        EinSumProblem {
428            expr: "k,kn->mn".to_string(),
429            a: tensor1(&[0f32, 0f32]),
430            b: tensor2(&[[0f32, 0.], [0., 0.]]),
431        }
432        .check()
433    }
434
435    #[test]
436    fn mk_k_mn_0() -> TractResult<()> {
437        EinSumProblem {
438            expr: "mk,k->mn".to_string(),
439            a: Tensor::zero::<f32>(&[2, 2]).unwrap(),
440            b: Tensor::zero::<f32>(&[2]).unwrap(),
441        }
442        .check()
443    }
444
445    #[test]
446    fn mk_k_mn_1() -> TractResult<()> {
447        EinSumProblem {
448            expr: "mk,k->mn".to_string(),
449            a: Tensor::zero::<f32>(&[1, 2]).unwrap(),
450            b: Tensor::zero::<f32>(&[2]).unwrap(),
451        }
452        .check()
453    }
454
455    #[test]
456    fn mk_kn_nm_0() -> TractResult<()> {
457        EinSumProblem {
458            expr: "mk,kn->mn".to_string(),
459            a: Tensor::zero::<f32>(&[3, 2]).unwrap(),
460            b: Tensor::zero::<f32>(&[2, 2]).unwrap(),
461        }
462        .check()
463    }
464
465    #[test]
466    fn amk_akn_amn_0() -> TractResult<()> {
467        EinSumProblem {
468            expr: "amk,akn->amn".to_string(),
469            a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
470            b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
471        }
472        .check()
473    }
474
475    #[test]
476    fn amk_akn_amn_1() -> TractResult<()> {
477        EinSumProblem {
478            expr: "amk,akn->amn".to_string(),
479            a: Tensor::zero::<f32>(&[2, 1, 2]).unwrap(),
480            b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
481        }
482        .check()
483    }
484
485    #[test]
486    fn amk_akn_amn_2() -> TractResult<()> {
487        EinSumProblem {
488            expr: "amk,akn->amn".to_string(),
489            a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
490            b: Tensor::zero::<f32>(&[2, 2, 2]).unwrap(),
491        }
492        .check()
493    }
494
495    #[test]
496    fn amk_akn_amn_3() -> TractResult<()> {
497        EinSumProblem {
498            expr: "amk,akn->amn".to_string(),
499            a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
500            b: Tensor::zero::<f32>(&[2, 2, 1]).unwrap(),
501        }
502        .check()
503    }
504
505    #[test]
506    fn km_anbck_bmn_0() -> TractResult<()> {
507        EinSumProblem {
508            expr: "km,anbck->bmn".to_string(),
509            a: Tensor::zero::<f32>(&[2, 1]).unwrap(),
510            b: Tensor::zero::<f32>(&[1, 1, 1, 1, 2]).unwrap(),
511        }
512        .check()
513    }
514
515    #[test]
516    fn q() -> TractResult<()> {
517        let qp = QParams::ZpScale { zero_point: 0, scale: 0.1 };
518        let op = EinSum {
519            axes: "mk,kn,m,,,,,,->mn".parse()?,
520            operating_dt: i32::datum_type(),
521            q_params: Some(DatumType::QI8(qp)),
522        };
523        let mut model = TypedModelPatch::default();
524        let inputs = [
525            model.add_source("a", DatumType::QI8(qp).fact([3, 2]))?,
526            model.add_source("b", DatumType::QI8(qp).fact([2, 4]))?,
527            model.add_source("bias", i32::datum_type().fact([3]))?,
528            model.add_const("a0", tensor0(qp.zp_scale().0))?,
529            model.add_const("a_scale", tensor0(qp.zp_scale().1))?,
530            model.add_const("b0", tensor0(qp.zp_scale().0))?,
531            model.add_const("b_scale", tensor0(qp.zp_scale().1))?,
532            model.add_const("c0", tensor0(qp.zp_scale().0))?,
533            model.add_const("c_scale", tensor0(qp.zp_scale().1))?,
534        ];
535        let wire = model.wire_node("einsum", op.clone(), &inputs)?;
536        model.set_output_outlets(&wire)?;
537        rewrite_einsum_to_prefix_matmul(&mut model, true)?;
538        assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
539        Ok(())
540    }
541}