1use std::collections::HashMap;
27use std::fmt;
28
29#[derive(Debug, Clone, PartialEq)]
31pub enum DimExpr {
32 Const(usize),
34 Var(String),
36 Add(Box<DimExpr>, Box<DimExpr>),
38 Sub(Box<DimExpr>, Box<DimExpr>),
40 Mul(Box<DimExpr>, Box<DimExpr>),
42 Div(Box<DimExpr>, Box<DimExpr>),
44 Max(Box<DimExpr>, Box<DimExpr>),
46 Min(Box<DimExpr>, Box<DimExpr>),
48 CeilDiv(Box<DimExpr>, Box<DimExpr>),
50}
51
52impl DimExpr {
53 pub fn constant(value: usize) -> Self {
55 DimExpr::Const(value)
56 }
57
58 pub fn var(name: impl Into<String>) -> Self {
60 DimExpr::Var(name.into())
61 }
62
63 #[allow(clippy::should_implement_trait)]
65 pub fn add(self, other: DimExpr) -> Self {
66 DimExpr::Add(Box::new(self), Box::new(other))
67 }
68
69 #[allow(clippy::should_implement_trait)]
71 pub fn sub(self, other: DimExpr) -> Self {
72 DimExpr::Sub(Box::new(self), Box::new(other))
73 }
74
75 #[allow(clippy::should_implement_trait)]
77 pub fn mul(self, other: DimExpr) -> Self {
78 DimExpr::Mul(Box::new(self), Box::new(other))
79 }
80
81 #[allow(clippy::should_implement_trait)]
83 pub fn div(self, other: DimExpr) -> Self {
84 DimExpr::Div(Box::new(self), Box::new(other))
85 }
86
87 pub fn max(self, other: DimExpr) -> Self {
89 DimExpr::Max(Box::new(self), Box::new(other))
90 }
91
92 pub fn min(self, other: DimExpr) -> Self {
94 DimExpr::Min(Box::new(self), Box::new(other))
95 }
96
97 pub fn ceil_div(self, other: DimExpr) -> Self {
99 DimExpr::CeilDiv(Box::new(self), Box::new(other))
100 }
101
102 pub fn eval(&self, ctx: &DependentTypeContext) -> Option<usize> {
104 match self {
105 DimExpr::Const(n) => Some(*n),
106 DimExpr::Var(name) => ctx.get_dim(name).copied(),
107 DimExpr::Add(a, b) => Some(a.eval(ctx)? + b.eval(ctx)?),
108 DimExpr::Sub(a, b) => {
109 let av = a.eval(ctx)?;
110 let bv = b.eval(ctx)?;
111 av.checked_sub(bv)
112 }
113 DimExpr::Mul(a, b) => Some(a.eval(ctx)? * b.eval(ctx)?),
114 DimExpr::Div(a, b) => {
115 let bv = b.eval(ctx)?;
116 if bv == 0 {
117 return None;
118 }
119 Some(a.eval(ctx)? / bv)
120 }
121 DimExpr::Max(a, b) => Some(a.eval(ctx)?.max(b.eval(ctx)?)),
122 DimExpr::Min(a, b) => Some(a.eval(ctx)?.min(b.eval(ctx)?)),
123 DimExpr::CeilDiv(a, b) => {
124 let av = a.eval(ctx)?;
125 let bv = b.eval(ctx)?;
126 if bv == 0 {
127 return None;
128 }
129 Some(av.div_ceil(bv))
130 }
131 }
132 }
133
134 pub fn free_variables(&self) -> Vec<String> {
136 match self {
137 DimExpr::Const(_) => vec![],
138 DimExpr::Var(name) => vec![name.clone()],
139 DimExpr::Add(a, b)
140 | DimExpr::Sub(a, b)
141 | DimExpr::Mul(a, b)
142 | DimExpr::Div(a, b)
143 | DimExpr::Max(a, b)
144 | DimExpr::Min(a, b)
145 | DimExpr::CeilDiv(a, b) => {
146 let mut vars = a.free_variables();
147 vars.extend(b.free_variables());
148 vars.sort();
149 vars.dedup();
150 vars
151 }
152 }
153 }
154
155 pub fn substitute(&self, var: &str, expr: &DimExpr) -> DimExpr {
157 match self {
158 DimExpr::Const(n) => DimExpr::Const(*n),
159 DimExpr::Var(name) => {
160 if name == var {
161 expr.clone()
162 } else {
163 DimExpr::Var(name.clone())
164 }
165 }
166 DimExpr::Add(a, b) => DimExpr::Add(
167 Box::new(a.substitute(var, expr)),
168 Box::new(b.substitute(var, expr)),
169 ),
170 DimExpr::Sub(a, b) => DimExpr::Sub(
171 Box::new(a.substitute(var, expr)),
172 Box::new(b.substitute(var, expr)),
173 ),
174 DimExpr::Mul(a, b) => DimExpr::Mul(
175 Box::new(a.substitute(var, expr)),
176 Box::new(b.substitute(var, expr)),
177 ),
178 DimExpr::Div(a, b) => DimExpr::Div(
179 Box::new(a.substitute(var, expr)),
180 Box::new(b.substitute(var, expr)),
181 ),
182 DimExpr::Max(a, b) => DimExpr::Max(
183 Box::new(a.substitute(var, expr)),
184 Box::new(b.substitute(var, expr)),
185 ),
186 DimExpr::Min(a, b) => DimExpr::Min(
187 Box::new(a.substitute(var, expr)),
188 Box::new(b.substitute(var, expr)),
189 ),
190 DimExpr::CeilDiv(a, b) => DimExpr::CeilDiv(
191 Box::new(a.substitute(var, expr)),
192 Box::new(b.substitute(var, expr)),
193 ),
194 }
195 }
196
197 pub fn simplify(&self) -> DimExpr {
199 match self {
200 DimExpr::Add(a, b) => {
201 let a = a.simplify();
202 let b = b.simplify();
203 match (&a, &b) {
204 (DimExpr::Const(x), DimExpr::Const(y)) => DimExpr::Const(x + y),
205 (DimExpr::Const(0), _) => b,
206 (_, DimExpr::Const(0)) => a,
207 _ => DimExpr::Add(Box::new(a), Box::new(b)),
208 }
209 }
210 DimExpr::Sub(a, b) => {
211 let a = a.simplify();
212 let b = b.simplify();
213 match (&a, &b) {
214 (DimExpr::Const(x), DimExpr::Const(y)) => DimExpr::Const(x.saturating_sub(*y)),
215 (_, DimExpr::Const(0)) => a,
216 _ => DimExpr::Sub(Box::new(a), Box::new(b)),
217 }
218 }
219 DimExpr::Mul(a, b) => {
220 let a = a.simplify();
221 let b = b.simplify();
222 match (&a, &b) {
223 (DimExpr::Const(x), DimExpr::Const(y)) => DimExpr::Const(x * y),
224 (DimExpr::Const(0), _) | (_, DimExpr::Const(0)) => DimExpr::Const(0),
225 (DimExpr::Const(1), _) => b,
226 (_, DimExpr::Const(1)) => a,
227 _ => DimExpr::Mul(Box::new(a), Box::new(b)),
228 }
229 }
230 DimExpr::Div(a, b) => {
231 let a = a.simplify();
232 let b = b.simplify();
233 match (&a, &b) {
234 (DimExpr::Const(x), DimExpr::Const(y)) if *y != 0 => DimExpr::Const(x / y),
235 (DimExpr::Const(0), _) => DimExpr::Const(0),
236 (_, DimExpr::Const(1)) => a,
237 _ => DimExpr::Div(Box::new(a), Box::new(b)),
238 }
239 }
240 DimExpr::Max(a, b) => {
241 let a = a.simplify();
242 let b = b.simplify();
243 match (&a, &b) {
244 (DimExpr::Const(x), DimExpr::Const(y)) => DimExpr::Const((*x).max(*y)),
245 _ => DimExpr::Max(Box::new(a), Box::new(b)),
246 }
247 }
248 DimExpr::Min(a, b) => {
249 let a = a.simplify();
250 let b = b.simplify();
251 match (&a, &b) {
252 (DimExpr::Const(x), DimExpr::Const(y)) => DimExpr::Const((*x).min(*y)),
253 _ => DimExpr::Min(Box::new(a), Box::new(b)),
254 }
255 }
256 DimExpr::CeilDiv(a, b) => {
257 let a = a.simplify();
258 let b = b.simplify();
259 match (&a, &b) {
260 (DimExpr::Const(x), DimExpr::Const(y)) if *y != 0 => {
261 DimExpr::Const(x.div_ceil(*y))
262 }
263 _ => DimExpr::CeilDiv(Box::new(a), Box::new(b)),
264 }
265 }
266 other => other.clone(),
267 }
268 }
269
270 pub fn is_equal(&self, other: &DimExpr, ctx: &DependentTypeContext) -> bool {
272 if self == other {
274 return true;
275 }
276
277 match (self.eval(ctx), other.eval(ctx)) {
279 (Some(a), Some(b)) => a == b,
280 _ => false,
281 }
282 }
283}
284
285impl fmt::Display for DimExpr {
286 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287 match self {
288 DimExpr::Const(n) => write!(f, "{}", n),
289 DimExpr::Var(name) => write!(f, "{}", name),
290 DimExpr::Add(a, b) => write!(f, "({} + {})", a, b),
291 DimExpr::Sub(a, b) => write!(f, "({} - {})", a, b),
292 DimExpr::Mul(a, b) => write!(f, "({} * {})", a, b),
293 DimExpr::Div(a, b) => write!(f, "({} / {})", a, b),
294 DimExpr::Max(a, b) => write!(f, "max({}, {})", a, b),
295 DimExpr::Min(a, b) => write!(f, "min({}, {})", a, b),
296 DimExpr::CeilDiv(a, b) => write!(f, "ceil({} / {})", a, b),
297 }
298 }
299}
300
301#[derive(Debug, Clone)]
303pub struct DependentType {
304 pub base_type: String,
306 pub type_params: Vec<String>,
308 pub dim_params: Vec<DimExpr>,
310 pub name: Option<String>,
312 pub description: Option<String>,
314 pub constraints: Vec<DimConstraint>,
316}
317
318#[derive(Debug, Clone)]
320pub struct DimConstraint {
321 pub lhs: DimExpr,
323 pub relation: DimRelation,
325 pub rhs: DimExpr,
327 pub message: Option<String>,
329}
330
331#[derive(Debug, Clone, Copy, PartialEq)]
333pub enum DimRelation {
334 Equal,
336 NotEqual,
338 LessThan,
340 LessThanOrEqual,
342 GreaterThan,
344 GreaterThanOrEqual,
346 DivisibleBy,
348}
349
350impl DependentType {
351 pub fn new(base_type: impl Into<String>) -> Self {
353 DependentType {
354 base_type: base_type.into(),
355 type_params: Vec::new(),
356 dim_params: Vec::new(),
357 name: None,
358 description: None,
359 constraints: Vec::new(),
360 }
361 }
362
363 pub fn scalar(element_type: impl Into<String>) -> Self {
365 DependentType::new(element_type)
366 }
367
368 pub fn vector(element_type: impl Into<String>, length: DimExpr) -> Self {
370 DependentType {
371 base_type: "Vector".to_string(),
372 type_params: vec![element_type.into()],
373 dim_params: vec![length],
374 name: None,
375 description: None,
376 constraints: Vec::new(),
377 }
378 }
379
380 pub fn matrix(element_type: impl Into<String>, rows: DimExpr, cols: DimExpr) -> Self {
382 DependentType {
383 base_type: "Matrix".to_string(),
384 type_params: vec![element_type.into()],
385 dim_params: vec![rows, cols],
386 name: None,
387 description: None,
388 constraints: Vec::new(),
389 }
390 }
391
392 pub fn tensor(element_type: impl Into<String>, dims: Vec<DimExpr>) -> Self {
394 DependentType {
395 base_type: "Tensor".to_string(),
396 type_params: vec![element_type.into()],
397 dim_params: dims,
398 name: None,
399 description: None,
400 constraints: Vec::new(),
401 }
402 }
403
404 pub fn with_name(mut self, name: impl Into<String>) -> Self {
406 self.name = Some(name.into());
407 self
408 }
409
410 pub fn with_description(mut self, description: impl Into<String>) -> Self {
412 self.description = Some(description.into());
413 self
414 }
415
416 pub fn with_type_param(mut self, param: impl Into<String>) -> Self {
418 self.type_params.push(param.into());
419 self
420 }
421
422 pub fn with_dim_param(mut self, dim: DimExpr) -> Self {
424 self.dim_params.push(dim);
425 self
426 }
427
428 pub fn with_constraint(mut self, constraint: DimConstraint) -> Self {
430 self.constraints.push(constraint);
431 self
432 }
433
434 pub fn type_name(&self) -> String {
436 if let Some(name) = &self.name {
437 return name.clone();
438 }
439
440 if self.dim_params.is_empty() && self.type_params.is_empty() {
441 return self.base_type.clone();
442 }
443
444 let mut result = self.base_type.clone();
445 if !self.type_params.is_empty() || !self.dim_params.is_empty() {
446 result.push('<');
447
448 let mut parts = Vec::new();
449 for tp in &self.type_params {
450 parts.push(tp.clone());
451 }
452 for dp in &self.dim_params {
453 parts.push(format!("{}", dp));
454 }
455
456 result.push_str(&parts.join(", "));
457 result.push('>');
458 }
459 result
460 }
461
462 pub fn eval_shape(&self, ctx: &DependentTypeContext) -> Option<Vec<usize>> {
464 self.dim_params.iter().map(|d| d.eval(ctx)).collect()
465 }
466
467 pub fn rank(&self) -> usize {
469 self.dim_params.len()
470 }
471
472 pub fn free_variables(&self) -> Vec<String> {
474 let mut vars = Vec::new();
475 for dim in &self.dim_params {
476 vars.extend(dim.free_variables());
477 }
478 for constraint in &self.constraints {
479 vars.extend(constraint.lhs.free_variables());
480 vars.extend(constraint.rhs.free_variables());
481 }
482 vars.sort();
483 vars.dedup();
484 vars
485 }
486
487 pub fn check_constraints(&self, ctx: &DependentTypeContext) -> Result<(), String> {
489 for constraint in &self.constraints {
490 if !constraint.check(ctx) {
491 let msg = constraint.message.clone().unwrap_or_else(|| {
492 format!(
493 "Constraint violated: {} {:?} {}",
494 constraint.lhs, constraint.relation, constraint.rhs
495 )
496 });
497 return Err(msg);
498 }
499 }
500 Ok(())
501 }
502
503 pub fn is_compatible_with(&self, other: &DependentType, ctx: &DependentTypeContext) -> bool {
505 if self.base_type != other.base_type {
507 return false;
508 }
509
510 if self.type_params != other.type_params {
512 return false;
513 }
514
515 if self.dim_params.len() != other.dim_params.len() {
517 return false;
518 }
519
520 for (a, b) in self.dim_params.iter().zip(&other.dim_params) {
521 if !a.is_equal(b, ctx) {
522 return false;
523 }
524 }
525
526 true
527 }
528}
529
530impl DimConstraint {
531 pub fn new(lhs: DimExpr, relation: DimRelation, rhs: DimExpr) -> Self {
533 DimConstraint {
534 lhs,
535 relation,
536 rhs,
537 message: None,
538 }
539 }
540
541 pub fn with_message(mut self, message: impl Into<String>) -> Self {
543 self.message = Some(message.into());
544 self
545 }
546
547 pub fn check(&self, ctx: &DependentTypeContext) -> bool {
549 let lhs_val = match self.lhs.eval(ctx) {
550 Some(v) => v,
551 None => return false,
552 };
553 let rhs_val = match self.rhs.eval(ctx) {
554 Some(v) => v,
555 None => return false,
556 };
557
558 match self.relation {
559 DimRelation::Equal => lhs_val == rhs_val,
560 DimRelation::NotEqual => lhs_val != rhs_val,
561 DimRelation::LessThan => lhs_val < rhs_val,
562 DimRelation::LessThanOrEqual => lhs_val <= rhs_val,
563 DimRelation::GreaterThan => lhs_val > rhs_val,
564 DimRelation::GreaterThanOrEqual => lhs_val >= rhs_val,
565 DimRelation::DivisibleBy => rhs_val != 0 && lhs_val % rhs_val == 0,
566 }
567 }
568}
569
570#[derive(Debug, Clone, Default)]
572pub struct DependentTypeContext {
573 dims: HashMap<String, usize>,
575 types: HashMap<String, DependentType>,
577}
578
579impl DependentTypeContext {
580 pub fn new() -> Self {
582 DependentTypeContext {
583 dims: HashMap::new(),
584 types: HashMap::new(),
585 }
586 }
587
588 pub fn set_dim(&mut self, name: impl Into<String>, value: usize) {
590 self.dims.insert(name.into(), value);
591 }
592
593 pub fn get_dim(&self, name: &str) -> Option<&usize> {
595 self.dims.get(name)
596 }
597
598 pub fn set_type(&mut self, name: impl Into<String>, ty: DependentType) {
600 self.types.insert(name.into(), ty);
601 }
602
603 pub fn get_type(&self, name: &str) -> Option<&DependentType> {
605 self.types.get(name)
606 }
607
608 pub fn has_dim(&self, name: &str) -> bool {
610 self.dims.contains_key(name)
611 }
612
613 pub fn dim_names(&self) -> Vec<&str> {
615 self.dims.keys().map(|s| s.as_str()).collect()
616 }
617
618 pub fn clear_dims(&mut self) {
620 self.dims.clear();
621 }
622
623 pub fn unify(&mut self, a: &DimExpr, b: &DimExpr) -> bool {
627 match (a, b) {
628 (DimExpr::Var(va), DimExpr::Var(vb)) if va == vb => true,
629 (DimExpr::Var(va), expr) | (expr, DimExpr::Var(va)) => {
630 if let Some(&existing) = self.dims.get(va) {
631 if let Some(val) = expr.eval(self) {
632 existing == val
633 } else {
634 false
635 }
636 } else if let Some(val) = expr.eval(self) {
637 self.dims.insert(va.clone(), val);
638 true
639 } else {
640 false
641 }
642 }
643 (DimExpr::Const(ca), DimExpr::Const(cb)) => ca == cb,
644 _ => {
645 match (a.eval(self), b.eval(self)) {
647 (Some(va), Some(vb)) => va == vb,
648 _ => false,
649 }
650 }
651 }
652 }
653}
654
655#[derive(Debug, Clone, Default)]
657pub struct DependentTypeRegistry {
658 types: HashMap<String, DependentType>,
660}
661
662impl DependentTypeRegistry {
663 pub fn new() -> Self {
665 DependentTypeRegistry {
666 types: HashMap::new(),
667 }
668 }
669
670 pub fn register(&mut self, ty: DependentType) {
672 let name = ty.type_name();
673 self.types.insert(name, ty);
674 }
675
676 pub fn get(&self, name: &str) -> Option<&DependentType> {
678 self.types.get(name)
679 }
680
681 pub fn contains(&self, name: &str) -> bool {
683 self.types.contains_key(name)
684 }
685
686 pub fn type_names(&self) -> Vec<&str> {
688 self.types.keys().map(|s| s.as_str()).collect()
689 }
690
691 pub fn len(&self) -> usize {
693 self.types.len()
694 }
695
696 pub fn is_empty(&self) -> bool {
698 self.types.is_empty()
699 }
700}
701
702pub mod patterns {
704 use super::*;
705
706 pub fn square_matrix(element_type: impl Into<String>, size: DimExpr) -> DependentType {
708 DependentType::matrix(element_type, size.clone(), size).with_name("SquareMatrix")
709 }
710
711 pub fn identity_matrix(size: DimExpr) -> DependentType {
713 DependentType::matrix("Float", size.clone(), size).with_name("IdentityMatrix")
714 }
715
716 pub fn batch_vector(
718 element_type: impl Into<String>,
719 batch: DimExpr,
720 length: DimExpr,
721 ) -> DependentType {
722 DependentType::tensor(element_type, vec![batch, length]).with_name("BatchVector")
723 }
724
725 pub fn batch_matrix(
727 element_type: impl Into<String>,
728 batch: DimExpr,
729 rows: DimExpr,
730 cols: DimExpr,
731 ) -> DependentType {
732 DependentType::tensor(element_type, vec![batch, rows, cols]).with_name("BatchMatrix")
733 }
734
735 pub fn image_tensor(
737 batch: DimExpr,
738 channels: DimExpr,
739 height: DimExpr,
740 width: DimExpr,
741 ) -> DependentType {
742 DependentType::tensor("Float", vec![batch, channels, height, width])
743 .with_name("ImageTensor")
744 }
745
746 pub fn attention_tensor(
748 batch: DimExpr,
749 heads: DimExpr,
750 seq_len: DimExpr,
751 head_dim: DimExpr,
752 ) -> DependentType {
753 DependentType::tensor("Float", vec![batch, heads, seq_len, head_dim])
754 .with_name("AttentionTensor")
755 }
756}
757
758#[cfg(test)]
759mod tests {
760 use super::*;
761
762 #[test]
763 fn test_dim_expr_const() {
764 let expr = DimExpr::Const(42);
765 let ctx = DependentTypeContext::new();
766 assert_eq!(expr.eval(&ctx), Some(42));
767 }
768
769 #[test]
770 fn test_dim_expr_var() {
771 let expr = DimExpr::Var("n".to_string());
772 let mut ctx = DependentTypeContext::new();
773 ctx.set_dim("n", 10);
774 assert_eq!(expr.eval(&ctx), Some(10));
775 }
776
777 #[test]
778 fn test_dim_expr_arithmetic() {
779 let mut ctx = DependentTypeContext::new();
780 ctx.set_dim("n", 10);
781 ctx.set_dim("m", 3);
782
783 let add = DimExpr::var("n").add(DimExpr::var("m"));
784 assert_eq!(add.eval(&ctx), Some(13));
785
786 let mul = DimExpr::var("n").mul(DimExpr::var("m"));
787 assert_eq!(mul.eval(&ctx), Some(30));
788
789 let div = DimExpr::var("n").div(DimExpr::var("m"));
790 assert_eq!(div.eval(&ctx), Some(3));
791 }
792
793 #[test]
794 fn test_dim_expr_max_min() {
795 let mut ctx = DependentTypeContext::new();
796 ctx.set_dim("a", 5);
797 ctx.set_dim("b", 10);
798
799 let max = DimExpr::var("a").max(DimExpr::var("b"));
800 assert_eq!(max.eval(&ctx), Some(10));
801
802 let min = DimExpr::var("a").min(DimExpr::var("b"));
803 assert_eq!(min.eval(&ctx), Some(5));
804 }
805
806 #[test]
807 fn test_dim_expr_simplify() {
808 let expr = DimExpr::constant(5).add(DimExpr::constant(3));
809 let simplified = expr.simplify();
810 assert_eq!(simplified, DimExpr::Const(8));
811
812 let expr = DimExpr::var("x").add(DimExpr::constant(0));
813 let simplified = expr.simplify();
814 assert_eq!(simplified, DimExpr::Var("x".to_string()));
815 }
816
817 #[test]
818 fn test_vector_type() {
819 let vec_ty = DependentType::vector("Float", DimExpr::var("n"));
820 let mut ctx = DependentTypeContext::new();
821 ctx.set_dim("n", 100);
822
823 assert_eq!(vec_ty.eval_shape(&ctx), Some(vec![100]));
824 assert_eq!(vec_ty.rank(), 1);
825 }
826
827 #[test]
828 fn test_matrix_type() {
829 let mat_ty = DependentType::matrix("Float", DimExpr::var("m"), DimExpr::var("n"));
830 let mut ctx = DependentTypeContext::new();
831 ctx.set_dim("m", 10);
832 ctx.set_dim("n", 20);
833
834 assert_eq!(mat_ty.eval_shape(&ctx), Some(vec![10, 20]));
835 assert_eq!(mat_ty.rank(), 2);
836 }
837
838 #[test]
839 fn test_dim_constraint() {
840 let constraint = DimConstraint::new(
841 DimExpr::var("n"),
842 DimRelation::GreaterThan,
843 DimExpr::constant(0),
844 );
845
846 let mut ctx = DependentTypeContext::new();
847 ctx.set_dim("n", 10);
848 assert!(constraint.check(&ctx));
849
850 ctx.set_dim("n", 0);
851 assert!(!constraint.check(&ctx));
852 }
853
854 #[test]
855 fn test_type_with_constraints() {
856 let ty = DependentType::matrix("Float", DimExpr::var("m"), DimExpr::var("n"))
857 .with_constraint(
858 DimConstraint::new(DimExpr::var("m"), DimRelation::Equal, DimExpr::var("n"))
859 .with_message("Matrix must be square"),
860 );
861
862 let mut ctx = DependentTypeContext::new();
863 ctx.set_dim("m", 10);
864 ctx.set_dim("n", 10);
865 assert!(ty.check_constraints(&ctx).is_ok());
866
867 ctx.set_dim("n", 20);
868 assert!(ty.check_constraints(&ctx).is_err());
869 }
870
871 #[test]
872 fn test_type_compatibility() {
873 let ty1 = DependentType::matrix("Float", DimExpr::var("m"), DimExpr::var("n"));
874 let ty2 = DependentType::matrix("Float", DimExpr::var("m"), DimExpr::var("n"));
875
876 let mut ctx = DependentTypeContext::new();
877 ctx.set_dim("m", 10);
878 ctx.set_dim("n", 20);
879
880 assert!(ty1.is_compatible_with(&ty2, &ctx));
881
882 let ty3 = DependentType::matrix("Int", DimExpr::var("m"), DimExpr::var("n"));
883 assert!(!ty1.is_compatible_with(&ty3, &ctx));
884 }
885
886 #[test]
887 fn test_free_variables() {
888 let expr = DimExpr::var("n")
889 .add(DimExpr::var("m"))
890 .mul(DimExpr::var("k"));
891 let vars = expr.free_variables();
892 assert_eq!(vars.len(), 3);
893 assert!(vars.contains(&"k".to_string()));
894 assert!(vars.contains(&"m".to_string()));
895 assert!(vars.contains(&"n".to_string()));
896 }
897
898 #[test]
899 fn test_substitute() {
900 let expr = DimExpr::var("n").add(DimExpr::constant(5));
901 let substituted = expr.substitute("n", &DimExpr::constant(10));
902
903 let ctx = DependentTypeContext::new();
904 assert_eq!(substituted.eval(&ctx), Some(15));
905 }
906
907 #[test]
908 fn test_ceil_div() {
909 let expr = DimExpr::constant(10).ceil_div(DimExpr::constant(3));
910 let ctx = DependentTypeContext::new();
911 assert_eq!(expr.eval(&ctx), Some(4)); }
913
914 #[test]
915 fn test_context_unify() {
916 let mut ctx = DependentTypeContext::new();
917
918 let success = ctx.unify(&DimExpr::var("n"), &DimExpr::constant(10));
920 assert!(success);
921 assert_eq!(ctx.get_dim("n"), Some(&10));
922
923 let success = ctx.unify(&DimExpr::var("n"), &DimExpr::constant(10));
925 assert!(success);
926
927 let success = ctx.unify(&DimExpr::var("n"), &DimExpr::constant(20));
929 assert!(!success);
930 }
931
932 #[test]
933 fn test_patterns() {
934 let mut ctx = DependentTypeContext::new();
935 ctx.set_dim("n", 64);
936 ctx.set_dim("batch", 32);
937 ctx.set_dim("heads", 8);
938 ctx.set_dim("seq_len", 512);
939 ctx.set_dim("head_dim", 64);
940
941 let sq = patterns::square_matrix("Float", DimExpr::var("n"));
942 assert_eq!(sq.eval_shape(&ctx), Some(vec![64, 64]));
943
944 let attn = patterns::attention_tensor(
945 DimExpr::var("batch"),
946 DimExpr::var("heads"),
947 DimExpr::var("seq_len"),
948 DimExpr::var("head_dim"),
949 );
950 assert_eq!(attn.eval_shape(&ctx), Some(vec![32, 8, 512, 64]));
951 }
952
953 #[test]
954 fn test_registry() {
955 let mut registry = DependentTypeRegistry::new();
956
957 registry
958 .register(DependentType::vector("Float", DimExpr::var("n")).with_name("FloatVector"));
959
960 assert!(registry.contains("FloatVector"));
961 assert_eq!(registry.len(), 1);
962
963 let ty = registry.get("FloatVector").unwrap();
964 assert_eq!(ty.base_type, "Vector");
965 }
966
967 #[test]
968 fn test_dim_display() {
969 let expr = DimExpr::var("n").add(DimExpr::var("m"));
970 assert_eq!(format!("{}", expr), "(n + m)");
971
972 let expr = DimExpr::var("a").mul(DimExpr::constant(2));
973 assert_eq!(format!("{}", expr), "(a * 2)");
974 }
975
976 #[test]
977 fn test_type_name() {
978 let ty = DependentType::matrix("Float", DimExpr::var("m"), DimExpr::var("n"));
979 assert_eq!(ty.type_name(), "Matrix<Float, m, n>");
980
981 let ty = ty.with_name("MyMatrix");
982 assert_eq!(ty.type_name(), "MyMatrix");
983 }
984}