1use std::collections::HashMap;
8use std::sync::Arc;
9use thiserror::Error;
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
13pub enum SymbolicDim {
14 Fixed(usize),
16 Symbolic(Arc<str>),
18 Product(Box<SymbolicDim>, Box<SymbolicDim>),
20}
21
22impl SymbolicDim {
23 pub fn fixed(n: usize) -> Self {
24 SymbolicDim::Fixed(n)
25 }
26
27 pub fn symbolic(name: impl Into<Arc<str>>) -> Self {
28 SymbolicDim::Symbolic(name.into())
29 }
30
31 pub fn product(a: SymbolicDim, b: SymbolicDim) -> Self {
32 SymbolicDim::Product(Box::new(a), Box::new(b))
33 }
34
35 pub fn is_fixed(&self) -> bool {
36 matches!(self, SymbolicDim::Fixed(_))
37 }
38
39 pub fn is_symbolic(&self) -> bool {
40 matches!(self, SymbolicDim::Symbolic(_))
41 }
42
43 pub fn concrete_value(&self) -> Option<usize> {
45 match self {
46 SymbolicDim::Fixed(n) => Some(*n),
47 SymbolicDim::Symbolic(_) => None,
48 SymbolicDim::Product(a, b) => {
49 let va = a.concrete_value()?;
50 let vb = b.concrete_value()?;
51 Some(va * vb)
52 }
53 }
54 }
55}
56
57impl std::fmt::Display for SymbolicDim {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 match self {
60 SymbolicDim::Fixed(n) => write!(f, "{}", n),
61 SymbolicDim::Symbolic(s) => write!(f, "{}", s),
62 SymbolicDim::Product(a, b) => write!(f, "({}*{})", a, b),
63 }
64 }
65}
66
67pub type SymbolicShape = Vec<SymbolicDim>;
69
70#[derive(Debug, Clone, PartialEq, Eq)]
72pub enum SymbolicShapeConstraint {
73 Equal(SymbolicDim, SymbolicDim),
75 GreaterThan(SymbolicDim, SymbolicDim),
77 Multiple(SymbolicDim, SymbolicDim),
79}
80
81#[derive(Debug, Error)]
83pub enum ShapeError {
84 #[error("Dimension contradiction: cannot unify {0} with {1}")]
85 Contradiction(String, String),
86 #[error("Unresolved symbolic dimension: {0}")]
87 Unresolved(String),
88 #[error("Invalid einsum spec: {0}")]
89 InvalidSpec(String),
90 #[error("Arity mismatch: expected {expected} inputs, got {got}")]
91 ArityMismatch { expected: usize, got: usize },
92}
93
94#[derive(Debug, Default)]
99pub struct SymbolicShapeEnv {
100 bindings: HashMap<Arc<str>, SymbolicDim>,
102 constraints: Vec<SymbolicShapeConstraint>,
104}
105
106impl SymbolicShapeEnv {
107 pub fn new() -> Self {
108 Self::default()
109 }
110
111 pub fn resolve(&self, dim: &SymbolicDim) -> SymbolicDim {
114 match dim {
115 SymbolicDim::Symbolic(name) => {
116 if let Some(bound) = self.bindings.get(name) {
117 self.resolve(bound)
118 } else {
119 dim.clone()
120 }
121 }
122 SymbolicDim::Product(a, b) => SymbolicDim::product(self.resolve(a), self.resolve(b)),
123 SymbolicDim::Fixed(_) => dim.clone(),
124 }
125 }
126
127 pub fn concrete_value(&self, dim: &SymbolicDim) -> Option<usize> {
129 self.resolve(dim).concrete_value()
130 }
131
132 pub fn unify(&mut self, a: &SymbolicDim, b: &SymbolicDim) -> Result<SymbolicDim, ShapeError> {
135 let ra = self.resolve(a);
136 let rb = self.resolve(b);
137 match (&ra, &rb) {
138 (SymbolicDim::Fixed(x), SymbolicDim::Fixed(y)) => {
139 if x == y {
140 Ok(ra)
141 } else {
142 Err(ShapeError::Contradiction(
143 format!("{}", x),
144 format!("{}", y),
145 ))
146 }
147 }
148 (SymbolicDim::Symbolic(name_a), SymbolicDim::Symbolic(name_b)) => {
149 if name_a == name_b {
151 Ok(ra)
152 } else {
153 self.bindings.insert(name_a.clone(), rb.clone());
155 Ok(rb)
156 }
157 }
158 (SymbolicDim::Symbolic(name), _) => {
159 self.bindings.insert(name.clone(), rb.clone());
160 Ok(rb)
161 }
162 (_, SymbolicDim::Symbolic(name)) => {
163 self.bindings.insert(name.clone(), ra.clone());
164 Ok(ra)
165 }
166 (SymbolicDim::Product(_, _), SymbolicDim::Fixed(_)) => {
168 if let Some(va) = ra.concrete_value() {
169 if let Some(vb) = rb.concrete_value() {
170 if va == vb {
171 Ok(ra)
172 } else {
173 Err(ShapeError::Contradiction(
174 format!("{}", va),
175 format!("{}", vb),
176 ))
177 }
178 } else {
179 Ok(ra)
180 }
181 } else {
182 self.add_constraint(SymbolicShapeConstraint::Equal(ra, rb));
184 Ok(SymbolicDim::symbolic("_unresolved"))
185 }
186 }
187 (SymbolicDim::Fixed(_), SymbolicDim::Product(_, _)) => self.unify(b, a),
188 _ => Ok(ra),
189 }
190 }
191
192 pub fn add_constraint(&mut self, c: SymbolicShapeConstraint) {
194 self.constraints.push(c);
195 }
196
197 pub fn check_consistency(&self) -> bool {
199 for c in &self.constraints {
200 match c {
201 SymbolicShapeConstraint::Equal(a, b) => {
202 if let (Some(va), Some(vb)) = (self.concrete_value(a), self.concrete_value(b)) {
203 if va != vb {
204 return false;
205 }
206 }
207 }
208 SymbolicShapeConstraint::GreaterThan(a, b) => {
209 if let (Some(va), Some(vb)) = (self.concrete_value(a), self.concrete_value(b)) {
210 if va <= vb {
211 return false;
212 }
213 }
214 }
215 SymbolicShapeConstraint::Multiple(a, b) => {
216 if let (Some(va), Some(vb)) = (self.concrete_value(a), self.concrete_value(b)) {
217 if vb == 0 || va % vb != 0 {
218 return false;
219 }
220 }
221 }
222 }
223 }
224 true
225 }
226
227 pub fn binding_count(&self) -> usize {
229 self.bindings.len()
230 }
231
232 pub fn bound_names(&self) -> impl Iterator<Item = &Arc<str>> {
234 self.bindings.keys()
235 }
236}
237
238pub fn propagate_einsum_shapes(
248 spec: &str,
249 input_shapes: &[SymbolicShape],
250 env: &mut SymbolicShapeEnv,
251) -> Result<SymbolicShape, ShapeError> {
252 let arrow_pos = spec
254 .find("->")
255 .ok_or_else(|| ShapeError::InvalidSpec(format!("missing '->' in einsum spec: {}", spec)))?;
256 let inputs_part = &spec[..arrow_pos];
257 let output_part = &spec[arrow_pos + 2..];
258
259 let operand_specs: Vec<&str> = inputs_part.split(',').collect();
260 if operand_specs.len() != input_shapes.len() {
261 return Err(ShapeError::ArityMismatch {
262 expected: operand_specs.len(),
263 got: input_shapes.len(),
264 });
265 }
266
267 let mut index_map: HashMap<char, SymbolicDim> = HashMap::new();
269
270 for (op_spec, shape) in operand_specs.iter().zip(input_shapes.iter()) {
271 let chars: Vec<char> = op_spec.chars().filter(|c| c.is_alphabetic()).collect();
272 if chars.len() != shape.len() {
273 return Err(ShapeError::InvalidSpec(format!(
274 "spec '{}' has {} indices but shape has {} dims",
275 op_spec,
276 chars.len(),
277 shape.len()
278 )));
279 }
280 for (ch, dim) in chars.iter().zip(shape.iter()) {
281 if let Some(existing) = index_map.get(ch) {
282 let unified = env.unify(existing, dim)?;
284 index_map.insert(*ch, unified);
285 } else {
286 index_map.insert(*ch, env.resolve(dim));
287 }
288 }
289 }
290
291 let output_chars: Vec<char> = output_part.chars().filter(|c| c.is_alphabetic()).collect();
293 let mut out_shape = Vec::with_capacity(output_chars.len());
294 for ch in output_chars {
295 let dim = index_map
296 .get(&ch)
297 .cloned()
298 .unwrap_or_else(|| SymbolicDim::symbolic(format!("_out_{}", ch)));
299 out_shape.push(env.resolve(&dim));
300 }
301
302 Ok(out_shape)
303}
304
305pub fn propagate_chain(
307 specs: &[&str],
308 initial_shapes: &[SymbolicShape],
309 env: &mut SymbolicShapeEnv,
310) -> Result<Vec<SymbolicShape>, ShapeError> {
311 let mut results = Vec::new();
312 let mut current_shapes: Vec<SymbolicShape> = initial_shapes.to_vec();
313 for spec in specs {
314 let out = propagate_einsum_shapes(spec, ¤t_shapes, env)?;
315 results.push(out.clone());
316 current_shapes = vec![out];
318 }
319 Ok(results)
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
328 fn test_fixed_dim_equality() {
329 assert_eq!(SymbolicDim::fixed(3), SymbolicDim::fixed(3));
330 assert_ne!(SymbolicDim::fixed(3), SymbolicDim::fixed(4));
331 }
332
333 #[test]
334 fn test_symbolic_dim_display() {
335 let d = SymbolicDim::symbolic("batch");
336 assert_eq!(format!("{}", d), "batch");
337 }
338
339 #[test]
340 fn test_fixed_dim_concrete_value() {
341 assert_eq!(SymbolicDim::fixed(42).concrete_value(), Some(42));
342 }
343
344 #[test]
345 fn test_symbolic_dim_no_concrete_value() {
346 assert_eq!(SymbolicDim::symbolic("N").concrete_value(), None);
347 }
348
349 #[test]
350 fn test_product_dim_resolves_when_both_fixed() {
351 let p = SymbolicDim::product(SymbolicDim::fixed(3), SymbolicDim::fixed(4));
352 assert_eq!(p.concrete_value(), Some(12));
353 }
354
355 #[test]
356 fn test_product_dim_unresolved_when_symbolic() {
357 let p = SymbolicDim::product(SymbolicDim::symbolic("N"), SymbolicDim::fixed(4));
358 assert_eq!(p.concrete_value(), None);
359 }
360
361 #[test]
363 fn test_unify_fixed_same() -> Result<(), ShapeError> {
364 let mut env = SymbolicShapeEnv::new();
365 let result = env.unify(&SymbolicDim::fixed(5), &SymbolicDim::fixed(5))?;
366 assert_eq!(result, SymbolicDim::fixed(5));
367 Ok(())
368 }
369
370 #[test]
371 fn test_unify_fixed_contradiction() {
372 let mut env = SymbolicShapeEnv::new();
373 let result = env.unify(&SymbolicDim::fixed(3), &SymbolicDim::fixed(4));
374 assert!(result.is_err());
375 }
376
377 #[test]
378 fn test_unify_symbolic_binds_to_fixed() -> Result<(), ShapeError> {
379 let mut env = SymbolicShapeEnv::new();
380 env.unify(&SymbolicDim::symbolic("N"), &SymbolicDim::fixed(7))?;
381 assert_eq!(env.concrete_value(&SymbolicDim::symbolic("N")), Some(7));
382 Ok(())
383 }
384
385 #[test]
386 fn test_unify_fixed_binds_symbolic() -> Result<(), ShapeError> {
387 let mut env = SymbolicShapeEnv::new();
388 env.unify(&SymbolicDim::fixed(4), &SymbolicDim::symbolic("M"))?;
389 assert_eq!(env.concrete_value(&SymbolicDim::symbolic("M")), Some(4));
390 Ok(())
391 }
392
393 #[test]
394 fn test_unify_two_symbolics() -> Result<(), ShapeError> {
395 let mut env = SymbolicShapeEnv::new();
396 env.unify(&SymbolicDim::symbolic("A"), &SymbolicDim::symbolic("B"))?;
397 let ra = env.resolve(&SymbolicDim::symbolic("A"));
399 let rb = env.resolve(&SymbolicDim::symbolic("B"));
400 assert_eq!(ra, rb);
402 Ok(())
403 }
404
405 #[test]
406 fn test_resolve_chain() -> Result<(), ShapeError> {
407 let mut env = SymbolicShapeEnv::new();
408 env.unify(&SymbolicDim::symbolic("A"), &SymbolicDim::symbolic("B"))?;
409 env.unify(&SymbolicDim::symbolic("B"), &SymbolicDim::fixed(10))?;
410 assert_eq!(env.concrete_value(&SymbolicDim::symbolic("A")), Some(10));
411 Ok(())
412 }
413
414 #[test]
415 fn test_binding_count() -> Result<(), ShapeError> {
416 let mut env = SymbolicShapeEnv::new();
417 assert_eq!(env.binding_count(), 0);
418 env.unify(&SymbolicDim::symbolic("N"), &SymbolicDim::fixed(5))?;
419 assert_eq!(env.binding_count(), 1);
420 Ok(())
421 }
422
423 #[test]
425 fn test_constraint_consistency_equal() {
426 let mut env = SymbolicShapeEnv::new();
427 env.add_constraint(SymbolicShapeConstraint::Equal(
428 SymbolicDim::fixed(3),
429 SymbolicDim::fixed(3),
430 ));
431 assert!(env.check_consistency());
432 }
433
434 #[test]
435 fn test_constraint_inconsistency_equal() {
436 let mut env = SymbolicShapeEnv::new();
437 env.add_constraint(SymbolicShapeConstraint::Equal(
438 SymbolicDim::fixed(3),
439 SymbolicDim::fixed(5),
440 ));
441 assert!(!env.check_consistency());
442 }
443
444 #[test]
445 fn test_constraint_greater_than() {
446 let mut env = SymbolicShapeEnv::new();
447 env.add_constraint(SymbolicShapeConstraint::GreaterThan(
448 SymbolicDim::fixed(10),
449 SymbolicDim::fixed(5),
450 ));
451 assert!(env.check_consistency());
452 }
453
454 #[test]
455 fn test_constraint_multiple() {
456 let mut env = SymbolicShapeEnv::new();
457 env.add_constraint(SymbolicShapeConstraint::Multiple(
458 SymbolicDim::fixed(12),
459 SymbolicDim::fixed(4),
460 ));
461 assert!(env.check_consistency());
462 }
463
464 #[test]
466 fn test_propagate_matmul_symbolic() -> Result<(), ShapeError> {
467 let mut env = SymbolicShapeEnv::new();
468 let shape_a = vec![SymbolicDim::symbolic("M"), SymbolicDim::symbolic("K")];
469 let shape_b = vec![SymbolicDim::symbolic("K"), SymbolicDim::symbolic("N")];
470 let out = propagate_einsum_shapes("ij,jk->ik", &[shape_a, shape_b], &mut env)?;
471 assert_eq!(out.len(), 2);
472 assert_eq!(format!("{}", out[0]), "M");
473 assert_eq!(format!("{}", out[1]), "N");
474 Ok(())
475 }
476
477 #[test]
478 fn test_propagate_matmul_fixed() -> Result<(), ShapeError> {
479 let mut env = SymbolicShapeEnv::new();
480 let shape_a = vec![SymbolicDim::fixed(4), SymbolicDim::fixed(3)];
481 let shape_b = vec![SymbolicDim::fixed(3), SymbolicDim::fixed(5)];
482 let out = propagate_einsum_shapes("ij,jk->ik", &[shape_a, shape_b], &mut env)?;
483 assert_eq!(out[0].concrete_value(), Some(4));
484 assert_eq!(out[1].concrete_value(), Some(5));
485 Ok(())
486 }
487
488 #[test]
489 fn test_propagate_contraction_unifies_k() -> Result<(), ShapeError> {
490 let mut env = SymbolicShapeEnv::new();
491 let shape_a = vec![SymbolicDim::symbolic("M"), SymbolicDim::symbolic("K")];
492 let shape_b = vec![SymbolicDim::symbolic("K"), SymbolicDim::fixed(5)];
493 propagate_einsum_shapes("ij,jk->ik", &[shape_a, shape_b], &mut env)?;
494 Ok(())
496 }
497
498 #[test]
499 fn test_propagate_inner_product() -> Result<(), ShapeError> {
500 let mut env = SymbolicShapeEnv::new();
501 let shape_a = vec![SymbolicDim::symbolic("N")];
502 let shape_b = vec![SymbolicDim::symbolic("N")];
503 let out = propagate_einsum_shapes("i,i->", &[shape_a, shape_b], &mut env)?;
504 assert_eq!(out.len(), 0); Ok(())
506 }
507
508 #[test]
509 fn test_propagate_batch_matmul() -> Result<(), ShapeError> {
510 let mut env = SymbolicShapeEnv::new();
511 let shape_a = vec![
512 SymbolicDim::symbolic("B"),
513 SymbolicDim::symbolic("M"),
514 SymbolicDim::symbolic("K"),
515 ];
516 let shape_b = vec![
517 SymbolicDim::symbolic("B"),
518 SymbolicDim::symbolic("K"),
519 SymbolicDim::symbolic("N"),
520 ];
521 let out = propagate_einsum_shapes("bij,bjk->bik", &[shape_a, shape_b], &mut env)?;
522 assert_eq!(out.len(), 3);
523 assert_eq!(format!("{}", out[0]), "B");
524 Ok(())
525 }
526
527 #[test]
528 fn test_propagate_arity_mismatch_error() {
529 let mut env = SymbolicShapeEnv::new();
530 let shape_a = vec![SymbolicDim::fixed(3), SymbolicDim::fixed(4)];
531 let result = propagate_einsum_shapes("ij,jk->ik", &[shape_a], &mut env);
533 assert!(matches!(result, Err(ShapeError::ArityMismatch { .. })));
534 }
535
536 #[test]
537 fn test_propagate_missing_arrow_error() {
538 let mut env = SymbolicShapeEnv::new();
539 let shape = vec![SymbolicDim::fixed(3)];
540 let result = propagate_einsum_shapes("i,j", &[shape.clone(), shape], &mut env);
541 assert!(matches!(result, Err(ShapeError::InvalidSpec(_))));
542 }
543}