Skip to main content

specl_types/
checker.rs

1//! Type checker implementation.
2
3use crate::env::{ActionSig, TypeEnv};
4use crate::error::{TypeError, TypeResult};
5use crate::types::{Substitution, Type, TypeVarGen};
6use specl_syntax::*;
7
8/// Type check a module.
9pub fn check_module(module: &Module) -> TypeResult<TypeEnv> {
10    let mut checker = TypeChecker::new();
11    checker.check_module(module)?;
12    Ok(checker.env)
13}
14
15/// The type checker.
16pub struct TypeChecker {
17    /// Type environment.
18    pub env: TypeEnv,
19    /// Type variable generator.
20    var_gen: TypeVarGen,
21}
22
23impl TypeChecker {
24    /// Create a new type checker.
25    pub fn new() -> Self {
26        Self {
27            env: TypeEnv::new(),
28            var_gen: TypeVarGen::new(),
29        }
30    }
31
32    /// Check a complete module.
33    pub fn check_module(&mut self, module: &Module) -> TypeResult<()> {
34        // First pass: collect all declarations
35        for decl in &module.decls {
36            self.collect_decl(decl)?;
37        }
38
39        // Second pass: check bodies
40        for decl in &module.decls {
41            self.check_decl(decl)?;
42        }
43
44        Ok(())
45    }
46
47    /// Collect declaration signatures (first pass).
48    fn collect_decl(&mut self, decl: &Decl) -> TypeResult<()> {
49        match decl {
50            Decl::Const(d) => {
51                let ty = match &d.value {
52                    specl_syntax::ConstValue::Type(type_expr) => {
53                        self.convert_type_expr(type_expr)?
54                    }
55                    specl_syntax::ConstValue::Scalar(n) => {
56                        // Scalar constants have type Int (or Nat if non-negative)
57                        if *n >= 0 {
58                            Type::Nat
59                        } else {
60                            Type::Int
61                        }
62                    }
63                };
64                self.env.define_const(d.name.name.clone(), ty);
65            }
66            Decl::Var(d) => {
67                let ty = self.convert_type_expr(&d.ty)?;
68                self.env.define_var(d.name.name.clone(), ty);
69            }
70            Decl::Type(d) => {
71                let ty = self.convert_type_expr(&d.ty)?;
72                self.env.define_type_alias(d.name.name.clone(), ty);
73            }
74            Decl::Action(d) => {
75                let params: Vec<(String, Type)> = d
76                    .params
77                    .iter()
78                    .map(|p| {
79                        let ty = self.convert_type_expr(&p.ty)?;
80                        Ok((p.name.name.clone(), ty))
81                    })
82                    .collect::<TypeResult<_>>()?;
83                self.env
84                    .define_action(d.name.name.clone(), ActionSig { params });
85            }
86            Decl::Func(d) => {
87                // Functions have polymorphic types - we use fresh type vars for params
88                let param_types: Vec<Type> =
89                    d.params.iter().map(|_| self.var_gen.fresh_type()).collect();
90                self.env.define_func(
91                    d.name.name.clone(),
92                    d.params.iter().map(|p| p.name.name.clone()).collect(),
93                    param_types,
94                );
95            }
96            _ => {}
97        }
98        Ok(())
99    }
100
101    /// Check declaration bodies (second pass).
102    fn check_decl(&mut self, decl: &Decl) -> TypeResult<()> {
103        match decl {
104            Decl::Init(d) => {
105                let ty = self.infer_expr(&d.body)?;
106                self.expect_bool(&ty, d.body.span)?;
107            }
108            Decl::Action(d) => {
109                self.env.push_scope();
110
111                // Bind parameters
112                for param in &d.params {
113                    let ty = self.convert_type_expr(&param.ty)?;
114                    self.env.bind_local(param.name.name.clone(), ty);
115                }
116
117                // Check requires
118                for req in &d.body.requires {
119                    let ty = self.infer_expr(req)?;
120                    self.expect_bool(&ty, req.span)?;
121                }
122
123                // Check effect
124                if let Some(effect) = &d.body.effect {
125                    let ty = self.infer_expr(effect)?;
126                    self.expect_bool(&ty, effect.span)?;
127                }
128
129                self.env.pop_scope();
130            }
131            Decl::Invariant(d) => {
132                let ty = self.infer_expr(&d.body)?;
133                self.expect_bool(&ty, d.body.span)?;
134            }
135            Decl::Property(d) => {
136                let ty = self.infer_expr(&d.body)?;
137                self.expect_bool(&ty, d.body.span)?;
138            }
139            Decl::Func(d) => {
140                self.env.push_scope();
141                // Bind parameters with fresh type variables
142                let param_types = self
143                    .env
144                    .lookup_func(&d.name.name)
145                    .map(|f| f.param_types.clone())
146                    .unwrap_or_default();
147                for (param, ty) in d.params.iter().zip(param_types.iter()) {
148                    self.env.bind_local(param.name.name.clone(), ty.clone());
149                }
150                // Infer body type (don't constrain it - functions can return anything)
151                let _body_ty = self.infer_expr(&d.body)?;
152                self.env.pop_scope();
153            }
154            _ => {}
155        }
156        Ok(())
157    }
158
159    /// Convert an AST type expression to a Type.
160    fn convert_type_expr(&mut self, ty_expr: &TypeExpr) -> TypeResult<Type> {
161        match ty_expr {
162            TypeExpr::Named(id) => {
163                let name = &id.name;
164                match name.as_str() {
165                    "Bool" => Ok(Type::Bool),
166                    "Nat" => Ok(Type::Nat),
167                    "Int" => Ok(Type::Int),
168                    "String" => Ok(Type::String),
169                    // Special marker for inferred types (used by TLA+ translator for empty domains)
170                    "_" | "Inferred" => Ok(self.var_gen.fresh_type()),
171                    _ => {
172                        // Check if it's a defined type alias
173                        if self.env.lookup_type_alias(name).is_some() {
174                            Ok(Type::Named(name.clone()))
175                        } else {
176                            Err(TypeError::UndefinedType {
177                                name: name.clone(),
178                                span: id.span,
179                            })
180                        }
181                    }
182                }
183            }
184            TypeExpr::Set(inner, _) => {
185                let inner_ty = self.convert_type_expr(inner)?;
186                Ok(Type::Set(Box::new(inner_ty)))
187            }
188            TypeExpr::Seq(inner, _) => {
189                let inner_ty = self.convert_type_expr(inner)?;
190                Ok(Type::Seq(Box::new(inner_ty)))
191            }
192            TypeExpr::Dict(key, value, _) => {
193                let key_ty = self.convert_type_expr(key)?;
194                let value_ty = self.convert_type_expr(value)?;
195                Ok(Type::Fn(Box::new(key_ty), Box::new(value_ty)))
196            }
197            TypeExpr::Option(inner, _) => {
198                let inner_ty = self.convert_type_expr(inner)?;
199                Ok(Type::Option(Box::new(inner_ty)))
200            }
201            TypeExpr::Range(lo, hi, _span) => {
202                // Check that both bounds are numeric
203                let lo_ty = self.infer_expr(lo)?;
204                let hi_ty = self.infer_expr(hi)?;
205
206                if !lo_ty.is_numeric() {
207                    return Err(TypeError::ExpectedNumeric {
208                        found: lo_ty,
209                        span: lo.span,
210                    });
211                }
212                if !hi_ty.is_numeric() {
213                    return Err(TypeError::ExpectedNumeric {
214                        found: hi_ty,
215                        span: hi.span,
216                    });
217                }
218
219                // Extract literal values if possible
220                let lo_val = self.extract_int_literal(lo);
221                let hi_val = self.extract_int_literal(hi);
222
223                match (lo_val, hi_val) {
224                    (Some(lo), Some(hi)) => Ok(Type::Range(lo, hi)),
225                    _ => Ok(Type::Int), // Fallback to Int if bounds aren't literals
226                }
227            }
228            TypeExpr::Tuple(elems, _) => {
229                let elem_types: Vec<Type> = elems
230                    .iter()
231                    .map(|ty| self.convert_type_expr(ty))
232                    .collect::<TypeResult<_>>()?;
233                Ok(Type::Tuple(elem_types))
234            }
235        }
236    }
237
238    /// Extract an integer literal value from an expression.
239    fn extract_int_literal(&self, expr: &Expr) -> Option<i64> {
240        match &expr.kind {
241            ExprKind::Int(n) => Some(*n),
242            ExprKind::Ident(_name) => {
243                // Check if it's a constant with a known value
244                // For now, we can't resolve constant values at type-checking time
245                None
246            }
247            _ => None,
248        }
249    }
250
251    /// Infer the type of an expression.
252    pub fn infer_expr(&mut self, expr: &Expr) -> TypeResult<Type> {
253        let ty = match &expr.kind {
254            ExprKind::Bool(_) => Type::Bool,
255            ExprKind::Int(_) => Type::Int,
256            ExprKind::String(_) => Type::String,
257
258            ExprKind::Ident(name) => {
259                if let Some(ty) = self.env.lookup_ident(name) {
260                    ty.clone()
261                } else {
262                    return Err(TypeError::UndefinedVariable {
263                        name: name.clone(),
264                        span: expr.span,
265                    });
266                }
267            }
268
269            ExprKind::Primed(name) => {
270                // Must be a state variable
271                if let Some(ty) = self.env.lookup_var(name) {
272                    ty.clone()
273                } else {
274                    return Err(TypeError::InvalidPrime { span: expr.span });
275                }
276            }
277
278            ExprKind::Binary { op, left, right } => self.infer_binary(*op, left, right)?,
279
280            ExprKind::Unary { op, operand } => self.infer_unary(*op, operand)?,
281
282            ExprKind::Index { base, index } => {
283                // Check if this is actually a slice (Seq indexed by Range)
284                // The parser parses `seq[1..3]` as Index with a Range expression
285                if let ExprKind::Range { lo, hi } = &index.kind {
286                    let base_ty_raw = self.infer_expr(base)?;
287                    let base_ty = self.env.resolve_type(&base_ty_raw);
288                    let lo_ty = self.infer_expr(lo)?;
289                    let hi_ty = self.infer_expr(hi)?;
290
291                    self.expect_numeric(&lo_ty, lo.span)?;
292                    self.expect_numeric(&hi_ty, hi.span)?;
293
294                    match base_ty {
295                        Type::Seq(_) => base_ty, // Slice returns same Seq type
296                        Type::Var(_) => self.var_gen.fresh_type(),
297                        _ => {
298                            return Err(TypeError::NotIndexable {
299                                ty: base_ty,
300                                span: base.span,
301                            });
302                        }
303                    }
304                } else {
305                    let base_ty_raw = self.infer_expr(base)?;
306                    let base_ty = self.env.resolve_type(&base_ty_raw);
307                    let index_ty = self.infer_expr(index)?;
308
309                    match base_ty {
310                        Type::Seq(elem_ty) => {
311                            self.expect_numeric(&index_ty, index.span)?;
312                            *elem_ty
313                        }
314                        Type::Fn(key_ty, value_ty) => {
315                            self.unify(&key_ty, &index_ty, index.span)?;
316                            *value_ty
317                        }
318                        // Accept type variables for polymorphic funcs
319                        Type::Var(_) => self.var_gen.fresh_type(),
320                        _ => {
321                            return Err(TypeError::NotIndexable {
322                                ty: base_ty,
323                                span: base.span,
324                            });
325                        }
326                    }
327                }
328            }
329
330            ExprKind::Slice { base, lo, hi } => {
331                let base_ty_raw = self.infer_expr(base)?;
332                let base_ty = self.env.resolve_type(&base_ty_raw);
333                let lo_ty = self.infer_expr(lo)?;
334                let hi_ty = self.infer_expr(hi)?;
335
336                self.expect_numeric(&lo_ty, lo.span)?;
337                self.expect_numeric(&hi_ty, hi.span)?;
338
339                match base_ty {
340                    Type::Seq(_) => base_ty,
341                    _ => {
342                        return Err(TypeError::NotIndexable {
343                            ty: base_ty,
344                            span: base.span,
345                        });
346                    }
347                }
348            }
349
350            ExprKind::Field { base, field } => {
351                let base_ty_raw = self.infer_expr(base)?;
352                let base_ty = self.env.resolve_type(&base_ty_raw);
353
354                match &base_ty {
355                    Type::Record(rec) => {
356                        if let Some(field_ty) = rec.get_field(&field.name) {
357                            field_ty.clone()
358                        } else {
359                            return Err(TypeError::InvalidField {
360                                ty: base_ty,
361                                field: field.name.clone(),
362                                span: field.span,
363                            });
364                        }
365                    }
366                    // MVP: Allow field access on Int (for TLA+ specs where elements are records
367                    // but type inference defaults to Int). Return Int as the field type.
368                    Type::Int => Type::Int,
369                    // Allow field access on type variables (for TLA+ specs with inferred types).
370                    // Return a fresh type variable for the field type.
371                    Type::Var(_) => self.var_gen.fresh_type(),
372                    _ => {
373                        return Err(TypeError::InvalidField {
374                            ty: base_ty,
375                            field: field.name.clone(),
376                            span: field.span,
377                        });
378                    }
379                }
380            }
381
382            ExprKind::Call { func, args } => {
383                // Check if it's a built-in function or an action call
384                if let ExprKind::Ident(name) = &func.kind {
385                    // Check for action call
386                    if let Some(sig) = self.env.lookup_action(name).cloned() {
387                        if args.len() != sig.params.len() {
388                            return Err(TypeError::ArityMismatch {
389                                expected: sig.params.len(),
390                                found: args.len(),
391                                span: expr.span,
392                            });
393                        }
394
395                        for (arg, (_, param_ty)) in args.iter().zip(sig.params.iter()) {
396                            let arg_ty = self.infer_expr(arg)?;
397                            self.unify(param_ty, &arg_ty, arg.span)?;
398                        }
399
400                        return Ok(Type::Bool); // Actions are predicates
401                    }
402
403                    // Check for user-defined function call
404                    if let Some(func_info) = self.env.lookup_func(name).cloned() {
405                        if args.len() != func_info.param_names.len() {
406                            return Err(TypeError::ArityMismatch {
407                                expected: func_info.param_names.len(),
408                                found: args.len(),
409                                span: expr.span,
410                            });
411                        }
412
413                        // Unify argument types with parameter types
414                        for (arg, param_ty) in args.iter().zip(func_info.param_types.iter()) {
415                            let arg_ty = self.infer_expr(arg)?;
416                            self.unify(param_ty, &arg_ty, arg.span)?;
417                        }
418
419                        // User-defined functions return a fresh type (polymorphic)
420                        return Ok(self.var_gen.fresh_type());
421                    }
422                }
423
424                // Generic function call
425                let func_ty = self.infer_expr(func)?;
426                let arg_types: Vec<Type> = args
427                    .iter()
428                    .map(|a| self.infer_expr(a))
429                    .collect::<TypeResult<_>>()?;
430
431                // For now, just return a fresh type variable
432                // Full inference would need to track function types
433                let _ = (func_ty, arg_types);
434                self.var_gen.fresh_type()
435            }
436
437            ExprKind::SetLit(elements) => {
438                if elements.is_empty() {
439                    // Empty set has a polymorphic element type
440                    Type::Set(Box::new(self.var_gen.fresh_type()))
441                } else {
442                    let elem_ty = self.infer_expr(&elements[0])?;
443                    for elem in elements.iter().skip(1) {
444                        let ty = self.infer_expr(elem)?;
445                        self.unify(&elem_ty, &ty, elem.span)?;
446                    }
447                    Type::Set(Box::new(elem_ty))
448                }
449            }
450
451            ExprKind::SeqLit(elements) => {
452                if elements.is_empty() {
453                    Type::Seq(Box::new(self.var_gen.fresh_type()))
454                } else {
455                    let elem_ty = self.infer_expr(&elements[0])?;
456                    for elem in elements.iter().skip(1) {
457                        let ty = self.infer_expr(elem)?;
458                        self.unify(&elem_ty, &ty, elem.span)?;
459                    }
460                    Type::Seq(Box::new(elem_ty))
461                }
462            }
463
464            ExprKind::TupleLit(elements) => {
465                // Each element can have a different type
466                let elem_types: Vec<Type> = elements
467                    .iter()
468                    .map(|e| self.infer_expr(e))
469                    .collect::<TypeResult<_>>()?;
470                Type::Tuple(elem_types)
471            }
472
473            ExprKind::DictLit(entries) => {
474                if entries.is_empty() {
475                    // Empty dict - use fresh type variables
476                    Type::Fn(
477                        Box::new(self.var_gen.fresh_type()),
478                        Box::new(self.var_gen.fresh_type()),
479                    )
480                } else {
481                    // Infer from first entry, unify with rest
482                    let (first_key, first_val) = &entries[0];
483                    let key_ty = self.infer_expr(first_key)?;
484                    let val_ty = self.infer_expr(first_val)?;
485
486                    for (key, val) in entries.iter().skip(1) {
487                        let k_ty = self.infer_expr(key)?;
488                        let v_ty = self.infer_expr(val)?;
489                        self.unify(&key_ty, &k_ty, key.span)?;
490                        self.unify(&val_ty, &v_ty, val.span)?;
491                    }
492
493                    Type::Fn(Box::new(key_ty), Box::new(val_ty))
494                }
495            }
496
497            ExprKind::FnLit { var, domain, body } => {
498                let domain_ty = self.infer_expr(domain)?;
499                let elem_ty = self.element_type(&domain_ty, domain.span)?;
500
501                self.env.push_scope();
502                self.env.bind_local(var.name.clone(), elem_ty.clone());
503                let body_ty = self.infer_expr(body)?;
504                self.env.pop_scope();
505
506                Type::Fn(Box::new(elem_ty), Box::new(body_ty))
507            }
508
509            ExprKind::SetComprehension {
510                element,
511                var,
512                domain,
513                filter,
514            } => {
515                let domain_ty = self.infer_expr(domain)?;
516                let elem_ty = self.element_type(&domain_ty, domain.span)?;
517
518                self.env.push_scope();
519                self.env.bind_local(var.name.clone(), elem_ty);
520
521                if let Some(f) = filter {
522                    let filter_ty = self.infer_expr(f)?;
523                    self.expect_bool(&filter_ty, f.span)?;
524                }
525
526                let element_ty = self.infer_expr(element)?;
527                self.env.pop_scope();
528
529                Type::Set(Box::new(element_ty))
530            }
531
532            ExprKind::RecordUpdate { base, updates } => {
533                let base_ty_raw = self.infer_expr(base)?;
534                let base_ty = self.env.resolve_type(&base_ty_raw);
535
536                match &base_ty {
537                    Type::Record(rec) => {
538                        for update in updates {
539                            match update {
540                                RecordFieldUpdate::Field { name, value } => {
541                                    let expected = rec.get_field(&name.name).ok_or_else(|| {
542                                        TypeError::InvalidField {
543                                            ty: base_ty.clone(),
544                                            field: name.name.clone(),
545                                            span: name.span,
546                                        }
547                                    })?;
548                                    let value_ty = self.infer_expr(value)?;
549                                    self.unify(expected, &value_ty, value.span)?;
550                                }
551                                RecordFieldUpdate::Dynamic { key, value } => {
552                                    let _ = self.infer_expr(key)?;
553                                    let _ = self.infer_expr(value)?;
554                                }
555                            }
556                        }
557                        base_ty
558                    }
559                    Type::Fn(key_ty, value_ty) => {
560                        // Function update { f with [k]: v }
561                        for update in updates {
562                            match update {
563                                RecordFieldUpdate::Dynamic { key, value } => {
564                                    let key_actual = self.infer_expr(key)?;
565                                    let value_actual = self.infer_expr(value)?;
566                                    self.unify(key_ty, &key_actual, key.span)?;
567                                    self.unify(value_ty, &value_actual, value.span)?;
568                                }
569                                RecordFieldUpdate::Field { name, .. } => {
570                                    return Err(TypeError::InvalidField {
571                                        ty: base_ty.clone(),
572                                        field: name.name.clone(),
573                                        span: name.span,
574                                    });
575                                }
576                            }
577                        }
578                        base_ty
579                    }
580                    _ => {
581                        return Err(TypeError::InvalidField {
582                            ty: base_ty,
583                            field: "<update>".to_string(),
584                            span: expr.span,
585                        });
586                    }
587                }
588            }
589
590            ExprKind::Quantifier {
591                kind: _,
592                bindings,
593                body,
594            } => {
595                self.env.push_scope();
596
597                for binding in bindings {
598                    let domain_ty = self.infer_expr(&binding.domain)?;
599                    let elem_ty = self.element_type(&domain_ty, binding.domain.span)?;
600                    self.env.bind_local(binding.var.name.clone(), elem_ty);
601                }
602
603                let body_ty = self.infer_expr(body)?;
604                self.expect_bool(&body_ty, body.span)?;
605
606                self.env.pop_scope();
607                Type::Bool
608            }
609
610            ExprKind::Choose {
611                var,
612                domain,
613                predicate,
614            } => {
615                let domain_ty = self.infer_expr(domain)?;
616                let elem_ty = self.element_type(&domain_ty, domain.span)?;
617
618                self.env.push_scope();
619                self.env.bind_local(var.name.clone(), elem_ty.clone());
620
621                let pred_ty = self.infer_expr(predicate)?;
622                self.expect_bool(&pred_ty, predicate.span)?;
623
624                self.env.pop_scope();
625                elem_ty
626            }
627
628            ExprKind::Fix { var, predicate } => {
629                // Fix expression: fix x: P - returns any value x satisfying P
630                // We can't determine the type statically without more context
631                // For now, return Int as a default
632                self.env.push_scope();
633                self.env.bind_local(var.name.clone(), Type::Int);
634
635                let pred_ty = self.infer_expr(predicate)?;
636                self.expect_bool(&pred_ty, predicate.span)?;
637
638                self.env.pop_scope();
639                Type::Int
640            }
641
642            ExprKind::Let { var, value, body } => {
643                let value_ty = self.infer_expr(value)?;
644
645                self.env.push_scope();
646                self.env.bind_local(var.name.clone(), value_ty);
647
648                let body_ty = self.infer_expr(body)?;
649                self.env.pop_scope();
650
651                body_ty
652            }
653
654            ExprKind::If {
655                cond,
656                then_branch,
657                else_branch,
658            } => {
659                let cond_ty = self.infer_expr(cond)?;
660                self.expect_bool(&cond_ty, cond.span)?;
661
662                let then_ty = self.infer_expr(then_branch)?;
663                let else_ty = self.infer_expr(else_branch)?;
664                self.unify(&then_ty, &else_ty, else_branch.span)?;
665
666                then_ty
667            }
668
669            ExprKind::Require(inner) => {
670                let inner_ty = self.infer_expr(inner)?;
671                self.expect_bool(&inner_ty, inner.span)?;
672                Type::Bool
673            }
674
675            ExprKind::Changes(_var) => {
676                // changes(v) is a boolean predicate
677                Type::Bool
678            }
679
680            ExprKind::Enabled(action) => {
681                // Check that action exists
682                if self.env.lookup_action(&action.name).is_none() {
683                    return Err(TypeError::UndefinedAction {
684                        name: action.name.clone(),
685                        span: action.span,
686                    });
687                }
688                Type::Bool
689            }
690
691            ExprKind::SeqHead(seq_expr) => {
692                let seq_ty_raw = self.infer_expr(seq_expr)?;
693                let seq_ty = self.env.resolve_type(&seq_ty_raw);
694                match seq_ty {
695                    Type::Seq(elem_ty) => *elem_ty,
696                    _ => {
697                        return Err(TypeError::TypeMismatch {
698                            expected: Type::Seq(Box::new(self.var_gen.fresh_type())),
699                            found: seq_ty,
700                            span: seq_expr.span,
701                        });
702                    }
703                }
704            }
705
706            ExprKind::SeqTail(seq_expr) => {
707                let seq_ty_raw = self.infer_expr(seq_expr)?;
708                let seq_ty = self.env.resolve_type(&seq_ty_raw);
709                match &seq_ty {
710                    Type::Seq(_) => seq_ty,
711                    _ => {
712                        return Err(TypeError::TypeMismatch {
713                            expected: Type::Seq(Box::new(self.var_gen.fresh_type())),
714                            found: seq_ty,
715                            span: seq_expr.span,
716                        });
717                    }
718                }
719            }
720
721            ExprKind::Len(expr) => {
722                let ty_raw = self.infer_expr(expr)?;
723                let ty = self.env.resolve_type(&ty_raw);
724                match ty {
725                    // Accept Seq, Set, Fn, or type variables (for polymorphic funcs)
726                    Type::Seq(_) | Type::Set(_) | Type::Fn(_, _) | Type::Var(_) => Type::Nat,
727                    _ => {
728                        return Err(TypeError::TypeMismatch {
729                            expected: Type::Seq(Box::new(self.var_gen.fresh_type())),
730                            found: ty,
731                            span: expr.span,
732                        });
733                    }
734                }
735            }
736
737            ExprKind::Keys(expr) => {
738                let ty_raw = self.infer_expr(expr)?;
739                let ty = self.env.resolve_type(&ty_raw);
740                match ty {
741                    Type::Fn(key_ty, _) => Type::Set(key_ty),
742                    // For sequences, keys returns 1..len(seq), which is Set[Int]
743                    Type::Seq(_) => Type::Set(Box::new(Type::Int)),
744                    _ => {
745                        return Err(TypeError::TypeMismatch {
746                            expected: Type::Fn(
747                                Box::new(self.var_gen.fresh_type()),
748                                Box::new(self.var_gen.fresh_type()),
749                            ),
750                            found: ty,
751                            span: expr.span,
752                        });
753                    }
754                }
755            }
756
757            ExprKind::Values(expr) => {
758                let ty_raw = self.infer_expr(expr)?;
759                let ty = self.env.resolve_type(&ty_raw);
760                match ty {
761                    Type::Fn(_, val_ty) => Type::Set(val_ty),
762                    _ => {
763                        return Err(TypeError::TypeMismatch {
764                            expected: Type::Fn(
765                                Box::new(self.var_gen.fresh_type()),
766                                Box::new(self.var_gen.fresh_type()),
767                            ),
768                            found: ty,
769                            span: expr.span,
770                        });
771                    }
772                }
773            }
774
775            ExprKind::BigUnion(expr) => {
776                let ty_raw = self.infer_expr(expr)?;
777                let ty = self.env.resolve_type(&ty_raw);
778                // BigUnion takes Set[Set[T]] and returns Set[T]
779                match ty {
780                    Type::Set(inner) => match *inner {
781                        Type::Set(elem_ty) => Type::Set(elem_ty),
782                        _ => {
783                            return Err(TypeError::TypeMismatch {
784                                expected: Type::Set(Box::new(self.var_gen.fresh_type())),
785                                found: *inner,
786                                span: expr.span,
787                            });
788                        }
789                    },
790                    _ => {
791                        return Err(TypeError::TypeMismatch {
792                            expected: Type::Set(Box::new(Type::Set(Box::new(
793                                self.var_gen.fresh_type(),
794                            )))),
795                            found: ty,
796                            span: expr.span,
797                        });
798                    }
799                }
800            }
801
802            ExprKind::Powerset(expr) => {
803                let ty_raw = self.infer_expr(expr)?;
804                let ty = self.env.resolve_type(&ty_raw);
805                // Powerset takes Set[T] and returns Set[Set[T]]
806                match ty {
807                    Type::Set(elem_ty) => Type::Set(Box::new(Type::Set(elem_ty))),
808                    _ => {
809                        return Err(TypeError::TypeMismatch {
810                            expected: Type::Set(Box::new(self.var_gen.fresh_type())),
811                            found: ty,
812                            span: expr.span,
813                        });
814                    }
815                }
816            }
817
818            ExprKind::Always(inner) | ExprKind::Eventually(inner) => {
819                let inner_ty = self.infer_expr(inner)?;
820                self.expect_bool(&inner_ty, inner.span)?;
821                Type::Bool
822            }
823
824            ExprKind::LeadsTo { left, right } => {
825                let left_ty = self.infer_expr(left)?;
826                let right_ty = self.infer_expr(right)?;
827                self.expect_bool(&left_ty, left.span)?;
828                self.expect_bool(&right_ty, right.span)?;
829                Type::Bool
830            }
831
832            ExprKind::Range { lo, hi } => {
833                let lo_ty = self.infer_expr(lo)?;
834                let hi_ty = self.infer_expr(hi)?;
835                self.expect_numeric(&lo_ty, lo.span)?;
836                self.expect_numeric(&hi_ty, hi.span)?;
837                // A range expression is a set of integers
838                Type::Set(Box::new(Type::Int))
839            }
840
841            ExprKind::Paren(inner) => self.infer_expr(inner)?,
842        };
843
844        Ok(ty)
845    }
846
847    /// Infer type of binary operation.
848    fn infer_binary(&mut self, op: BinOp, left: &Expr, right: &Expr) -> TypeResult<Type> {
849        let left_ty = self.infer_expr(left)?;
850        let right_ty = self.infer_expr(right)?;
851
852        match op {
853            // Logical operators
854            BinOp::And | BinOp::Or | BinOp::Implies | BinOp::Iff => {
855                self.expect_bool(&left_ty, left.span)?;
856                self.expect_bool(&right_ty, right.span)?;
857                Ok(Type::Bool)
858            }
859
860            // Comparison operators
861            BinOp::Eq | BinOp::Ne => {
862                self.unify(&left_ty, &right_ty, right.span)?;
863                Ok(Type::Bool)
864            }
865
866            BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge => {
867                self.expect_numeric(&left_ty, left.span)?;
868                self.expect_numeric(&right_ty, right.span)?;
869                Ok(Type::Bool)
870            }
871
872            // Arithmetic operators
873            BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Mod => {
874                self.expect_numeric(&left_ty, left.span)?;
875                self.expect_numeric(&right_ty, right.span)?;
876                // Return the more general numeric type
877                if matches!(left_ty, Type::Int) || matches!(right_ty, Type::Int) {
878                    Ok(Type::Int)
879                } else {
880                    Ok(Type::Nat)
881                }
882            }
883
884            // Set membership
885            BinOp::In | BinOp::NotIn => {
886                let resolved = self.env.resolve_type(&right_ty);
887                match resolved {
888                    Type::Set(elem_ty) => {
889                        self.unify(&left_ty, &elem_ty, left.span)?;
890                        Ok(Type::Bool)
891                    }
892                    _ => Err(TypeError::NotIterable {
893                        ty: right_ty,
894                        span: right.span,
895                    }),
896                }
897            }
898
899            // Set and dict operations
900            BinOp::Union | BinOp::Intersect | BinOp::Diff => {
901                let resolved_left = self.env.resolve_type(&left_ty);
902                let resolved_right = self.env.resolve_type(&right_ty);
903
904                match (&resolved_left, &resolved_right) {
905                    (Type::Set(_), Type::Set(_)) => {
906                        self.unify(&left_ty, &right_ty, right.span)?;
907                        Ok(left_ty)
908                    }
909                    (Type::Fn(_, _), Type::Fn(_, _)) => {
910                        // Dict merge with | operator
911                        self.unify(&left_ty, &right_ty, right.span)?;
912                        Ok(left_ty)
913                    }
914                    _ => Err(TypeError::TypeMismatch {
915                        expected: Type::Set(Box::new(self.var_gen.fresh_type())),
916                        found: left_ty,
917                        span: left.span,
918                    }),
919                }
920            }
921
922            BinOp::SubsetOf => {
923                let resolved_left = self.env.resolve_type(&left_ty);
924                let resolved_right = self.env.resolve_type(&right_ty);
925
926                match (&resolved_left, &resolved_right) {
927                    (Type::Set(_), Type::Set(_)) => {
928                        self.unify(&left_ty, &right_ty, right.span)?;
929                        Ok(Type::Bool)
930                    }
931                    _ => Err(TypeError::TypeMismatch {
932                        expected: Type::Set(Box::new(self.var_gen.fresh_type())),
933                        found: left_ty,
934                        span: left.span,
935                    }),
936                }
937            }
938
939            // Sequence concatenation
940            BinOp::Concat => {
941                let resolved_left = self.env.resolve_type(&left_ty);
942                let resolved_right = self.env.resolve_type(&right_ty);
943
944                match (&resolved_left, &resolved_right) {
945                    (Type::Seq(_), Type::Seq(_)) => {
946                        self.unify(&left_ty, &right_ty, right.span)?;
947                        Ok(left_ty)
948                    }
949                    _ => Err(TypeError::TypeMismatch {
950                        expected: Type::Seq(Box::new(self.var_gen.fresh_type())),
951                        found: left_ty,
952                        span: left.span,
953                    }),
954                }
955            }
956        }
957    }
958
959    /// Infer type of unary operation.
960    fn infer_unary(&mut self, op: UnaryOp, operand: &Expr) -> TypeResult<Type> {
961        let operand_ty = self.infer_expr(operand)?;
962
963        match op {
964            UnaryOp::Not => {
965                self.expect_bool(&operand_ty, operand.span)?;
966                Ok(Type::Bool)
967            }
968            UnaryOp::Neg => {
969                self.expect_numeric(&operand_ty, operand.span)?;
970                Ok(Type::Int) // Negation produces Int even from Nat
971            }
972        }
973    }
974
975    /// Get the element type of a collection or set.
976    fn element_type(&mut self, ty: &Type, span: Span) -> TypeResult<Type> {
977        let resolved = self.env.resolve_type(ty);
978        match resolved {
979            Type::Set(elem) | Type::Seq(elem) => Ok(*elem),
980            Type::Fn(key, _) => Ok(*key), // Iterating over function yields keys
981            Type::Range(_, _) => Ok(Type::Int),
982            Type::Var(_) => {
983                // Type variable used as iterable - assume it's a collection
984                // and return a fresh element type for polymorphism
985                Ok(self.var_gen.fresh_type())
986            }
987            _ => Err(TypeError::NotIterable { ty: resolved, span }),
988        }
989    }
990
991    /// Unify two types.
992    fn unify(&mut self, a: &Type, b: &Type, span: Span) -> TypeResult<Substitution> {
993        let a = self.env.resolve_type(a);
994        let b = self.env.resolve_type(b);
995
996        if a == b {
997            return Ok(Substitution::new());
998        }
999
1000        match (&a, &b) {
1001            // Type variables unify with anything
1002            (Type::Var(v), _) => {
1003                if b.has_vars() && matches!(b, Type::Var(v2) if *v == v2) {
1004                    // Same variable
1005                    Ok(Substitution::new())
1006                } else if self.occurs_in(v, &b) {
1007                    Err(TypeError::OccursCheck { span })
1008                } else {
1009                    let mut subst = Substitution::new();
1010                    subst.insert(*v, b.clone());
1011                    Ok(subst)
1012                }
1013            }
1014            (_, Type::Var(v)) => {
1015                if self.occurs_in(v, &a) {
1016                    Err(TypeError::OccursCheck { span })
1017                } else {
1018                    let mut subst = Substitution::new();
1019                    subst.insert(*v, a.clone());
1020                    Ok(subst)
1021                }
1022            }
1023
1024            // Structural unification
1025            (Type::Set(a_elem), Type::Set(b_elem)) => self.unify(a_elem, b_elem, span),
1026            (Type::Seq(a_elem), Type::Seq(b_elem)) => self.unify(a_elem, b_elem, span),
1027            (Type::Option(a_elem), Type::Option(b_elem)) => self.unify(a_elem, b_elem, span),
1028            (Type::Fn(a_key, a_val), Type::Fn(b_key, b_val)) => {
1029                let s1 = self.unify(a_key, b_key, span)?;
1030                let a_val = a_val.substitute(&s1);
1031                let b_val = b_val.substitute(&s1);
1032                let s2 = self.unify(&a_val, &b_val, span)?;
1033                Ok(s1.compose(&s2))
1034            }
1035            (Type::Record(a_rec), Type::Record(b_rec)) => {
1036                if a_rec.fields.keys().collect::<Vec<_>>()
1037                    != b_rec.fields.keys().collect::<Vec<_>>()
1038                {
1039                    return Err(TypeError::UnificationFailure {
1040                        a: a.clone(),
1041                        b: b.clone(),
1042                        span,
1043                    });
1044                }
1045
1046                let mut subst = Substitution::new();
1047                for (name, a_ty) in &a_rec.fields {
1048                    let b_ty = b_rec.fields.get(name).unwrap();
1049                    let a_ty = a_ty.substitute(&subst);
1050                    let b_ty = b_ty.substitute(&subst);
1051                    let s = self.unify(&a_ty, &b_ty, span)?;
1052                    subst = subst.compose(&s);
1053                }
1054                Ok(subst)
1055            }
1056            (Type::Tuple(a_elems), Type::Tuple(b_elems)) => {
1057                if a_elems.len() != b_elems.len() {
1058                    return Err(TypeError::UnificationFailure {
1059                        a: a.clone(),
1060                        b: b.clone(),
1061                        span,
1062                    });
1063                }
1064
1065                let mut subst = Substitution::new();
1066                for (a_ty, b_ty) in a_elems.iter().zip(b_elems.iter()) {
1067                    let a_ty = a_ty.substitute(&subst);
1068                    let b_ty = b_ty.substitute(&subst);
1069                    let s = self.unify(&a_ty, &b_ty, span)?;
1070                    subst = subst.compose(&s);
1071                }
1072                Ok(subst)
1073            }
1074
1075            // Numeric types: Int subsumes Nat and Range
1076            (Type::Int, Type::Nat) | (Type::Nat, Type::Int) => Ok(Substitution::new()),
1077            (Type::Int, Type::Range(_, _)) | (Type::Range(_, _), Type::Int) => {
1078                Ok(Substitution::new())
1079            }
1080            (Type::Nat, Type::Range(lo, _)) | (Type::Range(lo, _), Type::Nat) if *lo >= 0 => {
1081                Ok(Substitution::new())
1082            }
1083            (Type::Range(a_lo, a_hi), Type::Range(b_lo, b_hi)) if a_lo == b_lo && a_hi == b_hi => {
1084                Ok(Substitution::new())
1085            }
1086
1087            // Error type unifies with anything (for error recovery)
1088            (Type::Error, _) | (_, Type::Error) => Ok(Substitution::new()),
1089
1090            // MVP: Allow Int and Bool to unify (for TLA+ specs that mix Nil with true/false)
1091            (Type::Int, Type::Bool) | (Type::Bool, Type::Int) => Ok(Substitution::new()),
1092
1093            _ => Err(TypeError::UnificationFailure {
1094                a: a.clone(),
1095                b: b.clone(),
1096                span,
1097            }),
1098        }
1099    }
1100
1101    /// Check if a type variable occurs in a type.
1102    fn occurs_in(&self, var: &crate::types::TypeVar, ty: &Type) -> bool {
1103        Self::occurs_in_impl(var, ty)
1104    }
1105
1106    fn occurs_in_impl(var: &crate::types::TypeVar, ty: &Type) -> bool {
1107        match ty {
1108            Type::Var(v) => v == var,
1109            Type::Set(t) | Type::Seq(t) | Type::Option(t) => Self::occurs_in_impl(var, t),
1110            Type::Fn(k, v) => Self::occurs_in_impl(var, k) || Self::occurs_in_impl(var, v),
1111            Type::Record(r) => r.fields.values().any(|t| Self::occurs_in_impl(var, t)),
1112            Type::Tuple(elems) => elems.iter().any(|t| Self::occurs_in_impl(var, t)),
1113            _ => false,
1114        }
1115    }
1116
1117    /// Expect a boolean type.
1118    fn expect_bool(&self, ty: &Type, span: Span) -> TypeResult<()> {
1119        let resolved = self.env.resolve_type(ty);
1120        match resolved {
1121            Type::Bool | Type::Var(_) | Type::Error => Ok(()),
1122            _ => Err(TypeError::ExpectedBool {
1123                found: resolved,
1124                span,
1125            }),
1126        }
1127    }
1128
1129    /// Expect a numeric type.
1130    fn expect_numeric(&self, ty: &Type, span: Span) -> TypeResult<()> {
1131        let resolved = self.env.resolve_type(ty);
1132        match resolved {
1133            Type::Nat | Type::Int | Type::Range(_, _) | Type::Var(_) | Type::Error => Ok(()),
1134            _ => Err(TypeError::ExpectedNumeric {
1135                found: resolved,
1136                span,
1137            }),
1138        }
1139    }
1140}
1141
1142impl Default for TypeChecker {
1143    fn default() -> Self {
1144        Self::new()
1145    }
1146}
1147
1148#[cfg(test)]
1149mod tests {
1150    use super::*;
1151    use specl_syntax::parse;
1152
1153    fn check(source: &str) -> TypeResult<TypeEnv> {
1154        let module = parse(source).expect("parse error");
1155        check_module(&module)
1156    }
1157
1158    #[test]
1159    fn test_simple_types() {
1160        let source = r#"
1161module Test
1162var x: Nat
1163var y: Bool
1164init { x == 0 and y == true }
1165"#;
1166        assert!(check(source).is_ok());
1167    }
1168
1169    #[test]
1170    fn test_type_mismatch() {
1171        let source = r#"
1172module Test
1173var x: Nat
1174init { x == true }
1175"#;
1176        let result = check(source);
1177        assert!(result.is_err());
1178    }
1179
1180    #[test]
1181    fn test_undefined_variable() {
1182        let source = r#"
1183module Test
1184init { undefined_var == 0 }
1185"#;
1186        let result = check(source);
1187        assert!(matches!(result, Err(TypeError::UndefinedVariable { .. })));
1188    }
1189
1190    #[test]
1191    fn test_set_operations() {
1192        let source = r#"
1193module Test
1194var s: Set[Nat]
1195init { 1 in s and s union {} == s }
1196"#;
1197        assert!(check(source).is_ok());
1198    }
1199
1200    #[test]
1201    fn test_quantifier() {
1202        let source = r#"
1203module Test
1204var s: Set[Nat]
1205init { all x in s: x >= 0 }
1206"#;
1207        assert!(check(source).is_ok());
1208    }
1209
1210    #[test]
1211    fn test_action_type_check() {
1212        let source = r#"
1213module Test
1214var x: Nat
1215const MAX: Nat
1216action Inc() {
1217    require x < MAX
1218    x = x + 1
1219}
1220"#;
1221        assert!(check(source).is_ok());
1222    }
1223
1224    #[test]
1225    fn test_dict_access() {
1226        let source = r#"
1227module Test
1228var d: Dict[Nat, Bool]
1229init { d[0] == false and d[1] == true }
1230"#;
1231        assert!(check(source).is_ok());
1232    }
1233
1234    #[test]
1235    fn test_dict_comprehension() {
1236        let source = r#"
1237module Test
1238var d: Dict[0..2, Nat]
1239init { d == {x: 0 for x in 0..2} }
1240"#;
1241        assert!(check(source).is_ok());
1242    }
1243}