Skip to main content

tl_types/
lib.rs

1// ThinkingLanguage — Type System
2// Licensed under MIT OR Apache-2.0
3//
4// Provides the internal Type representation, type environment,
5// and type checker for gradual static typing.
6
7pub mod checker;
8pub mod convert;
9pub mod infer;
10
11use std::collections::HashMap;
12use std::fmt;
13
14/// Internal type representation used by the type checker.
15/// Separate from `TypeExpr` (AST surface syntax).
16#[derive(Debug, Clone, PartialEq)]
17pub enum Type {
18    /// Gradual typing: compatible with everything
19    Any,
20    /// Void return type
21    Unit,
22    /// Primitive types
23    Int,
24    Float,
25    String,
26    Bool,
27    None,
28    /// Fixed-point decimal for financial data
29    Decimal,
30    /// Composite types
31    List(Box<Type>),
32    Map(Box<Type>),
33    Set(Box<Type>),
34    /// Option: T or None
35    Option(Box<Type>),
36    /// Result: Ok(T) or Err(E)
37    Result(Box<Type>, Box<Type>),
38    /// Function type
39    Function {
40        params: Vec<Type>,
41        ret: Box<Type>,
42    },
43    /// Named struct type
44    Struct(std::string::String),
45    /// Named enum type
46    Enum(std::string::String),
47    /// Table with optional schema name and typed columns
48    Table {
49        name: Option<std::string::String>,
50        columns: Option<Vec<(std::string::String, Type)>>,
51    },
52    /// Generator yielding T
53    Generator(Box<Type>),
54    /// Task returning T
55    Task(Box<Type>),
56    /// Channel carrying T
57    Channel(Box<Type>),
58    /// Tensor type (shape is runtime-only)
59    Tensor,
60    /// Typed stream carrying elements of type T
61    Stream(Box<Type>),
62    /// Opaque pipeline type
63    Pipeline,
64    /// Type parameter (generic): T, U, etc.
65    TypeParam(std::string::String),
66    /// Inference variable (unresolved)
67    Var(u32),
68    /// An opaque Python object
69    PyObject,
70    /// Read-only reference type
71    Ref(Box<Type>),
72    /// Poison type — suppresses further errors
73    Error,
74}
75
76impl fmt::Display for Type {
77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        match self {
79            Type::Any => write!(f, "any"),
80            Type::Unit => write!(f, "unit"),
81            Type::Int => write!(f, "int"),
82            Type::Float => write!(f, "float"),
83            Type::String => write!(f, "string"),
84            Type::Bool => write!(f, "bool"),
85            Type::None => write!(f, "none"),
86            Type::Decimal => write!(f, "decimal"),
87            Type::List(t) => write!(f, "list<{t}>"),
88            Type::Map(t) => write!(f, "map<{t}>"),
89            Type::Set(t) => write!(f, "set<{t}>"),
90            Type::Option(t) => write!(f, "{t}?"),
91            Type::Result(ok, err) => write!(f, "result<{ok}, {err}>"),
92            Type::Function { params, ret } => {
93                write!(f, "fn(")?;
94                for (i, p) in params.iter().enumerate() {
95                    if i > 0 {
96                        write!(f, ", ")?;
97                    }
98                    write!(f, "{p}")?;
99                }
100                write!(f, ") -> {ret}")
101            }
102            Type::Struct(name) => write!(f, "{name}"),
103            Type::Enum(name) => write!(f, "{name}"),
104            Type::Table {
105                name: Some(name), ..
106            } => write!(f, "table<{name}>"),
107            Type::Table { name: None, .. } => write!(f, "table"),
108            Type::Generator(t) => write!(f, "generator<{t}>"),
109            Type::Task(t) => write!(f, "task<{t}>"),
110            Type::Channel(t) => write!(f, "channel<{t}>"),
111            Type::Tensor => write!(f, "tensor"),
112            Type::Stream(t) => write!(f, "stream<{t}>"),
113            Type::Pipeline => write!(f, "pipeline"),
114            Type::TypeParam(name) => write!(f, "{name}"),
115            Type::Var(id) => write!(f, "?T{id}"),
116            Type::PyObject => write!(f, "pyobject"),
117            Type::Ref(t) => write!(f, "&{t}"),
118            Type::Error => write!(f, "<error>"),
119        }
120    }
121}
122
123/// Information about a trait definition.
124#[derive(Debug, Clone)]
125pub struct TraitInfo {
126    pub name: std::string::String,
127    pub methods: Vec<(std::string::String, Vec<Type>, Type)>, // (name, param_types, return_type)
128    pub supertrait: Option<std::string::String>,
129}
130
131/// Type environment — tracks variable types across scopes.
132pub struct TypeEnv {
133    scopes: Vec<Scope>,
134    /// Function signatures: name -> (param types, return type)
135    functions: std::collections::HashMap<std::string::String, FnSig>,
136    /// Struct definitions: name -> field types
137    structs: std::collections::HashMap<std::string::String, Vec<(std::string::String, Type)>>,
138    /// Enum definitions: name -> variant list
139    enums: std::collections::HashMap<std::string::String, Vec<(std::string::String, Vec<Type>)>>,
140    /// Trait definitions: name -> trait info
141    traits: std::collections::HashMap<std::string::String, TraitInfo>,
142    /// Trait implementations: (trait_name, type_name) -> method names
143    trait_impls: std::collections::HashMap<
144        (std::string::String, std::string::String),
145        Vec<std::string::String>,
146    >,
147    /// Type aliases: name -> (type_params, TypeExpr)
148    type_aliases: std::collections::HashMap<
149        std::string::String,
150        (Vec<std::string::String>, tl_ast::TypeExpr),
151    >,
152    /// Sensitive field annotations: type_name -> [(field_name, annotation)]
153    sensitive_fields: std::collections::HashMap<
154        std::string::String,
155        Vec<(std::string::String, std::string::String)>,
156    >,
157    /// Currently resolving aliases (cycle detection)
158    resolving_aliases: std::collections::HashSet<std::string::String>,
159    /// Next inference variable ID
160    next_var: u32,
161}
162
163/// A function signature.
164#[derive(Debug, Clone)]
165pub struct FnSig {
166    pub params: Vec<(std::string::String, Type)>,
167    pub ret: Type,
168}
169
170struct Scope {
171    vars: std::collections::HashMap<std::string::String, Type>,
172}
173
174impl TypeEnv {
175    pub fn new() -> Self {
176        let mut env = TypeEnv {
177            scopes: vec![Scope {
178                vars: std::collections::HashMap::new(),
179            }],
180            functions: std::collections::HashMap::new(),
181            structs: std::collections::HashMap::new(),
182            enums: std::collections::HashMap::new(),
183            traits: std::collections::HashMap::new(),
184            trait_impls: std::collections::HashMap::new(),
185            type_aliases: std::collections::HashMap::new(),
186            sensitive_fields: std::collections::HashMap::new(),
187            resolving_aliases: std::collections::HashSet::new(),
188            next_var: 0,
189        };
190        env.register_builtin_traits();
191        env
192    }
193
194    /// Register built-in trait hierarchy.
195    fn register_builtin_traits(&mut self) {
196        // Hashable — int, float, string, bool
197        self.traits.insert(
198            "Hashable".into(),
199            TraitInfo {
200                name: "Hashable".into(),
201                methods: vec![],
202                supertrait: None,
203            },
204        );
205        // Comparable — int, float, string (implies Hashable)
206        self.traits.insert(
207            "Comparable".into(),
208            TraitInfo {
209                name: "Comparable".into(),
210                methods: vec![],
211                supertrait: Some("Hashable".into()),
212            },
213        );
214        // Numeric — int, float (implies Comparable)
215        self.traits.insert(
216            "Numeric".into(),
217            TraitInfo {
218                name: "Numeric".into(),
219                methods: vec![],
220                supertrait: Some("Comparable".into()),
221            },
222        );
223        // Displayable — all primitives
224        self.traits.insert(
225            "Displayable".into(),
226            TraitInfo {
227                name: "Displayable".into(),
228                methods: vec![("to_string".into(), vec![], Type::String)],
229                supertrait: None,
230            },
231        );
232        // Serializable — all primitives, structs
233        self.traits.insert(
234            "Serializable".into(),
235            TraitInfo {
236                name: "Serializable".into(),
237                methods: vec![],
238                supertrait: None,
239            },
240        );
241        // Default — all primitives, list, map, set
242        self.traits.insert(
243            "Default".into(),
244            TraitInfo {
245                name: "Default".into(),
246                methods: vec![],
247                supertrait: None,
248            },
249        );
250    }
251
252    pub fn scope_depth(&self) -> u32 {
253        self.scopes.len() as u32 - 1
254    }
255
256    pub fn push_scope(&mut self) {
257        self.scopes.push(Scope {
258            vars: std::collections::HashMap::new(),
259        });
260    }
261
262    pub fn pop_scope(&mut self) {
263        if self.scopes.len() > 1 {
264            self.scopes.pop();
265        }
266    }
267
268    pub fn define(&mut self, name: std::string::String, ty: Type) {
269        if let Some(scope) = self.scopes.last_mut() {
270            scope.vars.insert(name, ty);
271        }
272    }
273
274    pub fn lookup(&self, name: &str) -> Option<&Type> {
275        for scope in self.scopes.iter().rev() {
276            if let Some(ty) = scope.vars.get(name) {
277                return Some(ty);
278            }
279        }
280        None
281    }
282
283    pub fn define_fn(&mut self, name: std::string::String, sig: FnSig) {
284        self.functions.insert(name, sig);
285    }
286
287    pub fn lookup_fn(&self, name: &str) -> Option<&FnSig> {
288        self.functions.get(name)
289    }
290
291    pub fn define_struct(
292        &mut self,
293        name: std::string::String,
294        fields: Vec<(std::string::String, Type)>,
295    ) {
296        self.structs.insert(name, fields);
297    }
298
299    pub fn lookup_struct(&self, name: &str) -> Option<&Vec<(std::string::String, Type)>> {
300        self.structs.get(name)
301    }
302
303    pub fn define_enum(
304        &mut self,
305        name: std::string::String,
306        variants: Vec<(std::string::String, Vec<Type>)>,
307    ) {
308        self.enums.insert(name, variants);
309    }
310
311    pub fn lookup_enum(&self, name: &str) -> Option<&Vec<(std::string::String, Vec<Type>)>> {
312        self.enums.get(name)
313    }
314
315    pub fn fresh_var(&mut self) -> Type {
316        let id = self.next_var;
317        self.next_var += 1;
318        Type::Var(id)
319    }
320
321    pub fn define_trait(&mut self, name: std::string::String, info: TraitInfo) {
322        self.traits.insert(name, info);
323    }
324
325    pub fn lookup_trait(&self, name: &str) -> Option<&TraitInfo> {
326        self.traits.get(name)
327    }
328
329    pub fn register_type_alias(
330        &mut self,
331        name: std::string::String,
332        type_params: Vec<std::string::String>,
333        value: tl_ast::TypeExpr,
334    ) {
335        self.type_aliases.insert(name, (type_params, value));
336    }
337
338    pub fn lookup_type_alias(
339        &self,
340        name: &str,
341    ) -> Option<&(Vec<std::string::String>, tl_ast::TypeExpr)> {
342        self.type_aliases.get(name)
343    }
344
345    /// Check if an alias is currently being resolved (cycle detection).
346    pub fn is_resolving_alias(&self, name: &str) -> bool {
347        self.resolving_aliases.contains(name)
348    }
349
350    /// Mark an alias as being resolved.
351    pub fn start_resolving_alias(&mut self, name: std::string::String) {
352        self.resolving_aliases.insert(name);
353    }
354
355    /// Unmark an alias after resolution.
356    pub fn stop_resolving_alias(&mut self, name: &str) {
357        self.resolving_aliases.remove(name);
358    }
359
360    /// Register a sensitive field annotation.
361    pub fn register_sensitive_field(
362        &mut self,
363        type_name: std::string::String,
364        field_name: std::string::String,
365        annotation: std::string::String,
366    ) {
367        self.sensitive_fields
368            .entry(type_name)
369            .or_default()
370            .push((field_name, annotation));
371    }
372
373    /// Check if a field is annotated as sensitive.
374    pub fn is_field_sensitive(&self, type_name: &str, field_name: &str) -> bool {
375        self.sensitive_fields
376            .get(type_name)
377            .map(|fields| fields.iter().any(|(f, _)| f == field_name))
378            .unwrap_or(false)
379    }
380
381    /// Get sensitive annotations for a type's field.
382    pub fn get_field_annotations(
383        &self,
384        type_name: &str,
385    ) -> Option<&Vec<(std::string::String, std::string::String)>> {
386        self.sensitive_fields.get(type_name)
387    }
388
389    pub fn register_trait_impl(
390        &mut self,
391        trait_name: std::string::String,
392        type_name: std::string::String,
393        method_names: Vec<std::string::String>,
394    ) {
395        self.trait_impls
396            .insert((trait_name, type_name), method_names);
397    }
398
399    pub fn lookup_trait_impl(
400        &self,
401        trait_name: &str,
402        type_name: &str,
403    ) -> Option<&Vec<std::string::String>> {
404        self.trait_impls
405            .get(&(trait_name.to_string(), type_name.to_string()))
406    }
407
408    /// Check if a type satisfies a trait bound (including built-in trait hierarchy).
409    pub fn type_satisfies_trait(&self, ty: &Type, trait_name: &str) -> bool {
410        // any always satisfies
411        if matches!(ty, Type::Any | Type::Error | Type::TypeParam(_)) {
412            return true;
413        }
414        // Check built-in trait implementations
415        match trait_name {
416            "Numeric" => matches!(ty, Type::Int | Type::Float | Type::Decimal),
417            "Comparable" => {
418                matches!(ty, Type::Int | Type::Float | Type::String | Type::Decimal)
419                    || self.type_satisfies_trait(ty, "Numeric")
420            }
421            "Hashable" => {
422                matches!(
423                    ty,
424                    Type::Int | Type::Float | Type::String | Type::Bool | Type::Decimal
425                ) || self.type_satisfies_trait(ty, "Comparable")
426            }
427            "Displayable" => matches!(
428                ty,
429                Type::Int | Type::Float | Type::String | Type::Bool | Type::None | Type::Decimal
430            ),
431            "Default" => matches!(
432                ty,
433                Type::Int
434                    | Type::Float
435                    | Type::String
436                    | Type::Bool
437                    | Type::None
438                    | Type::List(_)
439                    | Type::Map(_)
440                    | Type::Set(_)
441            ),
442            "Serializable" => matches!(
443                ty,
444                Type::Int
445                    | Type::Float
446                    | Type::String
447                    | Type::Bool
448                    | Type::None
449                    | Type::Decimal
450                    | Type::Struct(_)
451            ),
452            _ => {
453                // Check user-defined trait impls
454                let type_name = match ty {
455                    Type::Struct(n) | Type::Enum(n) => n.as_str(),
456                    _ => return false,
457                };
458                self.lookup_trait_impl(trait_name, type_name).is_some()
459            }
460        }
461    }
462}
463
464impl Default for TypeEnv {
465    fn default() -> Self {
466        Self::new()
467    }
468}
469
470/// Check if two types are compatible under gradual typing.
471/// `any` is compatible with everything. `none` is compatible with `option<T>`.
472pub fn is_compatible(expected: &Type, found: &Type) -> bool {
473    // Any is compatible with everything (both directions)
474    if matches!(expected, Type::Any) || matches!(found, Type::Any) {
475        return true;
476    }
477    // Error poison type suppresses further errors
478    if matches!(expected, Type::Error) || matches!(found, Type::Error) {
479        return true;
480    }
481    // Type parameters are compatible with anything (generics are type-erased)
482    if matches!(expected, Type::TypeParam(_)) || matches!(found, Type::TypeParam(_)) {
483        return true;
484    }
485    // Same type
486    if expected == found {
487        return true;
488    }
489    // int promotes to float
490    if matches!(expected, Type::Float) && matches!(found, Type::Int) {
491        return true;
492    }
493    // decimal promotes to float
494    if matches!(expected, Type::Float) && matches!(found, Type::Decimal) {
495        return true;
496    }
497    // int promotes to decimal
498    if matches!(expected, Type::Decimal) && matches!(found, Type::Int) {
499        return true;
500    }
501    // none is compatible with option<T>
502    if matches!(found, Type::None) && matches!(expected, Type::Option(_)) {
503        return true;
504    }
505    // T is compatible with option<T>
506    if let Type::Option(inner) = expected
507        && is_compatible(inner, found)
508    {
509        return true;
510    }
511    // Structural compatibility for compound types
512    match (expected, found) {
513        (Type::List(a), Type::List(b)) => is_compatible(a, b),
514        (Type::Map(a), Type::Map(b)) => is_compatible(a, b),
515        (Type::Set(a), Type::Set(b)) => is_compatible(a, b),
516        (Type::Option(a), Type::Option(b)) => is_compatible(a, b),
517        (Type::Result(ok1, err1), Type::Result(ok2, err2)) => {
518            is_compatible(ok1, ok2) && is_compatible(err1, err2)
519        }
520        (Type::Generator(a), Type::Generator(b)) => is_compatible(a, b),
521        (Type::Task(a), Type::Task(b)) => is_compatible(a, b),
522        (Type::Channel(a), Type::Channel(b)) => is_compatible(a, b),
523        (Type::Stream(a), Type::Stream(b)) => is_compatible(a, b),
524        // Tables: None columns = compatible with any columns
525        (
526            Type::Table {
527                name: n1,
528                columns: c1,
529            },
530            Type::Table {
531                name: n2,
532                columns: c2,
533            },
534        ) => {
535            let name_ok = match (n1, n2) {
536                (Some(a), Some(b)) => a == b,
537                _ => true,
538            };
539            let cols_ok = match (c1, c2) {
540                (None, _) | (_, None) => true,
541                (Some(a), Some(b)) => {
542                    a.len() == b.len()
543                        && a.iter()
544                            .zip(b.iter())
545                            .all(|((n1, t1), (n2, t2))| n1 == n2 && is_compatible(t1, t2))
546                }
547            };
548            name_ok && cols_ok
549        }
550        (
551            Type::Function {
552                params: p1,
553                ret: r1,
554            },
555            Type::Function {
556                params: p2,
557                ret: r2,
558            },
559        ) => {
560            p1.len() == p2.len()
561                && p1.iter().zip(p2.iter()).all(|(a, b)| is_compatible(a, b))
562                && is_compatible(r1, r2)
563        }
564        _ => false,
565    }
566}
567
568/// A substitution mapping inference variables to concrete types.
569#[derive(Debug, Clone, Default)]
570pub struct Substitution {
571    pub mappings: HashMap<u32, Type>,
572}
573
574impl Substitution {
575    pub fn new() -> Self {
576        Self {
577            mappings: HashMap::new(),
578        }
579    }
580
581    /// Compose this substitution with another.
582    pub fn compose(&mut self, other: &Substitution) {
583        // Apply other to all existing mappings
584        for ty in self.mappings.values_mut() {
585            *ty = apply_substitution(ty, other);
586        }
587        // Add new mappings from other
588        for (k, v) in &other.mappings {
589            self.mappings.entry(*k).or_insert_with(|| v.clone());
590        }
591    }
592}
593
594/// Apply a substitution to a type, replacing Var(id) with mapped types.
595pub fn apply_substitution(ty: &Type, subst: &Substitution) -> Type {
596    match ty {
597        Type::Var(id) => {
598            if let Some(replacement) = subst.mappings.get(id) {
599                apply_substitution(replacement, subst)
600            } else {
601                ty.clone()
602            }
603        }
604        Type::List(inner) => Type::List(Box::new(apply_substitution(inner, subst))),
605        Type::Map(inner) => Type::Map(Box::new(apply_substitution(inner, subst))),
606        Type::Set(inner) => Type::Set(Box::new(apply_substitution(inner, subst))),
607        Type::Option(inner) => Type::Option(Box::new(apply_substitution(inner, subst))),
608        Type::Result(ok, err) => Type::Result(
609            Box::new(apply_substitution(ok, subst)),
610            Box::new(apply_substitution(err, subst)),
611        ),
612        Type::Generator(inner) => Type::Generator(Box::new(apply_substitution(inner, subst))),
613        Type::Task(inner) => Type::Task(Box::new(apply_substitution(inner, subst))),
614        Type::Channel(inner) => Type::Channel(Box::new(apply_substitution(inner, subst))),
615        Type::Stream(inner) => Type::Stream(Box::new(apply_substitution(inner, subst))),
616        Type::Function { params, ret } => Type::Function {
617            params: params
618                .iter()
619                .map(|p| apply_substitution(p, subst))
620                .collect(),
621            ret: Box::new(apply_substitution(ret, subst)),
622        },
623        _ => ty.clone(),
624    }
625}
626
627/// Occurs check: does Var(id) appear in ty?
628fn occurs_in(id: u32, ty: &Type) -> bool {
629    match ty {
630        Type::Var(v) => *v == id,
631        Type::List(inner)
632        | Type::Map(inner)
633        | Type::Set(inner)
634        | Type::Option(inner)
635        | Type::Generator(inner)
636        | Type::Task(inner)
637        | Type::Channel(inner)
638        | Type::Stream(inner) => occurs_in(id, inner),
639        Type::Result(ok, err) => occurs_in(id, ok) || occurs_in(id, err),
640        Type::Function { params, ret } => {
641            params.iter().any(|p| occurs_in(id, p)) || occurs_in(id, ret)
642        }
643        _ => false,
644    }
645}
646
647/// Unify two types, producing a substitution or an error.
648pub fn unify(a: &Type, b: &Type) -> Result<Substitution, std::string::String> {
649    // Any produces no constraint
650    if matches!(a, Type::Any) || matches!(b, Type::Any) {
651        return Ok(Substitution::new());
652    }
653    // Error suppresses
654    if matches!(a, Type::Error) || matches!(b, Type::Error) {
655        return Ok(Substitution::new());
656    }
657    // TypeParam is type-erased
658    if matches!(a, Type::TypeParam(_)) || matches!(b, Type::TypeParam(_)) {
659        return Ok(Substitution::new());
660    }
661    // Var unification
662    if let Type::Var(id) = a {
663        if occurs_in(*id, b) {
664            return Err(format!("infinite type: ?T{id} occurs in {b}"));
665        }
666        let mut s = Substitution::new();
667        s.mappings.insert(*id, b.clone());
668        return Ok(s);
669    }
670    if let Type::Var(id) = b {
671        if occurs_in(*id, a) {
672            return Err(format!("infinite type: ?T{id} occurs in {a}"));
673        }
674        let mut s = Substitution::new();
675        s.mappings.insert(*id, a.clone());
676        return Ok(s);
677    }
678    // Same type
679    if a == b {
680        return Ok(Substitution::new());
681    }
682    // int promotes to float
683    if matches!(a, Type::Float) && matches!(b, Type::Int) {
684        return Ok(Substitution::new());
685    }
686    if matches!(a, Type::Int) && matches!(b, Type::Float) {
687        return Ok(Substitution::new());
688    }
689    // Decimal promotions
690    if matches!(a, Type::Float) && matches!(b, Type::Decimal) {
691        return Ok(Substitution::new());
692    }
693    if matches!(a, Type::Decimal) && matches!(b, Type::Int) {
694        return Ok(Substitution::new());
695    }
696    // Structural recursion
697    match (a, b) {
698        (Type::List(a_inner), Type::List(b_inner)) => unify(a_inner, b_inner),
699        (Type::Map(a_inner), Type::Map(b_inner)) => unify(a_inner, b_inner),
700        (Type::Set(a_inner), Type::Set(b_inner)) => unify(a_inner, b_inner),
701        (Type::Option(a_inner), Type::Option(b_inner)) => unify(a_inner, b_inner),
702        (Type::Generator(a_inner), Type::Generator(b_inner)) => unify(a_inner, b_inner),
703        (Type::Task(a_inner), Type::Task(b_inner)) => unify(a_inner, b_inner),
704        (Type::Channel(a_inner), Type::Channel(b_inner)) => unify(a_inner, b_inner),
705        (Type::Stream(a_inner), Type::Stream(b_inner)) => unify(a_inner, b_inner),
706        (Type::Result(ok1, err1), Type::Result(ok2, err2)) => {
707            let mut s = unify(ok1, ok2)?;
708            let s2 = unify(&apply_substitution(err1, &s), &apply_substitution(err2, &s))?;
709            s.compose(&s2);
710            Ok(s)
711        }
712        (
713            Type::Function {
714                params: p1,
715                ret: r1,
716            },
717            Type::Function {
718                params: p2,
719                ret: r2,
720            },
721        ) => {
722            if p1.len() != p2.len() {
723                return Err(format!(
724                    "function arity mismatch: {} vs {}",
725                    p1.len(),
726                    p2.len()
727                ));
728            }
729            let mut s = Substitution::new();
730            for (a_p, b_p) in p1.iter().zip(p2.iter()) {
731                let s2 = unify(&apply_substitution(a_p, &s), &apply_substitution(b_p, &s))?;
732                s.compose(&s2);
733            }
734            let s2 = unify(&apply_substitution(r1, &s), &apply_substitution(r2, &s))?;
735            s.compose(&s2);
736            Ok(s)
737        }
738        _ => Err(format!("cannot unify `{a}` with `{b}`")),
739    }
740}
741
742#[cfg(test)]
743mod tests {
744    use super::*;
745
746    #[test]
747    fn test_type_display() {
748        assert_eq!(Type::Int.to_string(), "int");
749        assert_eq!(Type::Option(Box::new(Type::Int)).to_string(), "int?");
750        assert_eq!(
751            Type::Result(Box::new(Type::Int), Box::new(Type::String)).to_string(),
752            "result<int, string>"
753        );
754        assert_eq!(Type::List(Box::new(Type::Any)).to_string(), "list<any>");
755    }
756
757    #[test]
758    fn test_type_equality() {
759        assert_eq!(Type::Int, Type::Int);
760        assert_ne!(Type::Int, Type::Float);
761        assert_eq!(
762            Type::List(Box::new(Type::Int)),
763            Type::List(Box::new(Type::Int))
764        );
765        assert_ne!(
766            Type::List(Box::new(Type::Int)),
767            Type::List(Box::new(Type::Float))
768        );
769    }
770
771    #[test]
772    fn test_type_env_push_pop_scope() {
773        let mut env = TypeEnv::new();
774        env.define("x".into(), Type::Int);
775        assert_eq!(env.lookup("x"), Some(&Type::Int));
776
777        env.push_scope();
778        env.define("y".into(), Type::String);
779        assert_eq!(env.lookup("y"), Some(&Type::String));
780        assert_eq!(env.lookup("x"), Some(&Type::Int)); // parent scope visible
781
782        env.pop_scope();
783        assert_eq!(env.lookup("y"), None); // y gone
784        assert_eq!(env.lookup("x"), Some(&Type::Int));
785    }
786
787    #[test]
788    fn test_type_env_variable_shadowing() {
789        let mut env = TypeEnv::new();
790        env.define("x".into(), Type::Int);
791        env.push_scope();
792        env.define("x".into(), Type::String);
793        assert_eq!(env.lookup("x"), Some(&Type::String)); // shadowed
794
795        env.pop_scope();
796        assert_eq!(env.lookup("x"), Some(&Type::Int)); // original restored
797    }
798
799    #[test]
800    fn test_compatibility_any() {
801        assert!(is_compatible(&Type::Any, &Type::Int));
802        assert!(is_compatible(&Type::Int, &Type::Any));
803        assert!(is_compatible(&Type::Any, &Type::Any));
804    }
805
806    #[test]
807    fn test_compatibility_option_none() {
808        assert!(is_compatible(
809            &Type::Option(Box::new(Type::Int)),
810            &Type::None
811        ));
812        assert!(is_compatible(
813            &Type::Option(Box::new(Type::Int)),
814            &Type::Int
815        ));
816        assert!(!is_compatible(&Type::Int, &Type::None));
817    }
818
819    #[test]
820    fn test_compatibility_int_float_promotion() {
821        assert!(is_compatible(&Type::Float, &Type::Int));
822        assert!(!is_compatible(&Type::Int, &Type::Float));
823    }
824
825    #[test]
826    fn test_compatibility_error_poison() {
827        assert!(is_compatible(&Type::Error, &Type::Int));
828        assert!(is_compatible(&Type::Int, &Type::Error));
829    }
830
831    // ── Phase 22: Advanced Type System ──────────────────────
832
833    #[test]
834    fn test_new_type_display() {
835        assert_eq!(Type::Decimal.to_string(), "decimal");
836        assert_eq!(Type::Tensor.to_string(), "tensor");
837        assert_eq!(Type::Pipeline.to_string(), "pipeline");
838        assert_eq!(Type::Stream(Box::new(Type::Int)).to_string(), "stream<int>");
839        assert_eq!(
840            Type::Table {
841                name: Some("User".into()),
842                columns: None
843            }
844            .to_string(),
845            "table<User>"
846        );
847        assert_eq!(
848            Type::Table {
849                name: None,
850                columns: None
851            }
852            .to_string(),
853            "table"
854        );
855    }
856
857    #[test]
858    fn test_decimal_compatibility() {
859        // Decimal promotes to Float
860        assert!(is_compatible(&Type::Float, &Type::Decimal));
861        // Int promotes to Decimal
862        assert!(is_compatible(&Type::Decimal, &Type::Int));
863        // Decimal == Decimal
864        assert!(is_compatible(&Type::Decimal, &Type::Decimal));
865        // Decimal does NOT promote to Int
866        assert!(!is_compatible(&Type::Int, &Type::Decimal));
867    }
868
869    #[test]
870    fn test_stream_compatibility() {
871        assert!(is_compatible(
872            &Type::Stream(Box::new(Type::Int)),
873            &Type::Stream(Box::new(Type::Int))
874        ));
875        assert!(!is_compatible(
876            &Type::Stream(Box::new(Type::Int)),
877            &Type::Stream(Box::new(Type::String))
878        ));
879        // Any element type is compatible
880        assert!(is_compatible(
881            &Type::Stream(Box::new(Type::Any)),
882            &Type::Stream(Box::new(Type::Int))
883        ));
884    }
885
886    #[test]
887    fn test_table_column_compatibility() {
888        let t1 = Type::Table {
889            name: None,
890            columns: Some(vec![
891                ("id".into(), Type::Int),
892                ("name".into(), Type::String),
893            ]),
894        };
895        let t2 = Type::Table {
896            name: None,
897            columns: Some(vec![
898                ("id".into(), Type::Int),
899                ("name".into(), Type::String),
900            ]),
901        };
902        assert!(is_compatible(&t1, &t2));
903
904        // None columns = compatible with any
905        let t3 = Type::Table {
906            name: None,
907            columns: None,
908        };
909        assert!(is_compatible(&t1, &t3));
910        assert!(is_compatible(&t3, &t1));
911    }
912
913    #[test]
914    fn test_decimal_satisfies_traits() {
915        let env = TypeEnv::new();
916        assert!(env.type_satisfies_trait(&Type::Decimal, "Numeric"));
917        assert!(env.type_satisfies_trait(&Type::Decimal, "Comparable"));
918        assert!(env.type_satisfies_trait(&Type::Decimal, "Hashable"));
919        assert!(env.type_satisfies_trait(&Type::Decimal, "Displayable"));
920        assert!(env.type_satisfies_trait(&Type::Decimal, "Serializable"));
921    }
922
923    #[test]
924    fn test_unify_basic() {
925        // Same types unify
926        assert!(unify(&Type::Int, &Type::Int).is_ok());
927        // Any unifies with everything
928        assert!(unify(&Type::Any, &Type::Int).is_ok());
929        // Different concrete types fail
930        assert!(unify(&Type::Int, &Type::String).is_err());
931    }
932
933    #[test]
934    fn test_unify_var() {
935        let s = unify(&Type::Var(0), &Type::Int).unwrap();
936        assert_eq!(s.mappings.get(&0), Some(&Type::Int));
937    }
938
939    #[test]
940    fn test_unify_occurs_check() {
941        // Var(0) cannot unify with List(Var(0)) — infinite type
942        let result = unify(&Type::Var(0), &Type::List(Box::new(Type::Var(0))));
943        assert!(result.is_err());
944    }
945
946    #[test]
947    fn test_unify_structural() {
948        let s = unify(
949            &Type::List(Box::new(Type::Var(0))),
950            &Type::List(Box::new(Type::Int)),
951        )
952        .unwrap();
953        assert_eq!(s.mappings.get(&0), Some(&Type::Int));
954    }
955
956    #[test]
957    fn test_unify_function() {
958        let s = unify(
959            &Type::Function {
960                params: vec![Type::Var(0)],
961                ret: Box::new(Type::Var(1)),
962            },
963            &Type::Function {
964                params: vec![Type::Int],
965                ret: Box::new(Type::String),
966            },
967        )
968        .unwrap();
969        assert_eq!(s.mappings.get(&0), Some(&Type::Int));
970        assert_eq!(s.mappings.get(&1), Some(&Type::String));
971    }
972
973    #[test]
974    fn test_apply_substitution() {
975        let mut s = Substitution::new();
976        s.mappings.insert(0, Type::Int);
977        s.mappings.insert(1, Type::String);
978
979        let ty = Type::List(Box::new(Type::Var(0)));
980        assert_eq!(apply_substitution(&ty, &s), Type::List(Box::new(Type::Int)));
981
982        let ty2 = Type::Function {
983            params: vec![Type::Var(0)],
984            ret: Box::new(Type::Var(1)),
985        };
986        assert_eq!(
987            apply_substitution(&ty2, &s),
988            Type::Function {
989                params: vec![Type::Int],
990                ret: Box::new(Type::String)
991            }
992        );
993    }
994
995    #[test]
996    fn test_sensitive_fields() {
997        let mut env = TypeEnv::new();
998        env.register_sensitive_field("User".into(), "ssn".into(), "sensitive".into());
999        env.register_sensitive_field("User".into(), "email".into(), "pii".into());
1000
1001        assert!(env.is_field_sensitive("User", "ssn"));
1002        assert!(env.is_field_sensitive("User", "email"));
1003        assert!(!env.is_field_sensitive("User", "name"));
1004        assert!(!env.is_field_sensitive("Other", "ssn"));
1005    }
1006
1007    #[test]
1008    fn test_resolving_aliases_cycle_detection() {
1009        let mut env = TypeEnv::new();
1010        assert!(!env.is_resolving_alias("Foo"));
1011        env.start_resolving_alias("Foo".into());
1012        assert!(env.is_resolving_alias("Foo"));
1013        env.stop_resolving_alias("Foo");
1014        assert!(!env.is_resolving_alias("Foo"));
1015    }
1016
1017    #[test]
1018    fn test_unify_promotions() {
1019        // int <-> float is OK
1020        assert!(unify(&Type::Float, &Type::Int).is_ok());
1021        assert!(unify(&Type::Int, &Type::Float).is_ok());
1022        // decimal <-> float
1023        assert!(unify(&Type::Float, &Type::Decimal).is_ok());
1024        // int -> decimal
1025        assert!(unify(&Type::Decimal, &Type::Int).is_ok());
1026    }
1027}