Skip to main content

svod_tensor/
einsum.rs

1//! Einstein summation convention.
2
3use std::collections::HashMap;
4
5use snafu::ResultExt;
6
7use crate::Tensor;
8use crate::error::UOpSnafu;
9use crate::reduce::AxisSpec;
10
11type Result<T> = crate::Result<T>;
12
13fn argsort<T: Ord>(slice: &[T]) -> Vec<usize> {
14    let mut indices: Vec<usize> = (0..slice.len()).collect();
15    indices.sort_by(|&a, &b| slice[a].cmp(&slice[b]));
16    indices
17}
18
19impl Tensor {
20    pub fn einsum(formula: &str, operands: &[&Tensor]) -> Result<Tensor> {
21        let mut xs: Vec<Tensor> = operands.iter().map(|t| (*t).clone()).collect();
22        let formula = formula.replace(' ', "");
23
24        // Expand ellipsis
25        let formula = if formula.contains("...") {
26            let all_chars: std::collections::HashSet<char> =
27                formula.chars().filter(|c| c.is_ascii_alphabetic()).collect();
28            let ell: String = ('a'..='z').chain('A'..='Z').filter(|c| !all_chars.contains(c)).collect();
29
30            let lhs = formula.split("->").next().unwrap();
31            let input_strs: Vec<&str> = lhs.split(',').collect();
32
33            let ell_n: Vec<usize> = input_strs
34                .iter()
35                .zip(xs.iter())
36                .map(|(s, x)| {
37                    if s.contains("...") {
38                        let ndim = x.ndim().unwrap();
39                        let non_ell_chars = s.len() - 3; // subtract "..."
40                        ndim.saturating_sub(non_ell_chars)
41                    } else {
42                        0
43                    }
44                })
45                .collect();
46
47            let max_ell_n = *ell_n.iter().max().unwrap_or(&0);
48
49            let mut new_inputs: Vec<String> = Vec::new();
50            for (i, s) in input_strs.iter().enumerate() {
51                let replacement = &ell[max_ell_n - ell_n[i]..max_ell_n];
52                new_inputs.push(s.replace("...", replacement));
53            }
54
55            let new_lhs = new_inputs.join(",");
56
57            // Build auto output: sorted chars that appear exactly once in lhs and are not ellipsis chars
58            let ell_chars: std::collections::HashSet<char> = ell[..max_ell_n].chars().collect();
59            let auto: String = {
60                let mut chars: Vec<char> = lhs
61                    .chars()
62                    .filter(|c| {
63                        c.is_ascii_alphabetic() && *c != '.' && lhs.matches(*c).count() == 1 && !ell_chars.contains(c)
64                    })
65                    .collect();
66                chars.sort();
67                chars.into_iter().collect()
68            };
69
70            if formula.contains("->") {
71                let rhs = formula.split("->").nth(1).unwrap();
72                let new_rhs = rhs.replace("...", &ell[..max_ell_n]);
73                format!("{new_lhs}->{new_rhs}")
74            } else {
75                format!("{new_lhs}->{}{auto}", &ell[..max_ell_n])
76            }
77        } else {
78            formula
79        };
80
81        // Split into lhs and rhs
82        let (lhs, rhs) = if formula.contains("->") {
83            let parts: Vec<&str> = formula.split("->").collect();
84            (parts[0].to_string(), parts[1].to_string())
85        } else {
86            let auto: String = {
87                let mut chars: Vec<char> =
88                    formula.chars().filter(|c| c.is_ascii_alphabetic() && formula.matches(*c).count() == 1).collect();
89                chars.sort();
90                chars.into_iter().collect()
91            };
92            (formula.clone(), auto)
93        };
94
95        let mut inputs: Vec<String> = lhs.split(',').map(|s| s.to_string()).collect();
96
97        // Trace: diagonal for repeated letters
98        for i in 0..inputs.len() {
99            let mut s = inputs[i].clone();
100            let mut x = xs[i].clone();
101            let unique_chars: Vec<char> = {
102                let mut seen = std::collections::HashSet::new();
103                s.chars().filter(move |c| seen.insert(*c)).collect()
104            };
105            for c in unique_chars {
106                while s.matches(c).count() > 1 {
107                    let j = s.find(c).unwrap();
108                    let k = s[j + 1..].find(c).unwrap() + j + 1;
109                    let shape = x.shape()?;
110                    let n = shape[j].as_const().unwrap();
111                    let ndim = x.ndim()?;
112
113                    if ndim > 2 {
114                        // permute so j,k are last two dims
115                        let mut perm: Vec<isize> =
116                            (0..ndim).filter(|&d| d != j && d != k).map(|d| d as isize).collect();
117                        perm.push(j as isize);
118                        perm.push(k as isize);
119                        x = x.try_permute(&perm)?;
120
121                        // flatten last two dims
122                        x = x.flatten_range(-2, -1)?;
123
124                        // pad with n zeros at end of last dim
125                        let new_ndim = x.ndim()?;
126                        let mut padding = vec![(0isize, 0isize); new_ndim];
127                        padding[new_ndim - 1] = (0, n as isize);
128                        x = x.try_pad(&padding)?;
129
130                        // unflatten last dim into [n, n+1]
131                        x = x.unflatten(-1, &[n as isize, (n + 1) as isize])?;
132
133                        // take element 0 along last dim (shrink + squeeze)
134                        let cur_ndim = x.ndim()?;
135                        let mut ranges: Vec<(isize, isize)> =
136                            x.shape()?.iter().map(|d| (0, d.as_const().unwrap() as isize)).collect();
137                        ranges[cur_ndim - 1] = (0, 1);
138                        x = x.try_shrink(&ranges)?;
139                        x = x.try_squeeze(Some(-1))?;
140                    } else {
141                        // 2D diagonal: use flatten + stride approach
142                        // For a [n, n] matrix, diagonal = flatten then take every (n+1)th element
143                        x = x.flatten()?;
144                        let stride = n + 1;
145                        x = x.try_stride(&[stride as isize])?;
146                    }
147
148                    // Remove the second occurrence of c from s
149                    s = format!("{}{}", &s[..k], &s[k + 1..]);
150                }
151            }
152            inputs[i] = s;
153            xs[i] = x;
154        }
155
156        // Build size map
157        let mut sz: HashMap<char, usize> = HashMap::new();
158        for (s, x) in inputs.iter().zip(xs.iter()) {
159            let shape = x.shape()?;
160            for (c, dim) in s.chars().zip(shape.iter()) {
161                let dim_val = dim.as_const().unwrap();
162                sz.insert(c, dim_val);
163            }
164        }
165
166        let mut alpha: Vec<char> = sz.keys().copied().collect();
167        alpha.sort();
168
169        // Align, multiply, sum, permute
170        let full_shape: Vec<isize> = alpha.iter().map(|c| sz[c] as isize).collect();
171
172        let mut aligned: Vec<Tensor> = Vec::new();
173        for (s, x) in inputs.iter().zip(xs.iter()) {
174            if s.is_empty() {
175                aligned.push(x.clone());
176            } else {
177                let mut sorted_chars: Vec<char> = s.chars().collect();
178
179                let mut char_positions: Vec<(char, usize)> = s.chars().enumerate().map(|(i, c)| (c, i)).collect();
180                char_positions.sort_by_key(|(c, _)| *c);
181                let perm: Vec<isize> = char_positions.iter().map(|(_, pos)| *pos as isize).collect();
182
183                sorted_chars.sort();
184
185                let x = x.try_permute(&perm)?;
186
187                // Reshape: insert 1s for missing dims
188                let reshape: Vec<isize> =
189                    alpha.iter().map(|c| if sorted_chars.contains(c) { sz[c] as isize } else { 1 }).collect();
190                let x = x.try_reshape(&reshape)?;
191
192                // Expand to full shape
193                let x = x.try_expand(&full_shape)?;
194                aligned.push(x);
195            }
196        }
197
198        // Multiply all aligned tensors
199        let mut product = aligned[0].clone();
200        for t in aligned.iter().skip(1) {
201            product = product.try_mul(t)?;
202        }
203
204        // Sum over axes not in rhs
205        let sum_axes: Vec<isize> =
206            alpha.iter().enumerate().filter(|(_, c)| !rhs.contains(**c)).map(|(i, _)| i as isize).collect();
207
208        if !sum_axes.is_empty() {
209            product = product.sum_with().axes(AxisSpec::Multiple(sum_axes)).call()?;
210        }
211
212        // Permute to match rhs order
213        if !rhs.is_empty() {
214            let rhs_chars: Vec<char> = rhs.chars().collect();
215            let rhs_order = argsort(&argsort(&rhs_chars));
216            let perm: Vec<isize> = rhs_order.iter().map(|&i| i as isize).collect();
217            product = product.try_permute(&perm)?;
218        }
219
220        Ok(product)
221    }
222
223    /// Flatten a range of dimensions (inclusive).
224    fn flatten_range(&self, start: isize, end: isize) -> Result<Tensor> {
225        let shape = self.shape()?;
226        let ndim = shape.len();
227        let start = Self::normalize_axis(start, ndim)?;
228        let end = Self::normalize_axis(end, ndim)?;
229
230        let mut new_shape: Vec<isize> = Vec::new();
231        let mut merged = 1isize;
232        for (i, d) in shape.iter().enumerate() {
233            let v = d.as_const().unwrap() as isize;
234            if i >= start && i <= end {
235                merged *= v;
236                if i == end {
237                    new_shape.push(merged);
238                }
239            } else {
240                new_shape.push(v);
241            }
242        }
243        self.try_reshape(&new_shape)
244    }
245
246    /// Stride along each dimension (take every nth element).
247    fn try_stride(&self, strides: &[isize]) -> Result<Tensor> {
248        let shape = self.shape()?;
249        let ndim = shape.len();
250        assert_eq!(strides.len(), ndim);
251
252        let mut result = self.clone();
253        for (dim, &stride) in strides.iter().enumerate() {
254            if stride == 1 {
255                continue;
256            }
257            let cur_shape = result.shape()?;
258            let dim_size = cur_shape[dim].as_const().unwrap();
259            let new_dim_size = dim_size.div_ceil(stride as usize);
260
261            // Reshape dim into [new_dim_size, stride], take [:, 0]
262            let mut new_shape = svod_ir::shape::to_vec_isize(&cur_shape).context(UOpSnafu)?;
263
264            // Pad if needed so dim is evenly divisible by stride
265            let padded_size = new_dim_size * stride as usize;
266            if padded_size != dim_size {
267                let mut padding = vec![(0isize, 0isize); result.ndim()?];
268                padding[dim] = (0, (padded_size - dim_size) as isize);
269                result = result.try_pad(&padding)?;
270                new_shape[dim] = padded_size as isize;
271            }
272
273            // Unflatten dim into [new_dim_size, stride]
274            new_shape.splice(dim..=dim, [new_dim_size as isize, stride]);
275            result = result.try_reshape(&new_shape)?;
276
277            let mut ranges: Vec<(isize, isize)> =
278                result.shape()?.iter().map(|d| (0, d.as_const().unwrap() as isize)).collect();
279            ranges[dim + 1] = (0, 1);
280            result = result.try_shrink(&ranges)?;
281            result = result.try_squeeze(Some((dim + 1) as isize))?;
282        }
283        Ok(result)
284    }
285}