Skip to main content

tract_core/ops/einsum/
prefix_matmul.rs

1use tract_data::itertools::Itertools;
2use tract_linalg::Scaler;
3use tract_ndarray::Ix2;
4use tract_num_traits::One;
5
6use super::einsum_matmul::EinSumMatMul;
7use super::eval::dequant_inputs;
8use crate::internal::*;
9use crate::ops::einsum::block_quant_aware_input_shape;
10use crate::ops::konst::Const;
11
12pub fn rewrite_einsum_to_prefix_matmul(model: &mut TypedModel) -> TractResult<()> {
13    super::einsum_matmul::detect_all(model)?;
14    Rewriter::default().with_rule_for("einsum-to-prefix-matmul", rule).rewrite(&(), model)
15}
16
17fn rule(
18    _ctx: &(),
19    model: &TypedModel,
20    node: &TypedNode,
21    node_name: &str,
22    op: &EinSumMatMul,
23) -> TractResult<Option<TypedModelPatch>> {
24    // F: 2 inputs
25    // Q: 9 inputs
26    if !((op.q_params.is_none() && node.inputs.len() == 2)
27        || (op.q_params.is_some() && node.inputs.len() == 9))
28    {
29        return Ok(None);
30    }
31    if op.q_params.is_some()
32        && model.node_input_facts(node.id)?.iter().skip(3).any(|i| i.konst.is_none())
33    {
34        return Ok(None);
35    }
36    let prefix: String = op
37        .axes
38        .iter_all_axes()
39        .filter(|a| ![op.m_axis, op.k_axis, op.n_axis].contains(&a.repr))
40        .map(|a| a.repr)
41        .collect();
42    let mut patch = TypedModelPatch::default();
43    let inputs = patch.taps(model, &node.inputs)?;
44    let mut wire = tvec!(inputs[0], inputs[1]);
45
46    let (m, k, n) = (op.m_axis, op.k_axis, op.n_axis);
47    let a_order_es: String = op.axes.axes(InOut::In(0)).map(|a| a.repr).collect();
48    let a_order_mm = format!("{prefix}{m}{k}");
49    let a_order_mm_t = format!("{prefix}{k}{m}");
50    let a_transform = format!("{a_order_es}->{a_order_mm}")
51        .parse::<AxesMapping>()?
52        .translate_to_axis_ops()?;
53    let a_transform_t = format!("{a_order_es}->{a_order_mm_t}")
54        .parse::<AxesMapping>()?
55        .translate_to_axis_ops()?;
56    let transpose_a = a_transform.len() > a_transform_t.len();
57    let a_transform = if transpose_a { a_transform_t } else { a_transform };
58    let name = format!("{node_name}.fix_a");
59    for op in a_transform {
60        wire[0] = patch.wire_node(&name, op, &[wire[0]])?[0];
61    }
62    // terrible hack to maintain opaque fact through eager propatagation of constant through the
63    // axes transformation
64    if let Some(op) = patch.node_mut(wire[0].node).op_as_mut::<Const>() {
65        *op = Const::new_with_opt_opaque_fact(
66            op.val().clone(),
67            model.outlet_fact(node.inputs[0])?.opaque_fact.clone(),
68        )?;
69    }
70    patch
71        .outlet_fact_mut(wire[0])?
72        .opaque_fact
73        .clone_from(&model.outlet_fact(node.inputs[0])?.opaque_fact);
74    // end of hack
75
76    let b_order_es: String = op.axes.axes(InOut::In(1)).map(|a| a.repr).collect();
77    let b_order_mm = format!("{prefix}{k}{n}");
78    let b_order_mm_t = format!("{prefix}{n}{k}");
79    let b_transform = format!("{b_order_es}->{b_order_mm}")
80        .parse::<AxesMapping>()?
81        .translate_to_axis_ops()?;
82    let b_transform_t = format!("{b_order_es}->{b_order_mm_t}")
83        .parse::<AxesMapping>()?
84        .translate_to_axis_ops()?;
85    let transpose_b = b_transform.len() > b_transform_t.len();
86    let b_transform = if transpose_b { b_transform_t } else { b_transform };
87    let name = format!("{node_name}.fix_b");
88    for op in b_transform {
89        wire[1] = patch.wire_node(&name, op, &[wire[1]])?[0];
90    }
91
92    let c_order_es: String = op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
93    let c_order_mm = format!("{prefix}{m}{n}");
94    let c_order_mm_t = format!("{prefix}{n}{m}");
95    let c_transform = format!("{c_order_mm}->{c_order_es}")
96        .parse::<AxesMapping>()?
97        .translate_to_axis_ops()?;
98    let c_transform_t = format!("{c_order_mm_t}->{c_order_es}")
99        .parse::<AxesMapping>()?
100        .translate_to_axis_ops()?;
101    let transpose_c = c_transform.len() > c_transform_t.len();
102    let c_transform = if transpose_c { c_transform_t } else { c_transform };
103    let quantize_output = if let Some(qp) = op.q_params {
104        let qparams: Vec<&Tensor> = inputs[3..9]
105            .iter()
106            .map(|f| {
107                patch
108                    .outlet_fact(*f)?
109                    .konst
110                    .as_deref()
111                    .context("Can only translate fixed scalar quantization")
112            })
113            .try_collect()?;
114        Some(qp.with_qparams(QParams::ZpScale {
115            zero_point: qparams[4].cast_to_scalar::<i32>()?,
116            scale: qparams[5].cast_to_scalar::<f32>()?,
117        }))
118    } else {
119        None
120    };
121    wire = patch.wire_node(
122        node_name,
123        PrefixMatMul { transpose_a, transpose_b, transpose_c, quantize_output },
124        &wire,
125    )?;
126
127    for (ix, op) in c_transform.into_iter().enumerate() {
128        wire = patch.wire_node(format!("{node_name}.fix_c.{ix}"), op, &wire)?;
129    }
130    patch.shunt_outside(model, node.id.into(), wire[0])?;
131    Ok(Some(patch))
132}
133
134#[derive(Clone, Debug, Copy, Default)]
135pub struct PrefixMatMul {
136    pub transpose_a: bool,
137    pub transpose_b: bool,
138    pub transpose_c: bool,
139    pub quantize_output: Option<DatumType>,
140}
141
142impl PrefixMatMul {
143    fn output_shape<D: DimLike + One>(&self, a: &[D], b: &[D]) -> TVec<D> {
144        let rank = a.len();
145        let mut output: TVec<D> = (0..rank - 2)
146            .map(|ix| if a[ix].is_one() { b[ix].clone() } else { a[ix].clone() })
147            .collect();
148        output.push(a[rank - 2 + self.transpose_a as usize].clone());
149        output.push(b[rank - 2 + !self.transpose_b as usize].clone());
150        if self.transpose_c {
151            output.swap(rank - 2, rank - 1);
152        }
153        output
154    }
155
156    fn mm<Acc: Datum + tract_ndarray::LinalgScalar>(
157        &self,
158        acc: &mut Tensor,
159        a: &Tensor,
160        b: &Tensor,
161    ) -> TractResult<()> {
162        use crate::ndarray::Dimension;
163        let a = a.to_array_view::<Acc>()?;
164        let b = b.to_array_view::<Acc>()?;
165        let mut c = acc.to_array_view_mut::<Acc>()?;
166        for prefix in tract_ndarray::indices(&c.shape()[..c.ndim() - 2]) {
167            let mut a = a.view();
168            let mut b = b.view();
169            let mut c = c.view_mut();
170            for &d in prefix.slice().iter() {
171                a.index_axis_inplace(tract_ndarray::Axis(0), d.min(a.shape()[0] - 1));
172                b.index_axis_inplace(tract_ndarray::Axis(0), d.min(b.shape()[0] - 1));
173                c.index_axis_inplace(tract_ndarray::Axis(0), d);
174            }
175            let a = a.into_dimensionality::<Ix2>().unwrap();
176            let b = b.into_dimensionality::<Ix2>().unwrap();
177            let mut c = c.into_dimensionality::<Ix2>().unwrap();
178            let a = if self.transpose_a { a.t() } else { a };
179            let b = if self.transpose_b { b.t() } else { b };
180            if self.transpose_c {
181                c.assign(&b.t().dot(&a.t()))
182            } else {
183                c.assign(&a.dot(&b))
184            }
185        }
186        Ok(())
187    }
188}
189
190impl Op for PrefixMatMul {
191    fn name(&self) -> StaticName {
192        "PrefixMatMul".into()
193    }
194
195    fn info(&self) -> TractResult<Vec<String>> {
196        Ok(vec![format!(
197            "transpose_a: {} transpose_b: {} transpose_c: {} q: {:?}",
198            self.transpose_a, self.transpose_b, self.transpose_c, self.quantize_output
199        )])
200    }
201
202    op_as_typed_op!();
203}
204
205impl EvalOp for PrefixMatMul {
206    fn is_stateless(&self) -> bool {
207        true
208    }
209
210    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
211        let c_dt = if inputs[0].datum_type().is_number() {
212            inputs[0].datum_type()
213        } else if inputs[1].datum_type().is_number() {
214            inputs[1].datum_type()
215        } else {
216            f32::datum_type()
217        };
218        let inputs = dequant_inputs(c_dt, inputs)?;
219
220        let output_shape = self.output_shape(inputs[0].shape(), inputs[1].shape());
221
222        if let Some(qp) = self.quantize_output {
223            let mut acc = Tensor::zero_dt(i32::datum_type(), &output_shape)?;
224            let mut a_i32 = inputs[0].cast_to::<i32>()?.into_owned();
225            a_i32
226                .as_slice_mut::<i32>()?
227                .iter_mut()
228                .for_each(|x| *x -= inputs[0].datum_type().zp_scale().0);
229            let mut b_i32 = inputs[1].cast_to::<i32>()?.into_owned();
230            b_i32
231                .as_slice_mut::<i32>()?
232                .iter_mut()
233                .for_each(|x| *x -= inputs[1].datum_type().zp_scale().0);
234            self.mm::<i32>(&mut acc, &a_i32, &b_i32)?;
235            let scale = inputs[0].datum_type().zp_scale().1 * inputs[1].datum_type().zp_scale().1
236                / qp.zp_scale().1;
237            let scaler = Scaler::new(scale, tract_linalg::mmm::RoundingPolicy::Even);
238            acc.to_array_view_mut::<i32>()?.iter_mut().for_each(|x| *x = *x * scaler);
239            let mut c: Tensor = acc.cast_to_dt(qp.unquantized())?.into_owned();
240            unsafe { c.set_datum_type(qp) };
241            Ok(tvec!(c.into_tvalue()))
242        } else {
243            let mut c = Tensor::zero_dt(c_dt, &output_shape)?;
244            dispatch_floatlike!(Self::mm(c_dt)(self, &mut c, &inputs[0], &inputs[1]))?;
245            Ok(tvec!(c.into_tvalue()))
246        }
247    }
248}
249
250impl TypedOp for PrefixMatMul {
251    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
252        let [a, b] = inputs else {
253            bail!("Expects 2 inputs");
254        };
255        let a_shape = block_quant_aware_input_shape(inputs[0])?;
256        let b_shape = block_quant_aware_input_shape(inputs[1])?;
257        let dt = self.quantize_output.unwrap_or(if a.datum_type.is_number() {
258            a.datum_type
259        } else {
260            b.datum_type
261        });
262        Ok(tvec!(dt.fact(self.output_shape(&a_shape, &b_shape))))
263    }
264
265    as_op!();
266}
267
268#[cfg(test)]
269mod test {
270    use crate::ops::einsum::EinSum;
271
272    use super::*;
273    use proptest::collection::vec;
274    use proptest::prelude::*;
275    use proptest::test_runner::{TestCaseResult, TestRunner};
276    use tract_data::itertools::Itertools;
277
278    pub fn tensor(shape: &[usize]) -> BoxedStrategy<Tensor> {
279        let shape = shape.to_vec();
280        let len = shape.iter().product::<usize>();
281        vec((-10i8..=10i8).prop_map(|i| i as f32), len..=len)
282            .prop_map(move |vec| tensor1(&vec).into_shape(&shape).unwrap())
283            .boxed()
284    }
285
286    fn full_shapes(e: &AxesMapping) -> BoxedStrategy<(Vec<usize>, Vec<usize>)> {
287        let e = e.clone();
288        let inputs_axes = e
289            .iter_all_axes()
290            .filter(|axis| axis.inputs[0].len() + axis.inputs[1].len() > 0)
291            .cloned()
292            .collect_vec();
293        let dims = vec![2usize..6; inputs_axes.len()];
294        dims.prop_map(move |dims| {
295            let a: Vec<usize> = e
296                .axes(InOut::In(0))
297                .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
298                .collect_vec();
299            let b: Vec<usize> = e
300                .axes(InOut::In(1))
301                .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
302                .collect_vec();
303            (a, b)
304        })
305        .boxed()
306    }
307
308    fn test_expr(expr: &str) -> TestCaseResult {
309        let expr = expr.to_string();
310        let mut runner = TestRunner::default();
311        let axes: AxesMapping = expr.parse().unwrap();
312        fn is_k(axes: &AxesMapping, input: usize, position: usize) -> bool {
313            let axis = axes.axis((InOut::In(input), position)).unwrap();
314            axis.inputs[1 - input].len() == 1 && axis.outputs[0].len() == 0
315        }
316        fn is_disapearing_axis(axes: &AxesMapping, input: usize, position: usize) -> bool {
317            let axis = axes.axis((InOut::In(input), position)).unwrap();
318            axis.outputs[0].len() == 0
319        }
320        let cases = full_shapes(&axes)
321            .prop_flat_map(|(a, b)| {
322                (
323                    a.iter()
324                        .enumerate()
325                        .map(|(ix, d)| {
326                            if is_k(&axes, 0, ix) {
327                                prop_oneof![Just(*d)].boxed()
328                            } else if is_disapearing_axis(&axes, 0, ix) {
329                                Just(1).boxed()
330                            } else {
331                                prop_oneof![Just(1usize), Just(*d)].boxed()
332                            }
333                        })
334                        .collect_vec(),
335                    b.iter()
336                        .enumerate()
337                        .map(|(ix, d)| {
338                            if is_k(&axes, 1, ix) {
339                                prop_oneof![Just(*d)].boxed()
340                            } else if is_disapearing_axis(&axes, 1, ix) {
341                                Just(1).boxed()
342                            } else {
343                                prop_oneof![Just(1usize), Just(*d)].boxed()
344                            }
345                        })
346                        .collect_vec(),
347                )
348            })
349            .prop_flat_map(|(a_shape, b_shape)| (tensor(&a_shape), tensor(&b_shape)))
350            .prop_map(|(a, b)| EinSumProblem { expr: expr.clone(), a, b });
351        runner.run(&cases, |pb| pb.check().map_err(|e| TestCaseError::fail(e.to_string())))?;
352        Ok(())
353    }
354
355    #[derive(Debug, Clone, PartialEq)]
356    struct EinSumProblem {
357        expr: String,
358        a: Tensor,
359        b: Tensor,
360    }
361
362    impl EinSumProblem {
363        fn check(&self) -> TractResult<()> {
364            let mut model = TypedModel::default();
365            let sa = model.add_source("a", f32::fact(self.a.shape())).unwrap();
366            let sb = model.add_source("b", f32::fact(self.b.shape())).unwrap();
367            let einsum = model
368                .wire_node(
369                    "einsum",
370                    EinSum::new(self.expr.parse().unwrap(), f32::datum_type()),
371                    &[sa, sb],
372                )
373                .unwrap();
374            model.set_output_outlets(&einsum).unwrap();
375            let a = self.a.clone().into_tvalue();
376            let b = self.b.clone().into_tvalue();
377            let inputs = tvec!(a, b);
378            let reference =
379                TypedRunnableModel::new(&model).unwrap().run(inputs.clone()).unwrap().remove(0);
380            rewrite_einsum_to_prefix_matmul(&mut model)?;
381            assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
382            let test = TypedRunnableModel::new(&model).unwrap().run(inputs).unwrap().remove(0);
383            reference.close_enough(&test, true).unwrap();
384            Ok(())
385        }
386    }
387
388    #[rustfmt::skip] #[test] fn prop_mk_kn_mn() -> TestCaseResult { test_expr("mk,kn->mn") }
389    #[rustfmt::skip] #[test] fn prop_km_kn_mn() -> TestCaseResult { test_expr("km,kn->mn") }
390    #[rustfmt::skip] #[test] fn prop_mk_nk_mn() -> TestCaseResult { test_expr("mk,nk->mn") }
391    #[rustfmt::skip] #[test] fn prop_mk_kn_nm() -> TestCaseResult { test_expr("mk,kn->nm") }
392    #[rustfmt::skip] #[test] fn prop_k_kn_mn() -> TestCaseResult { test_expr("k,kn->mn") }
393    #[rustfmt::skip] #[test] fn prop_mk_k_mn() -> TestCaseResult { test_expr("mk,k->mn") }
394    #[rustfmt::skip] #[test] fn prop_m_n_mn() -> TestCaseResult { test_expr("m,n->mn") }
395    #[rustfmt::skip] #[test] fn prop_amk_akn_amn() -> TestCaseResult { test_expr("amk,akn->amn") }
396    #[rustfmt::skip] #[test] fn prop_mk_akn_amn() -> TestCaseResult { test_expr("mk,akn->amn") }
397    #[rustfmt::skip] #[test] fn prop_btgi_gih_tgh() -> TestCaseResult { test_expr("btgi,gih->tgh") }
398    #[rustfmt::skip] #[test] fn prop_tgi_gih_btgh() -> TestCaseResult { test_expr("tgi,gih->btgh") }
399
400    #[test]
401    fn k_kn_mn_0() -> TractResult<()> {
402        EinSumProblem {
403            expr: "k,kn->mn".to_string(),
404            a: tensor1(&[0f32, 0f32]),
405            b: tensor2(&[[0f32, 0.], [0., 0.]]),
406        }
407        .check()
408    }
409
410    #[test]
411    fn mk_k_mn_0() -> TractResult<()> {
412        EinSumProblem {
413            expr: "mk,k->mn".to_string(),
414            a: Tensor::zero::<f32>(&[2, 2]).unwrap(),
415            b: Tensor::zero::<f32>(&[2]).unwrap(),
416        }
417        .check()
418    }
419
420    #[test]
421    fn mk_k_mn_1() -> TractResult<()> {
422        EinSumProblem {
423            expr: "mk,k->mn".to_string(),
424            a: Tensor::zero::<f32>(&[1, 2]).unwrap(),
425            b: Tensor::zero::<f32>(&[2]).unwrap(),
426        }
427        .check()
428    }
429
430    #[test]
431    fn mk_kn_nm_0() -> TractResult<()> {
432        EinSumProblem {
433            expr: "mk,kn->mn".to_string(),
434            a: Tensor::zero::<f32>(&[3, 2]).unwrap(),
435            b: Tensor::zero::<f32>(&[2, 2]).unwrap(),
436        }
437        .check()
438    }
439
440    #[test]
441    fn amk_akn_amn_0() -> TractResult<()> {
442        EinSumProblem {
443            expr: "amk,akn->amn".to_string(),
444            a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
445            b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
446        }
447        .check()
448    }
449
450    #[test]
451    fn amk_akn_amn_1() -> TractResult<()> {
452        EinSumProblem {
453            expr: "amk,akn->amn".to_string(),
454            a: Tensor::zero::<f32>(&[2, 1, 2]).unwrap(),
455            b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
456        }
457        .check()
458    }
459
460    #[test]
461    fn amk_akn_amn_2() -> TractResult<()> {
462        EinSumProblem {
463            expr: "amk,akn->amn".to_string(),
464            a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
465            b: Tensor::zero::<f32>(&[2, 2, 2]).unwrap(),
466        }
467        .check()
468    }
469
470    #[test]
471    fn amk_akn_amn_3() -> TractResult<()> {
472        EinSumProblem {
473            expr: "amk,akn->amn".to_string(),
474            a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
475            b: Tensor::zero::<f32>(&[2, 2, 1]).unwrap(),
476        }
477        .check()
478    }
479
480    #[test]
481    fn km_anbck_bmn_0() -> TractResult<()> {
482        EinSumProblem {
483            expr: "km,anbck->bmn".to_string(),
484            a: Tensor::zero::<f32>(&[2, 1]).unwrap(),
485            b: Tensor::zero::<f32>(&[1, 1, 1, 1, 2]).unwrap(),
486        }
487        .check()
488    }
489
490    #[test]
491    fn q() -> TractResult<()> {
492        let qp = QParams::ZpScale { zero_point: 0, scale: 0.1 };
493        let op = EinSum {
494            axes: "mk,kn,m,,,,,,->mn".parse()?,
495            operating_dt: i32::datum_type(),
496            q_params: Some(DatumType::QI8(qp)),
497        };
498        let mut model = TypedModelPatch::default();
499        let inputs = [
500            model.add_source("a", DatumType::QI8(qp).fact([3, 2]))?,
501            model.add_source("b", DatumType::QI8(qp).fact([2, 4]))?,
502            model.add_source("bias", i32::datum_type().fact([3]))?,
503            model.add_const("a0", tensor0(qp.zp_scale().0))?,
504            model.add_const("a_scale", tensor0(qp.zp_scale().1))?,
505            model.add_const("b0", tensor0(qp.zp_scale().0))?,
506            model.add_const("b_scale", tensor0(qp.zp_scale().1))?,
507            model.add_const("c0", tensor0(qp.zp_scale().0))?,
508            model.add_const("c_scale", tensor0(qp.zp_scale().1))?,
509        ];
510        let wire = model.wire_node("einsum", op.clone(), &inputs)?;
511        model.set_output_outlets(&wire)?;
512        rewrite_einsum_to_prefix_matmul(&mut model)?;
513        assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
514        Ok(())
515    }
516}