1use std::collections::HashMap;
37use std::fmt;
38
39use crate::error::{Error, Result};
40use crate::shape::Shape;
41
42#[derive(Debug, Clone, PartialEq, Eq, Hash)]
49pub enum SymDim {
50 Fixed(usize),
52 Symbolic(String),
55 Dynamic,
58}
59
60impl SymDim {
61 pub fn fixed(n: usize) -> Self {
63 SymDim::Fixed(n)
64 }
65
66 pub fn symbolic(name: impl Into<String>) -> Self {
68 SymDim::Symbolic(name.into())
69 }
70
71 pub fn dynamic() -> Self {
73 SymDim::Dynamic
74 }
75
76 pub fn is_fixed(&self) -> bool {
78 matches!(self, SymDim::Fixed(_))
79 }
80
81 pub fn is_symbolic(&self) -> bool {
83 matches!(self, SymDim::Symbolic(_))
84 }
85
86 pub fn is_dynamic(&self) -> bool {
88 matches!(self, SymDim::Dynamic)
89 }
90
91 pub fn resolve(&self, env: &ShapeEnv) -> Option<usize> {
97 match self {
98 SymDim::Fixed(n) => Some(*n),
99 SymDim::Symbolic(name) => env.get(name),
100 SymDim::Dynamic => None,
101 }
102 }
103
104 pub fn matches(&self, value: usize, env: &ShapeEnv) -> bool {
110 match self {
111 SymDim::Fixed(n) => value == *n,
112 SymDim::Symbolic(name) => {
113 if let Some(bound) = env.get(name) {
114 value == bound
115 } else {
116 true }
118 }
119 SymDim::Dynamic => true,
120 }
121 }
122
123 pub fn unify(&self, value: usize, env: &mut ShapeEnv) -> bool {
128 match self {
129 SymDim::Fixed(n) => value == *n,
130 SymDim::Symbolic(name) => {
131 if let Some(bound) = env.get(name) {
132 value == bound
133 } else {
134 env.bind(name, value);
135 true
136 }
137 }
138 SymDim::Dynamic => true,
139 }
140 }
141}
142
143impl fmt::Display for SymDim {
144 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145 match self {
146 SymDim::Fixed(n) => write!(f, "{n}"),
147 SymDim::Symbolic(s) => write!(f, "{s}"),
148 SymDim::Dynamic => write!(f, "?"),
149 }
150 }
151}
152
153impl From<usize> for SymDim {
154 fn from(n: usize) -> Self {
155 SymDim::Fixed(n)
156 }
157}
158
159impl From<&str> for SymDim {
160 fn from(s: &str) -> Self {
161 SymDim::Symbolic(s.to_string())
162 }
163}
164
165#[derive(Debug, Clone, PartialEq, Eq)]
192pub struct SymbolicShape {
193 dims: Vec<SymDim>,
194}
195
196impl SymbolicShape {
197 pub fn new(dims: Vec<SymDim>) -> Self {
199 Self { dims }
200 }
201
202 pub fn from_shape(shape: &Shape) -> Self {
204 Self {
205 dims: shape.dims().iter().map(|&d| SymDim::Fixed(d)).collect(),
206 }
207 }
208
209 pub fn rank(&self) -> usize {
211 self.dims.len()
212 }
213
214 pub fn dims(&self) -> &[SymDim] {
216 &self.dims
217 }
218
219 pub fn is_concrete(&self) -> bool {
221 self.dims.iter().all(|d| d.is_fixed())
222 }
223
224 pub fn has_symbolic(&self) -> bool {
226 self.dims.iter().any(|d| !d.is_fixed())
227 }
228
229 pub fn symbolic_names(&self) -> Vec<&str> {
231 self.dims
232 .iter()
233 .filter_map(|d| match d {
234 SymDim::Symbolic(name) => Some(name.as_str()),
235 _ => None,
236 })
237 .collect()
238 }
239
240 pub fn resolve(&self, env: &ShapeEnv) -> Result<Shape> {
245 let mut concrete = Vec::with_capacity(self.dims.len());
246 for (i, dim) in self.dims.iter().enumerate() {
247 match dim.resolve(env) {
248 Some(n) => concrete.push(n),
249 None => {
250 return Err(Error::msg(format!(
251 "cannot resolve dimension {} ({}) — not bound in environment",
252 i, dim
253 )));
254 }
255 }
256 }
257 Ok(Shape::new(concrete))
258 }
259
260 pub fn resolve_with_default(&self, env: &ShapeEnv, default: usize) -> Shape {
263 let concrete: Vec<usize> = self
264 .dims
265 .iter()
266 .map(|d| d.resolve(env).unwrap_or(default))
267 .collect();
268 Shape::new(concrete)
269 }
270
271 pub fn matches(&self, shape: &Shape, env: &ShapeEnv) -> bool {
277 if self.rank() != shape.rank() {
278 return false;
279 }
280 self.dims
281 .iter()
282 .zip(shape.dims().iter())
283 .all(|(pattern, &value)| pattern.matches(value, env))
284 }
285
286 pub fn unify(&self, shape: &Shape, env: &mut ShapeEnv) -> bool {
292 if self.rank() != shape.rank() {
293 return false;
294 }
295 let mut new_bindings = Vec::new();
297 for (pattern, &value) in self.dims.iter().zip(shape.dims().iter()) {
298 match pattern {
299 SymDim::Fixed(n) => {
300 if value != *n {
301 return false;
302 }
303 }
304 SymDim::Symbolic(name) => {
305 if let Some(bound) = env.get(name) {
306 if value != bound {
307 return false;
308 }
309 } else {
310 if let Some(&prev) =
312 new_bindings.iter().find_map(|(n, v): &(&str, usize)| {
313 if *n == name.as_str() {
314 Some(v)
315 } else {
316 None
317 }
318 })
319 {
320 if value != prev {
321 return false;
322 }
323 } else {
324 new_bindings.push((name.as_str(), value));
325 }
326 }
327 }
328 SymDim::Dynamic => {} }
330 }
331 for (name, value) in new_bindings {
333 env.bind(name, value);
334 }
335 true
336 }
337
338 pub fn broadcast(&self, other: &SymbolicShape) -> Option<SymbolicShape> {
341 let rank = self.rank().max(other.rank());
342 let mut result = Vec::with_capacity(rank);
343
344 for i in 0..rank {
345 let a = if i < rank - self.rank() {
346 &SymDim::Fixed(1)
347 } else {
348 &self.dims[i - (rank - self.rank())]
349 };
350 let b = if i < rank - other.rank() {
351 &SymDim::Fixed(1)
352 } else {
353 &other.dims[i - (rank - other.rank())]
354 };
355
356 match (a, b) {
357 (SymDim::Fixed(1), _) => result.push(b.clone()),
358 (_, SymDim::Fixed(1)) => result.push(a.clone()),
359 (SymDim::Fixed(x), SymDim::Fixed(y)) if x == y => result.push(a.clone()),
360 (SymDim::Fixed(_), SymDim::Fixed(_)) => return None, (SymDim::Symbolic(s), SymDim::Symbolic(t)) if s == t => result.push(a.clone()),
362 (SymDim::Dynamic, _) | (_, SymDim::Dynamic) => result.push(SymDim::Dynamic),
363 (SymDim::Symbolic(_), _) | (_, SymDim::Symbolic(_)) => result.push(SymDim::Dynamic),
364 }
365 }
366 Some(SymbolicShape::new(result))
367 }
368}
369
370impl fmt::Display for SymbolicShape {
371 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
372 write!(f, "[")?;
373 for (i, d) in self.dims.iter().enumerate() {
374 if i > 0 {
375 write!(f, ", ")?;
376 }
377 write!(f, "{d}")?;
378 }
379 write!(f, "]")
380 }
381}
382
383impl From<Vec<SymDim>> for SymbolicShape {
384 fn from(dims: Vec<SymDim>) -> Self {
385 Self::new(dims)
386 }
387}
388
389impl From<Shape> for SymbolicShape {
390 fn from(shape: Shape) -> Self {
391 Self::from_shape(&shape)
392 }
393}
394
395#[derive(Debug, Clone)]
409pub struct ShapeEnv {
410 bindings: HashMap<String, usize>,
411}
412
413impl ShapeEnv {
414 pub fn new() -> Self {
416 Self {
417 bindings: HashMap::new(),
418 }
419 }
420
421 pub fn bind(&mut self, name: impl Into<String>, value: usize) {
423 self.bindings.insert(name.into(), value);
424 }
425
426 pub fn get(&self, name: &str) -> Option<usize> {
428 self.bindings.get(name).copied()
429 }
430
431 pub fn is_bound(&self, name: &str) -> bool {
433 self.bindings.contains_key(name)
434 }
435
436 pub fn bindings(&self) -> &HashMap<String, usize> {
438 &self.bindings
439 }
440
441 pub fn len(&self) -> usize {
443 self.bindings.len()
444 }
445
446 pub fn is_empty(&self) -> bool {
448 self.bindings.is_empty()
449 }
450
451 pub fn merge(&mut self, other: &ShapeEnv) -> Result<()> {
454 for (name, &value) in &other.bindings {
455 if let Some(&existing) = self.bindings.get(name) {
456 if existing != value {
457 return Err(Error::msg(format!(
458 "conflicting binding for '{}': {} vs {}",
459 name, existing, value
460 )));
461 }
462 } else {
463 self.bindings.insert(name.clone(), value);
464 }
465 }
466 Ok(())
467 }
468}
469
470impl Default for ShapeEnv {
471 fn default() -> Self {
472 Self::new()
473 }
474}
475
476impl From<&[(&str, usize)]> for ShapeEnv {
477 fn from(bindings: &[(&str, usize)]) -> Self {
478 let mut env = ShapeEnv::new();
479 for &(name, value) in bindings {
480 env.bind(name, value);
481 }
482 env
483 }
484}
485
486#[derive(Debug, Clone)]
508pub struct ShapeGuard {
509 name: String,
511 pattern: SymbolicShape,
513}
514
515impl ShapeGuard {
516 pub fn new(name: impl Into<String>, pattern: SymbolicShape) -> Self {
518 Self {
519 name: name.into(),
520 pattern,
521 }
522 }
523
524 pub fn validate(&self, shape: &Shape, env: &ShapeEnv) -> Result<()> {
529 if self.pattern.rank() != shape.rank() {
530 return Err(Error::msg(format!(
531 "shape mismatch for '{}': expected rank {} ({}), got rank {} ({:?})",
532 self.name,
533 self.pattern.rank(),
534 self.pattern,
535 shape.rank(),
536 shape.dims()
537 )));
538 }
539 for (i, (expected, &actual)) in self
540 .pattern
541 .dims()
542 .iter()
543 .zip(shape.dims().iter())
544 .enumerate()
545 {
546 if !expected.matches(actual, env) {
547 return Err(Error::msg(format!(
548 "shape mismatch for '{}' at dim {}: expected {}, got {}",
549 self.name, i, expected, actual
550 )));
551 }
552 }
553 Ok(())
554 }
555
556 pub fn validate_and_bind(&self, shape: &Shape, env: &mut ShapeEnv) -> Result<()> {
559 if !self.pattern.unify(shape, env) {
560 Err(Error::msg(format!(
561 "shape mismatch for '{}': expected {}, got {:?}",
562 self.name,
563 self.pattern,
564 shape.dims()
565 )))
566 } else {
567 Ok(())
568 }
569 }
570}
571
572#[cfg(test)]
575mod tests {
576 use super::*;
577 use crate::shape::Shape;
578
579 #[test]
582 fn test_symdim_fixed() {
583 let d = SymDim::fixed(32);
584 assert!(d.is_fixed());
585 assert!(!d.is_symbolic());
586 assert!(!d.is_dynamic());
587 assert_eq!(d.resolve(&ShapeEnv::new()), Some(32));
588 assert_eq!(format!("{d}"), "32");
589 }
590
591 #[test]
592 fn test_symdim_symbolic_bound() {
593 let d = SymDim::symbolic("Batch");
594 let mut env = ShapeEnv::new();
595 env.bind("Batch", 64);
596
597 assert!(d.is_symbolic());
598 assert_eq!(d.resolve(&env), Some(64));
599 assert!(d.matches(64, &env));
600 assert!(!d.matches(32, &env));
601 }
602
603 #[test]
604 fn test_symdim_symbolic_unbound() {
605 let d = SymDim::symbolic("SeqLen");
606 let env = ShapeEnv::new();
607
608 assert_eq!(d.resolve(&env), None);
609 assert!(d.matches(100, &env)); assert!(d.matches(200, &env));
611 }
612
613 #[test]
614 fn test_symdim_dynamic() {
615 let d = SymDim::dynamic();
616 let env = ShapeEnv::new();
617 assert!(d.is_dynamic());
618 assert_eq!(d.resolve(&env), None);
619 assert!(d.matches(999, &env));
620 }
621
622 #[test]
623 fn test_symdim_unify() {
624 let d = SymDim::symbolic("N");
625 let mut env = ShapeEnv::new();
626 assert!(d.unify(42, &mut env));
627 assert_eq!(env.get("N"), Some(42));
628 assert!(d.unify(42, &mut env));
630 assert!(!d.unify(99, &mut env));
632 }
633
634 #[test]
635 fn test_symdim_from() {
636 let d: SymDim = 32usize.into();
637 assert_eq!(d, SymDim::Fixed(32));
638 let d: SymDim = "Batch".into();
639 assert_eq!(d, SymDim::Symbolic("Batch".to_string()));
640 }
641
642 #[test]
645 fn test_symbolic_shape_basic() {
646 let s = SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(784)]);
647 assert_eq!(s.rank(), 2);
648 assert!(!s.is_concrete());
649 assert!(s.has_symbolic());
650 assert_eq!(s.symbolic_names(), vec!["Batch"]);
651 assert_eq!(format!("{s}"), "[Batch, 784]");
652 }
653
654 #[test]
655 fn test_symbolic_shape_resolve() {
656 let s = SymbolicShape::new(vec![
657 SymDim::symbolic("Batch"),
658 SymDim::symbolic("SeqLen"),
659 SymDim::fixed(768),
660 ]);
661 let mut env = ShapeEnv::new();
662 env.bind("Batch", 32);
663 env.bind("SeqLen", 128);
664
665 let concrete = s.resolve(&env).unwrap();
666 assert_eq!(concrete.dims(), &[32, 128, 768]);
667 }
668
669 #[test]
670 fn test_symbolic_shape_resolve_fails_unbound() {
671 let s = SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(784)]);
672 let env = ShapeEnv::new();
673 assert!(s.resolve(&env).is_err());
674 }
675
676 #[test]
677 fn test_symbolic_shape_resolve_dynamic_fails() {
678 let s = SymbolicShape::new(vec![SymDim::dynamic(), SymDim::fixed(784)]);
679 let env = ShapeEnv::new();
680 assert!(s.resolve(&env).is_err());
681 }
682
683 #[test]
684 fn test_symbolic_shape_resolve_with_default() {
685 let s = SymbolicShape::new(vec![
686 SymDim::dynamic(),
687 SymDim::symbolic("N"),
688 SymDim::fixed(768),
689 ]);
690 let mut env = ShapeEnv::new();
691 env.bind("N", 100);
692 let concrete = s.resolve_with_default(&env, 1);
693 assert_eq!(concrete.dims(), &[1, 100, 768]);
694 }
695
696 #[test]
697 fn test_symbolic_shape_matches() {
698 let s = SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(784)]);
699 let mut env = ShapeEnv::new();
700 env.bind("Batch", 32);
701
702 assert!(s.matches(&Shape::new(vec![32, 784]), &env));
703 assert!(!s.matches(&Shape::new(vec![32, 100]), &env)); assert!(!s.matches(&Shape::new(vec![64, 784]), &env)); assert!(!s.matches(&Shape::new(vec![32, 784, 1]), &env)); }
707
708 #[test]
709 fn test_symbolic_shape_unify() {
710 let s = SymbolicShape::new(vec![
711 SymDim::symbolic("Batch"),
712 SymDim::symbolic("SeqLen"),
713 SymDim::fixed(768),
714 ]);
715 let shape = Shape::new(vec![16, 256, 768]);
716 let mut env = ShapeEnv::new();
717
718 assert!(s.unify(&shape, &mut env));
719 assert_eq!(env.get("Batch"), Some(16));
720 assert_eq!(env.get("SeqLen"), Some(256));
721 }
722
723 #[test]
724 fn test_symbolic_shape_unify_consistency() {
725 let s = SymbolicShape::new(vec![SymDim::symbolic("N"), SymDim::symbolic("N")]);
727
728 let mut env1 = ShapeEnv::new();
729 assert!(s.unify(&Shape::new(vec![32, 32]), &mut env1));
730 assert_eq!(env1.get("N"), Some(32));
731
732 let mut env2 = ShapeEnv::new();
733 assert!(!s.unify(&Shape::new(vec![32, 64]), &mut env2)); }
735
736 #[test]
737 fn test_symbolic_shape_unify_fixed_mismatch() {
738 let s = SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(784)]);
739 let mut env = ShapeEnv::new();
740 assert!(!s.unify(&Shape::new(vec![32, 100]), &mut env)); }
742
743 #[test]
744 fn test_symbolic_shape_from_concrete() {
745 let shape = Shape::new(vec![2, 3, 4]);
746 let sym = SymbolicShape::from_shape(&shape);
747 assert!(sym.is_concrete());
748 assert!(!sym.has_symbolic());
749 let resolved = sym.resolve(&ShapeEnv::new()).unwrap();
750 assert_eq!(resolved.dims(), &[2, 3, 4]);
751 }
752
753 #[test]
754 fn test_symbolic_shape_broadcast() {
755 let a = SymbolicShape::new(vec![SymDim::fixed(3), SymDim::fixed(1)]);
756 let b = SymbolicShape::new(vec![SymDim::fixed(1), SymDim::fixed(4)]);
757 let c = a.broadcast(&b).unwrap();
758 assert_eq!(format!("{c}"), "[3, 4]");
759 }
760
761 #[test]
762 fn test_symbolic_shape_broadcast_symbolic() {
763 let a = SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(768)]);
764 let b = SymbolicShape::new(vec![SymDim::fixed(1), SymDim::fixed(768)]);
765 let c = a.broadcast(&b).unwrap();
766 assert_eq!(format!("{c}"), "[Batch, 768]");
767 }
768
769 #[test]
770 fn test_symbolic_shape_broadcast_incompatible() {
771 let a = SymbolicShape::new(vec![SymDim::fixed(3)]);
772 let b = SymbolicShape::new(vec![SymDim::fixed(4)]);
773 assert!(a.broadcast(&b).is_none());
774 }
775
776 #[test]
777 fn test_symbolic_shape_broadcast_rank_extension() {
778 let a = SymbolicShape::new(vec![SymDim::fixed(768)]);
779 let b = SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(768)]);
780 let c = a.broadcast(&b).unwrap();
781 assert_eq!(c.rank(), 2);
782 assert_eq!(format!("{c}"), "[Batch, 768]");
783 }
784
785 #[test]
788 fn test_shape_env_basic() {
789 let mut env = ShapeEnv::new();
790 assert!(env.is_empty());
791 env.bind("Batch", 32);
792 assert_eq!(env.len(), 1);
793 assert_eq!(env.get("Batch"), Some(32));
794 assert!(env.is_bound("Batch"));
795 assert!(!env.is_bound("SeqLen"));
796 }
797
798 #[test]
799 fn test_shape_env_merge() {
800 let mut a = ShapeEnv::new();
801 a.bind("Batch", 32);
802
803 let mut b = ShapeEnv::new();
804 b.bind("SeqLen", 128);
805 b.bind("Batch", 32); a.merge(&b).unwrap();
808 assert_eq!(a.get("Batch"), Some(32));
809 assert_eq!(a.get("SeqLen"), Some(128));
810 }
811
812 #[test]
813 fn test_shape_env_merge_conflict() {
814 let mut a = ShapeEnv::new();
815 a.bind("Batch", 32);
816
817 let mut b = ShapeEnv::new();
818 b.bind("Batch", 64); assert!(a.merge(&b).is_err());
821 }
822
823 #[test]
824 fn test_shape_env_from_slice() {
825 let env = ShapeEnv::from([("Batch", 32), ("SeqLen", 128)].as_slice());
826 assert_eq!(env.get("Batch"), Some(32));
827 assert_eq!(env.get("SeqLen"), Some(128));
828 }
829
830 #[test]
833 fn test_shape_guard_validate() {
834 let guard = ShapeGuard::new(
835 "input",
836 SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(784)]),
837 );
838 let mut env = ShapeEnv::new();
839 env.bind("Batch", 32);
840
841 assert!(guard.validate(&Shape::new(vec![32, 784]), &env).is_ok());
842 assert!(guard.validate(&Shape::new(vec![32, 100]), &env).is_err());
843 assert!(guard.validate(&Shape::new(vec![64, 784]), &env).is_err());
844 }
845
846 #[test]
847 fn test_shape_guard_validate_and_bind() {
848 let guard = ShapeGuard::new(
849 "input",
850 SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(784)]),
851 );
852 let mut env = ShapeEnv::new();
853
854 guard
856 .validate_and_bind(&Shape::new(vec![32, 784]), &mut env)
857 .unwrap();
858 assert_eq!(env.get("Batch"), Some(32));
859
860 guard
862 .validate_and_bind(&Shape::new(vec![32, 784]), &mut env)
863 .unwrap();
864
865 assert!(guard
867 .validate_and_bind(&Shape::new(vec![64, 784]), &mut env)
868 .is_err());
869 }
870
871 #[test]
872 fn test_shape_guard_wrong_rank() {
873 let guard = ShapeGuard::new(
874 "x",
875 SymbolicShape::new(vec![SymDim::dynamic(), SymDim::fixed(10)]),
876 );
877 let env = ShapeEnv::new();
878 assert!(guard.validate(&Shape::new(vec![10]), &env).is_err());
879 assert!(guard.validate(&Shape::new(vec![5, 10, 1]), &env).is_err());
880 }
881
882 #[test]
885 fn test_multi_tensor_batch_consistency() {
886 let input_guard = ShapeGuard::new(
888 "input",
889 SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(784)]),
890 );
891 let target_guard = ShapeGuard::new(
892 "target",
893 SymbolicShape::new(vec![SymDim::symbolic("Batch"), SymDim::fixed(10)]),
894 );
895
896 let mut env = ShapeEnv::new();
897
898 input_guard
900 .validate_and_bind(&Shape::new(vec![32, 784]), &mut env)
901 .unwrap();
902
903 target_guard
905 .validate_and_bind(&Shape::new(vec![32, 10]), &mut env)
906 .unwrap();
907
908 let mut env2 = env.clone();
910 assert!(target_guard
911 .validate_and_bind(&Shape::new(vec![16, 10]), &mut env2)
912 .is_err());
913 }
914
915 #[test]
916 fn test_transformer_shape_pattern() {
917 let pattern = SymbolicShape::new(vec![
919 SymDim::symbolic("Batch"),
920 SymDim::symbolic("SeqLen"),
921 SymDim::fixed(768),
922 ]);
923
924 let mut env1 = ShapeEnv::new();
926 assert!(pattern.unify(&Shape::new(vec![1, 512, 768]), &mut env1));
927 assert_eq!(env1.get("Batch"), Some(1));
928 assert_eq!(env1.get("SeqLen"), Some(512));
929
930 let mut env2 = ShapeEnv::new();
931 assert!(pattern.unify(&Shape::new(vec![32, 128, 768]), &mut env2));
932 assert_eq!(env2.get("Batch"), Some(32));
933 assert_eq!(env2.get("SeqLen"), Some(128));
934
935 let mut env3 = ShapeEnv::new();
937 assert!(!pattern.unify(&Shape::new(vec![32, 128, 512]), &mut env3));
938 }
939}