tract_hir/ops/
matmul.rs

1use crate::infer::*;
2use crate::internal::*;
3
4use tract_core::ops::einsum::EinSum;
5use tract_core::tract_data::itertools::Itertools;
6
7#[derive(Debug, Clone, Default, Hash)]
8pub struct MatMulInference {
9    pub a_trans: bool,
10    pub b_trans: bool,
11    pub c_trans: bool,
12}
13
14impl MatMulInference {
15    pub fn with_a_trans(self, a_trans: bool) -> MatMulInference {
16        MatMulInference { a_trans, ..self }
17    }
18
19    pub fn with_b_trans(self, b_trans: bool) -> MatMulInference {
20        MatMulInference { b_trans, ..self }
21    }
22
23    pub fn with_c_trans(self, c_trans: bool) -> MatMulInference {
24        MatMulInference { c_trans, ..self }
25    }
26}
27
28impl Expansion for MatMulInference {
29    fn name(&self) -> StaticName {
30        "MatMulInference".into()
31    }
32
33    fn rules<'r, 'p: 'r, 's: 'r>(
34        &'s self,
35        s: &mut Solver<'r>,
36        inputs: &'p [TensorProxy],
37        outputs: &'p [TensorProxy],
38    ) -> InferenceResult {
39        check_input_arity(inputs, 2)?;
40        check_output_arity(outputs, 1)?;
41        s.equals(&inputs[0].datum_type, &inputs[1].datum_type)?;
42        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
43        s.given_2(&inputs[0].shape, &inputs[1].shape, move |s, ashape, bshape| {
44            let (_, _, _, cshape) =
45                compute_shapes(ashape, bshape, self.a_trans, self.b_trans, self.c_trans)?;
46            s.equals(&outputs[0].shape, cshape)
47        })?;
48        Ok(())
49    }
50
51    fn wire(
52        &self,
53        prefix: &str,
54        target: &mut TypedModel,
55        inputs: &[OutletId],
56    ) -> TractResult<TVec<OutletId>> {
57        let a_rank = target.outlet_fact(inputs[0])?.rank();
58        let b_rank = target.outlet_fact(inputs[1])?.rank();
59        ensure!(a_rank > 1 || b_rank > 1);
60        let mk = if self.a_trans { "km" } else { "mk" };
61        let kn = if self.b_trans { "nk" } else { "kn" };
62        let mn = if self.c_trans { "nm" } else { "mn" };
63        let axes: AxesMapping = if a_rank == 1 {
64            let prefix: String = ('a'..).take(b_rank - 2).collect();
65            format!("k,{prefix}{kn}->{prefix}n").parse()?
66        } else if b_rank == 1 {
67            let prefix: String = ('a'..).take(a_rank - 2).collect();
68            format!("{prefix}{mk},k->{prefix}m").parse()?
69        } else {
70            let c_rank = b_rank.max(a_rank);
71            let a_prefix: String =
72                ('a'..).take(c_rank - 2).skip(b_rank.saturating_sub(a_rank)).collect();
73            let b_prefix: String =
74                ('a'..).take(c_rank - 2).skip(a_rank.saturating_sub(b_rank)).collect();
75            let c_prefix: String = ('a'..).take(c_rank - 2).collect();
76            format!("{a_prefix}{mk},{b_prefix}{kn}->{c_prefix}{mn}").parse()?
77        };
78        let dt = target.outlet_fact(inputs[0])?.datum_type;
79        target.wire_node(prefix, EinSum { axes, operating_dt: dt, q_params: None }, inputs)
80    }
81}
82
83#[allow(clippy::type_complexity)]
84pub fn compute_shapes<D: DimLike>(
85    mut ashape: TVec<D>,
86    mut bshape: TVec<D>,
87    a_trans: bool,
88    b_trans: bool,
89    c_trans: bool,
90) -> TractResult<(TVec<D>, TVec<D>, TVec<D>, TVec<D>)> {
91    let mut implicit_m = false;
92    let mut implicit_n = false;
93    if ashape.len() < 2 {
94        implicit_m = true;
95        ashape.insert(a_trans as usize, D::one());
96    }
97    if bshape.len() < 2 {
98        implicit_n = true;
99        bshape.insert(!b_trans as usize, D::one());
100    }
101    while ashape.len() < bshape.len() {
102        ashape.insert(0, D::one());
103    }
104    while bshape.len() < ashape.len() {
105        bshape.insert(0, D::one());
106    }
107    let c_bc_shape_prefix = tract_core::broadcast::multi_broadcast(&[
108        &ashape[..(ashape.len() - 2)],
109        &bshape[..(bshape.len() - 2)],
110    ])?;
111    let mut c_bc_shape: TVec<D> = c_bc_shape_prefix;
112    let (mut m, mut ka) = (ashape[ashape.len() - 2].clone(), ashape[ashape.len() - 1].clone());
113    let (mut kb, mut n) = (bshape[bshape.len() - 2].clone(), bshape[bshape.len() - 1].clone());
114    if a_trans {
115        std::mem::swap(&mut m, &mut ka);
116    }
117    if b_trans {
118        std::mem::swap(&mut kb, &mut n);
119    }
120    if !ka.compatible_with(&kb) {
121        bail!(
122            "Inconsistent matmul: a: {} b: {}, a_trans: {} b_trans: {} c_trans: {}",
123            ashape.iter().join(","),
124            bshape.iter().join(","),
125            a_trans,
126            b_trans,
127            c_trans
128        );
129    }
130    let mut c_shape_final = c_bc_shape.clone();
131    if c_trans {
132        c_bc_shape.push(n.clone());
133        c_bc_shape.push(m.clone());
134        if !implicit_n {
135            c_shape_final.push(n.clone());
136        }
137        if !implicit_m {
138            c_shape_final.push(m.clone());
139        }
140    } else {
141        c_bc_shape.push(m.clone());
142        c_bc_shape.push(n.clone());
143        if !implicit_m {
144            c_shape_final.push(m.clone());
145        }
146        if !implicit_n {
147            c_shape_final.push(n.clone());
148        }
149    }
150    Ok((ashape, bshape, c_bc_shape, c_shape_final))
151}