Skip to main content

shrew_core/
dynamic_shape.rs

1// Dynamic Shapes — Symbolic shape tracking and runtime shape resolution
2//
3// Deep learning models often have dimensions that vary at runtime:
4//   - Batch size: different in training vs inference, or varies each step
5//   - Sequence length: varies per input in NLP models
6//   - Image size: may differ with data augmentation
7//
8// This module provides a bridge between:
9//   - The IR's symbolic shape system (Dim::Symbolic, Dim::Dynamic)
10//   - The runtime's concrete shape system (Shape = Vec<usize>)
11//
12// COMPONENTS:
13//
14//   SymDim            — A dimension that can be fixed, symbolic, or dynamic
15//   SymbolicShape     — A shape pattern with mixed fixed/symbolic/dynamic dims
16//   ShapeEnv          — Environment mapping symbolic names → concrete values
17//   ShapeGuard        — Validates concrete tensors against shape patterns
18//
19// WORKFLOW:
20//
21//   1. Define model shapes with symbolic dims: [Batch, SeqLen, 768]
22//   2. At runtime, bind symbolic dims: { Batch → 32, SeqLen → 128 }
23//   3. Resolve symbolic shapes to concrete: [32, 128, 768]
24//   4. Or validate: does this concrete tensor match the expected pattern?
25//
26// EXAMPLES:
27//
28//   let pattern = SymbolicShape::new(vec![
29//       SymDim::Symbolic("Batch".into()),
30//       SymDim::Fixed(784),
31//   ]);
32//   let mut env = ShapeEnv::new();
33//   env.bind("Batch", 32);
34//   let concrete = pattern.resolve(&env)?; // Shape([32, 784])
35
36use std::collections::HashMap;
37use std::fmt;
38
39use crate::error::{Error, Result};
40use crate::shape::Shape;
41
42// SymDim — A single dimension that can be fixed, named, or dynamic
43
44/// A dimension that can be concrete, symbolic, or fully dynamic.
45///
46/// This is the runtime-level counterpart to the IR's `Dim` enum, but
47/// designed for use with actual tensor operations and shape validation.
48#[derive(Debug, Clone, PartialEq, Eq, Hash)]
49pub enum SymDim {
50    /// Known at compile time: 768, 50257, etc.
51    Fixed(usize),
52    /// Named symbolic dimension: "Batch", "SeqLen", "HiddenDim"
53    /// Resolved to a concrete value at runtime via ShapeEnv.
54    Symbolic(String),
55    /// Fully dynamic — matches any concrete value.
56    /// Used when a dimension is truly unknown until runtime.
57    Dynamic,
58}
59
60impl SymDim {
61    /// Create a fixed dimension.
62    pub fn fixed(n: usize) -> Self {
63        SymDim::Fixed(n)
64    }
65
66    /// Create a named symbolic dimension.
67    pub fn symbolic(name: impl Into<String>) -> Self {
68        SymDim::Symbolic(name.into())
69    }
70
71    /// Create a dynamic (wildcard) dimension.
72    pub fn dynamic() -> Self {
73        SymDim::Dynamic
74    }
75
76    /// Is this a concrete (fixed) dimension?
77    pub fn is_fixed(&self) -> bool {
78        matches!(self, SymDim::Fixed(_))
79    }
80
81    /// Is this a symbolic (named) dimension?
82    pub fn is_symbolic(&self) -> bool {
83        matches!(self, SymDim::Symbolic(_))
84    }
85
86    /// Is this fully dynamic?
87    pub fn is_dynamic(&self) -> bool {
88        matches!(self, SymDim::Dynamic)
89    }
90
91    /// Try to resolve this dimension to a concrete value.
92    ///
93    /// - Fixed: returns the value directly
94    /// - Symbolic: looks up the name in the environment
95    /// - Dynamic: returns None (cannot resolve without a concrete value)
96    pub fn resolve(&self, env: &ShapeEnv) -> Option<usize> {
97        match self {
98            SymDim::Fixed(n) => Some(*n),
99            SymDim::Symbolic(name) => env.get(name),
100            SymDim::Dynamic => None,
101        }
102    }
103
104    /// Check if a concrete value matches this dimension pattern.
105    ///
106    /// - Fixed(n): value must equal n
107    /// - Symbolic: checks env if bound, otherwise matches any value
108    /// - Dynamic: matches any value
109    pub fn matches(&self, value: usize, env: &ShapeEnv) -> bool {
110        match self {
111            SymDim::Fixed(n) => value == *n,
112            SymDim::Symbolic(name) => {
113                if let Some(bound) = env.get(name) {
114                    value == bound
115                } else {
116                    true // unbound symbolic matches anything
117                }
118            }
119            SymDim::Dynamic => true,
120        }
121    }
122
123    /// Try to unify this dimension with a concrete value, potentially
124    /// binding a symbolic name in the environment.
125    ///
126    /// Returns true if unification succeeds.
127    pub fn unify(&self, value: usize, env: &mut ShapeEnv) -> bool {
128        match self {
129            SymDim::Fixed(n) => value == *n,
130            SymDim::Symbolic(name) => {
131                if let Some(bound) = env.get(name) {
132                    value == bound
133                } else {
134                    env.bind(name, value);
135                    true
136                }
137            }
138            SymDim::Dynamic => true,
139        }
140    }
141}
142
143impl fmt::Display for SymDim {
144    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145        match self {
146            SymDim::Fixed(n) => write!(f, "{n}"),
147            SymDim::Symbolic(s) => write!(f, "{s}"),
148            SymDim::Dynamic => write!(f, "?"),
149        }
150    }
151}
152
153impl From<usize> for SymDim {
154    fn from(n: usize) -> Self {
155        SymDim::Fixed(n)
156    }
157}
158
159impl From<&str> for SymDim {
160    fn from(s: &str) -> Self {
161        SymDim::Symbolic(s.to_string())
162    }
163}
164
165// SymbolicShape — A shape pattern with mixed fixed/symbolic/dynamic dims
166
167/// A shape pattern that can contain fixed, symbolic, and dynamic dimensions.
168///
169/// Think of it as a "shape template" that can be resolved to a concrete
170/// `Shape` once all symbolic dimensions are bound.
171///
172/// # Examples
173/// ```ignore
174/// // Define a shape pattern for transformer input
175/// let pattern = SymbolicShape::from(vec![
176///     SymDim::symbolic("Batch"),
177///     SymDim::symbolic("SeqLen"),
178///     SymDim::fixed(768),
179/// ]);
180///
181/// // Resolve with concrete bindings
182/// let mut env = ShapeEnv::new();
183/// env.bind("Batch", 32);
184/// env.bind("SeqLen", 128);
185/// let concrete = pattern.resolve(&env)?; // Shape([32, 128, 768])
186///
187/// // Or validate a concrete tensor shape
188/// let guard = ShapeGuard::new(pattern);
189/// assert!(guard.validate_shape(&Shape::new(vec![32, 128, 768]), &env));
190/// ```
191#[derive(Debug, Clone, PartialEq, Eq)]
192pub struct SymbolicShape {
193    dims: Vec<SymDim>,
194}
195
196impl SymbolicShape {
197    /// Create a new symbolic shape from a vector of SymDim.
198    pub fn new(dims: Vec<SymDim>) -> Self {
199        Self { dims }
200    }
201
202    /// Create a fully-fixed symbolic shape from a concrete shape.
203    pub fn from_shape(shape: &Shape) -> Self {
204        Self {
205            dims: shape.dims().iter().map(|&d| SymDim::Fixed(d)).collect(),
206        }
207    }
208
209    /// Number of dimensions.
210    pub fn rank(&self) -> usize {
211        self.dims.len()
212    }
213
214    /// Get the dimension patterns.
215    pub fn dims(&self) -> &[SymDim] {
216        &self.dims
217    }
218
219    /// Check if all dimensions are fixed (fully concrete).
220    pub fn is_concrete(&self) -> bool {
221        self.dims.iter().all(|d| d.is_fixed())
222    }
223
224    /// Check if any dimension is symbolic or dynamic.
225    pub fn has_symbolic(&self) -> bool {
226        self.dims.iter().any(|d| !d.is_fixed())
227    }
228
229    /// Get all symbolic dimension names used in this shape.
230    pub fn symbolic_names(&self) -> Vec<&str> {
231        self.dims
232            .iter()
233            .filter_map(|d| match d {
234                SymDim::Symbolic(name) => Some(name.as_str()),
235                _ => None,
236            })
237            .collect()
238    }
239
240    /// Try to resolve this symbolic shape to a concrete Shape.
241    ///
242    /// Returns an error if any symbolic dimension is not bound in the env,
243    /// or if any dimension is Dynamic (cannot resolve without a value).
244    pub fn resolve(&self, env: &ShapeEnv) -> Result<Shape> {
245        let mut concrete = Vec::with_capacity(self.dims.len());
246        for (i, dim) in self.dims.iter().enumerate() {
247            match dim.resolve(env) {
248                Some(n) => concrete.push(n),
249                None => {
250                    return Err(Error::msg(format!(
251                        "cannot resolve dimension {} ({}) — not bound in environment",
252                        i, dim
253                    )));
254                }
255            }
256        }
257        Ok(Shape::new(concrete))
258    }
259
260    /// Try to resolve, falling back to a default for unresolved dims.
261    /// Dynamic/unbound symbolic dims use the provided default value.
262    pub fn resolve_with_default(&self, env: &ShapeEnv, default: usize) -> Shape {
263        let concrete: Vec<usize> = self
264            .dims
265            .iter()
266            .map(|d| d.resolve(env).unwrap_or(default))
267            .collect();
268        Shape::new(concrete)
269    }
270
271    /// Check if a concrete shape matches this pattern.
272    ///
273    /// Returns true if the shapes have the same rank and each dimension
274    /// matches (fixed dims must be equal, symbolic dims must match their
275    /// binding if bound, dynamic dims match anything).
276    pub fn matches(&self, shape: &Shape, env: &ShapeEnv) -> bool {
277        if self.rank() != shape.rank() {
278            return false;
279        }
280        self.dims
281            .iter()
282            .zip(shape.dims().iter())
283            .all(|(pattern, &value)| pattern.matches(value, env))
284    }
285
286    /// Unify this pattern with a concrete shape, binding symbolic dims.
287    ///
288    /// Returns true if unification succeeded (all dims compatible and
289    /// symbolic bindings are consistent). On success, newly discovered
290    /// bindings are added to `env`.
291    pub fn unify(&self, shape: &Shape, env: &mut ShapeEnv) -> bool {
292        if self.rank() != shape.rank() {
293            return false;
294        }
295        // First pass: check consistency without modifying env
296        let mut new_bindings = Vec::new();
297        for (pattern, &value) in self.dims.iter().zip(shape.dims().iter()) {
298            match pattern {
299                SymDim::Fixed(n) => {
300                    if value != *n {
301                        return false;
302                    }
303                }
304                SymDim::Symbolic(name) => {
305                    if let Some(bound) = env.get(name) {
306                        if value != bound {
307                            return false;
308                        }
309                    } else {
310                        // Check consistency with other new bindings
311                        if let Some(&prev) =
312                            new_bindings.iter().find_map(|(n, v): &(&str, usize)| {
313                                if *n == name.as_str() {
314                                    Some(v)
315                                } else {
316                                    None
317                                }
318                            })
319                        {
320                            if value != prev {
321                                return false;
322                            }
323                        } else {
324                            new_bindings.push((name.as_str(), value));
325                        }
326                    }
327                }
328                SymDim::Dynamic => {} // always matches
329            }
330        }
331        // Apply new bindings
332        for (name, value) in new_bindings {
333            env.bind(name, value);
334        }
335        true
336    }
337
338    /// Compute the output shape of a broadcasting operation between two
339    /// symbolic shapes. Returns None if shapes are incompatible.
340    pub fn broadcast(&self, other: &SymbolicShape) -> Option<SymbolicShape> {
341        let rank = self.rank().max(other.rank());
342        let mut result = Vec::with_capacity(rank);
343
344        for i in 0..rank {
345            let a = if i < rank - self.rank() {
346                &SymDim::Fixed(1)
347            } else {
348                &self.dims[i - (rank - self.rank())]
349            };
350            let b = if i < rank - other.rank() {
351                &SymDim::Fixed(1)
352            } else {
353                &other.dims[i - (rank - other.rank())]
354            };
355
356            match (a, b) {
357                (SymDim::Fixed(1), _) => result.push(b.clone()),
358                (_, SymDim::Fixed(1)) => result.push(a.clone()),
359                (SymDim::Fixed(x), SymDim::Fixed(y)) if x == y => result.push(a.clone()),
360                (SymDim::Fixed(_), SymDim::Fixed(_)) => return None, // incompatible
361                (SymDim::Symbolic(s), SymDim::Symbolic(t)) if s == t => result.push(a.clone()),
362                (SymDim::Dynamic, _) | (_, SymDim::Dynamic) => result.push(SymDim::Dynamic),
363                (SymDim::Symbolic(_), _) | (_, SymDim::Symbolic(_)) => result.push(SymDim::Dynamic),
364            }
365        }
366        Some(SymbolicShape::new(result))
367    }
368}
369
370impl fmt::Display for SymbolicShape {
371    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
372        write!(f, "[")?;
373        for (i, d) in self.dims.iter().enumerate() {
374            if i > 0 {
375                write!(f, ", ")?;
376            }
377            write!(f, "{d}")?;
378        }
379        write!(f, "]")
380    }
381}
382
383impl From<Vec<SymDim>> for SymbolicShape {
384    fn from(dims: Vec<SymDim>) -> Self {
385        Self::new(dims)
386    }
387}
388
389impl From<Shape> for SymbolicShape {
390    fn from(shape: Shape) -> Self {
391        Self::from_shape(&shape)
392    }
393}
394
395// ShapeEnv — Environment for symbolic dimension bindings
396
397/// Maps symbolic dimension names to concrete values.
398///
399/// Used during shape resolution to convert symbolic shapes to concrete ones.
400///
401/// # Examples
402/// ```ignore
403/// let mut env = ShapeEnv::new();
404/// env.bind("Batch", 32);
405/// env.bind("SeqLen", 128);
406/// assert_eq!(env.get("Batch"), Some(32));
407/// ```
408#[derive(Debug, Clone)]
409pub struct ShapeEnv {
410    bindings: HashMap<String, usize>,
411}
412
413impl ShapeEnv {
414    /// Create an empty shape environment.
415    pub fn new() -> Self {
416        Self {
417            bindings: HashMap::new(),
418        }
419    }
420
421    /// Bind a symbolic name to a concrete value.
422    pub fn bind(&mut self, name: impl Into<String>, value: usize) {
423        self.bindings.insert(name.into(), value);
424    }
425
426    /// Look up a symbolic name.
427    pub fn get(&self, name: &str) -> Option<usize> {
428        self.bindings.get(name).copied()
429    }
430
431    /// Check if a name is bound.
432    pub fn is_bound(&self, name: &str) -> bool {
433        self.bindings.contains_key(name)
434    }
435
436    /// Get all bindings.
437    pub fn bindings(&self) -> &HashMap<String, usize> {
438        &self.bindings
439    }
440
441    /// Number of bindings.
442    pub fn len(&self) -> usize {
443        self.bindings.len()
444    }
445
446    /// Is the environment empty?
447    pub fn is_empty(&self) -> bool {
448        self.bindings.is_empty()
449    }
450
451    /// Merge another environment into this one.
452    /// Returns an error if there are conflicting bindings.
453    pub fn merge(&mut self, other: &ShapeEnv) -> Result<()> {
454        for (name, &value) in &other.bindings {
455            if let Some(&existing) = self.bindings.get(name) {
456                if existing != value {
457                    return Err(Error::msg(format!(
458                        "conflicting binding for '{}': {} vs {}",
459                        name, existing, value
460                    )));
461                }
462            } else {
463                self.bindings.insert(name.clone(), value);
464            }
465        }
466        Ok(())
467    }
468}
469
470impl Default for ShapeEnv {
471    fn default() -> Self {
472        Self::new()
473    }
474}
475
476impl From<&[(&str, usize)]> for ShapeEnv {
477    fn from(bindings: &[(&str, usize)]) -> Self {
478        let mut env = ShapeEnv::new();
479        for &(name, value) in bindings {
480            env.bind(name, value);
481        }
482        env
483    }
484}
485
486// ShapeGuard — Runtime shape validation against patterns
487
488/// Validates concrete tensor shapes against symbolic patterns.
489///
490/// Useful for checking model inputs at runtime, ensuring shapes are
491/// compatible before attempting operations that would fail.
492///
493/// # Examples
494/// ```ignore
495/// let guard = ShapeGuard::new("input")
496///     .expect(SymbolicShape::from(vec![
497///         SymDim::symbolic("Batch"),
498///         SymDim::fixed(784),
499///     ]));
500///
501/// // This shape is valid (batch=32, features=784)
502/// guard.validate(&Shape::new(vec![32, 784]), &env)?;
503///
504/// // This would error (wrong feature dim)
505/// guard.validate(&Shape::new(vec![32, 100]), &env); // Error!
506/// ```
507#[derive(Debug, Clone)]
508pub struct ShapeGuard {
509    /// Name of the tensor being guarded (for error messages).
510    name: String,
511    /// Expected shape pattern.
512    pattern: SymbolicShape,
513}
514
515impl ShapeGuard {
516    /// Create a shape guard with a name and expected pattern.
517    pub fn new(name: impl Into<String>, pattern: SymbolicShape) -> Self {
518        Self {
519            name: name.into(),
520            pattern,
521        }
522    }
523
524    /// Validate a concrete shape against the expected pattern.
525    ///
526    /// Returns Ok(()) if the shape matches, or an error describing
527    /// the mismatch.
528    pub fn validate(&self, shape: &Shape, env: &ShapeEnv) -> Result<()> {
529        if self.pattern.rank() != shape.rank() {
530            return Err(Error::msg(format!(
531                "shape mismatch for '{}': expected rank {} ({}), got rank {} ({:?})",
532                self.name,
533                self.pattern.rank(),
534                self.pattern,
535                shape.rank(),
536                shape.dims()
537            )));
538        }
539        for (i, (expected, &actual)) in self
540            .pattern
541            .dims()
542            .iter()
543            .zip(shape.dims().iter())
544            .enumerate()
545        {
546            if !expected.matches(actual, env) {
547                return Err(Error::msg(format!(
548                    "shape mismatch for '{}' at dim {}: expected {}, got {}",
549                    self.name, i, expected, actual
550                )));
551            }
552        }
553        Ok(())
554    }
555
556    /// Validate and bind — simultaneously validates the shape and binds
557    /// any unbound symbolic dimensions. Returns Ok if consistent.
558    pub fn validate_and_bind(&self, shape: &Shape, env: &mut ShapeEnv) -> Result<()> {
559        if !self.pattern.unify(shape, env) {
560            Err(Error::msg(format!(
561                "shape mismatch for '{}': expected {}, got {:?}",
562                self.name,
563                self.pattern,
564                shape.dims()
565            )))
566        } else {
567            Ok(())
568        }
569    }
570}
571
572// Tests
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577    use crate::shape::Shape;
578
579    // ── SymDim ──
580
581    #[test]
582    fn test_symdim_fixed() {
583        let d = SymDim::fixed(32);
584        assert!(d.is_fixed());
585        assert!(!d.is_symbolic());
586        assert!(!d.is_dynamic());
587        assert_eq!(d.resolve(&ShapeEnv::new()), Some(32));
588        assert_eq!(format!("{d}"), "32");
589    }
590
591    #[test]
592    fn test_symdim_symbolic_bound() {
593        let d = SymDim::symbolic("Batch");
594        let mut env = ShapeEnv::new();
595        env.bind("Batch", 64);
596
597        assert!(d.is_symbolic());
598        assert_eq!(d.resolve(&env), Some(64));
599        assert!(d.matches(64, &env));
600        assert!(!d.matches(32, &env));
601    }
602
603    #[test]
604    fn test_symdim_symbolic_unbound() {
605        let d = SymDim::symbolic("SeqLen");
606        let env = ShapeEnv::new();
607
608        assert_eq!(d.resolve(&env), None);
609        assert!(d.matches(100, &env)); // unbound matches anything
610        assert!(d.matches(200, &env));
611    }
612
613    #[test]
614    fn test_symdim_dynamic() {
615        let d = SymDim::dynamic();
616        let env = ShapeEnv::new();
617        assert!(d.is_dynamic());
618        assert_eq!(d.resolve(&env), None);
619        assert!(d.matches(999, &env));
620    }
621
622    #[test]
623    fn test_symdim_unify() {
624        let d = SymDim::symbolic("N");
625        let mut env = ShapeEnv::new();
626        assert!(d.unify(42, &mut env));
627        assert_eq!(env.get("N"), Some(42));
628        // Same value: unify again succeeds
629        assert!(d.unify(42, &mut env));
630        // Different value: unify fails
631        assert!(!d.unify(99, &mut env));
632    }
633
634    #[test]
635    fn test_symdim_from() {
636        let d: SymDim = 32usize.into();
637        assert_eq!(d, SymDim::Fixed(32));
638        let d: SymDim = "Batch".into();
639        assert_eq!(d, SymDim::Symbolic("Batch".to_string()));
640    }
641
642    // ── SymbolicShape ──
643
644    #[test]
645    fn test_symbolic_shape_basic() {
646        let s = SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(784)]);
647        assert_eq!(s.rank(), 2);
648        assert!(!s.is_concrete());
649        assert!(s.has_symbolic());
650        assert_eq!(s.symbolic_names(), vec!["Batch"]);
651        assert_eq!(format!("{s}"), "[Batch, 784]");
652    }
653
654    #[test]
655    fn test_symbolic_shape_resolve() {
656        let s = SymbolicShape::new(vec![
657            SymDim::symbolic("Batch"),
658            SymDim::symbolic("SeqLen"),
659            SymDim::fixed(768),
660        ]);
661        let mut env = ShapeEnv::new();
662        env.bind("Batch", 32);
663        env.bind("SeqLen", 128);
664
665        let concrete = s.resolve(&env).unwrap();
666        assert_eq!(concrete.dims(), &[32, 128, 768]);
667    }
668
669    #[test]
670    fn test_symbolic_shape_resolve_fails_unbound() {
671        let s = SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(784)]);
672        let env = ShapeEnv::new();
673        assert!(s.resolve(&env).is_err());
674    }
675
676    #[test]
677    fn test_symbolic_shape_resolve_dynamic_fails() {
678        let s = SymbolicShape::new(vec![SymDim::dynamic(), SymDim::fixed(784)]);
679        let env = ShapeEnv::new();
680        assert!(s.resolve(&env).is_err());
681    }
682
683    #[test]
684    fn test_symbolic_shape_resolve_with_default() {
685        let s = SymbolicShape::new(vec![
686            SymDim::dynamic(),
687            SymDim::symbolic("N"),
688            SymDim::fixed(768),
689        ]);
690        let mut env = ShapeEnv::new();
691        env.bind("N", 100);
692        let concrete = s.resolve_with_default(&env, 1);
693        assert_eq!(concrete.dims(), &[1, 100, 768]);
694    }
695
696    #[test]
697    fn test_symbolic_shape_matches() {
698        let s = SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(784)]);
699        let mut env = ShapeEnv::new();
700        env.bind("Batch", 32);
701
702        assert!(s.matches(&Shape::new(vec![32, 784]), &env));
703        assert!(!s.matches(&Shape::new(vec![32, 100]), &env)); // wrong feature dim
704        assert!(!s.matches(&Shape::new(vec![64, 784]), &env)); // wrong batch
705        assert!(!s.matches(&Shape::new(vec![32, 784, 1]), &env)); // wrong rank
706    }
707
708    #[test]
709    fn test_symbolic_shape_unify() {
710        let s = SymbolicShape::new(vec![
711            SymDim::symbolic("Batch"),
712            SymDim::symbolic("SeqLen"),
713            SymDim::fixed(768),
714        ]);
715        let shape = Shape::new(vec![16, 256, 768]);
716        let mut env = ShapeEnv::new();
717
718        assert!(s.unify(&shape, &mut env));
719        assert_eq!(env.get("Batch"), Some(16));
720        assert_eq!(env.get("SeqLen"), Some(256));
721    }
722
723    #[test]
724    fn test_symbolic_shape_unify_consistency() {
725        // Same symbolic name used twice must have same value
726        let s = SymbolicShape::new(vec![SymDim::symbolic("N"), SymDim::symbolic("N")]);
727
728        let mut env1 = ShapeEnv::new();
729        assert!(s.unify(&Shape::new(vec![32, 32]), &mut env1));
730        assert_eq!(env1.get("N"), Some(32));
731
732        let mut env2 = ShapeEnv::new();
733        assert!(!s.unify(&Shape::new(vec![32, 64]), &mut env2)); // inconsistent
734    }
735
736    #[test]
737    fn test_symbolic_shape_unify_fixed_mismatch() {
738        let s = SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(784)]);
739        let mut env = ShapeEnv::new();
740        assert!(!s.unify(&Shape::new(vec![32, 100]), &mut env)); // 100 != 784
741    }
742
743    #[test]
744    fn test_symbolic_shape_from_concrete() {
745        let shape = Shape::new(vec![2, 3, 4]);
746        let sym = SymbolicShape::from_shape(&shape);
747        assert!(sym.is_concrete());
748        assert!(!sym.has_symbolic());
749        let resolved = sym.resolve(&ShapeEnv::new()).unwrap();
750        assert_eq!(resolved.dims(), &[2, 3, 4]);
751    }
752
753    #[test]
754    fn test_symbolic_shape_broadcast() {
755        let a = SymbolicShape::new(vec![SymDim::fixed(3), SymDim::fixed(1)]);
756        let b = SymbolicShape::new(vec![SymDim::fixed(1), SymDim::fixed(4)]);
757        let c = a.broadcast(&b).unwrap();
758        assert_eq!(format!("{c}"), "[3, 4]");
759    }
760
761    #[test]
762    fn test_symbolic_shape_broadcast_symbolic() {
763        let a = SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(768)]);
764        let b = SymbolicShape::new(vec![SymDim::fixed(1), SymDim::fixed(768)]);
765        let c = a.broadcast(&b).unwrap();
766        assert_eq!(format!("{c}"), "[Batch, 768]");
767    }
768
769    #[test]
770    fn test_symbolic_shape_broadcast_incompatible() {
771        let a = SymbolicShape::new(vec![SymDim::fixed(3)]);
772        let b = SymbolicShape::new(vec![SymDim::fixed(4)]);
773        assert!(a.broadcast(&b).is_none());
774    }
775
776    #[test]
777    fn test_symbolic_shape_broadcast_rank_extension() {
778        let a = SymbolicShape::new(vec![SymDim::fixed(768)]);
779        let b = SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(768)]);
780        let c = a.broadcast(&b).unwrap();
781        assert_eq!(c.rank(), 2);
782        assert_eq!(format!("{c}"), "[Batch, 768]");
783    }
784
785    // ── ShapeEnv ──
786
787    #[test]
788    fn test_shape_env_basic() {
789        let mut env = ShapeEnv::new();
790        assert!(env.is_empty());
791        env.bind("Batch", 32);
792        assert_eq!(env.len(), 1);
793        assert_eq!(env.get("Batch"), Some(32));
794        assert!(env.is_bound("Batch"));
795        assert!(!env.is_bound("SeqLen"));
796    }
797
798    #[test]
799    fn test_shape_env_merge() {
800        let mut a = ShapeEnv::new();
801        a.bind("Batch", 32);
802
803        let mut b = ShapeEnv::new();
804        b.bind("SeqLen", 128);
805        b.bind("Batch", 32); // same value — ok
806
807        a.merge(&b).unwrap();
808        assert_eq!(a.get("Batch"), Some(32));
809        assert_eq!(a.get("SeqLen"), Some(128));
810    }
811
812    #[test]
813    fn test_shape_env_merge_conflict() {
814        let mut a = ShapeEnv::new();
815        a.bind("Batch", 32);
816
817        let mut b = ShapeEnv::new();
818        b.bind("Batch", 64); // different value — conflict
819
820        assert!(a.merge(&b).is_err());
821    }
822
823    #[test]
824    fn test_shape_env_from_slice() {
825        let env = ShapeEnv::from([("Batch", 32), ("SeqLen", 128)].as_slice());
826        assert_eq!(env.get("Batch"), Some(32));
827        assert_eq!(env.get("SeqLen"), Some(128));
828    }
829
830    // ── ShapeGuard ──
831
832    #[test]
833    fn test_shape_guard_validate() {
834        let guard = ShapeGuard::new(
835            "input",
836            SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(784)]),
837        );
838        let mut env = ShapeEnv::new();
839        env.bind("Batch", 32);
840
841        assert!(guard.validate(&Shape::new(vec![32, 784]), &env).is_ok());
842        assert!(guard.validate(&Shape::new(vec![32, 100]), &env).is_err());
843        assert!(guard.validate(&Shape::new(vec![64, 784]), &env).is_err());
844    }
845
846    #[test]
847    fn test_shape_guard_validate_and_bind() {
848        let guard = ShapeGuard::new(
849            "input",
850            SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(784)]),
851        );
852        let mut env = ShapeEnv::new();
853
854        // First call: binds Batch=32
855        guard
856            .validate_and_bind(&Shape::new(vec![32, 784]), &mut env)
857            .unwrap();
858        assert_eq!(env.get("Batch"), Some(32));
859
860        // Second call with same batch: ok
861        guard
862            .validate_and_bind(&Shape::new(vec![32, 784]), &mut env)
863            .unwrap();
864
865        // Third call with different batch: error
866        assert!(guard
867            .validate_and_bind(&Shape::new(vec![64, 784]), &mut env)
868            .is_err());
869    }
870
871    #[test]
872    fn test_shape_guard_wrong_rank() {
873        let guard = ShapeGuard::new(
874            "x",
875            SymbolicShape::new(vec![SymDim::dynamic(), SymDim::fixed(10)]),
876        );
877        let env = ShapeEnv::new();
878        assert!(guard.validate(&Shape::new(vec![10]), &env).is_err());
879        assert!(guard.validate(&Shape::new(vec![5, 10, 1]), &env).is_err());
880    }
881
882    // ── Integration: multi-tensor shape consistency ──
883
884    #[test]
885    fn test_multi_tensor_batch_consistency() {
886        // Verify that batch dimensions stay consistent across tensors
887        let input_guard = ShapeGuard::new(
888            "input",
889            SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(784)]),
890        );
891        let target_guard = ShapeGuard::new(
892            "target",
893            SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(10)]),
894        );
895
896        let mut env = ShapeEnv::new();
897
898        // Input has batch=32
899        input_guard
900            .validate_and_bind(&Shape::new(vec![32, 784]), &mut env)
901            .unwrap();
902
903        // Target must also have batch=32
904        target_guard
905            .validate_and_bind(&Shape::new(vec![32, 10]), &mut env)
906            .unwrap();
907
908        // Target with batch=16 would fail
909        let mut env2 = env.clone();
910        assert!(target_guard
911            .validate_and_bind(&Shape::new(vec![16, 10]), &mut env2)
912            .is_err());
913    }
914
915    #[test]
916    fn test_transformer_shape_pattern() {
917        // A transformer model with batch, sequence, and hidden dimensions
918        let pattern = SymbolicShape::new(vec![
919            SymDim::symbolic("Batch"),
920            SymDim::symbolic("SeqLen"),
921            SymDim::fixed(768),
922        ]);
923
924        // Different batch sizes work
925        let mut env1 = ShapeEnv::new();
926        assert!(pattern.unify(&Shape::new(vec![1, 512, 768]), &mut env1));
927        assert_eq!(env1.get("Batch"), Some(1));
928        assert_eq!(env1.get("SeqLen"), Some(512));
929
930        let mut env2 = ShapeEnv::new();
931        assert!(pattern.unify(&Shape::new(vec![32, 128, 768]), &mut env2));
932        assert_eq!(env2.get("Batch"), Some(32));
933        assert_eq!(env2.get("SeqLen"), Some(128));
934
935        // Wrong hidden dim fails
936        let mut env3 = ShapeEnv::new();
937        assert!(!pattern.unify(&Shape::new(vec![32, 128, 512]), &mut env3));
938    }
939}