Skip to main content

tensorlogic_infer/
symbolic_shape.rs

1//! Symbolic shape support for TensorLogic.
2//!
3//! Enables shape inference for graphs with unknown or dynamic dimensions.
4//! Uses a unification-based approach: symbolic names act as type variables
5//! that get resolved when unified with concrete sizes.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use thiserror::Error;
10
11/// A single tensor dimension — either a known size or a symbolic name.
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
13pub enum SymbolicDim {
14    /// A fixed, known size.
15    Fixed(usize),
16    /// A symbolic name (e.g., "batch", "seq_len", "N").
17    Symbolic(Arc<str>),
18    /// A product of two dimensions (e.g., batch * seq_len).
19    Product(Box<SymbolicDim>, Box<SymbolicDim>),
20}
21
22impl SymbolicDim {
23    pub fn fixed(n: usize) -> Self {
24        SymbolicDim::Fixed(n)
25    }
26
27    pub fn symbolic(name: impl Into<Arc<str>>) -> Self {
28        SymbolicDim::Symbolic(name.into())
29    }
30
31    pub fn product(a: SymbolicDim, b: SymbolicDim) -> Self {
32        SymbolicDim::Product(Box::new(a), Box::new(b))
33    }
34
35    pub fn is_fixed(&self) -> bool {
36        matches!(self, SymbolicDim::Fixed(_))
37    }
38
39    pub fn is_symbolic(&self) -> bool {
40        matches!(self, SymbolicDim::Symbolic(_))
41    }
42
43    /// If this dimension is fully resolved, return its concrete value.
44    pub fn concrete_value(&self) -> Option<usize> {
45        match self {
46            SymbolicDim::Fixed(n) => Some(*n),
47            SymbolicDim::Symbolic(_) => None,
48            SymbolicDim::Product(a, b) => {
49                let va = a.concrete_value()?;
50                let vb = b.concrete_value()?;
51                Some(va * vb)
52            }
53        }
54    }
55}
56
57impl std::fmt::Display for SymbolicDim {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        match self {
60            SymbolicDim::Fixed(n) => write!(f, "{}", n),
61            SymbolicDim::Symbolic(s) => write!(f, "{}", s),
62            SymbolicDim::Product(a, b) => write!(f, "({}*{})", a, b),
63        }
64    }
65}
66
67/// A tensor shape as a vector of symbolic dimensions.
68pub type SymbolicShape = Vec<SymbolicDim>;
69
70/// Constraints between symbolic dimensions for consistency checking.
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub enum SymbolicShapeConstraint {
73    /// Two dimensions must be equal.
74    Equal(SymbolicDim, SymbolicDim),
75    /// First dimension must be strictly greater than second.
76    GreaterThan(SymbolicDim, SymbolicDim),
77    /// First dimension must be a multiple of second.
78    Multiple(SymbolicDim, SymbolicDim),
79}
80
81/// Error types for symbolic shape operations.
82#[derive(Debug, Error)]
83pub enum ShapeError {
84    #[error("Dimension contradiction: cannot unify {0} with {1}")]
85    Contradiction(String, String),
86    #[error("Unresolved symbolic dimension: {0}")]
87    Unresolved(String),
88    #[error("Invalid einsum spec: {0}")]
89    InvalidSpec(String),
90    #[error("Arity mismatch: expected {expected} inputs, got {got}")]
91    ArityMismatch { expected: usize, got: usize },
92}
93
94/// Unification environment for symbolic shapes.
95///
96/// Maintains a substitution map from symbolic names to resolved dimensions.
97/// Uses union-find semantics: unifying a symbolic name binds it permanently.
98#[derive(Debug, Default)]
99pub struct SymbolicShapeEnv {
100    /// Map from symbolic name → resolved SymbolicDim
101    bindings: HashMap<Arc<str>, SymbolicDim>,
102    /// Registered constraints
103    constraints: Vec<SymbolicShapeConstraint>,
104}
105
106impl SymbolicShapeEnv {
107    pub fn new() -> Self {
108        Self::default()
109    }
110
111    /// Resolve a dimension through the substitution map.
112    /// Returns the most-resolved form available given current bindings.
113    pub fn resolve(&self, dim: &SymbolicDim) -> SymbolicDim {
114        match dim {
115            SymbolicDim::Symbolic(name) => {
116                if let Some(bound) = self.bindings.get(name) {
117                    self.resolve(bound)
118                } else {
119                    dim.clone()
120                }
121            }
122            SymbolicDim::Product(a, b) => SymbolicDim::product(self.resolve(a), self.resolve(b)),
123            SymbolicDim::Fixed(_) => dim.clone(),
124        }
125    }
126
127    /// Try to get a concrete value for a dimension.
128    pub fn concrete_value(&self, dim: &SymbolicDim) -> Option<usize> {
129        self.resolve(dim).concrete_value()
130    }
131
132    /// Unify two dimensions. If both are Fixed, they must be equal.
133    /// If one is Symbolic, it gets bound to the other.
134    pub fn unify(&mut self, a: &SymbolicDim, b: &SymbolicDim) -> Result<SymbolicDim, ShapeError> {
135        let ra = self.resolve(a);
136        let rb = self.resolve(b);
137        match (&ra, &rb) {
138            (SymbolicDim::Fixed(x), SymbolicDim::Fixed(y)) => {
139                if x == y {
140                    Ok(ra)
141                } else {
142                    Err(ShapeError::Contradiction(
143                        format!("{}", x),
144                        format!("{}", y),
145                    ))
146                }
147            }
148            (SymbolicDim::Symbolic(name_a), SymbolicDim::Symbolic(name_b)) => {
149                // Avoid self-binding: if both resolve to the same symbolic name, no-op
150                if name_a == name_b {
151                    Ok(ra)
152                } else {
153                    // Bind name_a → rb (pointing to name_b or its binding)
154                    self.bindings.insert(name_a.clone(), rb.clone());
155                    Ok(rb)
156                }
157            }
158            (SymbolicDim::Symbolic(name), _) => {
159                self.bindings.insert(name.clone(), rb.clone());
160                Ok(rb)
161            }
162            (_, SymbolicDim::Symbolic(name)) => {
163                self.bindings.insert(name.clone(), ra.clone());
164                Ok(ra)
165            }
166            // Product × Fixed: try to resolve
167            (SymbolicDim::Product(_, _), SymbolicDim::Fixed(_)) => {
168                if let Some(va) = ra.concrete_value() {
169                    if let Some(vb) = rb.concrete_value() {
170                        if va == vb {
171                            Ok(ra)
172                        } else {
173                            Err(ShapeError::Contradiction(
174                                format!("{}", va),
175                                format!("{}", vb),
176                            ))
177                        }
178                    } else {
179                        Ok(ra)
180                    }
181                } else {
182                    // Cannot resolve product yet, store as constraint
183                    self.add_constraint(SymbolicShapeConstraint::Equal(ra, rb));
184                    Ok(SymbolicDim::symbolic("_unresolved"))
185                }
186            }
187            (SymbolicDim::Fixed(_), SymbolicDim::Product(_, _)) => self.unify(b, a),
188            _ => Ok(ra),
189        }
190    }
191
192    /// Register a shape constraint for later consistency checking.
193    pub fn add_constraint(&mut self, c: SymbolicShapeConstraint) {
194        self.constraints.push(c);
195    }
196
197    /// Check that all registered constraints are satisfiable given current bindings.
198    pub fn check_consistency(&self) -> bool {
199        for c in &self.constraints {
200            match c {
201                SymbolicShapeConstraint::Equal(a, b) => {
202                    if let (Some(va), Some(vb)) = (self.concrete_value(a), self.concrete_value(b)) {
203                        if va != vb {
204                            return false;
205                        }
206                    }
207                }
208                SymbolicShapeConstraint::GreaterThan(a, b) => {
209                    if let (Some(va), Some(vb)) = (self.concrete_value(a), self.concrete_value(b)) {
210                        if va <= vb {
211                            return false;
212                        }
213                    }
214                }
215                SymbolicShapeConstraint::Multiple(a, b) => {
216                    if let (Some(va), Some(vb)) = (self.concrete_value(a), self.concrete_value(b)) {
217                        if vb == 0 || va % vb != 0 {
218                            return false;
219                        }
220                    }
221                }
222            }
223        }
224        true
225    }
226
227    /// Number of bindings currently in the environment.
228    pub fn binding_count(&self) -> usize {
229        self.bindings.len()
230    }
231
232    /// All symbolic names currently bound.
233    pub fn bound_names(&self) -> impl Iterator<Item = &Arc<str>> {
234        self.bindings.keys()
235    }
236}
237
238/// Infer the output shape of an einsum operation given symbolic input shapes.
239///
240/// Uses the einsum spec notation: `"ij,jk->ik"` with inputs `[["M","K"],["K","N"]]`
241/// produces `["M","N"]`.
242///
243/// # Rules
244/// - Each index character maps to a SymbolicDim from its position in the corresponding input
245/// - Shared indices that appear in multiple inputs are unified (must be equal)
246/// - Output indices are collected in spec output order
247pub fn propagate_einsum_shapes(
248    spec: &str,
249    input_shapes: &[SymbolicShape],
250    env: &mut SymbolicShapeEnv,
251) -> Result<SymbolicShape, ShapeError> {
252    // Parse spec: "ij,jk->ik"
253    let arrow_pos = spec
254        .find("->")
255        .ok_or_else(|| ShapeError::InvalidSpec(format!("missing '->' in einsum spec: {}", spec)))?;
256    let inputs_part = &spec[..arrow_pos];
257    let output_part = &spec[arrow_pos + 2..];
258
259    let operand_specs: Vec<&str> = inputs_part.split(',').collect();
260    if operand_specs.len() != input_shapes.len() {
261        return Err(ShapeError::ArityMismatch {
262            expected: operand_specs.len(),
263            got: input_shapes.len(),
264        });
265    }
266
267    // Build index → SymbolicDim map
268    let mut index_map: HashMap<char, SymbolicDim> = HashMap::new();
269
270    for (op_spec, shape) in operand_specs.iter().zip(input_shapes.iter()) {
271        let chars: Vec<char> = op_spec.chars().filter(|c| c.is_alphabetic()).collect();
272        if chars.len() != shape.len() {
273            return Err(ShapeError::InvalidSpec(format!(
274                "spec '{}' has {} indices but shape has {} dims",
275                op_spec,
276                chars.len(),
277                shape.len()
278            )));
279        }
280        for (ch, dim) in chars.iter().zip(shape.iter()) {
281            if let Some(existing) = index_map.get(ch) {
282                // Unify existing with new
283                let unified = env.unify(existing, dim)?;
284                index_map.insert(*ch, unified);
285            } else {
286                index_map.insert(*ch, env.resolve(dim));
287            }
288        }
289    }
290
291    // Collect output shape
292    let output_chars: Vec<char> = output_part.chars().filter(|c| c.is_alphabetic()).collect();
293    let mut out_shape = Vec::with_capacity(output_chars.len());
294    for ch in output_chars {
295        let dim = index_map
296            .get(&ch)
297            .cloned()
298            .unwrap_or_else(|| SymbolicDim::symbolic(format!("_out_{}", ch)));
299        out_shape.push(env.resolve(&dim));
300    }
301
302    Ok(out_shape)
303}
304
305/// Convenience: propagate shapes through a chain of einsum operations.
306pub fn propagate_chain(
307    specs: &[&str],
308    initial_shapes: &[SymbolicShape],
309    env: &mut SymbolicShapeEnv,
310) -> Result<Vec<SymbolicShape>, ShapeError> {
311    let mut results = Vec::new();
312    let mut current_shapes: Vec<SymbolicShape> = initial_shapes.to_vec();
313    for spec in specs {
314        let out = propagate_einsum_shapes(spec, &current_shapes, env)?;
315        results.push(out.clone());
316        // For chains, the output becomes the first input of next (simplified)
317        current_shapes = vec![out];
318    }
319    Ok(results)
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    // SymbolicDim tests
327    #[test]
328    fn test_fixed_dim_equality() {
329        assert_eq!(SymbolicDim::fixed(3), SymbolicDim::fixed(3));
330        assert_ne!(SymbolicDim::fixed(3), SymbolicDim::fixed(4));
331    }
332
333    #[test]
334    fn test_symbolic_dim_display() {
335        let d = SymbolicDim::symbolic("batch");
336        assert_eq!(format!("{}", d), "batch");
337    }
338
339    #[test]
340    fn test_fixed_dim_concrete_value() {
341        assert_eq!(SymbolicDim::fixed(42).concrete_value(), Some(42));
342    }
343
344    #[test]
345    fn test_symbolic_dim_no_concrete_value() {
346        assert_eq!(SymbolicDim::symbolic("N").concrete_value(), None);
347    }
348
349    #[test]
350    fn test_product_dim_resolves_when_both_fixed() {
351        let p = SymbolicDim::product(SymbolicDim::fixed(3), SymbolicDim::fixed(4));
352        assert_eq!(p.concrete_value(), Some(12));
353    }
354
355    #[test]
356    fn test_product_dim_unresolved_when_symbolic() {
357        let p = SymbolicDim::product(SymbolicDim::symbolic("N"), SymbolicDim::fixed(4));
358        assert_eq!(p.concrete_value(), None);
359    }
360
361    // SymbolicShapeEnv unification tests
362    #[test]
363    fn test_unify_fixed_same() -> Result<(), ShapeError> {
364        let mut env = SymbolicShapeEnv::new();
365        let result = env.unify(&SymbolicDim::fixed(5), &SymbolicDim::fixed(5))?;
366        assert_eq!(result, SymbolicDim::fixed(5));
367        Ok(())
368    }
369
370    #[test]
371    fn test_unify_fixed_contradiction() {
372        let mut env = SymbolicShapeEnv::new();
373        let result = env.unify(&SymbolicDim::fixed(3), &SymbolicDim::fixed(4));
374        assert!(result.is_err());
375    }
376
377    #[test]
378    fn test_unify_symbolic_binds_to_fixed() -> Result<(), ShapeError> {
379        let mut env = SymbolicShapeEnv::new();
380        env.unify(&SymbolicDim::symbolic("N"), &SymbolicDim::fixed(7))?;
381        assert_eq!(env.concrete_value(&SymbolicDim::symbolic("N")), Some(7));
382        Ok(())
383    }
384
385    #[test]
386    fn test_unify_fixed_binds_symbolic() -> Result<(), ShapeError> {
387        let mut env = SymbolicShapeEnv::new();
388        env.unify(&SymbolicDim::fixed(4), &SymbolicDim::symbolic("M"))?;
389        assert_eq!(env.concrete_value(&SymbolicDim::symbolic("M")), Some(4));
390        Ok(())
391    }
392
393    #[test]
394    fn test_unify_two_symbolics() -> Result<(), ShapeError> {
395        let mut env = SymbolicShapeEnv::new();
396        env.unify(&SymbolicDim::symbolic("A"), &SymbolicDim::symbolic("B"))?;
397        // Both should now resolve to the same thing
398        let ra = env.resolve(&SymbolicDim::symbolic("A"));
399        let rb = env.resolve(&SymbolicDim::symbolic("B"));
400        // They should both resolve to the same concrete or both remain symbolic
401        assert_eq!(ra, rb);
402        Ok(())
403    }
404
405    #[test]
406    fn test_resolve_chain() -> Result<(), ShapeError> {
407        let mut env = SymbolicShapeEnv::new();
408        env.unify(&SymbolicDim::symbolic("A"), &SymbolicDim::symbolic("B"))?;
409        env.unify(&SymbolicDim::symbolic("B"), &SymbolicDim::fixed(10))?;
410        assert_eq!(env.concrete_value(&SymbolicDim::symbolic("A")), Some(10));
411        Ok(())
412    }
413
414    #[test]
415    fn test_binding_count() -> Result<(), ShapeError> {
416        let mut env = SymbolicShapeEnv::new();
417        assert_eq!(env.binding_count(), 0);
418        env.unify(&SymbolicDim::symbolic("N"), &SymbolicDim::fixed(5))?;
419        assert_eq!(env.binding_count(), 1);
420        Ok(())
421    }
422
423    // SymbolicShapeConstraint tests
424    #[test]
425    fn test_constraint_consistency_equal() {
426        let mut env = SymbolicShapeEnv::new();
427        env.add_constraint(SymbolicShapeConstraint::Equal(
428            SymbolicDim::fixed(3),
429            SymbolicDim::fixed(3),
430        ));
431        assert!(env.check_consistency());
432    }
433
434    #[test]
435    fn test_constraint_inconsistency_equal() {
436        let mut env = SymbolicShapeEnv::new();
437        env.add_constraint(SymbolicShapeConstraint::Equal(
438            SymbolicDim::fixed(3),
439            SymbolicDim::fixed(5),
440        ));
441        assert!(!env.check_consistency());
442    }
443
444    #[test]
445    fn test_constraint_greater_than() {
446        let mut env = SymbolicShapeEnv::new();
447        env.add_constraint(SymbolicShapeConstraint::GreaterThan(
448            SymbolicDim::fixed(10),
449            SymbolicDim::fixed(5),
450        ));
451        assert!(env.check_consistency());
452    }
453
454    #[test]
455    fn test_constraint_multiple() {
456        let mut env = SymbolicShapeEnv::new();
457        env.add_constraint(SymbolicShapeConstraint::Multiple(
458            SymbolicDim::fixed(12),
459            SymbolicDim::fixed(4),
460        ));
461        assert!(env.check_consistency());
462    }
463
464    // Einsum propagation tests
465    #[test]
466    fn test_propagate_matmul_symbolic() -> Result<(), ShapeError> {
467        let mut env = SymbolicShapeEnv::new();
468        let shape_a = vec![SymbolicDim::symbolic("M"), SymbolicDim::symbolic("K")];
469        let shape_b = vec![SymbolicDim::symbolic("K"), SymbolicDim::symbolic("N")];
470        let out = propagate_einsum_shapes("ij,jk->ik", &[shape_a, shape_b], &mut env)?;
471        assert_eq!(out.len(), 2);
472        assert_eq!(format!("{}", out[0]), "M");
473        assert_eq!(format!("{}", out[1]), "N");
474        Ok(())
475    }
476
477    #[test]
478    fn test_propagate_matmul_fixed() -> Result<(), ShapeError> {
479        let mut env = SymbolicShapeEnv::new();
480        let shape_a = vec![SymbolicDim::fixed(4), SymbolicDim::fixed(3)];
481        let shape_b = vec![SymbolicDim::fixed(3), SymbolicDim::fixed(5)];
482        let out = propagate_einsum_shapes("ij,jk->ik", &[shape_a, shape_b], &mut env)?;
483        assert_eq!(out[0].concrete_value(), Some(4));
484        assert_eq!(out[1].concrete_value(), Some(5));
485        Ok(())
486    }
487
488    #[test]
489    fn test_propagate_contraction_unifies_k() -> Result<(), ShapeError> {
490        let mut env = SymbolicShapeEnv::new();
491        let shape_a = vec![SymbolicDim::symbolic("M"), SymbolicDim::symbolic("K")];
492        let shape_b = vec![SymbolicDim::symbolic("K"), SymbolicDim::fixed(5)];
493        propagate_einsum_shapes("ij,jk->ik", &[shape_a, shape_b], &mut env)?;
494        // K should be unbound (it's contracted, not in output)
495        Ok(())
496    }
497
498    #[test]
499    fn test_propagate_inner_product() -> Result<(), ShapeError> {
500        let mut env = SymbolicShapeEnv::new();
501        let shape_a = vec![SymbolicDim::symbolic("N")];
502        let shape_b = vec![SymbolicDim::symbolic("N")];
503        let out = propagate_einsum_shapes("i,i->", &[shape_a, shape_b], &mut env)?;
504        assert_eq!(out.len(), 0); // scalar output
505        Ok(())
506    }
507
508    #[test]
509    fn test_propagate_batch_matmul() -> Result<(), ShapeError> {
510        let mut env = SymbolicShapeEnv::new();
511        let shape_a = vec![
512            SymbolicDim::symbolic("B"),
513            SymbolicDim::symbolic("M"),
514            SymbolicDim::symbolic("K"),
515        ];
516        let shape_b = vec![
517            SymbolicDim::symbolic("B"),
518            SymbolicDim::symbolic("K"),
519            SymbolicDim::symbolic("N"),
520        ];
521        let out = propagate_einsum_shapes("bij,bjk->bik", &[shape_a, shape_b], &mut env)?;
522        assert_eq!(out.len(), 3);
523        assert_eq!(format!("{}", out[0]), "B");
524        Ok(())
525    }
526
527    #[test]
528    fn test_propagate_arity_mismatch_error() {
529        let mut env = SymbolicShapeEnv::new();
530        let shape_a = vec![SymbolicDim::fixed(3), SymbolicDim::fixed(4)];
531        // Spec expects 2 inputs but we provide only 1
532        let result = propagate_einsum_shapes("ij,jk->ik", &[shape_a], &mut env);
533        assert!(matches!(result, Err(ShapeError::ArityMismatch { .. })));
534    }
535
536    #[test]
537    fn test_propagate_missing_arrow_error() {
538        let mut env = SymbolicShapeEnv::new();
539        let shape = vec![SymbolicDim::fixed(3)];
540        let result = propagate_einsum_shapes("i,j", &[shape.clone(), shape], &mut env);
541        assert!(matches!(result, Err(ShapeError::InvalidSpec(_))));
542    }
543}