1#[macro_export]
2macro_rules! dot {
39 ($lhs:ident: $($l_dim:literal)x+ * $rhs:ident: $($r_dim:literal)x+ => $out:ident: $($o_dim:literal)x+) => {
40 impl<T, U, V> std::ops::Mul<$rhs<U>> for $lhs<T>
41 where
42 T: tensor_macros::traits::TensorTrait + std::ops::Mul<U, Output=V>,
43 U: tensor_macros::traits::TensorTrait,
44 V: tensor_macros::traits::TensorTrait,
45 {
46 type Output = $out<V>;
47
48 fn mul(self, rhs: $rhs<U>) -> Self::Output {
49 let mut out = Self::Output::new();
50
51 split!(self, rhs, out; $($l_dim),*; $($r_dim),*; $($o_dim),*;;;;;);
52
53 out
54 }
55 }
56 };
57}
58
59#[macro_export]
60macro_rules! split {
61 ($($i:expr),*;
76 $($ls:literal),*; $r:literal $(,$rs:literal)*; $o1:literal, $o2:literal $(,$os:literal)*;
77 $($rr:literal),*;
78 $($c:literal),*; $($or:literal),*;;) => {
79 split!($($i),*;
80 $($ls),*; $($rs),*; $($os),*;
81 $r $(,$rr)*;
82 $($c,)* $o1; $($or,)* $o1, $o2;;
83 );
84 };
85
86 ($($i:expr),*;
87 $($ls:literal),*;; $o1:literal, $o2:literal $(,$os:literal)*;
88 $($rr:literal),*;
89 $($c:literal),*; $($or:literal),*;;) => {
90 split!($($i),*;
91 $($ls),*;; $($os),*;
92 $($rr),*;
93 $($c,)* $o1; $($or,)* $o1, $o2;;
94 );
95 };
96
97 ($($i:expr),*;
98 $($ls:literal),*; $r:literal $(,$rs:literal)*;;
99 $($rr:literal),*;
100 $c1:literal $(,$c:literal)*; $or1:literal $(,$or:literal)*; $($ol:literal),*;) => {
101 split!($($i),*;
102 $($ls),*; $($rs),*;;
103 $r $(,$rr)*;
104 $($c),*; $($or),*; $($ol,)* $or1;
105 );
106 };
107
108 ($($i:expr),*;
109 $($ls:literal),*;;;
110 $($rr:literal),*;
111 $c1:literal $(,$c:literal)*; $or1:literal $(,$or:literal)*; $($ol:literal),*;) => {
112 split!($($i),*;
113 $($ls),*;;;
114 $($rr),*;
115 $($c),*; $($or),*; $($ol,)* $or1;
116 );
117 };
118
119 ($($i:expr),*;
120 $($ls:literal),*; $r:literal $(,$rs:literal)*;;
121 $($rr:literal),*;
122 ; $or1:literal $(,$or:literal)*; $($ol:literal),*; $($orr:literal),*) => {
123 split!($($i),*;
124 $($ls),*; $($rs),*;;
125 $r $(,$rr)*;
126 ; $($or),*; $($ol),*; $or1 $(,$orr)*
127 );
128 };
129
130 ($($i:expr),*;
131 $($ls:literal),*;;;
132 $($rr:literal),*;
133 ; $or1:literal $(,$or:literal)*; $($ol:literal),*; $($orr:literal),*) => {
134 split!($($i),*;
135 $($ls),*;;;
136 $($rr),*;
137 ; $($or),*; $($ol),*; $or1 $(,$orr)*
138 );
139 };
140
141 ($($i:expr),*;
142 $($ls:literal),*; $r:literal $(,$rs:literal)*;;
143 $($rr:literal),*;
144 ;; $($ol:literal),*; $($orr:literal),*) => {
145 split!($($i),*;
146 $($ls),*; $($rs),*;;
147 $r $(,$rr)*;
148 ;; $($ol),*; $($orr),*
149 );
150 };
151
152 ($($i:expr),*;
153 $($ls:literal),*;;;
154 $($rr:literal),*;
155 ;; $($ol:literal),*; $($orr:literal),*) => {
156 split!(~ $($i),*;
157 $($ls),*;
158 $($rr),*;
159 $($ol),*; $($orr),*;;;
160 );
161 };
162
163 (~ $($i:expr),*;
170 $l1:literal $(,$l:literal)*;
171 $r1:literal $(,$r:literal)*;
172 $ol1:literal $(,$ol:literal)*;
173 $or1:literal $(,$or:literal)*;
174 $($ld:literal),*; $($rd:literal),*;) => {
175 assert_eq!($l1, $ol1, "Bad dimensions for tensor product");
176 assert_eq!($r1, $or1, "Bad dimensions for tensor product");
177 split!(~ $($i),*;
178 $($l),*;
179 $($r),*;
180 $($ol),*; $($or),*;
181 $($ld,)* $l1;
182 $r1 $(,$rd)*;
183 );
184 };
185 (~ $($i:expr),*;
186 $l1:literal $(,$l:literal)*;
187 $r1:literal $(,$r:literal),*;
188 ;;
189 $($ld:literal),*; $($rd:literal),*;
190 $($sd:literal),*) => {
191 split!(~ $($i),*;
192 $($l),*;
193 $($r),*;
194 ;;
195 $($ld),*; $($rd),*;
196 $($sd,)* $l1
197 );
198 };
199 (~ $($i:expr),*;
200 ;;
201 ;;
202 $($ld:literal),*; $($rd:literal),*;
203 $($sd:literal),*) => {
204 make_dot!($($i),*; $($ld),*; $($sd),*; $($rd),*;;;);
205 }
206}
207
208#[macro_export]
209macro_rules! make_dot {
210 ($l:expr, $r:expr, $o:expr; $d:literal $(,$ld:literal)*; $($sd:literal),*; $($rd:literal),*; $($lv:ident),*; $($sv:ident),*; $($rv:ident),*) => {
211 for i in 0..$d {
212 make_dot!($l, $r, $o; $($ld),*; $($sd),*; $($rd),*; $($lv,)* i; $($sv),*; $($rv),*);
213 }
214 };
215 ($l:expr, $r:expr, $o:expr; ; $d:literal $(,$sd:literal)*; $($rd:literal),*; $($lv:ident),*; $($sv:ident),*; $($rv:ident),*) => {
216 for j in 0..$d {
217 make_dot!($l, $r, $o; ; $($sd),*; $($rd),*; $($lv),*; $($sv,)* j; $($rv),*);
218 }
219 };
220 ($l:expr, $r:expr, $o:expr; ;; $d:literal $(,$rd:literal)*; $($lv:ident),*; $($sv:ident),*; $($rv:ident),*) => {
221 for k in 0..$d {
222 make_dot!($l, $r, $o; ;; $($rd),*; $($lv),*; $($sv),*; $($rv,)* k);
223 }
224 };
225 ($l:expr, $r:expr, $o:expr; ;;; $($lv:ident),+; $($sv:ident),+; $($rv:ident),*) => {
226 $o[($($lv),* $(,$rv),*)] += $l[($($lv),* $(,$sv)*)] * $r[($($sv,)* $($rv),*)]
227 };
228 ($l:expr, $r:expr, $o:expr; ;;;; $($sv:ident),+; $($rv:ident),+) => {
229 $o[($($rv),*)] += $l[($($lv,)* $($sv)*)] * $r[($($sv),* $(,$rv)*)]
230 };
231 ($l:expr, $r:expr, $o:expr; ;;;; $($sv:ident),+;) => {
232 $o += $l[($($lv),* $(,$sv)*)] * $r[($($sv,)* $($rv),*)]
233 };
234}
235
236#[cfg(test)]
237mod tests {
238 use crate as tensor_macros;
239
240 tensor!(T243: 2 x 4 x 3);
241 tensor!(M43: 4 x 3 x 1);
242 tensor!(V2: 2 x 1);
243
244 dot!(T243: 2 x 4 x 3 * M43: 4 x 3 x 1 => V2: 2 x 1);
245
246 #[test]
247 fn dot() {
248 let l = T243([
249 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
250 ]);
251 let r = M43([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]);
252 assert_eq!(l * r, V2([506, 1298]));
253 }
254}