steelix/ir/ops/math/
gemm.rs1use std::borrow::Cow;
2
3use crate::ir::ops::shape::multi_broadcast;
4use crate::prelude::*;
5use anyhow::bail;
6use smallvec::smallvec;
7use steelix_onnx::onnx_pb;
8
9#[derive(Debug, Clone)]
10pub struct Gemm {
11 #[allow(dead_code)]
12 trans_a: usize,
13 #[allow(dead_code)]
14 trans_b: usize,
15}
16
17impl Gemm {
18 fn compute_cost(
19 &self,
20 a_shape: Shape,
21 b_shape: Shape,
22 ab_shape: Shape,
23 c_shape: Shape,
24 ) -> OpCost {
25 let m = a_shape[0];
26 let n = a_shape[1];
27 let p = b_shape[1];
28 let ab_flops = m * n * (2 * p - 1);
29 let ab_c_flops = ab_shape[0] * ab_shape[1] * (2 * c_shape[1] - 1);
30 OpCost {
31 flops: ab_flops + ab_c_flops,
32 ..Default::default()
33 }
34 }
35}
36
37impl Op for Gemm {
38 fn name(&self) -> Cow<str> {
39 "Gemm".into()
40 }
41
42 fn op_group(&self) -> OpGroup {
43 OpGroup::Transform
44 }
45
46 fn realize(&self, providers: PVec) -> anyhow::Result<RealizedOp> {
48 validate_providers(&providers, 2, 3, &self.name())?;
49
50 let a = &providers[0];
51 let b = &providers[1];
52 let c = if providers.len() == 2 {
53 Tensor::new(a.dt, shape!(1)).into_arc_tensor()
54 } else {
55 providers[2].clone()
56 };
57
58 let a_shape = &a.shape;
59 let b_shape = &b.shape;
60
61 let matching_dim = |a_shape: &Shape, b_shape: &Shape| -> anyhow::Result<usize> {
62 for i in 0..a_shape.len() {
63 if a_shape[i] == b_shape[i] {
64 return Ok(i);
65 }
66 }
67 bail!(
68 "GEMM: No equal dimension found in {:?} and {:?}",
69 a_shape,
70 b_shape
71 );
72 };
73
74 let ab_shape = {
75 let mut ab_shape = smallvec![0;2];
76 let mut a_shape = a_shape.clone();
77 let mut b_shape = b_shape.clone();
78 let a_dim = matching_dim(&a_shape, &b_shape)?;
79 let b_dim = matching_dim(&b_shape, &a_shape)?;
80 a_shape.remove(a_dim);
81 b_shape.remove(b_dim);
82 ab_shape[0] = a_shape[0];
83 ab_shape[1] = b_shape[0];
84 Shape(ab_shape)
85 };
86
87 let c_shape = multi_broadcast(&[ab_shape.clone(), c.shape.clone()])
88 .expect("Could not broadcast C -> A*B in GEMM");
89
90 let res = Tensor::new(providers[0].dt, ab_shape.clone());
91
92 Ok(RealizedOp {
93 cost: self.compute_cost(a_shape.clone(), b_shape.clone(), ab_shape, c_shape),
94 outputs: smallvec![res.into_arc_tensor()],
95 })
96 }
97}
98
99pub fn build_gemm(proto: &onnx_pb::NodeProto) -> Result<BoxOp, anyhow::Error> {
100 let trans_a = proto.get_attribute("transA", Some(0))? as usize;
101 let trans_b = proto.get_attribute("transB", Some(0))? as usize;
102 Ok(Box::new(Gemm { trans_a, trans_b }) as BoxOp)
103}