1use serde::{Deserialize, Serialize};
32use std::collections::{HashMap, HashSet};
33use std::fmt;
34
35use crate::{ParametricType, Term};
36
37#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
42pub enum IndexExpr {
43 Var(String),
45 Const(i64),
47 Add(Box<IndexExpr>, Box<IndexExpr>),
49 Sub(Box<IndexExpr>, Box<IndexExpr>),
51 Mul(Box<IndexExpr>, Box<IndexExpr>),
53 Div(Box<IndexExpr>, Box<IndexExpr>),
55 Min(Box<IndexExpr>, Box<IndexExpr>),
57 Max(Box<IndexExpr>, Box<IndexExpr>),
59}
60
61impl IndexExpr {
62 pub fn var(name: impl Into<String>) -> Self {
64 IndexExpr::Var(name.into())
65 }
66
67 pub fn constant(value: i64) -> Self {
69 IndexExpr::Const(value)
70 }
71
72 #[allow(clippy::should_implement_trait)]
74 pub fn add(left: IndexExpr, right: IndexExpr) -> Self {
75 IndexExpr::Add(Box::new(left), Box::new(right))
76 }
77
78 #[allow(clippy::should_implement_trait)]
80 pub fn sub(left: IndexExpr, right: IndexExpr) -> Self {
81 IndexExpr::Sub(Box::new(left), Box::new(right))
82 }
83
84 #[allow(clippy::should_implement_trait)]
86 pub fn mul(left: IndexExpr, right: IndexExpr) -> Self {
87 IndexExpr::Mul(Box::new(left), Box::new(right))
88 }
89
90 #[allow(clippy::should_implement_trait)]
92 pub fn div(left: IndexExpr, right: IndexExpr) -> Self {
93 IndexExpr::Div(Box::new(left), Box::new(right))
94 }
95
96 pub fn min(left: IndexExpr, right: IndexExpr) -> Self {
98 IndexExpr::Min(Box::new(left), Box::new(right))
99 }
100
101 pub fn max(left: IndexExpr, right: IndexExpr) -> Self {
103 IndexExpr::Max(Box::new(left), Box::new(right))
104 }
105
106 pub fn free_vars(&self) -> HashSet<String> {
108 let mut vars = HashSet::new();
109 self.collect_vars(&mut vars);
110 vars
111 }
112
113 fn collect_vars(&self, vars: &mut HashSet<String>) {
114 match self {
115 IndexExpr::Var(name) => {
116 vars.insert(name.clone());
117 }
118 IndexExpr::Const(_) => {}
119 IndexExpr::Add(l, r)
120 | IndexExpr::Sub(l, r)
121 | IndexExpr::Mul(l, r)
122 | IndexExpr::Div(l, r)
123 | IndexExpr::Min(l, r)
124 | IndexExpr::Max(l, r) => {
125 l.collect_vars(vars);
126 r.collect_vars(vars);
127 }
128 }
129 }
130
131 pub fn substitute(&self, subst: &HashMap<String, IndexExpr>) -> IndexExpr {
133 match self {
134 IndexExpr::Var(name) => subst.get(name).cloned().unwrap_or_else(|| self.clone()),
135 IndexExpr::Const(_) => self.clone(),
136 IndexExpr::Add(l, r) => {
137 IndexExpr::Add(Box::new(l.substitute(subst)), Box::new(r.substitute(subst)))
138 }
139 IndexExpr::Sub(l, r) => {
140 IndexExpr::Sub(Box::new(l.substitute(subst)), Box::new(r.substitute(subst)))
141 }
142 IndexExpr::Mul(l, r) => {
143 IndexExpr::Mul(Box::new(l.substitute(subst)), Box::new(r.substitute(subst)))
144 }
145 IndexExpr::Div(l, r) => {
146 IndexExpr::Div(Box::new(l.substitute(subst)), Box::new(r.substitute(subst)))
147 }
148 IndexExpr::Min(l, r) => {
149 IndexExpr::Min(Box::new(l.substitute(subst)), Box::new(r.substitute(subst)))
150 }
151 IndexExpr::Max(l, r) => {
152 IndexExpr::Max(Box::new(l.substitute(subst)), Box::new(r.substitute(subst)))
153 }
154 }
155 }
156
157 pub fn simplify(&self) -> IndexExpr {
159 match self {
160 IndexExpr::Add(l, r) => match (l.simplify(), r.simplify()) {
161 (IndexExpr::Const(0), e) | (e, IndexExpr::Const(0)) => e,
162 (IndexExpr::Const(a), IndexExpr::Const(b)) => IndexExpr::Const(a + b),
163 (l, r) => IndexExpr::Add(Box::new(l), Box::new(r)),
164 },
165 IndexExpr::Sub(l, r) => match (l.simplify(), r.simplify()) {
166 (e, IndexExpr::Const(0)) => e,
167 (IndexExpr::Const(a), IndexExpr::Const(b)) => IndexExpr::Const(a - b),
168 (l, r) if l == r => IndexExpr::Const(0),
169 (l, r) => IndexExpr::Sub(Box::new(l), Box::new(r)),
170 },
171 IndexExpr::Mul(l, r) => match (l.simplify(), r.simplify()) {
172 (IndexExpr::Const(0), _) | (_, IndexExpr::Const(0)) => IndexExpr::Const(0),
173 (IndexExpr::Const(1), e) | (e, IndexExpr::Const(1)) => e,
174 (IndexExpr::Const(a), IndexExpr::Const(b)) => IndexExpr::Const(a * b),
175 (l, r) => IndexExpr::Mul(Box::new(l), Box::new(r)),
176 },
177 IndexExpr::Div(l, r) => match (l.simplify(), r.simplify()) {
178 (IndexExpr::Const(0), _) => IndexExpr::Const(0),
179 (e, IndexExpr::Const(1)) => e,
180 (IndexExpr::Const(a), IndexExpr::Const(b)) if b != 0 => IndexExpr::Const(a / b),
181 (l, r) if l == r => IndexExpr::Const(1),
182 (l, r) => IndexExpr::Div(Box::new(l), Box::new(r)),
183 },
184 IndexExpr::Min(l, r) => match (l.simplify(), r.simplify()) {
185 (IndexExpr::Const(a), IndexExpr::Const(b)) => IndexExpr::Const(a.min(b)),
186 (l, r) if l == r => l,
187 (l, r) => IndexExpr::Min(Box::new(l), Box::new(r)),
188 },
189 IndexExpr::Max(l, r) => match (l.simplify(), r.simplify()) {
190 (IndexExpr::Const(a), IndexExpr::Const(b)) => IndexExpr::Const(a.max(b)),
191 (l, r) if l == r => l,
192 (l, r) => IndexExpr::Max(Box::new(l), Box::new(r)),
193 },
194 _ => self.clone(),
195 }
196 }
197
198 pub fn try_eval(&self) -> Option<i64> {
200 match self {
201 IndexExpr::Const(v) => Some(*v),
202 IndexExpr::Add(l, r) => Some(l.try_eval()? + r.try_eval()?),
203 IndexExpr::Sub(l, r) => Some(l.try_eval()? - r.try_eval()?),
204 IndexExpr::Mul(l, r) => Some(l.try_eval()? * r.try_eval()?),
205 IndexExpr::Div(l, r) => {
206 let rv = r.try_eval()?;
207 if rv != 0 {
208 Some(l.try_eval()? / rv)
209 } else {
210 None
211 }
212 }
213 IndexExpr::Min(l, r) => Some(l.try_eval()?.min(r.try_eval()?)),
214 IndexExpr::Max(l, r) => Some(l.try_eval()?.max(r.try_eval()?)),
215 IndexExpr::Var(_) => None,
216 }
217 }
218}
219
220impl fmt::Display for IndexExpr {
221 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
222 match self {
223 IndexExpr::Var(name) => write!(f, "{}", name),
224 IndexExpr::Const(v) => write!(f, "{}", v),
225 IndexExpr::Add(l, r) => write!(f, "({} + {})", l, r),
226 IndexExpr::Sub(l, r) => write!(f, "({} - {})", l, r),
227 IndexExpr::Mul(l, r) => write!(f, "({} * {})", l, r),
228 IndexExpr::Div(l, r) => write!(f, "({} / {})", l, r),
229 IndexExpr::Min(l, r) => write!(f, "min({}, {})", l, r),
230 IndexExpr::Max(l, r) => write!(f, "max({}, {})", l, r),
231 }
232 }
233}
234
235#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
237pub enum DimConstraint {
238 Eq(IndexExpr, IndexExpr),
240 Lt(IndexExpr, IndexExpr),
242 Lte(IndexExpr, IndexExpr),
244 Gt(IndexExpr, IndexExpr),
246 Gte(IndexExpr, IndexExpr),
248 And(Box<DimConstraint>, Box<DimConstraint>),
250 Or(Box<DimConstraint>, Box<DimConstraint>),
252 Not(Box<DimConstraint>),
254}
255
256impl DimConstraint {
257 pub fn eq(left: IndexExpr, right: IndexExpr) -> Self {
258 DimConstraint::Eq(left, right)
259 }
260
261 pub fn lt(left: IndexExpr, right: IndexExpr) -> Self {
262 DimConstraint::Lt(left, right)
263 }
264
265 pub fn lte(left: IndexExpr, right: IndexExpr) -> Self {
266 DimConstraint::Lte(left, right)
267 }
268
269 pub fn gt(left: IndexExpr, right: IndexExpr) -> Self {
270 DimConstraint::Gt(left, right)
271 }
272
273 pub fn gte(left: IndexExpr, right: IndexExpr) -> Self {
274 DimConstraint::Gte(left, right)
275 }
276
277 pub fn and(left: DimConstraint, right: DimConstraint) -> Self {
278 DimConstraint::And(Box::new(left), Box::new(right))
279 }
280
281 pub fn or(left: DimConstraint, right: DimConstraint) -> Self {
282 DimConstraint::Or(Box::new(left), Box::new(right))
283 }
284
285 #[allow(clippy::should_implement_trait)]
286 pub fn not(constraint: DimConstraint) -> Self {
287 DimConstraint::Not(Box::new(constraint))
288 }
289
290 pub fn referenced_vars(&self) -> HashSet<String> {
292 let mut vars = HashSet::new();
293 self.collect_referenced_vars(&mut vars);
294 vars
295 }
296
297 fn collect_referenced_vars(&self, vars: &mut HashSet<String>) {
298 match self {
299 DimConstraint::Eq(l, r)
300 | DimConstraint::Lt(l, r)
301 | DimConstraint::Lte(l, r)
302 | DimConstraint::Gt(l, r)
303 | DimConstraint::Gte(l, r) => {
304 vars.extend(l.free_vars());
305 vars.extend(r.free_vars());
306 }
307 DimConstraint::And(l, r) | DimConstraint::Or(l, r) => {
308 l.collect_referenced_vars(vars);
309 r.collect_referenced_vars(vars);
310 }
311 DimConstraint::Not(c) => c.collect_referenced_vars(vars),
312 }
313 }
314
315 pub fn simplify(&self) -> DimConstraint {
317 match self {
318 DimConstraint::Eq(l, r) => DimConstraint::Eq(l.simplify(), r.simplify()),
319 DimConstraint::Lt(l, r) => DimConstraint::Lt(l.simplify(), r.simplify()),
320 DimConstraint::Lte(l, r) => DimConstraint::Lte(l.simplify(), r.simplify()),
321 DimConstraint::Gt(l, r) => DimConstraint::Gt(l.simplify(), r.simplify()),
322 DimConstraint::Gte(l, r) => DimConstraint::Gte(l.simplify(), r.simplify()),
323 DimConstraint::And(l, r) => {
324 DimConstraint::And(Box::new(l.simplify()), Box::new(r.simplify()))
325 }
326 DimConstraint::Or(l, r) => {
327 DimConstraint::Or(Box::new(l.simplify()), Box::new(r.simplify()))
328 }
329 DimConstraint::Not(c) => DimConstraint::Not(Box::new(c.simplify())),
330 }
331 }
332}
333
334impl fmt::Display for DimConstraint {
335 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
336 match self {
337 DimConstraint::Eq(l, r) => write!(f, "{} == {}", l, r),
338 DimConstraint::Lt(l, r) => write!(f, "{} < {}", l, r),
339 DimConstraint::Lte(l, r) => write!(f, "{} <= {}", l, r),
340 DimConstraint::Gt(l, r) => write!(f, "{} > {}", l, r),
341 DimConstraint::Gte(l, r) => write!(f, "{} >= {}", l, r),
342 DimConstraint::And(l, r) => write!(f, "({} ∧ {})", l, r),
343 DimConstraint::Or(l, r) => write!(f, "({} ∨ {})", l, r),
344 DimConstraint::Not(c) => write!(f, "¬{}", c),
345 }
346 }
347}
348
349#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
356pub enum DependentType {
357 Base(ParametricType),
359 Vector {
361 length: IndexExpr,
362 element_type: Box<DependentType>,
363 },
364 Matrix {
366 rows: IndexExpr,
367 cols: IndexExpr,
368 element_type: Box<DependentType>,
369 },
370 Tensor {
372 shape: Vec<IndexExpr>,
373 element_type: Box<DependentType>,
374 },
375 DependentFunction {
377 param_name: String,
378 param_type: Box<DependentType>,
379 return_type: Box<DependentType>,
380 },
381 Refinement {
383 var_name: String,
384 base_type: Box<DependentType>,
385 predicate: Term,
386 },
387 Constrained {
389 base_type: Box<DependentType>,
390 constraints: Vec<DimConstraint>,
391 },
392}
393
394impl DependentType {
395 pub fn base(param_type: ParametricType) -> Self {
397 DependentType::Base(param_type)
398 }
399
400 pub fn vector(length: IndexExpr, element_type: impl Into<String>) -> Self {
402 DependentType::Vector {
403 length,
404 element_type: Box::new(DependentType::Base(ParametricType::concrete(element_type))),
405 }
406 }
407
408 pub fn matrix(rows: IndexExpr, cols: IndexExpr, element_type: impl Into<String>) -> Self {
410 DependentType::Matrix {
411 rows,
412 cols,
413 element_type: Box::new(DependentType::Base(ParametricType::concrete(element_type))),
414 }
415 }
416
417 pub fn tensor(shape: Vec<IndexExpr>, element_type: impl Into<String>) -> Self {
419 DependentType::Tensor {
420 shape,
421 element_type: Box::new(DependentType::Base(ParametricType::concrete(element_type))),
422 }
423 }
424
425 pub fn dependent_function(
427 param_name: impl Into<String>,
428 param_type: DependentType,
429 return_type: DependentType,
430 ) -> Self {
431 DependentType::DependentFunction {
432 param_name: param_name.into(),
433 param_type: Box::new(param_type),
434 return_type: Box::new(return_type),
435 }
436 }
437
438 pub fn refinement(
440 var_name: impl Into<String>,
441 base_type: DependentType,
442 predicate: Term,
443 ) -> Self {
444 DependentType::Refinement {
445 var_name: var_name.into(),
446 base_type: Box::new(base_type),
447 predicate,
448 }
449 }
450
451 pub fn with_constraints(self, constraints: Vec<DimConstraint>) -> Self {
453 DependentType::Constrained {
454 base_type: Box::new(self),
455 constraints,
456 }
457 }
458
459 pub fn free_index_vars(&self) -> HashSet<String> {
461 let mut vars = HashSet::new();
462 self.collect_free_index_vars(&mut vars, &HashSet::new());
463 vars
464 }
465
466 fn collect_free_index_vars(&self, vars: &mut HashSet<String>, bound: &HashSet<String>) {
467 match self {
468 DependentType::Base(_) => {}
469 DependentType::Vector {
470 length,
471 element_type,
472 } => {
473 vars.extend(length.free_vars().difference(bound).cloned());
474 element_type.collect_free_index_vars(vars, bound);
475 }
476 DependentType::Matrix {
477 rows,
478 cols,
479 element_type,
480 } => {
481 vars.extend(rows.free_vars().difference(bound).cloned());
482 vars.extend(cols.free_vars().difference(bound).cloned());
483 element_type.collect_free_index_vars(vars, bound);
484 }
485 DependentType::Tensor {
486 shape,
487 element_type,
488 } => {
489 for dim in shape {
490 vars.extend(dim.free_vars().difference(bound).cloned());
491 }
492 element_type.collect_free_index_vars(vars, bound);
493 }
494 DependentType::DependentFunction {
495 param_name,
496 param_type,
497 return_type,
498 } => {
499 param_type.collect_free_index_vars(vars, bound);
500 let mut new_bound = bound.clone();
501 new_bound.insert(param_name.clone());
502 return_type.collect_free_index_vars(vars, &new_bound);
503 }
504 DependentType::Refinement {
505 var_name: _,
506 base_type,
507 predicate: _,
508 } => {
509 base_type.collect_free_index_vars(vars, bound);
510 }
511 DependentType::Constrained {
512 base_type,
513 constraints,
514 } => {
515 base_type.collect_free_index_vars(vars, bound);
516 for constraint in constraints {
517 vars.extend(constraint.referenced_vars().difference(bound).cloned());
518 }
519 }
520 }
521 }
522
523 pub fn is_well_formed(&self) -> bool {
525 self.free_index_vars().is_empty()
526 }
527}
528
529impl fmt::Display for DependentType {
530 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
531 match self {
532 DependentType::Base(t) => write!(f, "{}", t),
533 DependentType::Vector {
534 length,
535 element_type,
536 } => write!(f, "Vec<{}, {}>", length, element_type),
537 DependentType::Matrix {
538 rows,
539 cols,
540 element_type,
541 } => write!(f, "Matrix<{}, {}, {}>", rows, cols, element_type),
542 DependentType::Tensor {
543 shape,
544 element_type,
545 } => {
546 write!(f, "Tensor<[")?;
547 for (i, dim) in shape.iter().enumerate() {
548 if i > 0 {
549 write!(f, ", ")?;
550 }
551 write!(f, "{}", dim)?;
552 }
553 write!(f, "], {}>", element_type)
554 }
555 DependentType::DependentFunction {
556 param_name,
557 param_type,
558 return_type,
559 } => write!(f, "({}: {}) -> {}", param_name, param_type, return_type),
560 DependentType::Refinement {
561 var_name,
562 base_type,
563 predicate,
564 } => write!(f, "{{{}:{} | {}}}", var_name, base_type, predicate),
565 DependentType::Constrained {
566 base_type,
567 constraints,
568 } => {
569 write!(f, "{} where ", base_type)?;
570 for (i, c) in constraints.iter().enumerate() {
571 if i > 0 {
572 write!(f, ", ")?;
573 }
574 write!(f, "{}", c)?;
575 }
576 Ok(())
577 }
578 }
579 }
580}
581
582#[derive(Clone, Debug, Default)]
584pub struct DependentTypeContext {
585 index_bindings: HashMap<String, i64>,
587 constraints: Vec<DimConstraint>,
589}
590
591impl DependentTypeContext {
592 pub fn new() -> Self {
593 Self::default()
594 }
595
596 pub fn bind_index(&mut self, name: impl Into<String>, value: i64) {
598 self.index_bindings.insert(name.into(), value);
599 }
600
601 pub fn add_constraint(&mut self, constraint: DimConstraint) {
603 self.constraints.push(constraint);
604 }
605
606 pub fn is_satisfiable(&self) -> bool {
608 for constraint in &self.constraints {
610 if !self.check_constraint(constraint) {
611 return false;
612 }
613 }
614 true
615 }
616
617 fn check_constraint(&self, constraint: &DimConstraint) -> bool {
618 match constraint {
619 DimConstraint::Eq(l, r) => {
620 let lv = self.eval_index(l);
621 let rv = self.eval_index(r);
622 match (lv, rv) {
623 (Some(a), Some(b)) => a == b,
624 _ => true, }
626 }
627 DimConstraint::Lt(l, r) => {
628 let lv = self.eval_index(l);
629 let rv = self.eval_index(r);
630 match (lv, rv) {
631 (Some(a), Some(b)) => a < b,
632 _ => true,
633 }
634 }
635 DimConstraint::Lte(l, r) => {
636 let lv = self.eval_index(l);
637 let rv = self.eval_index(r);
638 match (lv, rv) {
639 (Some(a), Some(b)) => a <= b,
640 _ => true,
641 }
642 }
643 DimConstraint::Gt(l, r) => {
644 let lv = self.eval_index(l);
645 let rv = self.eval_index(r);
646 match (lv, rv) {
647 (Some(a), Some(b)) => a > b,
648 _ => true,
649 }
650 }
651 DimConstraint::Gte(l, r) => {
652 let lv = self.eval_index(l);
653 let rv = self.eval_index(r);
654 match (lv, rv) {
655 (Some(a), Some(b)) => a >= b,
656 _ => true,
657 }
658 }
659 DimConstraint::And(l, r) => self.check_constraint(l) && self.check_constraint(r),
660 DimConstraint::Or(l, r) => self.check_constraint(l) || self.check_constraint(r),
661 DimConstraint::Not(c) => !self.check_constraint(c),
662 }
663 }
664
665 fn eval_index(&self, expr: &IndexExpr) -> Option<i64> {
666 match expr {
667 IndexExpr::Var(name) => self.index_bindings.get(name).copied(),
668 IndexExpr::Const(v) => Some(*v),
669 IndexExpr::Add(l, r) => Some(self.eval_index(l)? + self.eval_index(r)?),
670 IndexExpr::Sub(l, r) => Some(self.eval_index(l)? - self.eval_index(r)?),
671 IndexExpr::Mul(l, r) => Some(self.eval_index(l)? * self.eval_index(r)?),
672 IndexExpr::Div(l, r) => {
673 let rv = self.eval_index(r)?;
674 if rv != 0 {
675 Some(self.eval_index(l)? / rv)
676 } else {
677 None
678 }
679 }
680 IndexExpr::Min(l, r) => Some(self.eval_index(l)?.min(self.eval_index(r)?)),
681 IndexExpr::Max(l, r) => Some(self.eval_index(l)?.max(self.eval_index(r)?)),
682 }
683 }
684}
685
686#[cfg(test)]
687mod tests {
688 use super::*;
689
690 #[test]
691 fn test_index_expr_basics() {
692 let n = IndexExpr::var("n");
693 let m = IndexExpr::var("m");
694 let c = IndexExpr::constant(10);
695
696 assert_eq!(n.to_string(), "n");
697 assert_eq!(c.to_string(), "10");
698 assert_eq!(IndexExpr::add(n.clone(), m.clone()).to_string(), "(n + m)");
699 }
700
701 #[test]
702 fn test_index_expr_simplification() {
703 let n = IndexExpr::var("n");
704 let zero = IndexExpr::constant(0);
705 let one = IndexExpr::constant(1);
706
707 let expr = IndexExpr::add(n.clone(), zero.clone());
709 assert_eq!(expr.simplify(), n);
710
711 let expr = IndexExpr::mul(n.clone(), one.clone());
713 assert_eq!(expr.simplify(), n);
714
715 let expr = IndexExpr::mul(n.clone(), zero.clone());
717 assert_eq!(expr.simplify(), zero);
718
719 let expr = IndexExpr::add(IndexExpr::constant(5), IndexExpr::constant(3));
721 assert_eq!(expr.simplify(), IndexExpr::constant(8));
722 }
723
724 #[test]
725 fn test_index_expr_eval() {
726 let expr = IndexExpr::add(IndexExpr::constant(5), IndexExpr::constant(3));
727 assert_eq!(expr.try_eval(), Some(8));
728
729 let expr = IndexExpr::mul(IndexExpr::constant(4), IndexExpr::constant(7));
730 assert_eq!(expr.try_eval(), Some(28));
731
732 let expr = IndexExpr::add(IndexExpr::var("n"), IndexExpr::constant(5));
733 assert_eq!(expr.try_eval(), None);
734 }
735
736 #[test]
737 fn test_dependent_vector_type() {
738 let n = IndexExpr::var("n");
739 let vec_type = DependentType::vector(n.clone(), "Int");
740
741 assert_eq!(vec_type.to_string(), "Vec<n, Int>");
742 assert_eq!(vec_type.free_index_vars(), {
743 let mut s = HashSet::new();
744 s.insert("n".to_string());
745 s
746 });
747 }
748
749 #[test]
750 fn test_dependent_matrix_type() {
751 let m = IndexExpr::var("m");
752 let n = IndexExpr::var("n");
753 let matrix_type = DependentType::matrix(m, n, "Float");
754
755 assert_eq!(matrix_type.to_string(), "Matrix<m, n, Float>");
756 }
757
758 #[test]
759 fn test_dependent_tensor_type() {
760 let d1 = IndexExpr::var("d1");
761 let d2 = IndexExpr::var("d2");
762 let d3 = IndexExpr::constant(10);
763
764 let tensor_type = DependentType::tensor(vec![d1, d2, d3], "Float");
765 assert_eq!(tensor_type.to_string(), "Tensor<[d1, d2, 10], Float>");
766 }
767
768 #[test]
769 fn test_dependent_function_type() {
770 let n_param = DependentType::base(ParametricType::concrete("Int"));
771 let n_var = IndexExpr::var("n");
772 let return_type = DependentType::vector(n_var, "Bool");
773
774 let func_type = DependentType::dependent_function("n", n_param, return_type);
775 assert_eq!(func_type.to_string(), "(n: Int) -> Vec<n, Bool>");
776 }
777
778 #[test]
779 fn test_dimension_constraints() {
780 let n = IndexExpr::var("n");
781 let m = IndexExpr::var("m");
782
783 let c1 = DimConstraint::lt(n.clone(), IndexExpr::constant(100));
784 let c2 = DimConstraint::gte(n.clone(), IndexExpr::constant(0));
785 let c3 = DimConstraint::eq(n.clone(), m.clone());
786
787 assert_eq!(c1.to_string(), "n < 100");
788 assert_eq!(c2.to_string(), "n >= 0");
789 assert_eq!(c3.to_string(), "n == m");
790
791 let combined = DimConstraint::and(c1, c2);
792 assert_eq!(combined.to_string(), "(n < 100 ∧ n >= 0)");
793 }
794
795 #[test]
796 fn test_constrained_type() {
797 let n = IndexExpr::var("n");
798 let vec_type = DependentType::vector(n.clone(), "Int");
799
800 let constraint = DimConstraint::lte(n.clone(), IndexExpr::constant(100));
801 let constrained = vec_type.with_constraints(vec![constraint]);
802
803 assert_eq!(constrained.to_string(), "Vec<n, Int> where n <= 100");
804 }
805
806 #[test]
807 fn test_type_context_satisfiability() {
808 let mut ctx = DependentTypeContext::new();
809 ctx.bind_index("n", 50);
810
811 let constraint = DimConstraint::lte(IndexExpr::var("n"), IndexExpr::constant(100));
812 ctx.add_constraint(constraint);
813
814 assert!(ctx.is_satisfiable());
815
816 let bad_constraint = DimConstraint::gt(IndexExpr::var("n"), IndexExpr::constant(100));
817 ctx.add_constraint(bad_constraint);
818
819 assert!(!ctx.is_satisfiable());
820 }
821
822 #[test]
823 fn test_refinement_type() {
824 let base = DependentType::base(ParametricType::concrete("Int"));
825 let predicate = Term::var("x"); let refined = DependentType::refinement("x", base, predicate);
828 assert!(refined.to_string().contains("{x:Int |"));
829 }
830
831 #[test]
832 fn test_free_index_vars_in_complex_type() {
833 let n_param = DependentType::base(ParametricType::concrete("Int"));
835 let n_var = IndexExpr::var("n");
836 let return_type = DependentType::matrix(n_var.clone(), n_var, "Float");
837
838 let func_type = DependentType::dependent_function("n", n_param, return_type);
839
840 assert!(func_type.is_well_formed());
842 }
843
844 #[test]
845 fn test_index_substitution() {
846 let n = IndexExpr::var("n");
847 let m = IndexExpr::var("m");
848 let expr = IndexExpr::add(n.clone(), m.clone());
849
850 let mut subst = HashMap::new();
851 subst.insert("n".to_string(), IndexExpr::constant(10));
852
853 let result = expr.substitute(&subst);
854 assert_eq!(result.to_string(), "(10 + m)");
855 }
856}