tensorlogic_adapters/
dependent.rs

1//! Dependent types for expressing value-dependent type constraints.
2//!
3//! Dependent types allow types to depend on values, enabling precise specification
4//! of tensor dimensions, vector lengths, and other parameterized types.
5//!
6//! # Examples
7//!
8//! ```rust
9//! use tensorlogic_adapters::{DependentType, DimExpr, DependentTypeContext};
10//!
11//! // Create a vector type with dependent length
12//! let vec_n = DependentType::vector("T", DimExpr::var("n"));
13//!
14//! // Create a matrix type with dependent dimensions
15//! let matrix = DependentType::matrix("Float", DimExpr::var("m"), DimExpr::var("n"));
16//!
17//! // Evaluate dimensions in context
18//! let mut ctx = DependentTypeContext::new();
19//! ctx.set_dim("n", 10);
20//! ctx.set_dim("m", 5);
21//!
22//! assert_eq!(vec_n.eval_shape(&ctx), Some(vec![10]));
23//! assert_eq!(matrix.eval_shape(&ctx), Some(vec![5, 10]));
24//! ```
25
26use std::collections::HashMap;
27use std::fmt;
28
29/// A dimension expression that can be evaluated.
30#[derive(Debug, Clone, PartialEq)]
31pub enum DimExpr {
32    /// A concrete dimension value
33    Const(usize),
34    /// A dimension variable
35    Var(String),
36    /// Addition of dimensions
37    Add(Box<DimExpr>, Box<DimExpr>),
38    /// Subtraction of dimensions
39    Sub(Box<DimExpr>, Box<DimExpr>),
40    /// Multiplication of dimensions
41    Mul(Box<DimExpr>, Box<DimExpr>),
42    /// Division of dimensions (integer division)
43    Div(Box<DimExpr>, Box<DimExpr>),
44    /// Maximum of two dimensions
45    Max(Box<DimExpr>, Box<DimExpr>),
46    /// Minimum of two dimensions
47    Min(Box<DimExpr>, Box<DimExpr>),
48    /// Ceiling division (useful for padding/strides)
49    CeilDiv(Box<DimExpr>, Box<DimExpr>),
50}
51
52impl DimExpr {
53    /// Create a constant dimension.
54    pub fn constant(value: usize) -> Self {
55        DimExpr::Const(value)
56    }
57
58    /// Create a dimension variable.
59    pub fn var(name: impl Into<String>) -> Self {
60        DimExpr::Var(name.into())
61    }
62
63    /// Add two dimension expressions.
64    #[allow(clippy::should_implement_trait)]
65    pub fn add(self, other: DimExpr) -> Self {
66        DimExpr::Add(Box::new(self), Box::new(other))
67    }
68
69    /// Subtract a dimension expression from this one.
70    #[allow(clippy::should_implement_trait)]
71    pub fn sub(self, other: DimExpr) -> Self {
72        DimExpr::Sub(Box::new(self), Box::new(other))
73    }
74
75    /// Multiply two dimension expressions.
76    #[allow(clippy::should_implement_trait)]
77    pub fn mul(self, other: DimExpr) -> Self {
78        DimExpr::Mul(Box::new(self), Box::new(other))
79    }
80
81    /// Divide this dimension expression by another.
82    #[allow(clippy::should_implement_trait)]
83    pub fn div(self, other: DimExpr) -> Self {
84        DimExpr::Div(Box::new(self), Box::new(other))
85    }
86
87    /// Take the maximum of two dimension expressions.
88    pub fn max(self, other: DimExpr) -> Self {
89        DimExpr::Max(Box::new(self), Box::new(other))
90    }
91
92    /// Take the minimum of two dimension expressions.
93    pub fn min(self, other: DimExpr) -> Self {
94        DimExpr::Min(Box::new(self), Box::new(other))
95    }
96
97    /// Ceiling division.
98    pub fn ceil_div(self, other: DimExpr) -> Self {
99        DimExpr::CeilDiv(Box::new(self), Box::new(other))
100    }
101
102    /// Evaluate the dimension expression in a context.
103    pub fn eval(&self, ctx: &DependentTypeContext) -> Option<usize> {
104        match self {
105            DimExpr::Const(n) => Some(*n),
106            DimExpr::Var(name) => ctx.get_dim(name).copied(),
107            DimExpr::Add(a, b) => Some(a.eval(ctx)? + b.eval(ctx)?),
108            DimExpr::Sub(a, b) => {
109                let av = a.eval(ctx)?;
110                let bv = b.eval(ctx)?;
111                av.checked_sub(bv)
112            }
113            DimExpr::Mul(a, b) => Some(a.eval(ctx)? * b.eval(ctx)?),
114            DimExpr::Div(a, b) => {
115                let bv = b.eval(ctx)?;
116                if bv == 0 {
117                    return None;
118                }
119                Some(a.eval(ctx)? / bv)
120            }
121            DimExpr::Max(a, b) => Some(a.eval(ctx)?.max(b.eval(ctx)?)),
122            DimExpr::Min(a, b) => Some(a.eval(ctx)?.min(b.eval(ctx)?)),
123            DimExpr::CeilDiv(a, b) => {
124                let av = a.eval(ctx)?;
125                let bv = b.eval(ctx)?;
126                if bv == 0 {
127                    return None;
128                }
129                Some(av.div_ceil(bv))
130            }
131        }
132    }
133
134    /// Get all free variables in this expression.
135    pub fn free_variables(&self) -> Vec<String> {
136        match self {
137            DimExpr::Const(_) => vec![],
138            DimExpr::Var(name) => vec![name.clone()],
139            DimExpr::Add(a, b)
140            | DimExpr::Sub(a, b)
141            | DimExpr::Mul(a, b)
142            | DimExpr::Div(a, b)
143            | DimExpr::Max(a, b)
144            | DimExpr::Min(a, b)
145            | DimExpr::CeilDiv(a, b) => {
146                let mut vars = a.free_variables();
147                vars.extend(b.free_variables());
148                vars.sort();
149                vars.dedup();
150                vars
151            }
152        }
153    }
154
155    /// Substitute a variable with an expression.
156    pub fn substitute(&self, var: &str, expr: &DimExpr) -> DimExpr {
157        match self {
158            DimExpr::Const(n) => DimExpr::Const(*n),
159            DimExpr::Var(name) => {
160                if name == var {
161                    expr.clone()
162                } else {
163                    DimExpr::Var(name.clone())
164                }
165            }
166            DimExpr::Add(a, b) => DimExpr::Add(
167                Box::new(a.substitute(var, expr)),
168                Box::new(b.substitute(var, expr)),
169            ),
170            DimExpr::Sub(a, b) => DimExpr::Sub(
171                Box::new(a.substitute(var, expr)),
172                Box::new(b.substitute(var, expr)),
173            ),
174            DimExpr::Mul(a, b) => DimExpr::Mul(
175                Box::new(a.substitute(var, expr)),
176                Box::new(b.substitute(var, expr)),
177            ),
178            DimExpr::Div(a, b) => DimExpr::Div(
179                Box::new(a.substitute(var, expr)),
180                Box::new(b.substitute(var, expr)),
181            ),
182            DimExpr::Max(a, b) => DimExpr::Max(
183                Box::new(a.substitute(var, expr)),
184                Box::new(b.substitute(var, expr)),
185            ),
186            DimExpr::Min(a, b) => DimExpr::Min(
187                Box::new(a.substitute(var, expr)),
188                Box::new(b.substitute(var, expr)),
189            ),
190            DimExpr::CeilDiv(a, b) => DimExpr::CeilDiv(
191                Box::new(a.substitute(var, expr)),
192                Box::new(b.substitute(var, expr)),
193            ),
194        }
195    }
196
197    /// Simplify the expression.
198    pub fn simplify(&self) -> DimExpr {
199        match self {
200            DimExpr::Add(a, b) => {
201                let a = a.simplify();
202                let b = b.simplify();
203                match (&a, &b) {
204                    (DimExpr::Const(x), DimExpr::Const(y)) => DimExpr::Const(x + y),
205                    (DimExpr::Const(0), _) => b,
206                    (_, DimExpr::Const(0)) => a,
207                    _ => DimExpr::Add(Box::new(a), Box::new(b)),
208                }
209            }
210            DimExpr::Sub(a, b) => {
211                let a = a.simplify();
212                let b = b.simplify();
213                match (&a, &b) {
214                    (DimExpr::Const(x), DimExpr::Const(y)) => DimExpr::Const(x.saturating_sub(*y)),
215                    (_, DimExpr::Const(0)) => a,
216                    _ => DimExpr::Sub(Box::new(a), Box::new(b)),
217                }
218            }
219            DimExpr::Mul(a, b) => {
220                let a = a.simplify();
221                let b = b.simplify();
222                match (&a, &b) {
223                    (DimExpr::Const(x), DimExpr::Const(y)) => DimExpr::Const(x * y),
224                    (DimExpr::Const(0), _) | (_, DimExpr::Const(0)) => DimExpr::Const(0),
225                    (DimExpr::Const(1), _) => b,
226                    (_, DimExpr::Const(1)) => a,
227                    _ => DimExpr::Mul(Box::new(a), Box::new(b)),
228                }
229            }
230            DimExpr::Div(a, b) => {
231                let a = a.simplify();
232                let b = b.simplify();
233                match (&a, &b) {
234                    (DimExpr::Const(x), DimExpr::Const(y)) if *y != 0 => DimExpr::Const(x / y),
235                    (DimExpr::Const(0), _) => DimExpr::Const(0),
236                    (_, DimExpr::Const(1)) => a,
237                    _ => DimExpr::Div(Box::new(a), Box::new(b)),
238                }
239            }
240            DimExpr::Max(a, b) => {
241                let a = a.simplify();
242                let b = b.simplify();
243                match (&a, &b) {
244                    (DimExpr::Const(x), DimExpr::Const(y)) => DimExpr::Const((*x).max(*y)),
245                    _ => DimExpr::Max(Box::new(a), Box::new(b)),
246                }
247            }
248            DimExpr::Min(a, b) => {
249                let a = a.simplify();
250                let b = b.simplify();
251                match (&a, &b) {
252                    (DimExpr::Const(x), DimExpr::Const(y)) => DimExpr::Const((*x).min(*y)),
253                    _ => DimExpr::Min(Box::new(a), Box::new(b)),
254                }
255            }
256            DimExpr::CeilDiv(a, b) => {
257                let a = a.simplify();
258                let b = b.simplify();
259                match (&a, &b) {
260                    (DimExpr::Const(x), DimExpr::Const(y)) if *y != 0 => {
261                        DimExpr::Const(x.div_ceil(*y))
262                    }
263                    _ => DimExpr::CeilDiv(Box::new(a), Box::new(b)),
264                }
265            }
266            other => other.clone(),
267        }
268    }
269
270    /// Check if two dimension expressions are equal.
271    pub fn is_equal(&self, other: &DimExpr, ctx: &DependentTypeContext) -> bool {
272        // Try symbolic equality first
273        if self == other {
274            return true;
275        }
276
277        // Try numeric equality
278        match (self.eval(ctx), other.eval(ctx)) {
279            (Some(a), Some(b)) => a == b,
280            _ => false,
281        }
282    }
283}
284
285impl fmt::Display for DimExpr {
286    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287        match self {
288            DimExpr::Const(n) => write!(f, "{}", n),
289            DimExpr::Var(name) => write!(f, "{}", name),
290            DimExpr::Add(a, b) => write!(f, "({} + {})", a, b),
291            DimExpr::Sub(a, b) => write!(f, "({} - {})", a, b),
292            DimExpr::Mul(a, b) => write!(f, "({} * {})", a, b),
293            DimExpr::Div(a, b) => write!(f, "({} / {})", a, b),
294            DimExpr::Max(a, b) => write!(f, "max({}, {})", a, b),
295            DimExpr::Min(a, b) => write!(f, "min({}, {})", a, b),
296            DimExpr::CeilDiv(a, b) => write!(f, "ceil({} / {})", a, b),
297        }
298    }
299}
300
301/// A dependent type with dimension parameters.
302#[derive(Debug, Clone)]
303pub struct DependentType {
304    /// Base type name
305    pub base_type: String,
306    /// Type parameters (other types)
307    pub type_params: Vec<String>,
308    /// Dimension parameters
309    pub dim_params: Vec<DimExpr>,
310    /// Optional name for the dependent type
311    pub name: Option<String>,
312    /// Description
313    pub description: Option<String>,
314    /// Constraints between dimensions
315    pub constraints: Vec<DimConstraint>,
316}
317
318/// A constraint between dimension expressions.
319#[derive(Debug, Clone)]
320pub struct DimConstraint {
321    /// Left-hand side expression
322    pub lhs: DimExpr,
323    /// Constraint relation
324    pub relation: DimRelation,
325    /// Right-hand side expression
326    pub rhs: DimExpr,
327    /// Error message if constraint is violated
328    pub message: Option<String>,
329}
330
331/// Relation for dimension constraints.
332#[derive(Debug, Clone, Copy, PartialEq)]
333pub enum DimRelation {
334    /// Equal
335    Equal,
336    /// Not equal
337    NotEqual,
338    /// Less than
339    LessThan,
340    /// Less than or equal
341    LessThanOrEqual,
342    /// Greater than
343    GreaterThan,
344    /// Greater than or equal
345    GreaterThanOrEqual,
346    /// Divisible by
347    DivisibleBy,
348}
349
350impl DependentType {
351    /// Create a new dependent type with a base type.
352    pub fn new(base_type: impl Into<String>) -> Self {
353        DependentType {
354            base_type: base_type.into(),
355            type_params: Vec::new(),
356            dim_params: Vec::new(),
357            name: None,
358            description: None,
359            constraints: Vec::new(),
360        }
361    }
362
363    /// Create a scalar type.
364    pub fn scalar(element_type: impl Into<String>) -> Self {
365        DependentType::new(element_type)
366    }
367
368    /// Create a vector type with a dimension.
369    pub fn vector(element_type: impl Into<String>, length: DimExpr) -> Self {
370        DependentType {
371            base_type: "Vector".to_string(),
372            type_params: vec![element_type.into()],
373            dim_params: vec![length],
374            name: None,
375            description: None,
376            constraints: Vec::new(),
377        }
378    }
379
380    /// Create a matrix type with dimensions.
381    pub fn matrix(element_type: impl Into<String>, rows: DimExpr, cols: DimExpr) -> Self {
382        DependentType {
383            base_type: "Matrix".to_string(),
384            type_params: vec![element_type.into()],
385            dim_params: vec![rows, cols],
386            name: None,
387            description: None,
388            constraints: Vec::new(),
389        }
390    }
391
392    /// Create a tensor type with arbitrary dimensions.
393    pub fn tensor(element_type: impl Into<String>, dims: Vec<DimExpr>) -> Self {
394        DependentType {
395            base_type: "Tensor".to_string(),
396            type_params: vec![element_type.into()],
397            dim_params: dims,
398            name: None,
399            description: None,
400            constraints: Vec::new(),
401        }
402    }
403
404    /// Set the name of this type.
405    pub fn with_name(mut self, name: impl Into<String>) -> Self {
406        self.name = Some(name.into());
407        self
408    }
409
410    /// Set the description.
411    pub fn with_description(mut self, description: impl Into<String>) -> Self {
412        self.description = Some(description.into());
413        self
414    }
415
416    /// Add a type parameter.
417    pub fn with_type_param(mut self, param: impl Into<String>) -> Self {
418        self.type_params.push(param.into());
419        self
420    }
421
422    /// Add a dimension parameter.
423    pub fn with_dim_param(mut self, dim: DimExpr) -> Self {
424        self.dim_params.push(dim);
425        self
426    }
427
428    /// Add a dimension constraint.
429    pub fn with_constraint(mut self, constraint: DimConstraint) -> Self {
430        self.constraints.push(constraint);
431        self
432    }
433
434    /// Get the effective name.
435    pub fn type_name(&self) -> String {
436        if let Some(name) = &self.name {
437            return name.clone();
438        }
439
440        if self.dim_params.is_empty() && self.type_params.is_empty() {
441            return self.base_type.clone();
442        }
443
444        let mut result = self.base_type.clone();
445        if !self.type_params.is_empty() || !self.dim_params.is_empty() {
446            result.push('<');
447
448            let mut parts = Vec::new();
449            for tp in &self.type_params {
450                parts.push(tp.clone());
451            }
452            for dp in &self.dim_params {
453                parts.push(format!("{}", dp));
454            }
455
456            result.push_str(&parts.join(", "));
457            result.push('>');
458        }
459        result
460    }
461
462    /// Evaluate the shape of this type in a context.
463    pub fn eval_shape(&self, ctx: &DependentTypeContext) -> Option<Vec<usize>> {
464        self.dim_params.iter().map(|d| d.eval(ctx)).collect()
465    }
466
467    /// Get the rank (number of dimensions).
468    pub fn rank(&self) -> usize {
469        self.dim_params.len()
470    }
471
472    /// Get all free dimension variables.
473    pub fn free_variables(&self) -> Vec<String> {
474        let mut vars = Vec::new();
475        for dim in &self.dim_params {
476            vars.extend(dim.free_variables());
477        }
478        for constraint in &self.constraints {
479            vars.extend(constraint.lhs.free_variables());
480            vars.extend(constraint.rhs.free_variables());
481        }
482        vars.sort();
483        vars.dedup();
484        vars
485    }
486
487    /// Check if constraints are satisfied in a context.
488    pub fn check_constraints(&self, ctx: &DependentTypeContext) -> Result<(), String> {
489        for constraint in &self.constraints {
490            if !constraint.check(ctx) {
491                let msg = constraint.message.clone().unwrap_or_else(|| {
492                    format!(
493                        "Constraint violated: {} {:?} {}",
494                        constraint.lhs, constraint.relation, constraint.rhs
495                    )
496                });
497                return Err(msg);
498            }
499        }
500        Ok(())
501    }
502
503    /// Check if this type is compatible with another for assignment.
504    pub fn is_compatible_with(&self, other: &DependentType, ctx: &DependentTypeContext) -> bool {
505        // Base types must match
506        if self.base_type != other.base_type {
507            return false;
508        }
509
510        // Type parameters must match
511        if self.type_params != other.type_params {
512            return false;
513        }
514
515        // Dimension parameters must be equal
516        if self.dim_params.len() != other.dim_params.len() {
517            return false;
518        }
519
520        for (a, b) in self.dim_params.iter().zip(&other.dim_params) {
521            if !a.is_equal(b, ctx) {
522                return false;
523            }
524        }
525
526        true
527    }
528}
529
530impl DimConstraint {
531    /// Create a new dimension constraint.
532    pub fn new(lhs: DimExpr, relation: DimRelation, rhs: DimExpr) -> Self {
533        DimConstraint {
534            lhs,
535            relation,
536            rhs,
537            message: None,
538        }
539    }
540
541    /// Set the error message.
542    pub fn with_message(mut self, message: impl Into<String>) -> Self {
543        self.message = Some(message.into());
544        self
545    }
546
547    /// Check if the constraint is satisfied.
548    pub fn check(&self, ctx: &DependentTypeContext) -> bool {
549        let lhs_val = match self.lhs.eval(ctx) {
550            Some(v) => v,
551            None => return false,
552        };
553        let rhs_val = match self.rhs.eval(ctx) {
554            Some(v) => v,
555            None => return false,
556        };
557
558        match self.relation {
559            DimRelation::Equal => lhs_val == rhs_val,
560            DimRelation::NotEqual => lhs_val != rhs_val,
561            DimRelation::LessThan => lhs_val < rhs_val,
562            DimRelation::LessThanOrEqual => lhs_val <= rhs_val,
563            DimRelation::GreaterThan => lhs_val > rhs_val,
564            DimRelation::GreaterThanOrEqual => lhs_val >= rhs_val,
565            DimRelation::DivisibleBy => rhs_val != 0 && lhs_val % rhs_val == 0,
566        }
567    }
568}
569
570/// Context for evaluating dependent types.
571#[derive(Debug, Clone, Default)]
572pub struct DependentTypeContext {
573    /// Dimension variable values
574    dims: HashMap<String, usize>,
575    /// Type definitions
576    types: HashMap<String, DependentType>,
577}
578
579impl DependentTypeContext {
580    /// Create a new empty context.
581    pub fn new() -> Self {
582        DependentTypeContext {
583            dims: HashMap::new(),
584            types: HashMap::new(),
585        }
586    }
587
588    /// Set a dimension variable.
589    pub fn set_dim(&mut self, name: impl Into<String>, value: usize) {
590        self.dims.insert(name.into(), value);
591    }
592
593    /// Get a dimension variable.
594    pub fn get_dim(&self, name: &str) -> Option<&usize> {
595        self.dims.get(name)
596    }
597
598    /// Register a type definition.
599    pub fn set_type(&mut self, name: impl Into<String>, ty: DependentType) {
600        self.types.insert(name.into(), ty);
601    }
602
603    /// Get a type definition.
604    pub fn get_type(&self, name: &str) -> Option<&DependentType> {
605        self.types.get(name)
606    }
607
608    /// Check if a dimension variable exists.
609    pub fn has_dim(&self, name: &str) -> bool {
610        self.dims.contains_key(name)
611    }
612
613    /// Get all dimension variable names.
614    pub fn dim_names(&self) -> Vec<&str> {
615        self.dims.keys().map(|s| s.as_str()).collect()
616    }
617
618    /// Clear all dimension bindings.
619    pub fn clear_dims(&mut self) {
620        self.dims.clear();
621    }
622
623    /// Unify two dimension expressions if possible.
624    ///
625    /// Returns true if unification succeeds and updates the context.
626    pub fn unify(&mut self, a: &DimExpr, b: &DimExpr) -> bool {
627        match (a, b) {
628            (DimExpr::Var(va), DimExpr::Var(vb)) if va == vb => true,
629            (DimExpr::Var(va), expr) | (expr, DimExpr::Var(va)) => {
630                if let Some(&existing) = self.dims.get(va) {
631                    if let Some(val) = expr.eval(self) {
632                        existing == val
633                    } else {
634                        false
635                    }
636                } else if let Some(val) = expr.eval(self) {
637                    self.dims.insert(va.clone(), val);
638                    true
639                } else {
640                    false
641                }
642            }
643            (DimExpr::Const(ca), DimExpr::Const(cb)) => ca == cb,
644            _ => {
645                // Try to evaluate both
646                match (a.eval(self), b.eval(self)) {
647                    (Some(va), Some(vb)) => va == vb,
648                    _ => false,
649                }
650            }
651        }
652    }
653}
654
655/// Registry for managing dependent types.
656#[derive(Debug, Clone, Default)]
657pub struct DependentTypeRegistry {
658    /// Named dependent types
659    types: HashMap<String, DependentType>,
660}
661
662impl DependentTypeRegistry {
663    /// Create a new empty registry.
664    pub fn new() -> Self {
665        DependentTypeRegistry {
666            types: HashMap::new(),
667        }
668    }
669
670    /// Register a dependent type.
671    pub fn register(&mut self, ty: DependentType) {
672        let name = ty.type_name();
673        self.types.insert(name, ty);
674    }
675
676    /// Get a type by name.
677    pub fn get(&self, name: &str) -> Option<&DependentType> {
678        self.types.get(name)
679    }
680
681    /// Check if a type exists.
682    pub fn contains(&self, name: &str) -> bool {
683        self.types.contains_key(name)
684    }
685
686    /// Get all type names.
687    pub fn type_names(&self) -> Vec<&str> {
688        self.types.keys().map(|s| s.as_str()).collect()
689    }
690
691    /// Get the number of registered types.
692    pub fn len(&self) -> usize {
693        self.types.len()
694    }
695
696    /// Check if the registry is empty.
697    pub fn is_empty(&self) -> bool {
698        self.types.is_empty()
699    }
700}
701
702/// Common dependent type patterns.
703pub mod patterns {
704    use super::*;
705
706    /// Square matrix type.
707    pub fn square_matrix(element_type: impl Into<String>, size: DimExpr) -> DependentType {
708        DependentType::matrix(element_type, size.clone(), size).with_name("SquareMatrix")
709    }
710
711    /// Identity matrix type (square with ones on diagonal).
712    pub fn identity_matrix(size: DimExpr) -> DependentType {
713        DependentType::matrix("Float", size.clone(), size).with_name("IdentityMatrix")
714    }
715
716    /// Batch of vectors.
717    pub fn batch_vector(
718        element_type: impl Into<String>,
719        batch: DimExpr,
720        length: DimExpr,
721    ) -> DependentType {
722        DependentType::tensor(element_type, vec![batch, length]).with_name("BatchVector")
723    }
724
725    /// Batch of matrices.
726    pub fn batch_matrix(
727        element_type: impl Into<String>,
728        batch: DimExpr,
729        rows: DimExpr,
730        cols: DimExpr,
731    ) -> DependentType {
732        DependentType::tensor(element_type, vec![batch, rows, cols]).with_name("BatchMatrix")
733    }
734
735    /// 4D image tensor (batch, channels, height, width).
736    pub fn image_tensor(
737        batch: DimExpr,
738        channels: DimExpr,
739        height: DimExpr,
740        width: DimExpr,
741    ) -> DependentType {
742        DependentType::tensor("Float", vec![batch, channels, height, width])
743            .with_name("ImageTensor")
744    }
745
746    /// Attention tensor (batch, heads, seq_len, head_dim).
747    pub fn attention_tensor(
748        batch: DimExpr,
749        heads: DimExpr,
750        seq_len: DimExpr,
751        head_dim: DimExpr,
752    ) -> DependentType {
753        DependentType::tensor("Float", vec![batch, heads, seq_len, head_dim])
754            .with_name("AttentionTensor")
755    }
756}
757
758#[cfg(test)]
759mod tests {
760    use super::*;
761
762    #[test]
763    fn test_dim_expr_const() {
764        let expr = DimExpr::Const(42);
765        let ctx = DependentTypeContext::new();
766        assert_eq!(expr.eval(&ctx), Some(42));
767    }
768
769    #[test]
770    fn test_dim_expr_var() {
771        let expr = DimExpr::Var("n".to_string());
772        let mut ctx = DependentTypeContext::new();
773        ctx.set_dim("n", 10);
774        assert_eq!(expr.eval(&ctx), Some(10));
775    }
776
777    #[test]
778    fn test_dim_expr_arithmetic() {
779        let mut ctx = DependentTypeContext::new();
780        ctx.set_dim("n", 10);
781        ctx.set_dim("m", 3);
782
783        let add = DimExpr::var("n").add(DimExpr::var("m"));
784        assert_eq!(add.eval(&ctx), Some(13));
785
786        let mul = DimExpr::var("n").mul(DimExpr::var("m"));
787        assert_eq!(mul.eval(&ctx), Some(30));
788
789        let div = DimExpr::var("n").div(DimExpr::var("m"));
790        assert_eq!(div.eval(&ctx), Some(3));
791    }
792
793    #[test]
794    fn test_dim_expr_max_min() {
795        let mut ctx = DependentTypeContext::new();
796        ctx.set_dim("a", 5);
797        ctx.set_dim("b", 10);
798
799        let max = DimExpr::var("a").max(DimExpr::var("b"));
800        assert_eq!(max.eval(&ctx), Some(10));
801
802        let min = DimExpr::var("a").min(DimExpr::var("b"));
803        assert_eq!(min.eval(&ctx), Some(5));
804    }
805
806    #[test]
807    fn test_dim_expr_simplify() {
808        let expr = DimExpr::constant(5).add(DimExpr::constant(3));
809        let simplified = expr.simplify();
810        assert_eq!(simplified, DimExpr::Const(8));
811
812        let expr = DimExpr::var("x").add(DimExpr::constant(0));
813        let simplified = expr.simplify();
814        assert_eq!(simplified, DimExpr::Var("x".to_string()));
815    }
816
817    #[test]
818    fn test_vector_type() {
819        let vec_ty = DependentType::vector("Float", DimExpr::var("n"));
820        let mut ctx = DependentTypeContext::new();
821        ctx.set_dim("n", 100);
822
823        assert_eq!(vec_ty.eval_shape(&ctx), Some(vec![100]));
824        assert_eq!(vec_ty.rank(), 1);
825    }
826
827    #[test]
828    fn test_matrix_type() {
829        let mat_ty = DependentType::matrix("Float", DimExpr::var("m"), DimExpr::var("n"));
830        let mut ctx = DependentTypeContext::new();
831        ctx.set_dim("m", 10);
832        ctx.set_dim("n", 20);
833
834        assert_eq!(mat_ty.eval_shape(&ctx), Some(vec![10, 20]));
835        assert_eq!(mat_ty.rank(), 2);
836    }
837
838    #[test]
839    fn test_dim_constraint() {
840        let constraint = DimConstraint::new(
841            DimExpr::var("n"),
842            DimRelation::GreaterThan,
843            DimExpr::constant(0),
844        );
845
846        let mut ctx = DependentTypeContext::new();
847        ctx.set_dim("n", 10);
848        assert!(constraint.check(&ctx));
849
850        ctx.set_dim("n", 0);
851        assert!(!constraint.check(&ctx));
852    }
853
854    #[test]
855    fn test_type_with_constraints() {
856        let ty = DependentType::matrix("Float", DimExpr::var("m"), DimExpr::var("n"))
857            .with_constraint(
858                DimConstraint::new(DimExpr::var("m"), DimRelation::Equal, DimExpr::var("n"))
859                    .with_message("Matrix must be square"),
860            );
861
862        let mut ctx = DependentTypeContext::new();
863        ctx.set_dim("m", 10);
864        ctx.set_dim("n", 10);
865        assert!(ty.check_constraints(&ctx).is_ok());
866
867        ctx.set_dim("n", 20);
868        assert!(ty.check_constraints(&ctx).is_err());
869    }
870
871    #[test]
872    fn test_type_compatibility() {
873        let ty1 = DependentType::matrix("Float", DimExpr::var("m"), DimExpr::var("n"));
874        let ty2 = DependentType::matrix("Float", DimExpr::var("m"), DimExpr::var("n"));
875
876        let mut ctx = DependentTypeContext::new();
877        ctx.set_dim("m", 10);
878        ctx.set_dim("n", 20);
879
880        assert!(ty1.is_compatible_with(&ty2, &ctx));
881
882        let ty3 = DependentType::matrix("Int", DimExpr::var("m"), DimExpr::var("n"));
883        assert!(!ty1.is_compatible_with(&ty3, &ctx));
884    }
885
886    #[test]
887    fn test_free_variables() {
888        let expr = DimExpr::var("n")
889            .add(DimExpr::var("m"))
890            .mul(DimExpr::var("k"));
891        let vars = expr.free_variables();
892        assert_eq!(vars.len(), 3);
893        assert!(vars.contains(&"k".to_string()));
894        assert!(vars.contains(&"m".to_string()));
895        assert!(vars.contains(&"n".to_string()));
896    }
897
898    #[test]
899    fn test_substitute() {
900        let expr = DimExpr::var("n").add(DimExpr::constant(5));
901        let substituted = expr.substitute("n", &DimExpr::constant(10));
902
903        let ctx = DependentTypeContext::new();
904        assert_eq!(substituted.eval(&ctx), Some(15));
905    }
906
907    #[test]
908    fn test_ceil_div() {
909        let expr = DimExpr::constant(10).ceil_div(DimExpr::constant(3));
910        let ctx = DependentTypeContext::new();
911        assert_eq!(expr.eval(&ctx), Some(4)); // ceil(10/3) = 4
912    }
913
914    #[test]
915    fn test_context_unify() {
916        let mut ctx = DependentTypeContext::new();
917
918        // Unify variable with constant
919        let success = ctx.unify(&DimExpr::var("n"), &DimExpr::constant(10));
920        assert!(success);
921        assert_eq!(ctx.get_dim("n"), Some(&10));
922
923        // Unify with existing binding
924        let success = ctx.unify(&DimExpr::var("n"), &DimExpr::constant(10));
925        assert!(success);
926
927        // Fail to unify incompatible
928        let success = ctx.unify(&DimExpr::var("n"), &DimExpr::constant(20));
929        assert!(!success);
930    }
931
932    #[test]
933    fn test_patterns() {
934        let mut ctx = DependentTypeContext::new();
935        ctx.set_dim("n", 64);
936        ctx.set_dim("batch", 32);
937        ctx.set_dim("heads", 8);
938        ctx.set_dim("seq_len", 512);
939        ctx.set_dim("head_dim", 64);
940
941        let sq = patterns::square_matrix("Float", DimExpr::var("n"));
942        assert_eq!(sq.eval_shape(&ctx), Some(vec![64, 64]));
943
944        let attn = patterns::attention_tensor(
945            DimExpr::var("batch"),
946            DimExpr::var("heads"),
947            DimExpr::var("seq_len"),
948            DimExpr::var("head_dim"),
949        );
950        assert_eq!(attn.eval_shape(&ctx), Some(vec![32, 8, 512, 64]));
951    }
952
953    #[test]
954    fn test_registry() {
955        let mut registry = DependentTypeRegistry::new();
956
957        registry
958            .register(DependentType::vector("Float", DimExpr::var("n")).with_name("FloatVector"));
959
960        assert!(registry.contains("FloatVector"));
961        assert_eq!(registry.len(), 1);
962
963        let ty = registry.get("FloatVector").unwrap();
964        assert_eq!(ty.base_type, "Vector");
965    }
966
967    #[test]
968    fn test_dim_display() {
969        let expr = DimExpr::var("n").add(DimExpr::var("m"));
970        assert_eq!(format!("{}", expr), "(n + m)");
971
972        let expr = DimExpr::var("a").mul(DimExpr::constant(2));
973        assert_eq!(format!("{}", expr), "(a * 2)");
974    }
975
976    #[test]
977    fn test_type_name() {
978        let ty = DependentType::matrix("Float", DimExpr::var("m"), DimExpr::var("n"));
979        assert_eq!(ty.type_name(), "Matrix<Float, m, n>");
980
981        let ty = ty.with_name("MyMatrix");
982        assert_eq!(ty.type_name(), "MyMatrix");
983    }
984}