1use std::collections::HashMap;
4
5use snafu::ResultExt;
6
7use crate::Tensor;
8use crate::error::UOpSnafu;
9use crate::reduce::AxisSpec;
10
11type Result<T> = crate::Result<T>;
12
13fn argsort<T: Ord>(slice: &[T]) -> Vec<usize> {
14 let mut indices: Vec<usize> = (0..slice.len()).collect();
15 indices.sort_by(|&a, &b| slice[a].cmp(&slice[b]));
16 indices
17}
18
19impl Tensor {
20 pub fn einsum(formula: &str, operands: &[&Tensor]) -> Result<Tensor> {
21 let mut xs: Vec<Tensor> = operands.iter().map(|t| (*t).clone()).collect();
22 let formula = formula.replace(' ', "");
23
24 let formula = if formula.contains("...") {
26 let all_chars: std::collections::HashSet<char> =
27 formula.chars().filter(|c| c.is_ascii_alphabetic()).collect();
28 let ell: String = ('a'..='z').chain('A'..='Z').filter(|c| !all_chars.contains(c)).collect();
29
30 let lhs = formula.split("->").next().unwrap();
31 let input_strs: Vec<&str> = lhs.split(',').collect();
32
33 let ell_n: Vec<usize> = input_strs
34 .iter()
35 .zip(xs.iter())
36 .map(|(s, x)| {
37 if s.contains("...") {
38 let ndim = x.ndim().unwrap();
39 let non_ell_chars = s.len() - 3; ndim.saturating_sub(non_ell_chars)
41 } else {
42 0
43 }
44 })
45 .collect();
46
47 let max_ell_n = *ell_n.iter().max().unwrap_or(&0);
48
49 let mut new_inputs: Vec<String> = Vec::new();
50 for (i, s) in input_strs.iter().enumerate() {
51 let replacement = &ell[max_ell_n - ell_n[i]..max_ell_n];
52 new_inputs.push(s.replace("...", replacement));
53 }
54
55 let new_lhs = new_inputs.join(",");
56
57 let ell_chars: std::collections::HashSet<char> = ell[..max_ell_n].chars().collect();
59 let auto: String = {
60 let mut chars: Vec<char> = lhs
61 .chars()
62 .filter(|c| {
63 c.is_ascii_alphabetic() && *c != '.' && lhs.matches(*c).count() == 1 && !ell_chars.contains(c)
64 })
65 .collect();
66 chars.sort();
67 chars.into_iter().collect()
68 };
69
70 if formula.contains("->") {
71 let rhs = formula.split("->").nth(1).unwrap();
72 let new_rhs = rhs.replace("...", &ell[..max_ell_n]);
73 format!("{new_lhs}->{new_rhs}")
74 } else {
75 format!("{new_lhs}->{}{auto}", &ell[..max_ell_n])
76 }
77 } else {
78 formula
79 };
80
81 let (lhs, rhs) = if formula.contains("->") {
83 let parts: Vec<&str> = formula.split("->").collect();
84 (parts[0].to_string(), parts[1].to_string())
85 } else {
86 let auto: String = {
87 let mut chars: Vec<char> =
88 formula.chars().filter(|c| c.is_ascii_alphabetic() && formula.matches(*c).count() == 1).collect();
89 chars.sort();
90 chars.into_iter().collect()
91 };
92 (formula.clone(), auto)
93 };
94
95 let mut inputs: Vec<String> = lhs.split(',').map(|s| s.to_string()).collect();
96
97 for i in 0..inputs.len() {
99 let mut s = inputs[i].clone();
100 let mut x = xs[i].clone();
101 let unique_chars: Vec<char> = {
102 let mut seen = std::collections::HashSet::new();
103 s.chars().filter(move |c| seen.insert(*c)).collect()
104 };
105 for c in unique_chars {
106 while s.matches(c).count() > 1 {
107 let j = s.find(c).unwrap();
108 let k = s[j + 1..].find(c).unwrap() + j + 1;
109 let shape = x.shape()?;
110 let n = shape[j].as_const().unwrap();
111 let ndim = x.ndim()?;
112
113 if ndim > 2 {
114 let mut perm: Vec<isize> =
116 (0..ndim).filter(|&d| d != j && d != k).map(|d| d as isize).collect();
117 perm.push(j as isize);
118 perm.push(k as isize);
119 x = x.try_permute(&perm)?;
120
121 x = x.flatten_range(-2, -1)?;
123
124 let new_ndim = x.ndim()?;
126 let mut padding = vec![(0isize, 0isize); new_ndim];
127 padding[new_ndim - 1] = (0, n as isize);
128 x = x.try_pad(&padding)?;
129
130 x = x.unflatten(-1, &[n as isize, (n + 1) as isize])?;
132
133 let cur_ndim = x.ndim()?;
135 let mut ranges: Vec<(isize, isize)> =
136 x.shape()?.iter().map(|d| (0, d.as_const().unwrap() as isize)).collect();
137 ranges[cur_ndim - 1] = (0, 1);
138 x = x.try_shrink(&ranges)?;
139 x = x.try_squeeze(Some(-1))?;
140 } else {
141 x = x.flatten()?;
144 let stride = n + 1;
145 x = x.try_stride(&[stride as isize])?;
146 }
147
148 s = format!("{}{}", &s[..k], &s[k + 1..]);
150 }
151 }
152 inputs[i] = s;
153 xs[i] = x;
154 }
155
156 let mut sz: HashMap<char, usize> = HashMap::new();
158 for (s, x) in inputs.iter().zip(xs.iter()) {
159 let shape = x.shape()?;
160 for (c, dim) in s.chars().zip(shape.iter()) {
161 let dim_val = dim.as_const().unwrap();
162 sz.insert(c, dim_val);
163 }
164 }
165
166 let mut alpha: Vec<char> = sz.keys().copied().collect();
167 alpha.sort();
168
169 let full_shape: Vec<isize> = alpha.iter().map(|c| sz[c] as isize).collect();
171
172 let mut aligned: Vec<Tensor> = Vec::new();
173 for (s, x) in inputs.iter().zip(xs.iter()) {
174 if s.is_empty() {
175 aligned.push(x.clone());
176 } else {
177 let mut sorted_chars: Vec<char> = s.chars().collect();
178
179 let mut char_positions: Vec<(char, usize)> = s.chars().enumerate().map(|(i, c)| (c, i)).collect();
180 char_positions.sort_by_key(|(c, _)| *c);
181 let perm: Vec<isize> = char_positions.iter().map(|(_, pos)| *pos as isize).collect();
182
183 sorted_chars.sort();
184
185 let x = x.try_permute(&perm)?;
186
187 let reshape: Vec<isize> =
189 alpha.iter().map(|c| if sorted_chars.contains(c) { sz[c] as isize } else { 1 }).collect();
190 let x = x.try_reshape(&reshape)?;
191
192 let x = x.try_expand(&full_shape)?;
194 aligned.push(x);
195 }
196 }
197
198 let mut product = aligned[0].clone();
200 for t in aligned.iter().skip(1) {
201 product = product.try_mul(t)?;
202 }
203
204 let sum_axes: Vec<isize> =
206 alpha.iter().enumerate().filter(|(_, c)| !rhs.contains(**c)).map(|(i, _)| i as isize).collect();
207
208 if !sum_axes.is_empty() {
209 product = product.sum_with().axes(AxisSpec::Multiple(sum_axes)).call()?;
210 }
211
212 if !rhs.is_empty() {
214 let rhs_chars: Vec<char> = rhs.chars().collect();
215 let rhs_order = argsort(&argsort(&rhs_chars));
216 let perm: Vec<isize> = rhs_order.iter().map(|&i| i as isize).collect();
217 product = product.try_permute(&perm)?;
218 }
219
220 Ok(product)
221 }
222
223 fn flatten_range(&self, start: isize, end: isize) -> Result<Tensor> {
225 let shape = self.shape()?;
226 let ndim = shape.len();
227 let start = Self::normalize_axis(start, ndim)?;
228 let end = Self::normalize_axis(end, ndim)?;
229
230 let mut new_shape: Vec<isize> = Vec::new();
231 let mut merged = 1isize;
232 for (i, d) in shape.iter().enumerate() {
233 let v = d.as_const().unwrap() as isize;
234 if i >= start && i <= end {
235 merged *= v;
236 if i == end {
237 new_shape.push(merged);
238 }
239 } else {
240 new_shape.push(v);
241 }
242 }
243 self.try_reshape(&new_shape)
244 }
245
246 fn try_stride(&self, strides: &[isize]) -> Result<Tensor> {
248 let shape = self.shape()?;
249 let ndim = shape.len();
250 assert_eq!(strides.len(), ndim);
251
252 let mut result = self.clone();
253 for (dim, &stride) in strides.iter().enumerate() {
254 if stride == 1 {
255 continue;
256 }
257 let cur_shape = result.shape()?;
258 let dim_size = cur_shape[dim].as_const().unwrap();
259 let new_dim_size = dim_size.div_ceil(stride as usize);
260
261 let mut new_shape = svod_ir::shape::to_vec_isize(&cur_shape).context(UOpSnafu)?;
263
264 let padded_size = new_dim_size * stride as usize;
266 if padded_size != dim_size {
267 let mut padding = vec![(0isize, 0isize); result.ndim()?];
268 padding[dim] = (0, (padded_size - dim_size) as isize);
269 result = result.try_pad(&padding)?;
270 new_shape[dim] = padded_size as isize;
271 }
272
273 new_shape.splice(dim..=dim, [new_dim_size as isize, stride]);
275 result = result.try_reshape(&new_shape)?;
276
277 let mut ranges: Vec<(isize, isize)> =
278 result.shape()?.iter().map(|d| (0, d.as_const().unwrap() as isize)).collect();
279 ranges[dim + 1] = (0, 1);
280 result = result.try_shrink(&ranges)?;
281 result = result.try_squeeze(Some((dim + 1) as isize))?;
282 }
283 Ok(result)
284 }
285}