1use crate::env::{ActionSig, TypeEnv};
4use crate::error::{TypeError, TypeResult};
5use crate::types::{Substitution, Type, TypeVarGen};
6use specl_syntax::*;
7
8pub fn check_module(module: &Module) -> TypeResult<TypeEnv> {
10 let mut checker = TypeChecker::new();
11 checker.check_module(module)?;
12 Ok(checker.env)
13}
14
15pub struct TypeChecker {
17 pub env: TypeEnv,
19 var_gen: TypeVarGen,
21}
22
23impl TypeChecker {
24 pub fn new() -> Self {
26 Self {
27 env: TypeEnv::new(),
28 var_gen: TypeVarGen::new(),
29 }
30 }
31
32 pub fn check_module(&mut self, module: &Module) -> TypeResult<()> {
34 for decl in &module.decls {
36 self.collect_decl(decl)?;
37 }
38
39 for decl in &module.decls {
41 self.check_decl(decl)?;
42 }
43
44 Ok(())
45 }
46
47 fn collect_decl(&mut self, decl: &Decl) -> TypeResult<()> {
49 match decl {
50 Decl::Const(d) => {
51 let ty = match &d.value {
52 specl_syntax::ConstValue::Type(type_expr) => {
53 self.convert_type_expr(type_expr)?
54 }
55 specl_syntax::ConstValue::Scalar(n) => {
56 if *n >= 0 {
58 Type::Nat
59 } else {
60 Type::Int
61 }
62 }
63 };
64 self.env.define_const(d.name.name.clone(), ty);
65 }
66 Decl::Var(d) => {
67 let ty = self.convert_type_expr(&d.ty)?;
68 self.env.define_var(d.name.name.clone(), ty);
69 }
70 Decl::Type(d) => {
71 let ty = self.convert_type_expr(&d.ty)?;
72 self.env.define_type_alias(d.name.name.clone(), ty);
73 }
74 Decl::Action(d) => {
75 let params: Vec<(String, Type)> = d
76 .params
77 .iter()
78 .map(|p| {
79 let ty = self.convert_type_expr(&p.ty)?;
80 Ok((p.name.name.clone(), ty))
81 })
82 .collect::<TypeResult<_>>()?;
83 self.env
84 .define_action(d.name.name.clone(), ActionSig { params });
85 }
86 Decl::Func(d) => {
87 let param_types: Vec<Type> =
89 d.params.iter().map(|_| self.var_gen.fresh_type()).collect();
90 self.env.define_func(
91 d.name.name.clone(),
92 d.params.iter().map(|p| p.name.name.clone()).collect(),
93 param_types,
94 );
95 }
96 _ => {}
97 }
98 Ok(())
99 }
100
101 fn check_decl(&mut self, decl: &Decl) -> TypeResult<()> {
103 match decl {
104 Decl::Init(d) => {
105 let ty = self.infer_expr(&d.body)?;
106 self.expect_bool(&ty, d.body.span)?;
107 }
108 Decl::Action(d) => {
109 self.env.push_scope();
110
111 for param in &d.params {
113 let ty = self.convert_type_expr(¶m.ty)?;
114 self.env.bind_local(param.name.name.clone(), ty);
115 }
116
117 for req in &d.body.requires {
119 let ty = self.infer_expr(req)?;
120 self.expect_bool(&ty, req.span)?;
121 }
122
123 if let Some(effect) = &d.body.effect {
125 let ty = self.infer_expr(effect)?;
126 self.expect_bool(&ty, effect.span)?;
127 }
128
129 self.env.pop_scope();
130 }
131 Decl::Invariant(d) => {
132 let ty = self.infer_expr(&d.body)?;
133 self.expect_bool(&ty, d.body.span)?;
134 }
135 Decl::Property(d) => {
136 let ty = self.infer_expr(&d.body)?;
137 self.expect_bool(&ty, d.body.span)?;
138 }
139 Decl::Func(d) => {
140 self.env.push_scope();
141 let param_types = self
143 .env
144 .lookup_func(&d.name.name)
145 .map(|f| f.param_types.clone())
146 .unwrap_or_default();
147 for (param, ty) in d.params.iter().zip(param_types.iter()) {
148 self.env.bind_local(param.name.name.clone(), ty.clone());
149 }
150 let _body_ty = self.infer_expr(&d.body)?;
152 self.env.pop_scope();
153 }
154 _ => {}
155 }
156 Ok(())
157 }
158
159 fn convert_type_expr(&mut self, ty_expr: &TypeExpr) -> TypeResult<Type> {
161 match ty_expr {
162 TypeExpr::Named(id) => {
163 let name = &id.name;
164 match name.as_str() {
165 "Bool" => Ok(Type::Bool),
166 "Nat" => Ok(Type::Nat),
167 "Int" => Ok(Type::Int),
168 "String" => Ok(Type::String),
169 "_" | "Inferred" => Ok(self.var_gen.fresh_type()),
171 _ => {
172 if self.env.lookup_type_alias(name).is_some() {
174 Ok(Type::Named(name.clone()))
175 } else {
176 Err(TypeError::UndefinedType {
177 name: name.clone(),
178 span: id.span,
179 })
180 }
181 }
182 }
183 }
184 TypeExpr::Set(inner, _) => {
185 let inner_ty = self.convert_type_expr(inner)?;
186 Ok(Type::Set(Box::new(inner_ty)))
187 }
188 TypeExpr::Seq(inner, _) => {
189 let inner_ty = self.convert_type_expr(inner)?;
190 Ok(Type::Seq(Box::new(inner_ty)))
191 }
192 TypeExpr::Dict(key, value, _) => {
193 let key_ty = self.convert_type_expr(key)?;
194 let value_ty = self.convert_type_expr(value)?;
195 Ok(Type::Fn(Box::new(key_ty), Box::new(value_ty)))
196 }
197 TypeExpr::Option(inner, _) => {
198 let inner_ty = self.convert_type_expr(inner)?;
199 Ok(Type::Option(Box::new(inner_ty)))
200 }
201 TypeExpr::Range(lo, hi, _span) => {
202 let lo_ty = self.infer_expr(lo)?;
204 let hi_ty = self.infer_expr(hi)?;
205
206 if !lo_ty.is_numeric() {
207 return Err(TypeError::ExpectedNumeric {
208 found: lo_ty,
209 span: lo.span,
210 });
211 }
212 if !hi_ty.is_numeric() {
213 return Err(TypeError::ExpectedNumeric {
214 found: hi_ty,
215 span: hi.span,
216 });
217 }
218
219 let lo_val = self.extract_int_literal(lo);
221 let hi_val = self.extract_int_literal(hi);
222
223 match (lo_val, hi_val) {
224 (Some(lo), Some(hi)) => Ok(Type::Range(lo, hi)),
225 _ => Ok(Type::Int), }
227 }
228 TypeExpr::Tuple(elems, _) => {
229 let elem_types: Vec<Type> = elems
230 .iter()
231 .map(|ty| self.convert_type_expr(ty))
232 .collect::<TypeResult<_>>()?;
233 Ok(Type::Tuple(elem_types))
234 }
235 }
236 }
237
238 fn extract_int_literal(&self, expr: &Expr) -> Option<i64> {
240 match &expr.kind {
241 ExprKind::Int(n) => Some(*n),
242 ExprKind::Ident(_name) => {
243 None
246 }
247 _ => None,
248 }
249 }
250
251 pub fn infer_expr(&mut self, expr: &Expr) -> TypeResult<Type> {
253 let ty = match &expr.kind {
254 ExprKind::Bool(_) => Type::Bool,
255 ExprKind::Int(_) => Type::Int,
256 ExprKind::String(_) => Type::String,
257
258 ExprKind::Ident(name) => {
259 if let Some(ty) = self.env.lookup_ident(name) {
260 ty.clone()
261 } else {
262 return Err(TypeError::UndefinedVariable {
263 name: name.clone(),
264 span: expr.span,
265 });
266 }
267 }
268
269 ExprKind::Primed(name) => {
270 if let Some(ty) = self.env.lookup_var(name) {
272 ty.clone()
273 } else {
274 return Err(TypeError::InvalidPrime { span: expr.span });
275 }
276 }
277
278 ExprKind::Binary { op, left, right } => self.infer_binary(*op, left, right)?,
279
280 ExprKind::Unary { op, operand } => self.infer_unary(*op, operand)?,
281
282 ExprKind::Index { base, index } => {
283 if let ExprKind::Range { lo, hi } = &index.kind {
286 let base_ty_raw = self.infer_expr(base)?;
287 let base_ty = self.env.resolve_type(&base_ty_raw);
288 let lo_ty = self.infer_expr(lo)?;
289 let hi_ty = self.infer_expr(hi)?;
290
291 self.expect_numeric(&lo_ty, lo.span)?;
292 self.expect_numeric(&hi_ty, hi.span)?;
293
294 match base_ty {
295 Type::Seq(_) => base_ty, Type::Var(_) => self.var_gen.fresh_type(),
297 _ => {
298 return Err(TypeError::NotIndexable {
299 ty: base_ty,
300 span: base.span,
301 });
302 }
303 }
304 } else {
305 let base_ty_raw = self.infer_expr(base)?;
306 let base_ty = self.env.resolve_type(&base_ty_raw);
307 let index_ty = self.infer_expr(index)?;
308
309 match base_ty {
310 Type::Seq(elem_ty) => {
311 self.expect_numeric(&index_ty, index.span)?;
312 *elem_ty
313 }
314 Type::Fn(key_ty, value_ty) => {
315 self.unify(&key_ty, &index_ty, index.span)?;
316 *value_ty
317 }
318 Type::Var(_) => self.var_gen.fresh_type(),
320 _ => {
321 return Err(TypeError::NotIndexable {
322 ty: base_ty,
323 span: base.span,
324 });
325 }
326 }
327 }
328 }
329
330 ExprKind::Slice { base, lo, hi } => {
331 let base_ty_raw = self.infer_expr(base)?;
332 let base_ty = self.env.resolve_type(&base_ty_raw);
333 let lo_ty = self.infer_expr(lo)?;
334 let hi_ty = self.infer_expr(hi)?;
335
336 self.expect_numeric(&lo_ty, lo.span)?;
337 self.expect_numeric(&hi_ty, hi.span)?;
338
339 match base_ty {
340 Type::Seq(_) => base_ty,
341 _ => {
342 return Err(TypeError::NotIndexable {
343 ty: base_ty,
344 span: base.span,
345 });
346 }
347 }
348 }
349
350 ExprKind::Field { base, field } => {
351 let base_ty_raw = self.infer_expr(base)?;
352 let base_ty = self.env.resolve_type(&base_ty_raw);
353
354 match &base_ty {
355 Type::Record(rec) => {
356 if let Some(field_ty) = rec.get_field(&field.name) {
357 field_ty.clone()
358 } else {
359 return Err(TypeError::InvalidField {
360 ty: base_ty,
361 field: field.name.clone(),
362 span: field.span,
363 });
364 }
365 }
366 Type::Int => Type::Int,
369 Type::Var(_) => self.var_gen.fresh_type(),
372 _ => {
373 return Err(TypeError::InvalidField {
374 ty: base_ty,
375 field: field.name.clone(),
376 span: field.span,
377 });
378 }
379 }
380 }
381
382 ExprKind::Call { func, args } => {
383 if let ExprKind::Ident(name) = &func.kind {
385 if let Some(sig) = self.env.lookup_action(name).cloned() {
387 if args.len() != sig.params.len() {
388 return Err(TypeError::ArityMismatch {
389 expected: sig.params.len(),
390 found: args.len(),
391 span: expr.span,
392 });
393 }
394
395 for (arg, (_, param_ty)) in args.iter().zip(sig.params.iter()) {
396 let arg_ty = self.infer_expr(arg)?;
397 self.unify(param_ty, &arg_ty, arg.span)?;
398 }
399
400 return Ok(Type::Bool); }
402
403 if let Some(func_info) = self.env.lookup_func(name).cloned() {
405 if args.len() != func_info.param_names.len() {
406 return Err(TypeError::ArityMismatch {
407 expected: func_info.param_names.len(),
408 found: args.len(),
409 span: expr.span,
410 });
411 }
412
413 for (arg, param_ty) in args.iter().zip(func_info.param_types.iter()) {
415 let arg_ty = self.infer_expr(arg)?;
416 self.unify(param_ty, &arg_ty, arg.span)?;
417 }
418
419 return Ok(self.var_gen.fresh_type());
421 }
422 }
423
424 let func_ty = self.infer_expr(func)?;
426 let arg_types: Vec<Type> = args
427 .iter()
428 .map(|a| self.infer_expr(a))
429 .collect::<TypeResult<_>>()?;
430
431 let _ = (func_ty, arg_types);
434 self.var_gen.fresh_type()
435 }
436
437 ExprKind::SetLit(elements) => {
438 if elements.is_empty() {
439 Type::Set(Box::new(self.var_gen.fresh_type()))
441 } else {
442 let elem_ty = self.infer_expr(&elements[0])?;
443 for elem in elements.iter().skip(1) {
444 let ty = self.infer_expr(elem)?;
445 self.unify(&elem_ty, &ty, elem.span)?;
446 }
447 Type::Set(Box::new(elem_ty))
448 }
449 }
450
451 ExprKind::SeqLit(elements) => {
452 if elements.is_empty() {
453 Type::Seq(Box::new(self.var_gen.fresh_type()))
454 } else {
455 let elem_ty = self.infer_expr(&elements[0])?;
456 for elem in elements.iter().skip(1) {
457 let ty = self.infer_expr(elem)?;
458 self.unify(&elem_ty, &ty, elem.span)?;
459 }
460 Type::Seq(Box::new(elem_ty))
461 }
462 }
463
464 ExprKind::TupleLit(elements) => {
465 let elem_types: Vec<Type> = elements
467 .iter()
468 .map(|e| self.infer_expr(e))
469 .collect::<TypeResult<_>>()?;
470 Type::Tuple(elem_types)
471 }
472
473 ExprKind::DictLit(entries) => {
474 if entries.is_empty() {
475 Type::Fn(
477 Box::new(self.var_gen.fresh_type()),
478 Box::new(self.var_gen.fresh_type()),
479 )
480 } else {
481 let (first_key, first_val) = &entries[0];
483 let key_ty = self.infer_expr(first_key)?;
484 let val_ty = self.infer_expr(first_val)?;
485
486 for (key, val) in entries.iter().skip(1) {
487 let k_ty = self.infer_expr(key)?;
488 let v_ty = self.infer_expr(val)?;
489 self.unify(&key_ty, &k_ty, key.span)?;
490 self.unify(&val_ty, &v_ty, val.span)?;
491 }
492
493 Type::Fn(Box::new(key_ty), Box::new(val_ty))
494 }
495 }
496
497 ExprKind::FnLit { var, domain, body } => {
498 let domain_ty = self.infer_expr(domain)?;
499 let elem_ty = self.element_type(&domain_ty, domain.span)?;
500
501 self.env.push_scope();
502 self.env.bind_local(var.name.clone(), elem_ty.clone());
503 let body_ty = self.infer_expr(body)?;
504 self.env.pop_scope();
505
506 Type::Fn(Box::new(elem_ty), Box::new(body_ty))
507 }
508
509 ExprKind::SetComprehension {
510 element,
511 var,
512 domain,
513 filter,
514 } => {
515 let domain_ty = self.infer_expr(domain)?;
516 let elem_ty = self.element_type(&domain_ty, domain.span)?;
517
518 self.env.push_scope();
519 self.env.bind_local(var.name.clone(), elem_ty);
520
521 if let Some(f) = filter {
522 let filter_ty = self.infer_expr(f)?;
523 self.expect_bool(&filter_ty, f.span)?;
524 }
525
526 let element_ty = self.infer_expr(element)?;
527 self.env.pop_scope();
528
529 Type::Set(Box::new(element_ty))
530 }
531
532 ExprKind::RecordUpdate { base, updates } => {
533 let base_ty_raw = self.infer_expr(base)?;
534 let base_ty = self.env.resolve_type(&base_ty_raw);
535
536 match &base_ty {
537 Type::Record(rec) => {
538 for update in updates {
539 match update {
540 RecordFieldUpdate::Field { name, value } => {
541 let expected = rec.get_field(&name.name).ok_or_else(|| {
542 TypeError::InvalidField {
543 ty: base_ty.clone(),
544 field: name.name.clone(),
545 span: name.span,
546 }
547 })?;
548 let value_ty = self.infer_expr(value)?;
549 self.unify(expected, &value_ty, value.span)?;
550 }
551 RecordFieldUpdate::Dynamic { key, value } => {
552 let _ = self.infer_expr(key)?;
553 let _ = self.infer_expr(value)?;
554 }
555 }
556 }
557 base_ty
558 }
559 Type::Fn(key_ty, value_ty) => {
560 for update in updates {
562 match update {
563 RecordFieldUpdate::Dynamic { key, value } => {
564 let key_actual = self.infer_expr(key)?;
565 let value_actual = self.infer_expr(value)?;
566 self.unify(key_ty, &key_actual, key.span)?;
567 self.unify(value_ty, &value_actual, value.span)?;
568 }
569 RecordFieldUpdate::Field { name, .. } => {
570 return Err(TypeError::InvalidField {
571 ty: base_ty.clone(),
572 field: name.name.clone(),
573 span: name.span,
574 });
575 }
576 }
577 }
578 base_ty
579 }
580 _ => {
581 return Err(TypeError::InvalidField {
582 ty: base_ty,
583 field: "<update>".to_string(),
584 span: expr.span,
585 });
586 }
587 }
588 }
589
590 ExprKind::Quantifier {
591 kind: _,
592 bindings,
593 body,
594 } => {
595 self.env.push_scope();
596
597 for binding in bindings {
598 let domain_ty = self.infer_expr(&binding.domain)?;
599 let elem_ty = self.element_type(&domain_ty, binding.domain.span)?;
600 self.env.bind_local(binding.var.name.clone(), elem_ty);
601 }
602
603 let body_ty = self.infer_expr(body)?;
604 self.expect_bool(&body_ty, body.span)?;
605
606 self.env.pop_scope();
607 Type::Bool
608 }
609
610 ExprKind::Choose {
611 var,
612 domain,
613 predicate,
614 } => {
615 let domain_ty = self.infer_expr(domain)?;
616 let elem_ty = self.element_type(&domain_ty, domain.span)?;
617
618 self.env.push_scope();
619 self.env.bind_local(var.name.clone(), elem_ty.clone());
620
621 let pred_ty = self.infer_expr(predicate)?;
622 self.expect_bool(&pred_ty, predicate.span)?;
623
624 self.env.pop_scope();
625 elem_ty
626 }
627
628 ExprKind::Fix { var, predicate } => {
629 self.env.push_scope();
633 self.env.bind_local(var.name.clone(), Type::Int);
634
635 let pred_ty = self.infer_expr(predicate)?;
636 self.expect_bool(&pred_ty, predicate.span)?;
637
638 self.env.pop_scope();
639 Type::Int
640 }
641
642 ExprKind::Let { var, value, body } => {
643 let value_ty = self.infer_expr(value)?;
644
645 self.env.push_scope();
646 self.env.bind_local(var.name.clone(), value_ty);
647
648 let body_ty = self.infer_expr(body)?;
649 self.env.pop_scope();
650
651 body_ty
652 }
653
654 ExprKind::If {
655 cond,
656 then_branch,
657 else_branch,
658 } => {
659 let cond_ty = self.infer_expr(cond)?;
660 self.expect_bool(&cond_ty, cond.span)?;
661
662 let then_ty = self.infer_expr(then_branch)?;
663 let else_ty = self.infer_expr(else_branch)?;
664 self.unify(&then_ty, &else_ty, else_branch.span)?;
665
666 then_ty
667 }
668
669 ExprKind::Require(inner) => {
670 let inner_ty = self.infer_expr(inner)?;
671 self.expect_bool(&inner_ty, inner.span)?;
672 Type::Bool
673 }
674
675 ExprKind::Changes(_var) => {
676 Type::Bool
678 }
679
680 ExprKind::Enabled(action) => {
681 if self.env.lookup_action(&action.name).is_none() {
683 return Err(TypeError::UndefinedAction {
684 name: action.name.clone(),
685 span: action.span,
686 });
687 }
688 Type::Bool
689 }
690
691 ExprKind::SeqHead(seq_expr) => {
692 let seq_ty_raw = self.infer_expr(seq_expr)?;
693 let seq_ty = self.env.resolve_type(&seq_ty_raw);
694 match seq_ty {
695 Type::Seq(elem_ty) => *elem_ty,
696 _ => {
697 return Err(TypeError::TypeMismatch {
698 expected: Type::Seq(Box::new(self.var_gen.fresh_type())),
699 found: seq_ty,
700 span: seq_expr.span,
701 });
702 }
703 }
704 }
705
706 ExprKind::SeqTail(seq_expr) => {
707 let seq_ty_raw = self.infer_expr(seq_expr)?;
708 let seq_ty = self.env.resolve_type(&seq_ty_raw);
709 match &seq_ty {
710 Type::Seq(_) => seq_ty,
711 _ => {
712 return Err(TypeError::TypeMismatch {
713 expected: Type::Seq(Box::new(self.var_gen.fresh_type())),
714 found: seq_ty,
715 span: seq_expr.span,
716 });
717 }
718 }
719 }
720
721 ExprKind::Len(expr) => {
722 let ty_raw = self.infer_expr(expr)?;
723 let ty = self.env.resolve_type(&ty_raw);
724 match ty {
725 Type::Seq(_) | Type::Set(_) | Type::Fn(_, _) | Type::Var(_) => Type::Nat,
727 _ => {
728 return Err(TypeError::TypeMismatch {
729 expected: Type::Seq(Box::new(self.var_gen.fresh_type())),
730 found: ty,
731 span: expr.span,
732 });
733 }
734 }
735 }
736
737 ExprKind::Keys(expr) => {
738 let ty_raw = self.infer_expr(expr)?;
739 let ty = self.env.resolve_type(&ty_raw);
740 match ty {
741 Type::Fn(key_ty, _) => Type::Set(key_ty),
742 Type::Seq(_) => Type::Set(Box::new(Type::Int)),
744 _ => {
745 return Err(TypeError::TypeMismatch {
746 expected: Type::Fn(
747 Box::new(self.var_gen.fresh_type()),
748 Box::new(self.var_gen.fresh_type()),
749 ),
750 found: ty,
751 span: expr.span,
752 });
753 }
754 }
755 }
756
757 ExprKind::Values(expr) => {
758 let ty_raw = self.infer_expr(expr)?;
759 let ty = self.env.resolve_type(&ty_raw);
760 match ty {
761 Type::Fn(_, val_ty) => Type::Set(val_ty),
762 _ => {
763 return Err(TypeError::TypeMismatch {
764 expected: Type::Fn(
765 Box::new(self.var_gen.fresh_type()),
766 Box::new(self.var_gen.fresh_type()),
767 ),
768 found: ty,
769 span: expr.span,
770 });
771 }
772 }
773 }
774
775 ExprKind::BigUnion(expr) => {
776 let ty_raw = self.infer_expr(expr)?;
777 let ty = self.env.resolve_type(&ty_raw);
778 match ty {
780 Type::Set(inner) => match *inner {
781 Type::Set(elem_ty) => Type::Set(elem_ty),
782 _ => {
783 return Err(TypeError::TypeMismatch {
784 expected: Type::Set(Box::new(self.var_gen.fresh_type())),
785 found: *inner,
786 span: expr.span,
787 });
788 }
789 },
790 _ => {
791 return Err(TypeError::TypeMismatch {
792 expected: Type::Set(Box::new(Type::Set(Box::new(
793 self.var_gen.fresh_type(),
794 )))),
795 found: ty,
796 span: expr.span,
797 });
798 }
799 }
800 }
801
802 ExprKind::Powerset(expr) => {
803 let ty_raw = self.infer_expr(expr)?;
804 let ty = self.env.resolve_type(&ty_raw);
805 match ty {
807 Type::Set(elem_ty) => Type::Set(Box::new(Type::Set(elem_ty))),
808 _ => {
809 return Err(TypeError::TypeMismatch {
810 expected: Type::Set(Box::new(self.var_gen.fresh_type())),
811 found: ty,
812 span: expr.span,
813 });
814 }
815 }
816 }
817
818 ExprKind::Always(inner) | ExprKind::Eventually(inner) => {
819 let inner_ty = self.infer_expr(inner)?;
820 self.expect_bool(&inner_ty, inner.span)?;
821 Type::Bool
822 }
823
824 ExprKind::LeadsTo { left, right } => {
825 let left_ty = self.infer_expr(left)?;
826 let right_ty = self.infer_expr(right)?;
827 self.expect_bool(&left_ty, left.span)?;
828 self.expect_bool(&right_ty, right.span)?;
829 Type::Bool
830 }
831
832 ExprKind::Range { lo, hi } => {
833 let lo_ty = self.infer_expr(lo)?;
834 let hi_ty = self.infer_expr(hi)?;
835 self.expect_numeric(&lo_ty, lo.span)?;
836 self.expect_numeric(&hi_ty, hi.span)?;
837 Type::Set(Box::new(Type::Int))
839 }
840
841 ExprKind::Paren(inner) => self.infer_expr(inner)?,
842 };
843
844 Ok(ty)
845 }
846
847 fn infer_binary(&mut self, op: BinOp, left: &Expr, right: &Expr) -> TypeResult<Type> {
849 let left_ty = self.infer_expr(left)?;
850 let right_ty = self.infer_expr(right)?;
851
852 match op {
853 BinOp::And | BinOp::Or | BinOp::Implies | BinOp::Iff => {
855 self.expect_bool(&left_ty, left.span)?;
856 self.expect_bool(&right_ty, right.span)?;
857 Ok(Type::Bool)
858 }
859
860 BinOp::Eq | BinOp::Ne => {
862 self.unify(&left_ty, &right_ty, right.span)?;
863 Ok(Type::Bool)
864 }
865
866 BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge => {
867 self.expect_numeric(&left_ty, left.span)?;
868 self.expect_numeric(&right_ty, right.span)?;
869 Ok(Type::Bool)
870 }
871
872 BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Mod => {
874 self.expect_numeric(&left_ty, left.span)?;
875 self.expect_numeric(&right_ty, right.span)?;
876 if matches!(left_ty, Type::Int) || matches!(right_ty, Type::Int) {
878 Ok(Type::Int)
879 } else {
880 Ok(Type::Nat)
881 }
882 }
883
884 BinOp::In | BinOp::NotIn => {
886 let resolved = self.env.resolve_type(&right_ty);
887 match resolved {
888 Type::Set(elem_ty) => {
889 self.unify(&left_ty, &elem_ty, left.span)?;
890 Ok(Type::Bool)
891 }
892 _ => Err(TypeError::NotIterable {
893 ty: right_ty,
894 span: right.span,
895 }),
896 }
897 }
898
899 BinOp::Union | BinOp::Intersect | BinOp::Diff => {
901 let resolved_left = self.env.resolve_type(&left_ty);
902 let resolved_right = self.env.resolve_type(&right_ty);
903
904 match (&resolved_left, &resolved_right) {
905 (Type::Set(_), Type::Set(_)) => {
906 self.unify(&left_ty, &right_ty, right.span)?;
907 Ok(left_ty)
908 }
909 (Type::Fn(_, _), Type::Fn(_, _)) => {
910 self.unify(&left_ty, &right_ty, right.span)?;
912 Ok(left_ty)
913 }
914 _ => Err(TypeError::TypeMismatch {
915 expected: Type::Set(Box::new(self.var_gen.fresh_type())),
916 found: left_ty,
917 span: left.span,
918 }),
919 }
920 }
921
922 BinOp::SubsetOf => {
923 let resolved_left = self.env.resolve_type(&left_ty);
924 let resolved_right = self.env.resolve_type(&right_ty);
925
926 match (&resolved_left, &resolved_right) {
927 (Type::Set(_), Type::Set(_)) => {
928 self.unify(&left_ty, &right_ty, right.span)?;
929 Ok(Type::Bool)
930 }
931 _ => Err(TypeError::TypeMismatch {
932 expected: Type::Set(Box::new(self.var_gen.fresh_type())),
933 found: left_ty,
934 span: left.span,
935 }),
936 }
937 }
938
939 BinOp::Concat => {
941 let resolved_left = self.env.resolve_type(&left_ty);
942 let resolved_right = self.env.resolve_type(&right_ty);
943
944 match (&resolved_left, &resolved_right) {
945 (Type::Seq(_), Type::Seq(_)) => {
946 self.unify(&left_ty, &right_ty, right.span)?;
947 Ok(left_ty)
948 }
949 _ => Err(TypeError::TypeMismatch {
950 expected: Type::Seq(Box::new(self.var_gen.fresh_type())),
951 found: left_ty,
952 span: left.span,
953 }),
954 }
955 }
956 }
957 }
958
959 fn infer_unary(&mut self, op: UnaryOp, operand: &Expr) -> TypeResult<Type> {
961 let operand_ty = self.infer_expr(operand)?;
962
963 match op {
964 UnaryOp::Not => {
965 self.expect_bool(&operand_ty, operand.span)?;
966 Ok(Type::Bool)
967 }
968 UnaryOp::Neg => {
969 self.expect_numeric(&operand_ty, operand.span)?;
970 Ok(Type::Int) }
972 }
973 }
974
975 fn element_type(&mut self, ty: &Type, span: Span) -> TypeResult<Type> {
977 let resolved = self.env.resolve_type(ty);
978 match resolved {
979 Type::Set(elem) | Type::Seq(elem) => Ok(*elem),
980 Type::Fn(key, _) => Ok(*key), Type::Range(_, _) => Ok(Type::Int),
982 Type::Var(_) => {
983 Ok(self.var_gen.fresh_type())
986 }
987 _ => Err(TypeError::NotIterable { ty: resolved, span }),
988 }
989 }
990
991 fn unify(&mut self, a: &Type, b: &Type, span: Span) -> TypeResult<Substitution> {
993 let a = self.env.resolve_type(a);
994 let b = self.env.resolve_type(b);
995
996 if a == b {
997 return Ok(Substitution::new());
998 }
999
1000 match (&a, &b) {
1001 (Type::Var(v), _) => {
1003 if b.has_vars() && matches!(b, Type::Var(v2) if *v == v2) {
1004 Ok(Substitution::new())
1006 } else if self.occurs_in(v, &b) {
1007 Err(TypeError::OccursCheck { span })
1008 } else {
1009 let mut subst = Substitution::new();
1010 subst.insert(*v, b.clone());
1011 Ok(subst)
1012 }
1013 }
1014 (_, Type::Var(v)) => {
1015 if self.occurs_in(v, &a) {
1016 Err(TypeError::OccursCheck { span })
1017 } else {
1018 let mut subst = Substitution::new();
1019 subst.insert(*v, a.clone());
1020 Ok(subst)
1021 }
1022 }
1023
1024 (Type::Set(a_elem), Type::Set(b_elem)) => self.unify(a_elem, b_elem, span),
1026 (Type::Seq(a_elem), Type::Seq(b_elem)) => self.unify(a_elem, b_elem, span),
1027 (Type::Option(a_elem), Type::Option(b_elem)) => self.unify(a_elem, b_elem, span),
1028 (Type::Fn(a_key, a_val), Type::Fn(b_key, b_val)) => {
1029 let s1 = self.unify(a_key, b_key, span)?;
1030 let a_val = a_val.substitute(&s1);
1031 let b_val = b_val.substitute(&s1);
1032 let s2 = self.unify(&a_val, &b_val, span)?;
1033 Ok(s1.compose(&s2))
1034 }
1035 (Type::Record(a_rec), Type::Record(b_rec)) => {
1036 if a_rec.fields.keys().collect::<Vec<_>>()
1037 != b_rec.fields.keys().collect::<Vec<_>>()
1038 {
1039 return Err(TypeError::UnificationFailure {
1040 a: a.clone(),
1041 b: b.clone(),
1042 span,
1043 });
1044 }
1045
1046 let mut subst = Substitution::new();
1047 for (name, a_ty) in &a_rec.fields {
1048 let b_ty = b_rec.fields.get(name).unwrap();
1049 let a_ty = a_ty.substitute(&subst);
1050 let b_ty = b_ty.substitute(&subst);
1051 let s = self.unify(&a_ty, &b_ty, span)?;
1052 subst = subst.compose(&s);
1053 }
1054 Ok(subst)
1055 }
1056 (Type::Tuple(a_elems), Type::Tuple(b_elems)) => {
1057 if a_elems.len() != b_elems.len() {
1058 return Err(TypeError::UnificationFailure {
1059 a: a.clone(),
1060 b: b.clone(),
1061 span,
1062 });
1063 }
1064
1065 let mut subst = Substitution::new();
1066 for (a_ty, b_ty) in a_elems.iter().zip(b_elems.iter()) {
1067 let a_ty = a_ty.substitute(&subst);
1068 let b_ty = b_ty.substitute(&subst);
1069 let s = self.unify(&a_ty, &b_ty, span)?;
1070 subst = subst.compose(&s);
1071 }
1072 Ok(subst)
1073 }
1074
1075 (Type::Int, Type::Nat) | (Type::Nat, Type::Int) => Ok(Substitution::new()),
1077 (Type::Int, Type::Range(_, _)) | (Type::Range(_, _), Type::Int) => {
1078 Ok(Substitution::new())
1079 }
1080 (Type::Nat, Type::Range(lo, _)) | (Type::Range(lo, _), Type::Nat) if *lo >= 0 => {
1081 Ok(Substitution::new())
1082 }
1083 (Type::Range(a_lo, a_hi), Type::Range(b_lo, b_hi)) if a_lo == b_lo && a_hi == b_hi => {
1084 Ok(Substitution::new())
1085 }
1086
1087 (Type::Error, _) | (_, Type::Error) => Ok(Substitution::new()),
1089
1090 (Type::Int, Type::Bool) | (Type::Bool, Type::Int) => Ok(Substitution::new()),
1092
1093 _ => Err(TypeError::UnificationFailure {
1094 a: a.clone(),
1095 b: b.clone(),
1096 span,
1097 }),
1098 }
1099 }
1100
1101 fn occurs_in(&self, var: &crate::types::TypeVar, ty: &Type) -> bool {
1103 Self::occurs_in_impl(var, ty)
1104 }
1105
1106 fn occurs_in_impl(var: &crate::types::TypeVar, ty: &Type) -> bool {
1107 match ty {
1108 Type::Var(v) => v == var,
1109 Type::Set(t) | Type::Seq(t) | Type::Option(t) => Self::occurs_in_impl(var, t),
1110 Type::Fn(k, v) => Self::occurs_in_impl(var, k) || Self::occurs_in_impl(var, v),
1111 Type::Record(r) => r.fields.values().any(|t| Self::occurs_in_impl(var, t)),
1112 Type::Tuple(elems) => elems.iter().any(|t| Self::occurs_in_impl(var, t)),
1113 _ => false,
1114 }
1115 }
1116
1117 fn expect_bool(&self, ty: &Type, span: Span) -> TypeResult<()> {
1119 let resolved = self.env.resolve_type(ty);
1120 match resolved {
1121 Type::Bool | Type::Var(_) | Type::Error => Ok(()),
1122 _ => Err(TypeError::ExpectedBool {
1123 found: resolved,
1124 span,
1125 }),
1126 }
1127 }
1128
1129 fn expect_numeric(&self, ty: &Type, span: Span) -> TypeResult<()> {
1131 let resolved = self.env.resolve_type(ty);
1132 match resolved {
1133 Type::Nat | Type::Int | Type::Range(_, _) | Type::Var(_) | Type::Error => Ok(()),
1134 _ => Err(TypeError::ExpectedNumeric {
1135 found: resolved,
1136 span,
1137 }),
1138 }
1139 }
1140}
1141
1142impl Default for TypeChecker {
1143 fn default() -> Self {
1144 Self::new()
1145 }
1146}
1147
1148#[cfg(test)]
1149mod tests {
1150 use super::*;
1151 use specl_syntax::parse;
1152
1153 fn check(source: &str) -> TypeResult<TypeEnv> {
1154 let module = parse(source).expect("parse error");
1155 check_module(&module)
1156 }
1157
1158 #[test]
1159 fn test_simple_types() {
1160 let source = r#"
1161module Test
1162var x: Nat
1163var y: Bool
1164init { x == 0 and y == true }
1165"#;
1166 assert!(check(source).is_ok());
1167 }
1168
1169 #[test]
1170 fn test_type_mismatch() {
1171 let source = r#"
1172module Test
1173var x: Nat
1174init { x == true }
1175"#;
1176 let result = check(source);
1177 assert!(result.is_err());
1178 }
1179
1180 #[test]
1181 fn test_undefined_variable() {
1182 let source = r#"
1183module Test
1184init { undefined_var == 0 }
1185"#;
1186 let result = check(source);
1187 assert!(matches!(result, Err(TypeError::UndefinedVariable { .. })));
1188 }
1189
1190 #[test]
1191 fn test_set_operations() {
1192 let source = r#"
1193module Test
1194var s: Set[Nat]
1195init { 1 in s and s union {} == s }
1196"#;
1197 assert!(check(source).is_ok());
1198 }
1199
1200 #[test]
1201 fn test_quantifier() {
1202 let source = r#"
1203module Test
1204var s: Set[Nat]
1205init { all x in s: x >= 0 }
1206"#;
1207 assert!(check(source).is_ok());
1208 }
1209
1210 #[test]
1211 fn test_action_type_check() {
1212 let source = r#"
1213module Test
1214var x: Nat
1215const MAX: Nat
1216action Inc() {
1217 require x < MAX
1218 x = x + 1
1219}
1220"#;
1221 assert!(check(source).is_ok());
1222 }
1223
1224 #[test]
1225 fn test_dict_access() {
1226 let source = r#"
1227module Test
1228var d: Dict[Nat, Bool]
1229init { d[0] == false and d[1] == true }
1230"#;
1231 assert!(check(source).is_ok());
1232 }
1233
1234 #[test]
1235 fn test_dict_comprehension() {
1236 let source = r#"
1237module Test
1238var d: Dict[0..2, Nat]
1239init { d == {x: 0 for x in 0..2} }
1240"#;
1241 assert!(check(source).is_ok());
1242 }
1243}