steelix/ir/ops/math/
gemm.rs

1use std::borrow::Cow;
2
3use crate::ir::ops::shape::multi_broadcast;
4use crate::prelude::*;
5use anyhow::bail;
6use smallvec::smallvec;
7use steelix_onnx::onnx_pb;
8
9#[derive(Debug, Clone)]
10pub struct Gemm {
11    #[allow(dead_code)]
12    trans_a: usize,
13    #[allow(dead_code)]
14    trans_b: usize,
15}
16
17impl Gemm {
18    fn compute_cost(
19        &self,
20        a_shape: Shape,
21        b_shape: Shape,
22        ab_shape: Shape,
23        c_shape: Shape,
24    ) -> OpCost {
25        let m = a_shape[0];
26        let n = a_shape[1];
27        let p = b_shape[1];
28        let ab_flops = m * n * (2 * p - 1);
29        let ab_c_flops = ab_shape[0] * ab_shape[1] * (2 * c_shape[1] - 1);
30        OpCost {
31            flops: ab_flops + ab_c_flops,
32            ..Default::default()
33        }
34    }
35}
36
37impl Op for Gemm {
38    fn name(&self) -> Cow<str> {
39        "Gemm".into()
40    }
41
42    fn op_group(&self) -> OpGroup {
43        OpGroup::Transform
44    }
45
46    //TODO: support transpose
47    fn realize(&self, providers: PVec) -> anyhow::Result<RealizedOp> {
48        validate_providers(&providers, 2, 3, &self.name())?;
49
50        let a = &providers[0];
51        let b = &providers[1];
52        let c = if providers.len() == 2 {
53            Tensor::new(a.dt, shape!(1)).into_arc_tensor()
54        } else {
55            providers[2].clone()
56        };
57
58        let a_shape = &a.shape;
59        let b_shape = &b.shape;
60
61        let matching_dim = |a_shape: &Shape, b_shape: &Shape| -> anyhow::Result<usize> {
62            for i in 0..a_shape.len() {
63                if a_shape[i] == b_shape[i] {
64                    return Ok(i);
65                }
66            }
67            bail!(
68                "GEMM: No equal dimension found in {:?} and {:?}",
69                a_shape,
70                b_shape
71            );
72        };
73
74        let ab_shape = {
75            let mut ab_shape = smallvec![0;2];
76            let mut a_shape = a_shape.clone();
77            let mut b_shape = b_shape.clone();
78            let a_dim = matching_dim(&a_shape, &b_shape)?;
79            let b_dim = matching_dim(&b_shape, &a_shape)?;
80            a_shape.remove(a_dim);
81            b_shape.remove(b_dim);
82            ab_shape[0] = a_shape[0];
83            ab_shape[1] = b_shape[0];
84            Shape(ab_shape)
85        };
86
87        let c_shape = multi_broadcast(&[ab_shape.clone(), c.shape.clone()])
88            .expect("Could not broadcast C -> A*B in GEMM");
89
90        let res = Tensor::new(providers[0].dt, ab_shape.clone());
91
92        Ok(RealizedOp {
93            cost: self.compute_cost(a_shape.clone(), b_shape.clone(), ab_shape, c_shape),
94            outputs: smallvec![res.into_arc_tensor()],
95        })
96    }
97}
98
99pub fn build_gemm(proto: &onnx_pb::NodeProto) -> Result<BoxOp, anyhow::Error> {
100    let trans_a = proto.get_attribute("transA", Some(0))? as usize;
101    let trans_b = proto.get_attribute("transB", Some(0))? as usize;
102    Ok(Box::new(Gemm { trans_a, trans_b }) as BoxOp)
103}