1pub mod checker;
8pub mod convert;
9pub mod infer;
10
11use std::collections::HashMap;
12use std::fmt;
13
14#[derive(Debug, Clone, PartialEq)]
17pub enum Type {
18 Any,
20 Unit,
22 Int,
24 Float,
25 String,
26 Bool,
27 None,
28 Decimal,
30 List(Box<Type>),
32 Map(Box<Type>),
33 Set(Box<Type>),
34 Option(Box<Type>),
36 Result(Box<Type>, Box<Type>),
38 Function {
40 params: Vec<Type>,
41 ret: Box<Type>,
42 },
43 Struct(std::string::String),
45 Enum(std::string::String),
47 Table {
49 name: Option<std::string::String>,
50 columns: Option<Vec<(std::string::String, Type)>>,
51 },
52 Generator(Box<Type>),
54 Task(Box<Type>),
56 Channel(Box<Type>),
58 Tensor,
60 Stream(Box<Type>),
62 Pipeline,
64 TypeParam(std::string::String),
66 Var(u32),
68 PyObject,
70 Ref(Box<Type>),
72 Error,
74}
75
76impl fmt::Display for Type {
77 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78 match self {
79 Type::Any => write!(f, "any"),
80 Type::Unit => write!(f, "unit"),
81 Type::Int => write!(f, "int"),
82 Type::Float => write!(f, "float"),
83 Type::String => write!(f, "string"),
84 Type::Bool => write!(f, "bool"),
85 Type::None => write!(f, "none"),
86 Type::Decimal => write!(f, "decimal"),
87 Type::List(t) => write!(f, "list<{t}>"),
88 Type::Map(t) => write!(f, "map<{t}>"),
89 Type::Set(t) => write!(f, "set<{t}>"),
90 Type::Option(t) => write!(f, "{t}?"),
91 Type::Result(ok, err) => write!(f, "result<{ok}, {err}>"),
92 Type::Function { params, ret } => {
93 write!(f, "fn(")?;
94 for (i, p) in params.iter().enumerate() {
95 if i > 0 {
96 write!(f, ", ")?;
97 }
98 write!(f, "{p}")?;
99 }
100 write!(f, ") -> {ret}")
101 }
102 Type::Struct(name) => write!(f, "{name}"),
103 Type::Enum(name) => write!(f, "{name}"),
104 Type::Table {
105 name: Some(name), ..
106 } => write!(f, "table<{name}>"),
107 Type::Table { name: None, .. } => write!(f, "table"),
108 Type::Generator(t) => write!(f, "generator<{t}>"),
109 Type::Task(t) => write!(f, "task<{t}>"),
110 Type::Channel(t) => write!(f, "channel<{t}>"),
111 Type::Tensor => write!(f, "tensor"),
112 Type::Stream(t) => write!(f, "stream<{t}>"),
113 Type::Pipeline => write!(f, "pipeline"),
114 Type::TypeParam(name) => write!(f, "{name}"),
115 Type::Var(id) => write!(f, "?T{id}"),
116 Type::PyObject => write!(f, "pyobject"),
117 Type::Ref(t) => write!(f, "&{t}"),
118 Type::Error => write!(f, "<error>"),
119 }
120 }
121}
122
123#[derive(Debug, Clone)]
125pub struct TraitInfo {
126 pub name: std::string::String,
127 pub methods: Vec<(std::string::String, Vec<Type>, Type)>, pub supertrait: Option<std::string::String>,
129}
130
131pub struct TypeEnv {
133 scopes: Vec<Scope>,
134 functions: std::collections::HashMap<std::string::String, FnSig>,
136 structs: std::collections::HashMap<std::string::String, Vec<(std::string::String, Type)>>,
138 enums: std::collections::HashMap<std::string::String, Vec<(std::string::String, Vec<Type>)>>,
140 traits: std::collections::HashMap<std::string::String, TraitInfo>,
142 trait_impls: std::collections::HashMap<
144 (std::string::String, std::string::String),
145 Vec<std::string::String>,
146 >,
147 type_aliases: std::collections::HashMap<
149 std::string::String,
150 (Vec<std::string::String>, tl_ast::TypeExpr),
151 >,
152 sensitive_fields: std::collections::HashMap<
154 std::string::String,
155 Vec<(std::string::String, std::string::String)>,
156 >,
157 resolving_aliases: std::collections::HashSet<std::string::String>,
159 next_var: u32,
161}
162
163#[derive(Debug, Clone)]
165pub struct FnSig {
166 pub params: Vec<(std::string::String, Type)>,
167 pub ret: Type,
168}
169
170struct Scope {
171 vars: std::collections::HashMap<std::string::String, Type>,
172}
173
174impl TypeEnv {
175 pub fn new() -> Self {
176 let mut env = TypeEnv {
177 scopes: vec![Scope {
178 vars: std::collections::HashMap::new(),
179 }],
180 functions: std::collections::HashMap::new(),
181 structs: std::collections::HashMap::new(),
182 enums: std::collections::HashMap::new(),
183 traits: std::collections::HashMap::new(),
184 trait_impls: std::collections::HashMap::new(),
185 type_aliases: std::collections::HashMap::new(),
186 sensitive_fields: std::collections::HashMap::new(),
187 resolving_aliases: std::collections::HashSet::new(),
188 next_var: 0,
189 };
190 env.register_builtin_traits();
191 env
192 }
193
194 fn register_builtin_traits(&mut self) {
196 self.traits.insert(
198 "Hashable".into(),
199 TraitInfo {
200 name: "Hashable".into(),
201 methods: vec![],
202 supertrait: None,
203 },
204 );
205 self.traits.insert(
207 "Comparable".into(),
208 TraitInfo {
209 name: "Comparable".into(),
210 methods: vec![],
211 supertrait: Some("Hashable".into()),
212 },
213 );
214 self.traits.insert(
216 "Numeric".into(),
217 TraitInfo {
218 name: "Numeric".into(),
219 methods: vec![],
220 supertrait: Some("Comparable".into()),
221 },
222 );
223 self.traits.insert(
225 "Displayable".into(),
226 TraitInfo {
227 name: "Displayable".into(),
228 methods: vec![("to_string".into(), vec![], Type::String)],
229 supertrait: None,
230 },
231 );
232 self.traits.insert(
234 "Serializable".into(),
235 TraitInfo {
236 name: "Serializable".into(),
237 methods: vec![],
238 supertrait: None,
239 },
240 );
241 self.traits.insert(
243 "Default".into(),
244 TraitInfo {
245 name: "Default".into(),
246 methods: vec![],
247 supertrait: None,
248 },
249 );
250 }
251
252 pub fn scope_depth(&self) -> u32 {
253 self.scopes.len() as u32 - 1
254 }
255
256 pub fn push_scope(&mut self) {
257 self.scopes.push(Scope {
258 vars: std::collections::HashMap::new(),
259 });
260 }
261
262 pub fn pop_scope(&mut self) {
263 if self.scopes.len() > 1 {
264 self.scopes.pop();
265 }
266 }
267
268 pub fn define(&mut self, name: std::string::String, ty: Type) {
269 if let Some(scope) = self.scopes.last_mut() {
270 scope.vars.insert(name, ty);
271 }
272 }
273
274 pub fn lookup(&self, name: &str) -> Option<&Type> {
275 for scope in self.scopes.iter().rev() {
276 if let Some(ty) = scope.vars.get(name) {
277 return Some(ty);
278 }
279 }
280 None
281 }
282
283 pub fn define_fn(&mut self, name: std::string::String, sig: FnSig) {
284 self.functions.insert(name, sig);
285 }
286
287 pub fn lookup_fn(&self, name: &str) -> Option<&FnSig> {
288 self.functions.get(name)
289 }
290
291 pub fn define_struct(
292 &mut self,
293 name: std::string::String,
294 fields: Vec<(std::string::String, Type)>,
295 ) {
296 self.structs.insert(name, fields);
297 }
298
299 pub fn lookup_struct(&self, name: &str) -> Option<&Vec<(std::string::String, Type)>> {
300 self.structs.get(name)
301 }
302
303 pub fn define_enum(
304 &mut self,
305 name: std::string::String,
306 variants: Vec<(std::string::String, Vec<Type>)>,
307 ) {
308 self.enums.insert(name, variants);
309 }
310
311 pub fn lookup_enum(&self, name: &str) -> Option<&Vec<(std::string::String, Vec<Type>)>> {
312 self.enums.get(name)
313 }
314
315 pub fn fresh_var(&mut self) -> Type {
316 let id = self.next_var;
317 self.next_var += 1;
318 Type::Var(id)
319 }
320
321 pub fn define_trait(&mut self, name: std::string::String, info: TraitInfo) {
322 self.traits.insert(name, info);
323 }
324
325 pub fn lookup_trait(&self, name: &str) -> Option<&TraitInfo> {
326 self.traits.get(name)
327 }
328
329 pub fn register_type_alias(
330 &mut self,
331 name: std::string::String,
332 type_params: Vec<std::string::String>,
333 value: tl_ast::TypeExpr,
334 ) {
335 self.type_aliases.insert(name, (type_params, value));
336 }
337
338 pub fn lookup_type_alias(
339 &self,
340 name: &str,
341 ) -> Option<&(Vec<std::string::String>, tl_ast::TypeExpr)> {
342 self.type_aliases.get(name)
343 }
344
345 pub fn is_resolving_alias(&self, name: &str) -> bool {
347 self.resolving_aliases.contains(name)
348 }
349
350 pub fn start_resolving_alias(&mut self, name: std::string::String) {
352 self.resolving_aliases.insert(name);
353 }
354
355 pub fn stop_resolving_alias(&mut self, name: &str) {
357 self.resolving_aliases.remove(name);
358 }
359
360 pub fn register_sensitive_field(
362 &mut self,
363 type_name: std::string::String,
364 field_name: std::string::String,
365 annotation: std::string::String,
366 ) {
367 self.sensitive_fields
368 .entry(type_name)
369 .or_default()
370 .push((field_name, annotation));
371 }
372
373 pub fn is_field_sensitive(&self, type_name: &str, field_name: &str) -> bool {
375 self.sensitive_fields
376 .get(type_name)
377 .map(|fields| fields.iter().any(|(f, _)| f == field_name))
378 .unwrap_or(false)
379 }
380
381 pub fn get_field_annotations(
383 &self,
384 type_name: &str,
385 ) -> Option<&Vec<(std::string::String, std::string::String)>> {
386 self.sensitive_fields.get(type_name)
387 }
388
389 pub fn register_trait_impl(
390 &mut self,
391 trait_name: std::string::String,
392 type_name: std::string::String,
393 method_names: Vec<std::string::String>,
394 ) {
395 self.trait_impls
396 .insert((trait_name, type_name), method_names);
397 }
398
399 pub fn lookup_trait_impl(
400 &self,
401 trait_name: &str,
402 type_name: &str,
403 ) -> Option<&Vec<std::string::String>> {
404 self.trait_impls
405 .get(&(trait_name.to_string(), type_name.to_string()))
406 }
407
408 pub fn type_satisfies_trait(&self, ty: &Type, trait_name: &str) -> bool {
410 if matches!(ty, Type::Any | Type::Error | Type::TypeParam(_)) {
412 return true;
413 }
414 match trait_name {
416 "Numeric" => matches!(ty, Type::Int | Type::Float | Type::Decimal),
417 "Comparable" => {
418 matches!(ty, Type::Int | Type::Float | Type::String | Type::Decimal)
419 || self.type_satisfies_trait(ty, "Numeric")
420 }
421 "Hashable" => {
422 matches!(
423 ty,
424 Type::Int | Type::Float | Type::String | Type::Bool | Type::Decimal
425 ) || self.type_satisfies_trait(ty, "Comparable")
426 }
427 "Displayable" => matches!(
428 ty,
429 Type::Int | Type::Float | Type::String | Type::Bool | Type::None | Type::Decimal
430 ),
431 "Default" => matches!(
432 ty,
433 Type::Int
434 | Type::Float
435 | Type::String
436 | Type::Bool
437 | Type::None
438 | Type::List(_)
439 | Type::Map(_)
440 | Type::Set(_)
441 ),
442 "Serializable" => matches!(
443 ty,
444 Type::Int
445 | Type::Float
446 | Type::String
447 | Type::Bool
448 | Type::None
449 | Type::Decimal
450 | Type::Struct(_)
451 ),
452 _ => {
453 let type_name = match ty {
455 Type::Struct(n) | Type::Enum(n) => n.as_str(),
456 _ => return false,
457 };
458 self.lookup_trait_impl(trait_name, type_name).is_some()
459 }
460 }
461 }
462}
463
464impl Default for TypeEnv {
465 fn default() -> Self {
466 Self::new()
467 }
468}
469
470pub fn is_compatible(expected: &Type, found: &Type) -> bool {
473 if matches!(expected, Type::Any) || matches!(found, Type::Any) {
475 return true;
476 }
477 if matches!(expected, Type::Error) || matches!(found, Type::Error) {
479 return true;
480 }
481 if matches!(expected, Type::TypeParam(_)) || matches!(found, Type::TypeParam(_)) {
483 return true;
484 }
485 if expected == found {
487 return true;
488 }
489 if matches!(expected, Type::Float) && matches!(found, Type::Int) {
491 return true;
492 }
493 if matches!(expected, Type::Float) && matches!(found, Type::Decimal) {
495 return true;
496 }
497 if matches!(expected, Type::Decimal) && matches!(found, Type::Int) {
499 return true;
500 }
501 if matches!(found, Type::None) && matches!(expected, Type::Option(_)) {
503 return true;
504 }
505 if let Type::Option(inner) = expected
507 && is_compatible(inner, found)
508 {
509 return true;
510 }
511 match (expected, found) {
513 (Type::List(a), Type::List(b)) => is_compatible(a, b),
514 (Type::Map(a), Type::Map(b)) => is_compatible(a, b),
515 (Type::Set(a), Type::Set(b)) => is_compatible(a, b),
516 (Type::Option(a), Type::Option(b)) => is_compatible(a, b),
517 (Type::Result(ok1, err1), Type::Result(ok2, err2)) => {
518 is_compatible(ok1, ok2) && is_compatible(err1, err2)
519 }
520 (Type::Generator(a), Type::Generator(b)) => is_compatible(a, b),
521 (Type::Task(a), Type::Task(b)) => is_compatible(a, b),
522 (Type::Channel(a), Type::Channel(b)) => is_compatible(a, b),
523 (Type::Stream(a), Type::Stream(b)) => is_compatible(a, b),
524 (
526 Type::Table {
527 name: n1,
528 columns: c1,
529 },
530 Type::Table {
531 name: n2,
532 columns: c2,
533 },
534 ) => {
535 let name_ok = match (n1, n2) {
536 (Some(a), Some(b)) => a == b,
537 _ => true,
538 };
539 let cols_ok = match (c1, c2) {
540 (None, _) | (_, None) => true,
541 (Some(a), Some(b)) => {
542 a.len() == b.len()
543 && a.iter()
544 .zip(b.iter())
545 .all(|((n1, t1), (n2, t2))| n1 == n2 && is_compatible(t1, t2))
546 }
547 };
548 name_ok && cols_ok
549 }
550 (
551 Type::Function {
552 params: p1,
553 ret: r1,
554 },
555 Type::Function {
556 params: p2,
557 ret: r2,
558 },
559 ) => {
560 p1.len() == p2.len()
561 && p1.iter().zip(p2.iter()).all(|(a, b)| is_compatible(a, b))
562 && is_compatible(r1, r2)
563 }
564 _ => false,
565 }
566}
567
568#[derive(Debug, Clone, Default)]
570pub struct Substitution {
571 pub mappings: HashMap<u32, Type>,
572}
573
574impl Substitution {
575 pub fn new() -> Self {
576 Self {
577 mappings: HashMap::new(),
578 }
579 }
580
581 pub fn compose(&mut self, other: &Substitution) {
583 for ty in self.mappings.values_mut() {
585 *ty = apply_substitution(ty, other);
586 }
587 for (k, v) in &other.mappings {
589 self.mappings.entry(*k).or_insert_with(|| v.clone());
590 }
591 }
592}
593
594pub fn apply_substitution(ty: &Type, subst: &Substitution) -> Type {
596 match ty {
597 Type::Var(id) => {
598 if let Some(replacement) = subst.mappings.get(id) {
599 apply_substitution(replacement, subst)
600 } else {
601 ty.clone()
602 }
603 }
604 Type::List(inner) => Type::List(Box::new(apply_substitution(inner, subst))),
605 Type::Map(inner) => Type::Map(Box::new(apply_substitution(inner, subst))),
606 Type::Set(inner) => Type::Set(Box::new(apply_substitution(inner, subst))),
607 Type::Option(inner) => Type::Option(Box::new(apply_substitution(inner, subst))),
608 Type::Result(ok, err) => Type::Result(
609 Box::new(apply_substitution(ok, subst)),
610 Box::new(apply_substitution(err, subst)),
611 ),
612 Type::Generator(inner) => Type::Generator(Box::new(apply_substitution(inner, subst))),
613 Type::Task(inner) => Type::Task(Box::new(apply_substitution(inner, subst))),
614 Type::Channel(inner) => Type::Channel(Box::new(apply_substitution(inner, subst))),
615 Type::Stream(inner) => Type::Stream(Box::new(apply_substitution(inner, subst))),
616 Type::Function { params, ret } => Type::Function {
617 params: params
618 .iter()
619 .map(|p| apply_substitution(p, subst))
620 .collect(),
621 ret: Box::new(apply_substitution(ret, subst)),
622 },
623 _ => ty.clone(),
624 }
625}
626
627fn occurs_in(id: u32, ty: &Type) -> bool {
629 match ty {
630 Type::Var(v) => *v == id,
631 Type::List(inner)
632 | Type::Map(inner)
633 | Type::Set(inner)
634 | Type::Option(inner)
635 | Type::Generator(inner)
636 | Type::Task(inner)
637 | Type::Channel(inner)
638 | Type::Stream(inner) => occurs_in(id, inner),
639 Type::Result(ok, err) => occurs_in(id, ok) || occurs_in(id, err),
640 Type::Function { params, ret } => {
641 params.iter().any(|p| occurs_in(id, p)) || occurs_in(id, ret)
642 }
643 _ => false,
644 }
645}
646
647pub fn unify(a: &Type, b: &Type) -> Result<Substitution, std::string::String> {
649 if matches!(a, Type::Any) || matches!(b, Type::Any) {
651 return Ok(Substitution::new());
652 }
653 if matches!(a, Type::Error) || matches!(b, Type::Error) {
655 return Ok(Substitution::new());
656 }
657 if matches!(a, Type::TypeParam(_)) || matches!(b, Type::TypeParam(_)) {
659 return Ok(Substitution::new());
660 }
661 if let Type::Var(id) = a {
663 if occurs_in(*id, b) {
664 return Err(format!("infinite type: ?T{id} occurs in {b}"));
665 }
666 let mut s = Substitution::new();
667 s.mappings.insert(*id, b.clone());
668 return Ok(s);
669 }
670 if let Type::Var(id) = b {
671 if occurs_in(*id, a) {
672 return Err(format!("infinite type: ?T{id} occurs in {a}"));
673 }
674 let mut s = Substitution::new();
675 s.mappings.insert(*id, a.clone());
676 return Ok(s);
677 }
678 if a == b {
680 return Ok(Substitution::new());
681 }
682 if matches!(a, Type::Float) && matches!(b, Type::Int) {
684 return Ok(Substitution::new());
685 }
686 if matches!(a, Type::Int) && matches!(b, Type::Float) {
687 return Ok(Substitution::new());
688 }
689 if matches!(a, Type::Float) && matches!(b, Type::Decimal) {
691 return Ok(Substitution::new());
692 }
693 if matches!(a, Type::Decimal) && matches!(b, Type::Int) {
694 return Ok(Substitution::new());
695 }
696 match (a, b) {
698 (Type::List(a_inner), Type::List(b_inner)) => unify(a_inner, b_inner),
699 (Type::Map(a_inner), Type::Map(b_inner)) => unify(a_inner, b_inner),
700 (Type::Set(a_inner), Type::Set(b_inner)) => unify(a_inner, b_inner),
701 (Type::Option(a_inner), Type::Option(b_inner)) => unify(a_inner, b_inner),
702 (Type::Generator(a_inner), Type::Generator(b_inner)) => unify(a_inner, b_inner),
703 (Type::Task(a_inner), Type::Task(b_inner)) => unify(a_inner, b_inner),
704 (Type::Channel(a_inner), Type::Channel(b_inner)) => unify(a_inner, b_inner),
705 (Type::Stream(a_inner), Type::Stream(b_inner)) => unify(a_inner, b_inner),
706 (Type::Result(ok1, err1), Type::Result(ok2, err2)) => {
707 let mut s = unify(ok1, ok2)?;
708 let s2 = unify(&apply_substitution(err1, &s), &apply_substitution(err2, &s))?;
709 s.compose(&s2);
710 Ok(s)
711 }
712 (
713 Type::Function {
714 params: p1,
715 ret: r1,
716 },
717 Type::Function {
718 params: p2,
719 ret: r2,
720 },
721 ) => {
722 if p1.len() != p2.len() {
723 return Err(format!(
724 "function arity mismatch: {} vs {}",
725 p1.len(),
726 p2.len()
727 ));
728 }
729 let mut s = Substitution::new();
730 for (a_p, b_p) in p1.iter().zip(p2.iter()) {
731 let s2 = unify(&apply_substitution(a_p, &s), &apply_substitution(b_p, &s))?;
732 s.compose(&s2);
733 }
734 let s2 = unify(&apply_substitution(r1, &s), &apply_substitution(r2, &s))?;
735 s.compose(&s2);
736 Ok(s)
737 }
738 _ => Err(format!("cannot unify `{a}` with `{b}`")),
739 }
740}
741
742#[cfg(test)]
743mod tests {
744 use super::*;
745
746 #[test]
747 fn test_type_display() {
748 assert_eq!(Type::Int.to_string(), "int");
749 assert_eq!(Type::Option(Box::new(Type::Int)).to_string(), "int?");
750 assert_eq!(
751 Type::Result(Box::new(Type::Int), Box::new(Type::String)).to_string(),
752 "result<int, string>"
753 );
754 assert_eq!(Type::List(Box::new(Type::Any)).to_string(), "list<any>");
755 }
756
757 #[test]
758 fn test_type_equality() {
759 assert_eq!(Type::Int, Type::Int);
760 assert_ne!(Type::Int, Type::Float);
761 assert_eq!(
762 Type::List(Box::new(Type::Int)),
763 Type::List(Box::new(Type::Int))
764 );
765 assert_ne!(
766 Type::List(Box::new(Type::Int)),
767 Type::List(Box::new(Type::Float))
768 );
769 }
770
771 #[test]
772 fn test_type_env_push_pop_scope() {
773 let mut env = TypeEnv::new();
774 env.define("x".into(), Type::Int);
775 assert_eq!(env.lookup("x"), Some(&Type::Int));
776
777 env.push_scope();
778 env.define("y".into(), Type::String);
779 assert_eq!(env.lookup("y"), Some(&Type::String));
780 assert_eq!(env.lookup("x"), Some(&Type::Int)); env.pop_scope();
783 assert_eq!(env.lookup("y"), None); assert_eq!(env.lookup("x"), Some(&Type::Int));
785 }
786
787 #[test]
788 fn test_type_env_variable_shadowing() {
789 let mut env = TypeEnv::new();
790 env.define("x".into(), Type::Int);
791 env.push_scope();
792 env.define("x".into(), Type::String);
793 assert_eq!(env.lookup("x"), Some(&Type::String)); env.pop_scope();
796 assert_eq!(env.lookup("x"), Some(&Type::Int)); }
798
799 #[test]
800 fn test_compatibility_any() {
801 assert!(is_compatible(&Type::Any, &Type::Int));
802 assert!(is_compatible(&Type::Int, &Type::Any));
803 assert!(is_compatible(&Type::Any, &Type::Any));
804 }
805
806 #[test]
807 fn test_compatibility_option_none() {
808 assert!(is_compatible(
809 &Type::Option(Box::new(Type::Int)),
810 &Type::None
811 ));
812 assert!(is_compatible(
813 &Type::Option(Box::new(Type::Int)),
814 &Type::Int
815 ));
816 assert!(!is_compatible(&Type::Int, &Type::None));
817 }
818
819 #[test]
820 fn test_compatibility_int_float_promotion() {
821 assert!(is_compatible(&Type::Float, &Type::Int));
822 assert!(!is_compatible(&Type::Int, &Type::Float));
823 }
824
825 #[test]
826 fn test_compatibility_error_poison() {
827 assert!(is_compatible(&Type::Error, &Type::Int));
828 assert!(is_compatible(&Type::Int, &Type::Error));
829 }
830
831 #[test]
834 fn test_new_type_display() {
835 assert_eq!(Type::Decimal.to_string(), "decimal");
836 assert_eq!(Type::Tensor.to_string(), "tensor");
837 assert_eq!(Type::Pipeline.to_string(), "pipeline");
838 assert_eq!(Type::Stream(Box::new(Type::Int)).to_string(), "stream<int>");
839 assert_eq!(
840 Type::Table {
841 name: Some("User".into()),
842 columns: None
843 }
844 .to_string(),
845 "table<User>"
846 );
847 assert_eq!(
848 Type::Table {
849 name: None,
850 columns: None
851 }
852 .to_string(),
853 "table"
854 );
855 }
856
857 #[test]
858 fn test_decimal_compatibility() {
859 assert!(is_compatible(&Type::Float, &Type::Decimal));
861 assert!(is_compatible(&Type::Decimal, &Type::Int));
863 assert!(is_compatible(&Type::Decimal, &Type::Decimal));
865 assert!(!is_compatible(&Type::Int, &Type::Decimal));
867 }
868
869 #[test]
870 fn test_stream_compatibility() {
871 assert!(is_compatible(
872 &Type::Stream(Box::new(Type::Int)),
873 &Type::Stream(Box::new(Type::Int))
874 ));
875 assert!(!is_compatible(
876 &Type::Stream(Box::new(Type::Int)),
877 &Type::Stream(Box::new(Type::String))
878 ));
879 assert!(is_compatible(
881 &Type::Stream(Box::new(Type::Any)),
882 &Type::Stream(Box::new(Type::Int))
883 ));
884 }
885
886 #[test]
887 fn test_table_column_compatibility() {
888 let t1 = Type::Table {
889 name: None,
890 columns: Some(vec![
891 ("id".into(), Type::Int),
892 ("name".into(), Type::String),
893 ]),
894 };
895 let t2 = Type::Table {
896 name: None,
897 columns: Some(vec![
898 ("id".into(), Type::Int),
899 ("name".into(), Type::String),
900 ]),
901 };
902 assert!(is_compatible(&t1, &t2));
903
904 let t3 = Type::Table {
906 name: None,
907 columns: None,
908 };
909 assert!(is_compatible(&t1, &t3));
910 assert!(is_compatible(&t3, &t1));
911 }
912
913 #[test]
914 fn test_decimal_satisfies_traits() {
915 let env = TypeEnv::new();
916 assert!(env.type_satisfies_trait(&Type::Decimal, "Numeric"));
917 assert!(env.type_satisfies_trait(&Type::Decimal, "Comparable"));
918 assert!(env.type_satisfies_trait(&Type::Decimal, "Hashable"));
919 assert!(env.type_satisfies_trait(&Type::Decimal, "Displayable"));
920 assert!(env.type_satisfies_trait(&Type::Decimal, "Serializable"));
921 }
922
923 #[test]
924 fn test_unify_basic() {
925 assert!(unify(&Type::Int, &Type::Int).is_ok());
927 assert!(unify(&Type::Any, &Type::Int).is_ok());
929 assert!(unify(&Type::Int, &Type::String).is_err());
931 }
932
933 #[test]
934 fn test_unify_var() {
935 let s = unify(&Type::Var(0), &Type::Int).unwrap();
936 assert_eq!(s.mappings.get(&0), Some(&Type::Int));
937 }
938
939 #[test]
940 fn test_unify_occurs_check() {
941 let result = unify(&Type::Var(0), &Type::List(Box::new(Type::Var(0))));
943 assert!(result.is_err());
944 }
945
946 #[test]
947 fn test_unify_structural() {
948 let s = unify(
949 &Type::List(Box::new(Type::Var(0))),
950 &Type::List(Box::new(Type::Int)),
951 )
952 .unwrap();
953 assert_eq!(s.mappings.get(&0), Some(&Type::Int));
954 }
955
956 #[test]
957 fn test_unify_function() {
958 let s = unify(
959 &Type::Function {
960 params: vec![Type::Var(0)],
961 ret: Box::new(Type::Var(1)),
962 },
963 &Type::Function {
964 params: vec![Type::Int],
965 ret: Box::new(Type::String),
966 },
967 )
968 .unwrap();
969 assert_eq!(s.mappings.get(&0), Some(&Type::Int));
970 assert_eq!(s.mappings.get(&1), Some(&Type::String));
971 }
972
973 #[test]
974 fn test_apply_substitution() {
975 let mut s = Substitution::new();
976 s.mappings.insert(0, Type::Int);
977 s.mappings.insert(1, Type::String);
978
979 let ty = Type::List(Box::new(Type::Var(0)));
980 assert_eq!(apply_substitution(&ty, &s), Type::List(Box::new(Type::Int)));
981
982 let ty2 = Type::Function {
983 params: vec![Type::Var(0)],
984 ret: Box::new(Type::Var(1)),
985 };
986 assert_eq!(
987 apply_substitution(&ty2, &s),
988 Type::Function {
989 params: vec![Type::Int],
990 ret: Box::new(Type::String)
991 }
992 );
993 }
994
995 #[test]
996 fn test_sensitive_fields() {
997 let mut env = TypeEnv::new();
998 env.register_sensitive_field("User".into(), "ssn".into(), "sensitive".into());
999 env.register_sensitive_field("User".into(), "email".into(), "pii".into());
1000
1001 assert!(env.is_field_sensitive("User", "ssn"));
1002 assert!(env.is_field_sensitive("User", "email"));
1003 assert!(!env.is_field_sensitive("User", "name"));
1004 assert!(!env.is_field_sensitive("Other", "ssn"));
1005 }
1006
1007 #[test]
1008 fn test_resolving_aliases_cycle_detection() {
1009 let mut env = TypeEnv::new();
1010 assert!(!env.is_resolving_alias("Foo"));
1011 env.start_resolving_alias("Foo".into());
1012 assert!(env.is_resolving_alias("Foo"));
1013 env.stop_resolving_alias("Foo");
1014 assert!(!env.is_resolving_alias("Foo"));
1015 }
1016
1017 #[test]
1018 fn test_unify_promotions() {
1019 assert!(unify(&Type::Float, &Type::Int).is_ok());
1021 assert!(unify(&Type::Int, &Type::Float).is_ok());
1022 assert!(unify(&Type::Float, &Type::Decimal).is_ok());
1024 assert!(unify(&Type::Decimal, &Type::Int).is_ok());
1026 }
1027}