tensor_macros/
dot.rs

1#[macro_export]
2/// Creates a tensor dot product function
3///
4/// # Example
5///
6/// ```rust
7
8/// #![feature(try_from)]
9///
10/// #[macro_use]
11/// use tensor_macros::*;
12/// use tensor_macros::tensor::*;
13///
14/// tensor!(T243: 2 x 4 x 3);
15/// tensor!(M43: 4 x 3 x 1);
16/// tensor!(V2: 2 x 1);
17///
18/// dot!(T243: 2 x 4 x 3 * M43: 4 x 3 x 1 => V2: 2 x 1);
19///
20/// let l = T243([
21///     0, 1, 2, 3,
22///     4, 5, 6, 7,
23///     
24///     8, 9, 10, 11,
25///     12, 13, 14, 15,
26///
27///     16, 17, 18, 19,
28///     20, 21, 22, 23,
29/// ]);
30/// let r = M43([
31///     0, 1, 2,
32///     3, 4, 5,
33///     6, 7, 8,
34///     9, 10, 11
35/// ]);
36/// assert_eq!(l * r, V2([506, 1298]));
37/// ```
38macro_rules! dot {
39    ($lhs:ident: $($l_dim:literal)x+ * $rhs:ident: $($r_dim:literal)x+ => $out:ident: $($o_dim:literal)x+) => {
40        impl<T, U, V> std::ops::Mul<$rhs<U>> for $lhs<T>
41        where
42            T: tensor_macros::traits::TensorTrait + std::ops::Mul<U, Output=V>,
43            U: tensor_macros::traits::TensorTrait,
44            V: tensor_macros::traits::TensorTrait,
45        {
46            type Output = $out<V>;
47
48            fn mul(self, rhs: $rhs<U>) -> Self::Output {
49                let mut out = Self::Output::new();
50
51                split!(self, rhs, out; $($l_dim),*; $($r_dim),*; $($o_dim),*;;;;;);
52
53                out
54            }
55        }
56    };
57}
58
59#[macro_export]
60macro_rules! split {
61    // left; right; output; right rev; count; out right; out left; out right rev
62    // l    ; r    ; o      ; rr   ; c  ; or     ; ol; orr
63    // 2 4 3; 3 2 1; 2 4 2 1;
64    // 2 4 3;   2 1;     2 1; 3    ; 2  ; 2 4    ;
65    // 2 4 3;     1;        ; 2 3  ; 2 2; 2 4 2 1;
66    // 2 4 3;      ;        ; 1 2 3;   2;   4 2 1; 2
67    // 2 4 3;      ;        ; 1 2 3;    ;     2 1; 2 4
68    // 2 4 3;      ;        ; 1 2 3;    ;       1; 2 4; 2
69    // 2 4 3;      ;        ; 1 2 3;    ;        ; 2 4; 1 2
70
71    // 2 4 3; 4 3 1; 2 1;
72    // 2 4 3;   3 1;    ;     4; 2; 2 1;  ;
73    // 2 4 3;     1;    ;   3 4;  ;   1; 2;
74    // 2 4 3;      ;    ; 1 3 4;  ;    ; 2; 1
75    ($($i:expr),*;
76        $($ls:literal),*; $r:literal $(,$rs:literal)*; $o1:literal, $o2:literal $(,$os:literal)*;
77        $($rr:literal),*;
78        $($c:literal),*; $($or:literal),*;;) => {
79        split!($($i),*;
80            $($ls),*; $($rs),*; $($os),*;
81            $r $(,$rr)*;
82            $($c,)* $o1; $($or,)* $o1, $o2;;
83        );
84    };
85
86    ($($i:expr),*;
87        $($ls:literal),*;; $o1:literal, $o2:literal $(,$os:literal)*;
88        $($rr:literal),*;
89        $($c:literal),*; $($or:literal),*;;) => {
90        split!($($i),*;
91            $($ls),*;; $($os),*;
92            $($rr),*;
93            $($c,)* $o1; $($or,)* $o1, $o2;;
94        );
95    };
96
97    ($($i:expr),*;
98        $($ls:literal),*; $r:literal $(,$rs:literal)*;;
99        $($rr:literal),*;
100        $c1:literal $(,$c:literal)*; $or1:literal $(,$or:literal)*; $($ol:literal),*;) => {
101        split!($($i),*;
102            $($ls),*; $($rs),*;;
103            $r $(,$rr)*;
104            $($c),*; $($or),*; $($ol,)* $or1;
105        );
106    };
107
108    ($($i:expr),*;
109        $($ls:literal),*;;;
110        $($rr:literal),*;
111        $c1:literal $(,$c:literal)*; $or1:literal $(,$or:literal)*; $($ol:literal),*;) => {
112        split!($($i),*;
113            $($ls),*;;;
114            $($rr),*;
115            $($c),*; $($or),*; $($ol,)* $or1;
116        );
117    };
118
119    ($($i:expr),*;
120        $($ls:literal),*; $r:literal $(,$rs:literal)*;;
121        $($rr:literal),*;
122        ; $or1:literal $(,$or:literal)*; $($ol:literal),*; $($orr:literal),*) => {
123        split!($($i),*;
124            $($ls),*; $($rs),*;;
125            $r $(,$rr)*;
126            ; $($or),*; $($ol),*; $or1 $(,$orr)*
127        );
128    };
129
130    ($($i:expr),*;
131        $($ls:literal),*;;;
132        $($rr:literal),*;
133        ; $or1:literal $(,$or:literal)*; $($ol:literal),*; $($orr:literal),*) => {
134        split!($($i),*;
135            $($ls),*;;;
136            $($rr),*;
137            ; $($or),*; $($ol),*; $or1 $(,$orr)*
138        );
139    };
140
141    ($($i:expr),*;
142        $($ls:literal),*; $r:literal $(,$rs:literal)*;;
143        $($rr:literal),*;
144        ;; $($ol:literal),*; $($orr:literal),*) => {
145        split!($($i),*;
146            $($ls),*; $($rs),*;;
147            $r $(,$rr)*;
148            ;; $($ol),*; $($orr),*
149        );
150    };
151
152    ($($i:expr),*;
153        $($ls:literal),*;;;
154        $($rr:literal),*;
155        ;; $($ol:literal),*; $($orr:literal),*) => {
156        split!(~ $($i),*;
157            $($ls),*;
158            $($rr),*;
159            $($ol),*; $($orr),*;;;
160        );
161    };
162
163    // Actually performing the split
164    // 2, 4, 3 * 3, 2, 1 -> 2, 4, 2, 1
165    // 2, 4, 3 | 1, 2, 3 | 2, 4 | 1, 2 |      |      |
166    //    4, 3 |    2, 3 |    4 |    2 | 2    |    1 |
167    //       3 |       3 |      |      | 2, 4 | 2, 1 |   |
168    //         |         |      |      | 2, 4 | 2, 1 | 3 |
169    (~ $($i:expr),*;
170        $l1:literal $(,$l:literal)*;
171        $r1:literal $(,$r:literal)*;
172        $ol1:literal $(,$ol:literal)*;
173        $or1:literal $(,$or:literal)*;
174        $($ld:literal),*; $($rd:literal),*;) => {
175        assert_eq!($l1, $ol1, "Bad dimensions for tensor product");
176        assert_eq!($r1, $or1, "Bad dimensions for tensor product");
177        split!(~ $($i),*;
178            $($l),*;
179            $($r),*;
180            $($ol),*; $($or),*;
181            $($ld,)* $l1;
182            $r1 $(,$rd)*;
183        );
184    };
185    (~ $($i:expr),*;
186        $l1:literal $(,$l:literal)*;
187        $r1:literal $(,$r:literal),*;
188        ;;
189        $($ld:literal),*; $($rd:literal),*;
190        $($sd:literal),*) => {
191        split!(~ $($i),*;
192            $($l),*;
193            $($r),*;
194            ;;
195            $($ld),*; $($rd),*;
196            $($sd,)* $l1
197        );
198    };
199    (~ $($i:expr),*;
200        ;;
201        ;;
202        $($ld:literal),*; $($rd:literal),*;
203        $($sd:literal),*) => {
204        make_dot!($($i),*; $($ld),*; $($sd),*; $($rd),*;;;);
205    }
206}
207
208#[macro_export]
209macro_rules! make_dot {
210    ($l:expr, $r:expr, $o:expr; $d:literal $(,$ld:literal)*; $($sd:literal),*; $($rd:literal),*; $($lv:ident),*; $($sv:ident),*; $($rv:ident),*) => {
211        for i in 0..$d {
212            make_dot!($l, $r, $o; $($ld),*; $($sd),*; $($rd),*; $($lv,)* i; $($sv),*; $($rv),*);
213        }
214    };
215    ($l:expr, $r:expr, $o:expr; ; $d:literal $(,$sd:literal)*; $($rd:literal),*; $($lv:ident),*; $($sv:ident),*; $($rv:ident),*) => {
216        for j in 0..$d {
217            make_dot!($l, $r, $o; ; $($sd),*; $($rd),*; $($lv),*; $($sv,)* j; $($rv),*);
218        }
219    };
220    ($l:expr, $r:expr, $o:expr; ;; $d:literal $(,$rd:literal)*; $($lv:ident),*; $($sv:ident),*; $($rv:ident),*) => {
221        for k in 0..$d {
222            make_dot!($l, $r, $o; ;; $($rd),*; $($lv),*; $($sv),*; $($rv,)* k);
223        }
224    };
225    ($l:expr, $r:expr, $o:expr; ;;; $($lv:ident),+; $($sv:ident),+; $($rv:ident),*) => {
226        $o[($($lv),* $(,$rv),*)] += $l[($($lv),* $(,$sv)*)] * $r[($($sv,)* $($rv),*)]
227    };
228    ($l:expr, $r:expr, $o:expr; ;;;; $($sv:ident),+; $($rv:ident),+) => {
229        $o[($($rv),*)] += $l[($($lv,)* $($sv)*)] * $r[($($sv),* $(,$rv)*)]
230    };
231    ($l:expr, $r:expr, $o:expr; ;;;; $($sv:ident),+;) => {
232        $o += $l[($($lv),* $(,$sv)*)] * $r[($($sv,)* $($rv),*)]
233    };
234}
235
236#[cfg(test)]
237mod tests {
238    use crate as tensor_macros;
239
240    tensor!(T243: 2 x 4 x 3);
241    tensor!(M43: 4 x 3 x 1);
242    tensor!(V2: 2 x 1);
243
244    dot!(T243: 2 x 4 x 3 * M43: 4 x 3 x 1 => V2: 2 x 1);
245
246    #[test]
247    fn dot() {
248        let l = T243([
249            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
250        ]);
251        let r = M43([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]);
252        assert_eq!(l * r, V2([506, 1298]));
253    }
254}