Skip to main content

tensorlogic_ir/
dependent.rs

1//! Dependent type system for value-dependent types in TensorLogic.
2//!
3//! This module implements dependent types, where types can depend on runtime values.
4//! This is crucial for tensor operations where dimensions are first-class values.
5//!
6//! # Examples
7//!
8//! ```
9//! use tensorlogic_ir::dependent::{DependentType, IndexExpr, DimConstraint};
10//!
11//! // Vector of length n: Vec<n, T>
12//! let n = IndexExpr::var("n");
13//! let vec_n_int = DependentType::vector(n.clone(), "Int");
14//!
15//! // Matrix with dimensions m×n: Matrix<m, n, T>
16//! let m = IndexExpr::var("m");
17//! let matrix_type = DependentType::matrix(m.clone(), n.clone(), "Float");
18//!
19//! // Bounded vector: Vec<n, T> where n <= 100
20//! let constraint = DimConstraint::lte(n.clone(), IndexExpr::constant(100));
21//! ```
22//!
23//! # Key Features
24//!
25//! - **Index expressions**: Arithmetic on dimension variables
26//! - **Dependent function types**: (x: T) -> U(x)
27//! - **Refinement types**: Types with predicates on values
28//! - **Dimension constraints**: Bounds and relationships between dimensions
29//! - **Type-level computation**: Compute types from values
30
31use serde::{Deserialize, Serialize};
32use std::collections::{HashMap, HashSet};
33use std::fmt;
34
35use crate::{ParametricType, Term};
36
37/// Index expression for dimension calculations.
38///
39/// Index expressions represent compile-time or runtime values used in type indices,
40/// particularly for tensor dimensions.
41#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
42pub enum IndexExpr {
43    /// Variable index (e.g., n, m)
44    Var(String),
45    /// Constant index value
46    Const(i64),
47    /// Addition: e1 + e2
48    Add(Box<IndexExpr>, Box<IndexExpr>),
49    /// Subtraction: e1 - e2
50    Sub(Box<IndexExpr>, Box<IndexExpr>),
51    /// Multiplication: e1 * e2
52    Mul(Box<IndexExpr>, Box<IndexExpr>),
53    /// Division: e1 / e2
54    Div(Box<IndexExpr>, Box<IndexExpr>),
55    /// Minimum: min(e1, e2)
56    Min(Box<IndexExpr>, Box<IndexExpr>),
57    /// Maximum: max(e1, e2)
58    Max(Box<IndexExpr>, Box<IndexExpr>),
59}
60
61impl IndexExpr {
62    /// Create a variable index expression
63    pub fn var(name: impl Into<String>) -> Self {
64        IndexExpr::Var(name.into())
65    }
66
67    /// Create a constant index expression
68    pub fn constant(value: i64) -> Self {
69        IndexExpr::Const(value)
70    }
71
72    /// Addition
73    #[allow(clippy::should_implement_trait)]
74    pub fn add(left: IndexExpr, right: IndexExpr) -> Self {
75        IndexExpr::Add(Box::new(left), Box::new(right))
76    }
77
78    /// Subtraction
79    #[allow(clippy::should_implement_trait)]
80    pub fn sub(left: IndexExpr, right: IndexExpr) -> Self {
81        IndexExpr::Sub(Box::new(left), Box::new(right))
82    }
83
84    /// Multiplication
85    #[allow(clippy::should_implement_trait)]
86    pub fn mul(left: IndexExpr, right: IndexExpr) -> Self {
87        IndexExpr::Mul(Box::new(left), Box::new(right))
88    }
89
90    /// Division
91    #[allow(clippy::should_implement_trait)]
92    pub fn div(left: IndexExpr, right: IndexExpr) -> Self {
93        IndexExpr::Div(Box::new(left), Box::new(right))
94    }
95
96    /// Minimum
97    pub fn min(left: IndexExpr, right: IndexExpr) -> Self {
98        IndexExpr::Min(Box::new(left), Box::new(right))
99    }
100
101    /// Maximum
102    pub fn max(left: IndexExpr, right: IndexExpr) -> Self {
103        IndexExpr::Max(Box::new(left), Box::new(right))
104    }
105
106    /// Get all free variables in this expression
107    pub fn free_vars(&self) -> HashSet<String> {
108        let mut vars = HashSet::new();
109        self.collect_vars(&mut vars);
110        vars
111    }
112
113    fn collect_vars(&self, vars: &mut HashSet<String>) {
114        match self {
115            IndexExpr::Var(name) => {
116                vars.insert(name.clone());
117            }
118            IndexExpr::Const(_) => {}
119            IndexExpr::Add(l, r)
120            | IndexExpr::Sub(l, r)
121            | IndexExpr::Mul(l, r)
122            | IndexExpr::Div(l, r)
123            | IndexExpr::Min(l, r)
124            | IndexExpr::Max(l, r) => {
125                l.collect_vars(vars);
126                r.collect_vars(vars);
127            }
128        }
129    }
130
131    /// Substitute variables with index expressions
132    pub fn substitute(&self, subst: &HashMap<String, IndexExpr>) -> IndexExpr {
133        match self {
134            IndexExpr::Var(name) => subst.get(name).cloned().unwrap_or_else(|| self.clone()),
135            IndexExpr::Const(_) => self.clone(),
136            IndexExpr::Add(l, r) => {
137                IndexExpr::Add(Box::new(l.substitute(subst)), Box::new(r.substitute(subst)))
138            }
139            IndexExpr::Sub(l, r) => {
140                IndexExpr::Sub(Box::new(l.substitute(subst)), Box::new(r.substitute(subst)))
141            }
142            IndexExpr::Mul(l, r) => {
143                IndexExpr::Mul(Box::new(l.substitute(subst)), Box::new(r.substitute(subst)))
144            }
145            IndexExpr::Div(l, r) => {
146                IndexExpr::Div(Box::new(l.substitute(subst)), Box::new(r.substitute(subst)))
147            }
148            IndexExpr::Min(l, r) => {
149                IndexExpr::Min(Box::new(l.substitute(subst)), Box::new(r.substitute(subst)))
150            }
151            IndexExpr::Max(l, r) => {
152                IndexExpr::Max(Box::new(l.substitute(subst)), Box::new(r.substitute(subst)))
153            }
154        }
155    }
156
157    /// Simplify the index expression
158    pub fn simplify(&self) -> IndexExpr {
159        match self {
160            IndexExpr::Add(l, r) => match (l.simplify(), r.simplify()) {
161                (IndexExpr::Const(0), e) | (e, IndexExpr::Const(0)) => e,
162                (IndexExpr::Const(a), IndexExpr::Const(b)) => IndexExpr::Const(a + b),
163                (l, r) => IndexExpr::Add(Box::new(l), Box::new(r)),
164            },
165            IndexExpr::Sub(l, r) => match (l.simplify(), r.simplify()) {
166                (e, IndexExpr::Const(0)) => e,
167                (IndexExpr::Const(a), IndexExpr::Const(b)) => IndexExpr::Const(a - b),
168                (l, r) if l == r => IndexExpr::Const(0),
169                (l, r) => IndexExpr::Sub(Box::new(l), Box::new(r)),
170            },
171            IndexExpr::Mul(l, r) => match (l.simplify(), r.simplify()) {
172                (IndexExpr::Const(0), _) | (_, IndexExpr::Const(0)) => IndexExpr::Const(0),
173                (IndexExpr::Const(1), e) | (e, IndexExpr::Const(1)) => e,
174                (IndexExpr::Const(a), IndexExpr::Const(b)) => IndexExpr::Const(a * b),
175                (l, r) => IndexExpr::Mul(Box::new(l), Box::new(r)),
176            },
177            IndexExpr::Div(l, r) => match (l.simplify(), r.simplify()) {
178                (IndexExpr::Const(0), _) => IndexExpr::Const(0),
179                (e, IndexExpr::Const(1)) => e,
180                (IndexExpr::Const(a), IndexExpr::Const(b)) if b != 0 => IndexExpr::Const(a / b),
181                (l, r) if l == r => IndexExpr::Const(1),
182                (l, r) => IndexExpr::Div(Box::new(l), Box::new(r)),
183            },
184            IndexExpr::Min(l, r) => match (l.simplify(), r.simplify()) {
185                (IndexExpr::Const(a), IndexExpr::Const(b)) => IndexExpr::Const(a.min(b)),
186                (l, r) if l == r => l,
187                (l, r) => IndexExpr::Min(Box::new(l), Box::new(r)),
188            },
189            IndexExpr::Max(l, r) => match (l.simplify(), r.simplify()) {
190                (IndexExpr::Const(a), IndexExpr::Const(b)) => IndexExpr::Const(a.max(b)),
191                (l, r) if l == r => l,
192                (l, r) => IndexExpr::Max(Box::new(l), Box::new(r)),
193            },
194            _ => self.clone(),
195        }
196    }
197
198    /// Try to evaluate to a constant value
199    pub fn try_eval(&self) -> Option<i64> {
200        match self {
201            IndexExpr::Const(v) => Some(*v),
202            IndexExpr::Add(l, r) => Some(l.try_eval()? + r.try_eval()?),
203            IndexExpr::Sub(l, r) => Some(l.try_eval()? - r.try_eval()?),
204            IndexExpr::Mul(l, r) => Some(l.try_eval()? * r.try_eval()?),
205            IndexExpr::Div(l, r) => {
206                let rv = r.try_eval()?;
207                if rv != 0 {
208                    Some(l.try_eval()? / rv)
209                } else {
210                    None
211                }
212            }
213            IndexExpr::Min(l, r) => Some(l.try_eval()?.min(r.try_eval()?)),
214            IndexExpr::Max(l, r) => Some(l.try_eval()?.max(r.try_eval()?)),
215            IndexExpr::Var(_) => None,
216        }
217    }
218}
219
220impl fmt::Display for IndexExpr {
221    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
222        match self {
223            IndexExpr::Var(name) => write!(f, "{}", name),
224            IndexExpr::Const(v) => write!(f, "{}", v),
225            IndexExpr::Add(l, r) => write!(f, "({} + {})", l, r),
226            IndexExpr::Sub(l, r) => write!(f, "({} - {})", l, r),
227            IndexExpr::Mul(l, r) => write!(f, "({} * {})", l, r),
228            IndexExpr::Div(l, r) => write!(f, "({} / {})", l, r),
229            IndexExpr::Min(l, r) => write!(f, "min({}, {})", l, r),
230            IndexExpr::Max(l, r) => write!(f, "max({}, {})", l, r),
231        }
232    }
233}
234
235/// Dimension constraints for dependent types.
236#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
237pub enum DimConstraint {
238    /// Equality: e1 == e2
239    Eq(IndexExpr, IndexExpr),
240    /// Less than: e1 < e2
241    Lt(IndexExpr, IndexExpr),
242    /// Less than or equal: e1 <= e2
243    Lte(IndexExpr, IndexExpr),
244    /// Greater than: e1 > e2
245    Gt(IndexExpr, IndexExpr),
246    /// Greater than or equal: e1 >= e2
247    Gte(IndexExpr, IndexExpr),
248    /// Conjunction: c1 ∧ c2
249    And(Box<DimConstraint>, Box<DimConstraint>),
250    /// Disjunction: c1 ∨ c2
251    Or(Box<DimConstraint>, Box<DimConstraint>),
252    /// Negation: ¬c
253    Not(Box<DimConstraint>),
254}
255
256impl DimConstraint {
257    pub fn eq(left: IndexExpr, right: IndexExpr) -> Self {
258        DimConstraint::Eq(left, right)
259    }
260
261    pub fn lt(left: IndexExpr, right: IndexExpr) -> Self {
262        DimConstraint::Lt(left, right)
263    }
264
265    pub fn lte(left: IndexExpr, right: IndexExpr) -> Self {
266        DimConstraint::Lte(left, right)
267    }
268
269    pub fn gt(left: IndexExpr, right: IndexExpr) -> Self {
270        DimConstraint::Gt(left, right)
271    }
272
273    pub fn gte(left: IndexExpr, right: IndexExpr) -> Self {
274        DimConstraint::Gte(left, right)
275    }
276
277    pub fn and(left: DimConstraint, right: DimConstraint) -> Self {
278        DimConstraint::And(Box::new(left), Box::new(right))
279    }
280
281    pub fn or(left: DimConstraint, right: DimConstraint) -> Self {
282        DimConstraint::Or(Box::new(left), Box::new(right))
283    }
284
285    #[allow(clippy::should_implement_trait)]
286    pub fn not(constraint: DimConstraint) -> Self {
287        DimConstraint::Not(Box::new(constraint))
288    }
289
290    /// Get all index variables referenced in this constraint
291    pub fn referenced_vars(&self) -> HashSet<String> {
292        let mut vars = HashSet::new();
293        self.collect_referenced_vars(&mut vars);
294        vars
295    }
296
297    fn collect_referenced_vars(&self, vars: &mut HashSet<String>) {
298        match self {
299            DimConstraint::Eq(l, r)
300            | DimConstraint::Lt(l, r)
301            | DimConstraint::Lte(l, r)
302            | DimConstraint::Gt(l, r)
303            | DimConstraint::Gte(l, r) => {
304                vars.extend(l.free_vars());
305                vars.extend(r.free_vars());
306            }
307            DimConstraint::And(l, r) | DimConstraint::Or(l, r) => {
308                l.collect_referenced_vars(vars);
309                r.collect_referenced_vars(vars);
310            }
311            DimConstraint::Not(c) => c.collect_referenced_vars(vars),
312        }
313    }
314
315    /// Simplify the constraint
316    pub fn simplify(&self) -> DimConstraint {
317        match self {
318            DimConstraint::Eq(l, r) => DimConstraint::Eq(l.simplify(), r.simplify()),
319            DimConstraint::Lt(l, r) => DimConstraint::Lt(l.simplify(), r.simplify()),
320            DimConstraint::Lte(l, r) => DimConstraint::Lte(l.simplify(), r.simplify()),
321            DimConstraint::Gt(l, r) => DimConstraint::Gt(l.simplify(), r.simplify()),
322            DimConstraint::Gte(l, r) => DimConstraint::Gte(l.simplify(), r.simplify()),
323            DimConstraint::And(l, r) => {
324                DimConstraint::And(Box::new(l.simplify()), Box::new(r.simplify()))
325            }
326            DimConstraint::Or(l, r) => {
327                DimConstraint::Or(Box::new(l.simplify()), Box::new(r.simplify()))
328            }
329            DimConstraint::Not(c) => DimConstraint::Not(Box::new(c.simplify())),
330        }
331    }
332}
333
334impl fmt::Display for DimConstraint {
335    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
336        match self {
337            DimConstraint::Eq(l, r) => write!(f, "{} == {}", l, r),
338            DimConstraint::Lt(l, r) => write!(f, "{} < {}", l, r),
339            DimConstraint::Lte(l, r) => write!(f, "{} <= {}", l, r),
340            DimConstraint::Gt(l, r) => write!(f, "{} > {}", l, r),
341            DimConstraint::Gte(l, r) => write!(f, "{} >= {}", l, r),
342            DimConstraint::And(l, r) => write!(f, "({} ∧ {})", l, r),
343            DimConstraint::Or(l, r) => write!(f, "({} ∨ {})", l, r),
344            DimConstraint::Not(c) => write!(f, "¬{}", c),
345        }
346    }
347}
348
349/// Dependent type: types that depend on runtime values.
350///
351/// Examples:
352/// - `Vec<n, T>`: Vector of length n with elements of type T
353/// - `Matrix<m, n, T>`: Matrix with dimensions m×n
354/// - `(x: Int) -> Vec<x, Bool>`: Function returning a vector of length x
355#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
356pub enum DependentType {
357    /// Base parametric type (non-dependent)
358    Base(ParametricType),
359    /// Vector with dependent length: Vec<n, T>
360    Vector {
361        length: IndexExpr,
362        element_type: Box<DependentType>,
363    },
364    /// Matrix with dependent dimensions: Matrix<rows, cols, T>
365    Matrix {
366        rows: IndexExpr,
367        cols: IndexExpr,
368        element_type: Box<DependentType>,
369    },
370    /// Tensor with dependent shape: Tensor<[d1, d2, ...], T>
371    Tensor {
372        shape: Vec<IndexExpr>,
373        element_type: Box<DependentType>,
374    },
375    /// Dependent function type: (x: T1) -> T2(x)
376    DependentFunction {
377        param_name: String,
378        param_type: Box<DependentType>,
379        return_type: Box<DependentType>,
380    },
381    /// Refinement type: {x: T | P(x)}
382    Refinement {
383        var_name: String,
384        base_type: Box<DependentType>,
385        predicate: Term,
386    },
387    /// Constrained type: T where C
388    Constrained {
389        base_type: Box<DependentType>,
390        constraints: Vec<DimConstraint>,
391    },
392}
393
394impl DependentType {
395    /// Create a base non-dependent type
396    pub fn base(param_type: ParametricType) -> Self {
397        DependentType::Base(param_type)
398    }
399
400    /// Create a dependent vector type
401    pub fn vector(length: IndexExpr, element_type: impl Into<String>) -> Self {
402        DependentType::Vector {
403            length,
404            element_type: Box::new(DependentType::Base(ParametricType::concrete(element_type))),
405        }
406    }
407
408    /// Create a dependent matrix type
409    pub fn matrix(rows: IndexExpr, cols: IndexExpr, element_type: impl Into<String>) -> Self {
410        DependentType::Matrix {
411            rows,
412            cols,
413            element_type: Box::new(DependentType::Base(ParametricType::concrete(element_type))),
414        }
415    }
416
417    /// Create a dependent tensor type
418    pub fn tensor(shape: Vec<IndexExpr>, element_type: impl Into<String>) -> Self {
419        DependentType::Tensor {
420            shape,
421            element_type: Box::new(DependentType::Base(ParametricType::concrete(element_type))),
422        }
423    }
424
425    /// Create a dependent function type
426    pub fn dependent_function(
427        param_name: impl Into<String>,
428        param_type: DependentType,
429        return_type: DependentType,
430    ) -> Self {
431        DependentType::DependentFunction {
432            param_name: param_name.into(),
433            param_type: Box::new(param_type),
434            return_type: Box::new(return_type),
435        }
436    }
437
438    /// Create a refinement type
439    pub fn refinement(
440        var_name: impl Into<String>,
441        base_type: DependentType,
442        predicate: Term,
443    ) -> Self {
444        DependentType::Refinement {
445            var_name: var_name.into(),
446            base_type: Box::new(base_type),
447            predicate,
448        }
449    }
450
451    /// Add constraints to a type
452    pub fn with_constraints(self, constraints: Vec<DimConstraint>) -> Self {
453        DependentType::Constrained {
454            base_type: Box::new(self),
455            constraints,
456        }
457    }
458
459    /// Get all free index variables
460    pub fn free_index_vars(&self) -> HashSet<String> {
461        let mut vars = HashSet::new();
462        self.collect_free_index_vars(&mut vars, &HashSet::new());
463        vars
464    }
465
466    fn collect_free_index_vars(&self, vars: &mut HashSet<String>, bound: &HashSet<String>) {
467        match self {
468            DependentType::Base(_) => {}
469            DependentType::Vector {
470                length,
471                element_type,
472            } => {
473                vars.extend(length.free_vars().difference(bound).cloned());
474                element_type.collect_free_index_vars(vars, bound);
475            }
476            DependentType::Matrix {
477                rows,
478                cols,
479                element_type,
480            } => {
481                vars.extend(rows.free_vars().difference(bound).cloned());
482                vars.extend(cols.free_vars().difference(bound).cloned());
483                element_type.collect_free_index_vars(vars, bound);
484            }
485            DependentType::Tensor {
486                shape,
487                element_type,
488            } => {
489                for dim in shape {
490                    vars.extend(dim.free_vars().difference(bound).cloned());
491                }
492                element_type.collect_free_index_vars(vars, bound);
493            }
494            DependentType::DependentFunction {
495                param_name,
496                param_type,
497                return_type,
498            } => {
499                param_type.collect_free_index_vars(vars, bound);
500                let mut new_bound = bound.clone();
501                new_bound.insert(param_name.clone());
502                return_type.collect_free_index_vars(vars, &new_bound);
503            }
504            DependentType::Refinement {
505                var_name: _,
506                base_type,
507                predicate: _,
508            } => {
509                base_type.collect_free_index_vars(vars, bound);
510            }
511            DependentType::Constrained {
512                base_type,
513                constraints,
514            } => {
515                base_type.collect_free_index_vars(vars, bound);
516                for constraint in constraints {
517                    vars.extend(constraint.referenced_vars().difference(bound).cloned());
518                }
519            }
520        }
521    }
522
523    /// Check if this type is well-formed (no unbound index variables)
524    pub fn is_well_formed(&self) -> bool {
525        self.free_index_vars().is_empty()
526    }
527}
528
529impl fmt::Display for DependentType {
530    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
531        match self {
532            DependentType::Base(t) => write!(f, "{}", t),
533            DependentType::Vector {
534                length,
535                element_type,
536            } => write!(f, "Vec<{}, {}>", length, element_type),
537            DependentType::Matrix {
538                rows,
539                cols,
540                element_type,
541            } => write!(f, "Matrix<{}, {}, {}>", rows, cols, element_type),
542            DependentType::Tensor {
543                shape,
544                element_type,
545            } => {
546                write!(f, "Tensor<[")?;
547                for (i, dim) in shape.iter().enumerate() {
548                    if i > 0 {
549                        write!(f, ", ")?;
550                    }
551                    write!(f, "{}", dim)?;
552                }
553                write!(f, "], {}>", element_type)
554            }
555            DependentType::DependentFunction {
556                param_name,
557                param_type,
558                return_type,
559            } => write!(f, "({}: {}) -> {}", param_name, param_type, return_type),
560            DependentType::Refinement {
561                var_name,
562                base_type,
563                predicate,
564            } => write!(f, "{{{}:{} | {}}}", var_name, base_type, predicate),
565            DependentType::Constrained {
566                base_type,
567                constraints,
568            } => {
569                write!(f, "{} where ", base_type)?;
570                for (i, c) in constraints.iter().enumerate() {
571                    if i > 0 {
572                        write!(f, ", ")?;
573                    }
574                    write!(f, "{}", c)?;
575                }
576                Ok(())
577            }
578        }
579    }
580}
581
582/// Type checking context for dependent types.
583#[derive(Clone, Debug, Default)]
584pub struct DependentTypeContext {
585    /// Index variable bindings
586    index_bindings: HashMap<String, i64>,
587    /// Dimension constraints
588    constraints: Vec<DimConstraint>,
589}
590
591impl DependentTypeContext {
592    pub fn new() -> Self {
593        Self::default()
594    }
595
596    /// Bind an index variable to a value
597    pub fn bind_index(&mut self, name: impl Into<String>, value: i64) {
598        self.index_bindings.insert(name.into(), value);
599    }
600
601    /// Add a dimension constraint
602    pub fn add_constraint(&mut self, constraint: DimConstraint) {
603        self.constraints.push(constraint);
604    }
605
606    /// Check if constraints are satisfiable (simplified check)
607    pub fn is_satisfiable(&self) -> bool {
608        // For now, just check if we can evaluate all constraints with current bindings
609        for constraint in &self.constraints {
610            if !self.check_constraint(constraint) {
611                return false;
612            }
613        }
614        true
615    }
616
617    fn check_constraint(&self, constraint: &DimConstraint) -> bool {
618        match constraint {
619            DimConstraint::Eq(l, r) => {
620                let lv = self.eval_index(l);
621                let rv = self.eval_index(r);
622                match (lv, rv) {
623                    (Some(a), Some(b)) => a == b,
624                    _ => true, // Unknown, assume satisfiable
625                }
626            }
627            DimConstraint::Lt(l, r) => {
628                let lv = self.eval_index(l);
629                let rv = self.eval_index(r);
630                match (lv, rv) {
631                    (Some(a), Some(b)) => a < b,
632                    _ => true,
633                }
634            }
635            DimConstraint::Lte(l, r) => {
636                let lv = self.eval_index(l);
637                let rv = self.eval_index(r);
638                match (lv, rv) {
639                    (Some(a), Some(b)) => a <= b,
640                    _ => true,
641                }
642            }
643            DimConstraint::Gt(l, r) => {
644                let lv = self.eval_index(l);
645                let rv = self.eval_index(r);
646                match (lv, rv) {
647                    (Some(a), Some(b)) => a > b,
648                    _ => true,
649                }
650            }
651            DimConstraint::Gte(l, r) => {
652                let lv = self.eval_index(l);
653                let rv = self.eval_index(r);
654                match (lv, rv) {
655                    (Some(a), Some(b)) => a >= b,
656                    _ => true,
657                }
658            }
659            DimConstraint::And(l, r) => self.check_constraint(l) && self.check_constraint(r),
660            DimConstraint::Or(l, r) => self.check_constraint(l) || self.check_constraint(r),
661            DimConstraint::Not(c) => !self.check_constraint(c),
662        }
663    }
664
665    fn eval_index(&self, expr: &IndexExpr) -> Option<i64> {
666        match expr {
667            IndexExpr::Var(name) => self.index_bindings.get(name).copied(),
668            IndexExpr::Const(v) => Some(*v),
669            IndexExpr::Add(l, r) => Some(self.eval_index(l)? + self.eval_index(r)?),
670            IndexExpr::Sub(l, r) => Some(self.eval_index(l)? - self.eval_index(r)?),
671            IndexExpr::Mul(l, r) => Some(self.eval_index(l)? * self.eval_index(r)?),
672            IndexExpr::Div(l, r) => {
673                let rv = self.eval_index(r)?;
674                if rv != 0 {
675                    Some(self.eval_index(l)? / rv)
676                } else {
677                    None
678                }
679            }
680            IndexExpr::Min(l, r) => Some(self.eval_index(l)?.min(self.eval_index(r)?)),
681            IndexExpr::Max(l, r) => Some(self.eval_index(l)?.max(self.eval_index(r)?)),
682        }
683    }
684}
685
686#[cfg(test)]
687mod tests {
688    use super::*;
689
690    #[test]
691    fn test_index_expr_basics() {
692        let n = IndexExpr::var("n");
693        let m = IndexExpr::var("m");
694        let c = IndexExpr::constant(10);
695
696        assert_eq!(n.to_string(), "n");
697        assert_eq!(c.to_string(), "10");
698        assert_eq!(IndexExpr::add(n.clone(), m.clone()).to_string(), "(n + m)");
699    }
700
701    #[test]
702    fn test_index_expr_simplification() {
703        let n = IndexExpr::var("n");
704        let zero = IndexExpr::constant(0);
705        let one = IndexExpr::constant(1);
706
707        // n + 0 = n
708        let expr = IndexExpr::add(n.clone(), zero.clone());
709        assert_eq!(expr.simplify(), n);
710
711        // n * 1 = n
712        let expr = IndexExpr::mul(n.clone(), one.clone());
713        assert_eq!(expr.simplify(), n);
714
715        // n * 0 = 0
716        let expr = IndexExpr::mul(n.clone(), zero.clone());
717        assert_eq!(expr.simplify(), zero);
718
719        // 5 + 3 = 8
720        let expr = IndexExpr::add(IndexExpr::constant(5), IndexExpr::constant(3));
721        assert_eq!(expr.simplify(), IndexExpr::constant(8));
722    }
723
724    #[test]
725    fn test_index_expr_eval() {
726        let expr = IndexExpr::add(IndexExpr::constant(5), IndexExpr::constant(3));
727        assert_eq!(expr.try_eval(), Some(8));
728
729        let expr = IndexExpr::mul(IndexExpr::constant(4), IndexExpr::constant(7));
730        assert_eq!(expr.try_eval(), Some(28));
731
732        let expr = IndexExpr::add(IndexExpr::var("n"), IndexExpr::constant(5));
733        assert_eq!(expr.try_eval(), None);
734    }
735
736    #[test]
737    fn test_dependent_vector_type() {
738        let n = IndexExpr::var("n");
739        let vec_type = DependentType::vector(n.clone(), "Int");
740
741        assert_eq!(vec_type.to_string(), "Vec<n, Int>");
742        assert_eq!(vec_type.free_index_vars(), {
743            let mut s = HashSet::new();
744            s.insert("n".to_string());
745            s
746        });
747    }
748
749    #[test]
750    fn test_dependent_matrix_type() {
751        let m = IndexExpr::var("m");
752        let n = IndexExpr::var("n");
753        let matrix_type = DependentType::matrix(m, n, "Float");
754
755        assert_eq!(matrix_type.to_string(), "Matrix<m, n, Float>");
756    }
757
758    #[test]
759    fn test_dependent_tensor_type() {
760        let d1 = IndexExpr::var("d1");
761        let d2 = IndexExpr::var("d2");
762        let d3 = IndexExpr::constant(10);
763
764        let tensor_type = DependentType::tensor(vec![d1, d2, d3], "Float");
765        assert_eq!(tensor_type.to_string(), "Tensor<[d1, d2, 10], Float>");
766    }
767
768    #[test]
769    fn test_dependent_function_type() {
770        let n_param = DependentType::base(ParametricType::concrete("Int"));
771        let n_var = IndexExpr::var("n");
772        let return_type = DependentType::vector(n_var, "Bool");
773
774        let func_type = DependentType::dependent_function("n", n_param, return_type);
775        assert_eq!(func_type.to_string(), "(n: Int) -> Vec<n, Bool>");
776    }
777
778    #[test]
779    fn test_dimension_constraints() {
780        let n = IndexExpr::var("n");
781        let m = IndexExpr::var("m");
782
783        let c1 = DimConstraint::lt(n.clone(), IndexExpr::constant(100));
784        let c2 = DimConstraint::gte(n.clone(), IndexExpr::constant(0));
785        let c3 = DimConstraint::eq(n.clone(), m.clone());
786
787        assert_eq!(c1.to_string(), "n < 100");
788        assert_eq!(c2.to_string(), "n >= 0");
789        assert_eq!(c3.to_string(), "n == m");
790
791        let combined = DimConstraint::and(c1, c2);
792        assert_eq!(combined.to_string(), "(n < 100 ∧ n >= 0)");
793    }
794
795    #[test]
796    fn test_constrained_type() {
797        let n = IndexExpr::var("n");
798        let vec_type = DependentType::vector(n.clone(), "Int");
799
800        let constraint = DimConstraint::lte(n.clone(), IndexExpr::constant(100));
801        let constrained = vec_type.with_constraints(vec![constraint]);
802
803        assert_eq!(constrained.to_string(), "Vec<n, Int> where n <= 100");
804    }
805
806    #[test]
807    fn test_type_context_satisfiability() {
808        let mut ctx = DependentTypeContext::new();
809        ctx.bind_index("n", 50);
810
811        let constraint = DimConstraint::lte(IndexExpr::var("n"), IndexExpr::constant(100));
812        ctx.add_constraint(constraint);
813
814        assert!(ctx.is_satisfiable());
815
816        let bad_constraint = DimConstraint::gt(IndexExpr::var("n"), IndexExpr::constant(100));
817        ctx.add_constraint(bad_constraint);
818
819        assert!(!ctx.is_satisfiable());
820    }
821
822    #[test]
823    fn test_refinement_type() {
824        let base = DependentType::base(ParametricType::concrete("Int"));
825        let predicate = Term::var("x"); // Simplified predicate
826
827        let refined = DependentType::refinement("x", base, predicate);
828        assert!(refined.to_string().contains("{x:Int |"));
829    }
830
831    #[test]
832    fn test_free_index_vars_in_complex_type() {
833        // (n: Int) -> Matrix<n, n, Float>
834        let n_param = DependentType::base(ParametricType::concrete("Int"));
835        let n_var = IndexExpr::var("n");
836        let return_type = DependentType::matrix(n_var.clone(), n_var, "Float");
837
838        let func_type = DependentType::dependent_function("n", n_param, return_type);
839
840        // 'n' should be bound in the function, so no free variables
841        assert!(func_type.is_well_formed());
842    }
843
844    #[test]
845    fn test_index_substitution() {
846        let n = IndexExpr::var("n");
847        let m = IndexExpr::var("m");
848        let expr = IndexExpr::add(n.clone(), m.clone());
849
850        let mut subst = HashMap::new();
851        subst.insert("n".to_string(), IndexExpr::constant(10));
852
853        let result = expr.substitute(&subst);
854        assert_eq!(result.to_string(), "(10 + m)");
855    }
856}