1use indexmap::IndexMap;
4use smol_str::SmolStr;
5use solscript_ast::{self as ast, Span};
6
7use crate::error::TypeError;
8use crate::scope::{ScopeKind, SymbolTable};
9use crate::types::{
10 ContractDef, EnumDef, ErrorDef, ErrorParam, EventDef, EventParam, FunctionType, InterfaceDef,
11 ModifierType, NamedType, PrimitiveType, StructDef, Type, TypeDef, TypeVar,
12};
13
14pub struct TypeChecker {
16 symbols: SymbolTable,
18 source: String,
20 next_type_var: u32,
22 errors: Vec<TypeError>,
24 return_type: Option<Type>,
26 self_type: Option<Type>,
28 contracts: std::collections::HashMap<String, ast::ContractDef>,
30}
31
32impl TypeChecker {
33 pub fn new(source: String) -> Self {
34 Self {
35 symbols: SymbolTable::new(),
36 source,
37 next_type_var: 0,
38 errors: Vec::new(),
39 return_type: None,
40 self_type: None,
41 contracts: std::collections::HashMap::new(),
42 }
43 }
44
45 pub fn check_program(&mut self, program: &ast::Program) -> Result<(), Vec<TypeError>> {
47 for item in &program.items {
49 self.collect_type_def(item);
50 }
51
52 for item in &program.items {
54 if let ast::Item::Contract(c) = item {
55 self.contracts.insert(c.name.name.to_string(), c.clone());
56 }
57 }
58
59 for item in &program.items {
61 self.check_item(item);
62 }
63
64 if self.errors.is_empty() {
65 Ok(())
66 } else {
67 Err(std::mem::take(&mut self.errors))
68 }
69 }
70
71 fn fresh_type_var(&mut self) -> Type {
73 let var = TypeVar(self.next_type_var);
74 self.next_type_var += 1;
75 Type::Var(var)
76 }
77
78 fn span(&self, span: Span) -> (usize, usize) {
80 (span.start, span.end)
81 }
82
83 fn error(&mut self, err: TypeError) {
85 self.errors.push(err);
86 }
87
88 fn collect_type_def(&mut self, item: &ast::Item) {
93 match item {
94 ast::Item::Struct(s) => {
95 let def = self.build_struct_def(s);
96 self.symbols
97 .define_type(s.name.name.clone(), TypeDef::Struct(def));
98 }
99 ast::Item::Enum(e) => {
100 let def = self.build_enum_def(e);
101 self.symbols
102 .define_type(e.name.name.clone(), TypeDef::Enum(def));
103 }
104 ast::Item::Contract(c) => {
105 let def = self.build_contract_def(c);
106 self.symbols
107 .define_type(c.name.name.clone(), TypeDef::Contract(def));
108
109 for member in &c.members {
111 match member {
112 ast::ContractMember::Event(e) => {
113 let event_def = self.build_event_def(e);
114 self.symbols
115 .define_type(e.name.name.clone(), TypeDef::Event(event_def));
116 }
117 ast::ContractMember::Error(e) => {
118 let error_def = self.build_error_def(e);
119 self.symbols
120 .define_type(e.name.name.clone(), TypeDef::Error(error_def));
121 }
122 ast::ContractMember::Struct(s) => {
123 let struct_def = self.build_struct_def(s);
124 self.symbols
125 .define_type(s.name.name.clone(), TypeDef::Struct(struct_def));
126 }
127 ast::ContractMember::Enum(e) => {
128 let enum_def = self.build_enum_def(e);
129 self.symbols
130 .define_type(e.name.name.clone(), TypeDef::Enum(enum_def));
131 }
132 _ => {}
133 }
134 }
135 }
136 ast::Item::Interface(i) => {
137 let def = self.build_interface_def(i);
138 self.symbols
139 .define_type(i.name.name.clone(), TypeDef::Interface(def));
140 }
141 ast::Item::Event(e) => {
142 let def = self.build_event_def(e);
143 self.symbols
144 .define_type(e.name.name.clone(), TypeDef::Event(def));
145 }
146 ast::Item::Error(e) => {
147 let def = self.build_error_def(e);
148 self.symbols
149 .define_type(e.name.name.clone(), TypeDef::Error(def));
150 }
151 _ => {}
152 }
153 }
154
155 fn build_struct_def(&mut self, s: &ast::StructDef) -> StructDef {
156 let type_params = s
157 .generic_params
158 .as_ref()
159 .map(|g| g.params.iter().map(|p| p.name.name.clone()).collect())
160 .unwrap_or_default();
161
162 let mut fields = IndexMap::new();
163 for field in &s.fields {
164 let ty = self.resolve_type_expr(&field.ty);
165 fields.insert(field.name.name.clone(), ty);
166 }
167
168 StructDef {
169 name: s.name.name.clone(),
170 type_params,
171 fields,
172 }
173 }
174
175 fn build_enum_def(&mut self, e: &ast::EnumDef) -> EnumDef {
176 let variants = e.variants.iter().map(|v| v.name.name.clone()).collect();
178
179 EnumDef {
180 name: e.name.name.clone(),
181 variants,
182 }
183 }
184
185 fn build_contract_def(&mut self, c: &ast::ContractDef) -> ContractDef {
186 let type_params = Vec::new(); let bases: Vec<SmolStr> = c.bases.iter().map(|b| b.name().clone()).collect();
189
190 let mut state_fields = IndexMap::new();
191 let mut methods = IndexMap::new();
192 let mut modifiers = IndexMap::new();
193
194 for member in &c.members {
195 match member {
196 ast::ContractMember::StateVar(f) => {
197 state_fields.insert(f.name.name.clone(), self.resolve_type_expr(&f.ty));
198 }
199 ast::ContractMember::Function(f) => {
200 let fn_ty = self.build_function_type(f);
201 methods.insert(f.name.name.clone(), fn_ty);
202 }
203 ast::ContractMember::Constructor(_) => {} ast::ContractMember::Modifier(m) => {
205 let mod_ty = self.build_modifier_type(m);
206 modifiers.insert(m.name.name.clone(), mod_ty);
207 }
208 ast::ContractMember::Event(_) | ast::ContractMember::Error(_) => {
209 }
211 ast::ContractMember::Struct(_) | ast::ContractMember::Enum(_) => {
212 }
214 }
215 }
216
217 ContractDef {
218 name: c.name.name.clone(),
219 type_params,
220 bases,
221 state_fields,
222 methods,
223 modifiers,
224 }
225 }
226
227 fn build_modifier_type(&mut self, m: &ast::ModifierDef) -> ModifierType {
228 let params: Vec<Type> = m
229 .params
230 .iter()
231 .map(|p| self.resolve_type_expr(&p.ty))
232 .collect();
233 ModifierType {
234 name: m.name.name.clone(),
235 params,
236 }
237 }
238
239 fn build_event_def(&mut self, e: &ast::EventDef) -> EventDef {
240 let params: Vec<EventParam> = e
241 .params
242 .iter()
243 .map(|p| EventParam {
244 name: p.name.name.clone(),
245 ty: self.resolve_type_expr(&p.ty),
246 indexed: p.indexed,
247 })
248 .collect();
249 EventDef {
250 name: e.name.name.clone(),
251 params,
252 }
253 }
254
255 fn build_error_def(&mut self, e: &ast::ErrorDef) -> ErrorDef {
256 let params: Vec<ErrorParam> = e
257 .params
258 .iter()
259 .map(|p| ErrorParam {
260 name: p.name.name.clone(),
261 ty: self.resolve_type_expr(&p.ty),
262 })
263 .collect();
264 ErrorDef {
265 name: e.name.name.clone(),
266 params,
267 }
268 }
269
270 fn build_interface_def(&mut self, i: &ast::InterfaceDef) -> InterfaceDef {
271 let bases: Vec<SmolStr> = i.bases.iter().map(|b| b.name().clone()).collect();
272
273 let mut methods = IndexMap::new();
274 for sig in &i.members {
275 let fn_ty = self.build_fn_sig_type(sig);
276 methods.insert(sig.name.name.clone(), fn_ty);
277 }
278
279 InterfaceDef {
280 name: i.name.name.clone(),
281 bases,
282 methods,
283 }
284 }
285
286 fn build_function_type(&mut self, f: &ast::FnDef) -> FunctionType {
287 let params: Vec<Type> = f
288 .params
289 .iter()
290 .map(|p| self.resolve_type_expr(&p.ty))
291 .collect();
292
293 let return_type = if f.return_params.is_empty() {
295 Type::Unit
296 } else if f.return_params.len() == 1 {
297 self.resolve_type_expr(&f.return_params[0].ty)
298 } else {
299 let types: Vec<Type> = f
301 .return_params
302 .iter()
303 .map(|rp| self.resolve_type_expr(&rp.ty))
304 .collect();
305 Type::Tuple(types)
306 };
307
308 FunctionType {
309 params,
310 return_type: Box::new(return_type),
311 }
312 }
313
314 fn build_fn_sig_type(&mut self, sig: &ast::FnSig) -> FunctionType {
315 let params: Vec<Type> = sig
316 .params
317 .iter()
318 .map(|p| self.resolve_type_expr(&p.ty))
319 .collect();
320
321 let return_type = if sig.return_params.is_empty() {
322 Type::Unit
323 } else if sig.return_params.len() == 1 {
324 self.resolve_type_expr(&sig.return_params[0].ty)
325 } else {
326 let types: Vec<Type> = sig
327 .return_params
328 .iter()
329 .map(|rp| self.resolve_type_expr(&rp.ty))
330 .collect();
331 Type::Tuple(types)
332 };
333
334 FunctionType {
335 params,
336 return_type: Box::new(return_type),
337 }
338 }
339
340 fn resolve_type_expr(&mut self, ty: &ast::TypeExpr) -> Type {
345 match ty {
346 ast::TypeExpr::Path(path) => self.resolve_type_path(path),
347 ast::TypeExpr::Array(arr) => {
348 let elem = self.resolve_type_path(&arr.element);
349 let mut current_type = elem;
351 for size in arr.sizes.iter().rev() {
352 current_type = match size {
353 Some(n) => Type::Array(Box::new(current_type), *n),
354 None => Type::DynamicArray(Box::new(current_type)),
355 };
356 }
357 current_type
358 }
359 ast::TypeExpr::Mapping(mapping) => {
360 let key = self.resolve_type_expr(&mapping.key);
361 let value = self.resolve_type_expr(&mapping.value);
362 Type::Mapping(Box::new(key), Box::new(value))
363 }
364 ast::TypeExpr::Tuple(tuple) => {
365 let elems: Vec<Type> = tuple
366 .elements
367 .iter()
368 .map(|t| self.resolve_type_expr(t))
369 .collect();
370 Type::Tuple(elems)
371 }
372 }
373 }
374
375 fn resolve_type_path(&mut self, path: &ast::TypePath) -> Type {
376 let name = path.name();
377
378 if let Some(prim) = PrimitiveType::parse(name.as_str()) {
380 return Type::Primitive(prim);
381 }
382
383 if self.symbols.lookup_type(name).is_some() {
385 let type_args = path
386 .generic_args
387 .as_ref()
388 .map(|g| g.args.iter().map(|a| self.resolve_type_expr(a)).collect())
389 .unwrap_or_default();
390 Type::Named(NamedType::with_args(name.clone(), type_args))
391 } else {
392 self.error(TypeError::undefined_type(
393 name,
394 self.span(path.span),
395 &self.source,
396 ));
397 Type::Error
398 }
399 }
400
401 fn check_item(&mut self, item: &ast::Item) {
406 match item {
407 ast::Item::Contract(c) => self.check_contract(c),
408 ast::Item::Struct(s) => self.check_struct(s),
409 ast::Item::Enum(e) => self.check_enum(e),
410 ast::Item::Function(f) => self.check_function(f),
411 ast::Item::Interface(_) => {} ast::Item::Import(_) => {} ast::Item::Event(_) => {} ast::Item::Error(_) => {} }
416 }
417
418 fn check_contract(&mut self, contract: &ast::ContractDef) {
419 let contract_type = Type::Named(NamedType::new(contract.name.name.clone()));
420 self.self_type = Some(contract_type);
421
422 self.symbols.push_scope(ScopeKind::Contract);
423
424 for base in &contract.bases {
426 let base_name = base.segments.first().map(|s| s.name.as_str()).unwrap_or("");
427 if let Some(base_contract) = self.contracts.get(base_name).cloned() {
428 for member in &base_contract.members {
430 if let ast::ContractMember::StateVar(f) = member {
431 let ty = self.resolve_type_expr(&f.ty);
432 self.symbols.define_variable(f.name.name.clone(), ty, true);
433 }
434 }
435 }
436 }
437
438 for member in &contract.members {
440 if let ast::ContractMember::StateVar(f) = member {
441 let ty = self.resolve_type_expr(&f.ty);
442 self.symbols.define_variable(f.name.name.clone(), ty, true);
443 }
444 }
445
446 for member in &contract.members {
448 if let ast::ContractMember::Function(f) = member {
449 let fn_ty = self.build_function_type(f);
450 self.symbols
451 .define_variable(f.name.name.clone(), Type::Function(fn_ty), false);
452 }
453 }
454
455 for member in &contract.members {
457 match member {
458 ast::ContractMember::Constructor(c) => self.check_constructor(c),
459 ast::ContractMember::Function(f) => self.check_function(f),
460 ast::ContractMember::Modifier(m) => self.check_modifier_def(m),
461 ast::ContractMember::StateVar(_) => {} ast::ContractMember::Event(_) => {} ast::ContractMember::Error(_) => {} ast::ContractMember::Struct(s) => self.check_struct(s),
465 ast::ContractMember::Enum(e) => self.check_enum(e),
466 }
467 }
468
469 self.symbols.pop_scope();
470 self.self_type = None;
471 }
472
473 fn check_struct(&mut self, s: &ast::StructDef) {
474 let mut seen_fields = std::collections::HashSet::new();
476 for field in &s.fields {
477 let field_name = field.name.name.as_str();
478 if seen_fields.contains(field_name) {
479 self.error(TypeError::DuplicateDefinition {
480 name: field_name.to_string(),
481 span: miette::SourceSpan::new(
482 field.name.span.start.into(),
483 field.name.span.end - field.name.span.start,
484 ),
485 src: self.source.clone(),
486 });
487 } else {
488 seen_fields.insert(field_name.to_string());
489 }
490
491 let _ = self.resolve_type_expr(&field.ty);
493 }
494 }
495
496 fn check_enum(&mut self, e: &ast::EnumDef) {
497 let mut seen_variants = std::collections::HashSet::new();
499 for variant in &e.variants {
500 let variant_name = variant.name.name.as_str();
501 if seen_variants.contains(variant_name) {
502 self.error(TypeError::DuplicateDefinition {
503 name: variant_name.to_string(),
504 span: miette::SourceSpan::new(
505 variant.name.span.start.into(),
506 variant.name.span.end - variant.name.span.start,
507 ),
508 src: self.source.clone(),
509 });
510 } else {
511 seen_variants.insert(variant_name.to_string());
512 }
513 }
514 }
515
516 fn check_function(&mut self, f: &ast::FnDef) {
517 let fn_ty = self.build_function_type(f);
518 self.return_type = Some((*fn_ty.return_type).clone());
519
520 for modifier in &f.modifiers {
522 self.check_modifier_invocation(modifier);
523 }
524
525 self.symbols.push_scope(ScopeKind::Function);
526
527 for param in &f.params {
529 let ty = self.resolve_type_expr(¶m.ty);
530 self.symbols
531 .define_variable(param.name.name.clone(), ty, false);
532 }
533
534 if let Some(body) = &f.body {
536 self.check_block(body);
537 }
538
539 self.symbols.pop_scope();
540 self.return_type = None;
541 }
542
543 fn check_modifier_invocation(&mut self, modifier: &ast::ModifierInvocation) {
544 let modifier_name = &modifier.name.name;
545
546 if let Some(Type::Named(named)) = &self.self_type {
548 if let Some(TypeDef::Contract(contract_def)) = self.symbols.lookup_type(&named.name) {
549 if let Some(mod_type) = contract_def.modifiers.get(modifier_name).cloned() {
551 self.validate_modifier_args(modifier, &mod_type);
552 return;
553 }
554
555 for base_name in &contract_def.bases {
557 if let Some(TypeDef::Contract(base_def)) = self.symbols.lookup_type(base_name) {
558 if let Some(mod_type) = base_def.modifiers.get(modifier_name).cloned() {
559 self.validate_modifier_args(modifier, &mod_type);
560 return;
561 }
562 }
563 }
564 }
565 }
566
567 self.error(TypeError::UndefinedModifier {
569 name: modifier_name.to_string(),
570 span: miette::SourceSpan::new(
571 modifier.name.span.start.into(),
572 modifier.name.span.end - modifier.name.span.start,
573 ),
574 src: self.source.clone(),
575 });
576 }
577
578 fn validate_modifier_args(
579 &mut self,
580 modifier: &ast::ModifierInvocation,
581 mod_type: &ModifierType,
582 ) {
583 if modifier.args.len() != mod_type.params.len() {
585 self.error(TypeError::wrong_arg_count(
586 mod_type.params.len(),
587 modifier.args.len(),
588 self.span(modifier.span),
589 &self.source,
590 ));
591 return;
592 }
593
594 for (arg, param_ty) in modifier.args.iter().zip(mod_type.params.iter()) {
596 let arg_ty = self.check_expr(&arg.value);
597 if !self.types_compatible(param_ty, &arg_ty) {
598 self.error(TypeError::type_mismatch(
599 param_ty,
600 &arg_ty,
601 self.span(arg.value.span()),
602 &self.source,
603 ));
604 }
605 }
606 }
607
608 fn check_constructor(&mut self, c: &ast::ConstructorDef) {
609 self.return_type = Some(Type::Unit);
610
611 self.symbols.push_scope(ScopeKind::Function);
612
613 for param in &c.params {
615 let ty = self.resolve_type_expr(¶m.ty);
616 self.symbols
617 .define_variable(param.name.name.clone(), ty, false);
618 }
619
620 self.check_block(&c.body);
622
623 self.symbols.pop_scope();
624 self.return_type = None;
625 }
626
627 fn check_modifier_def(&mut self, m: &ast::ModifierDef) {
628 self.return_type = Some(Type::Unit);
629
630 self.symbols.push_scope(ScopeKind::Function);
631
632 for param in &m.params {
634 let ty = self.resolve_type_expr(¶m.ty);
635 self.symbols
636 .define_variable(param.name.name.clone(), ty, false);
637 }
638
639 self.check_block(&m.body);
641
642 self.symbols.pop_scope();
643 self.return_type = None;
644 }
645
646 fn check_block(&mut self, block: &ast::Block) {
651 self.symbols.push_scope(ScopeKind::Block);
652
653 for stmt in &block.stmts {
654 self.check_stmt(stmt);
655 }
656
657 self.symbols.pop_scope();
658 }
659
660 fn check_stmt(&mut self, stmt: &ast::Stmt) {
661 match stmt {
662 ast::Stmt::VarDecl(v) => self.check_var_decl_stmt(v),
663 ast::Stmt::Return(r) => self.check_return_stmt(r),
664 ast::Stmt::If(i) => self.check_if_stmt(i),
665 ast::Stmt::While(w) => self.check_while_stmt(w),
666 ast::Stmt::For(f) => self.check_for_stmt(f),
667 ast::Stmt::Emit(e) => self.check_emit_stmt(e),
668 ast::Stmt::Require(r) => self.check_require_stmt(r),
669 ast::Stmt::Revert(r) => self.check_revert_stmt(r),
670 ast::Stmt::Delete(d) => {
671 self.check_expr(&d.target);
673 }
674 ast::Stmt::Selfdestruct(s) => {
675 let recipient_ty = self.check_expr(&s.recipient);
677 if !matches!(recipient_ty, Type::Primitive(PrimitiveType::Address)) {
678 self.error(TypeError::type_mismatch(
679 &Type::Primitive(PrimitiveType::Address),
680 &recipient_ty,
681 self.span(s.span),
682 &self.source,
683 ));
684 }
685 }
686 ast::Stmt::Placeholder(_) => {} ast::Stmt::Expr(e) => {
688 self.check_expr(&e.expr);
689 }
690 }
691 }
692
693 fn check_var_decl_stmt(&mut self, v: &ast::VarDeclStmt) {
694 let declared_ty = self.resolve_type_expr(&v.ty);
695
696 if let Some(init) = &v.initializer {
697 let value_ty = self.check_expr(init);
698
699 if !self.types_compatible(&declared_ty, &value_ty) {
700 self.error(TypeError::type_mismatch(
701 &declared_ty,
702 &value_ty,
703 self.span(v.span),
704 &self.source,
705 ));
706 }
707 }
708
709 self.symbols
711 .define_variable(v.name.name.clone(), declared_ty, true);
712 }
713
714 fn check_return_stmt(&mut self, r: &ast::ReturnStmt) {
715 let value_ty = r
716 .value
717 .as_ref()
718 .map(|v| self.check_expr(v))
719 .unwrap_or(Type::Unit);
720
721 if let Some(expected) = &self.return_type {
722 if !self.types_compatible(expected, &value_ty) {
723 self.error(TypeError::type_mismatch(
724 expected,
725 &value_ty,
726 self.span(r.span),
727 &self.source,
728 ));
729 }
730 }
731 }
732
733 fn check_if_stmt(&mut self, i: &ast::IfStmt) {
734 let cond_ty = self.check_expr(&i.condition);
735 if !cond_ty.is_bool() && !matches!(cond_ty, Type::Error) {
736 self.error(TypeError::type_mismatch(
737 &Type::Primitive(PrimitiveType::Bool),
738 &cond_ty,
739 self.span(i.condition.span()),
740 &self.source,
741 ));
742 }
743
744 self.check_block(&i.then_block);
745
746 if let Some(else_branch) = &i.else_branch {
747 match else_branch {
748 ast::ElseBranch::Else(block) => self.check_block(block),
749 ast::ElseBranch::ElseIf(elif) => self.check_if_stmt(elif),
750 }
751 }
752 }
753
754 fn check_while_stmt(&mut self, w: &ast::WhileStmt) {
755 let cond_ty = self.check_expr(&w.condition);
756 if !cond_ty.is_bool() && !matches!(cond_ty, Type::Error) {
757 self.error(TypeError::type_mismatch(
758 &Type::Primitive(PrimitiveType::Bool),
759 &cond_ty,
760 self.span(w.condition.span()),
761 &self.source,
762 ));
763 }
764
765 self.check_block(&w.body);
766 }
767
768 fn check_for_stmt(&mut self, f: &ast::ForStmt) {
769 self.symbols.push_scope(ScopeKind::Block);
770
771 if let Some(init) = &f.init {
773 match init {
774 ast::ForInit::VarDecl(v) => self.check_var_decl_stmt(v),
775 ast::ForInit::Expr(e) => {
776 self.check_expr(e);
777 }
778 }
779 }
780
781 if let Some(cond) = &f.condition {
783 let cond_ty = self.check_expr(cond);
784 if !cond_ty.is_bool() && !matches!(cond_ty, Type::Error) {
785 self.error(TypeError::type_mismatch(
786 &Type::Primitive(PrimitiveType::Bool),
787 &cond_ty,
788 self.span(cond.span()),
789 &self.source,
790 ));
791 }
792 }
793
794 if let Some(update) = &f.update {
796 self.check_expr(update);
797 }
798
799 self.check_block(&f.body);
801
802 self.symbols.pop_scope();
803 }
804
805 fn check_emit_stmt(&mut self, e: &ast::EmitStmt) {
806 let event_name = &e.event.name;
807
808 if let Some(TypeDef::Event(event_def)) = self.symbols.lookup_type(event_name) {
810 if e.args.len() != event_def.params.len() {
812 self.error(TypeError::wrong_arg_count(
813 event_def.params.len(),
814 e.args.len(),
815 self.span(e.span),
816 &self.source,
817 ));
818 return;
819 }
820
821 let event_params = event_def.params.clone();
823 for (arg, param) in e.args.iter().zip(event_params.iter()) {
824 let arg_ty = self.check_expr(&arg.value);
825 if !self.types_compatible(¶m.ty, &arg_ty) {
826 self.error(TypeError::type_mismatch(
827 ¶m.ty,
828 &arg_ty,
829 self.span(arg.value.span()),
830 &self.source,
831 ));
832 }
833 }
834 } else {
835 self.error(TypeError::UndefinedEvent {
836 name: event_name.to_string(),
837 span: miette::SourceSpan::new(
838 e.event.span.start.into(),
839 e.event.span.end - e.event.span.start,
840 ),
841 src: self.source.clone(),
842 });
843 }
844 }
845
846 fn check_require_stmt(&mut self, r: &ast::RequireStmt) {
847 let cond_ty = self.check_expr(&r.condition);
848 if !cond_ty.is_bool() && !matches!(cond_ty, Type::Error) {
849 self.error(TypeError::type_mismatch(
850 &Type::Primitive(PrimitiveType::Bool),
851 &cond_ty,
852 self.span(r.condition.span()),
853 &self.source,
854 ));
855 }
856 }
857
858 fn check_revert_stmt(&mut self, r: &ast::RevertStmt) {
859 match &r.kind {
860 ast::RevertKind::Message(_) => {
861 }
863 ast::RevertKind::Error { name, args } => {
864 let error_name = &name.name;
866
867 if let Some(type_def) = self.symbols.lookup_type(error_name) {
868 if let TypeDef::Error(error_def) = type_def {
869 if args.len() != error_def.params.len() {
871 self.error(TypeError::wrong_arg_count(
872 error_def.params.len(),
873 args.len(),
874 self.span(r.span),
875 &self.source,
876 ));
877 return;
878 }
879 let error_params = error_def.params.clone();
881 for (arg, param) in args.iter().zip(error_params.iter()) {
882 let arg_ty = self.check_expr(&arg.value);
883 if !self.types_compatible(¶m.ty, &arg_ty) {
884 self.error(TypeError::type_mismatch(
885 ¶m.ty,
886 &arg_ty,
887 self.span(arg.span),
888 &self.source,
889 ));
890 }
891 }
892 } else {
893 self.error(TypeError::not_callable(
895 &Type::Named(NamedType::new(error_name.clone())),
896 self.span(name.span),
897 &self.source,
898 ));
899 }
900 } else {
901 self.error(TypeError::undefined_type(
902 error_name.as_str(),
903 self.span(name.span),
904 &self.source,
905 ));
906 }
907 }
908 }
909 }
910
911 fn check_expr(&mut self, expr: &ast::Expr) -> Type {
916 match expr {
917 ast::Expr::Literal(lit) => self.check_literal(lit),
918 ast::Expr::Ident(ident) => self.check_ident_expr(ident),
919 ast::Expr::Binary(bin) => self.check_binary_expr(bin),
920 ast::Expr::Unary(un) => self.check_unary_expr(un),
921 ast::Expr::Call(call) => self.check_call_expr(call),
922 ast::Expr::MethodCall(mc) => self.check_method_call(mc),
923 ast::Expr::FieldAccess(fa) => self.check_field_access(fa),
924 ast::Expr::Index(idx) => self.check_index_expr(idx),
925 ast::Expr::If(if_expr) => self.check_if_expr(if_expr),
926 ast::Expr::Array(arr) => self.check_array_expr(arr),
927 ast::Expr::Tuple(tuple) => self.check_tuple_expr(tuple),
928 ast::Expr::Assign(a) => self.check_assign_expr(a),
929 ast::Expr::Ternary(t) => self.check_ternary_expr(t),
930 ast::Expr::New(n) => self.check_new_expr(n),
931 ast::Expr::Paren(e) => self.check_expr(e),
932 }
933 }
934
935 fn check_literal(&mut self, lit: &ast::Literal) -> Type {
936 match lit {
937 ast::Literal::Bool(_, _) => Type::Primitive(PrimitiveType::Bool),
938 ast::Literal::Int(_, _) => Type::Primitive(PrimitiveType::Uint256), ast::Literal::HexInt(_, _) => Type::Primitive(PrimitiveType::Uint256),
940 ast::Literal::String(_, _) => Type::Primitive(PrimitiveType::String),
941 ast::Literal::HexString(_, _) => Type::Primitive(PrimitiveType::Bytes),
942 ast::Literal::Address(_, _) => Type::Primitive(PrimitiveType::Address),
943 }
944 }
945
946 fn check_ident_expr(&mut self, ident: &ast::Ident) -> Type {
947 let name = &ident.name;
948
949 match name.as_str() {
951 "msg" | "block" | "tx" | "token" | "clock" | "rent" => {
952 return Type::Named(NamedType::new(name.clone()));
955 }
956 _ => {}
957 }
958
959 if let Some(var) = self.symbols.lookup_variable(name) {
961 return var.ty.clone();
962 }
963
964 if let Some(func) = self.symbols.lookup_function(name) {
966 return Type::Function(func.ty.clone());
967 }
968
969 self.error(TypeError::undefined_variable(
970 name,
971 self.span(ident.span),
972 &self.source,
973 ));
974 Type::Error
975 }
976
977 fn check_binary_expr(&mut self, bin: &ast::BinaryExpr) -> Type {
978 let left_ty = self.check_expr(&bin.left);
979 let right_ty = self.check_expr(&bin.right);
980
981 if matches!(left_ty, Type::Error) || matches!(right_ty, Type::Error) {
983 return Type::Error;
984 }
985
986 match bin.op {
987 ast::BinaryOp::Add
989 | ast::BinaryOp::Sub
990 | ast::BinaryOp::Mul
991 | ast::BinaryOp::Div
992 | ast::BinaryOp::Rem
993 | ast::BinaryOp::Exp => {
994 if left_ty.is_integer() && self.types_compatible(&left_ty, &right_ty) {
995 left_ty
996 } else {
997 self.error(TypeError::invalid_binary_op(
998 &format!("{:?}", bin.op),
999 &left_ty,
1000 &right_ty,
1001 self.span(bin.span),
1002 &self.source,
1003 ));
1004 Type::Error
1005 }
1006 }
1007 ast::BinaryOp::Eq
1009 | ast::BinaryOp::Ne
1010 | ast::BinaryOp::Lt
1011 | ast::BinaryOp::Le
1012 | ast::BinaryOp::Gt
1013 | ast::BinaryOp::Ge => {
1014 if self.types_compatible(&left_ty, &right_ty) {
1015 Type::Primitive(PrimitiveType::Bool)
1016 } else {
1017 self.error(TypeError::invalid_binary_op(
1018 &format!("{:?}", bin.op),
1019 &left_ty,
1020 &right_ty,
1021 self.span(bin.span),
1022 &self.source,
1023 ));
1024 Type::Error
1025 }
1026 }
1027 ast::BinaryOp::And | ast::BinaryOp::Or => {
1029 if left_ty.is_bool() && right_ty.is_bool() {
1030 Type::Primitive(PrimitiveType::Bool)
1031 } else {
1032 self.error(TypeError::invalid_binary_op(
1033 &format!("{:?}", bin.op),
1034 &left_ty,
1035 &right_ty,
1036 self.span(bin.span),
1037 &self.source,
1038 ));
1039 Type::Error
1040 }
1041 }
1042 ast::BinaryOp::BitAnd
1044 | ast::BinaryOp::BitOr
1045 | ast::BinaryOp::BitXor
1046 | ast::BinaryOp::Shl
1047 | ast::BinaryOp::Shr => {
1048 if left_ty.is_integer() && right_ty.is_integer() {
1049 left_ty
1050 } else {
1051 self.error(TypeError::invalid_binary_op(
1052 &format!("{:?}", bin.op),
1053 &left_ty,
1054 &right_ty,
1055 self.span(bin.span),
1056 &self.source,
1057 ));
1058 Type::Error
1059 }
1060 }
1061 }
1062 }
1063
1064 fn check_unary_expr(&mut self, un: &ast::UnaryExpr) -> Type {
1065 let expr_ty = self.check_expr(&un.expr);
1066
1067 if matches!(expr_ty, Type::Error) {
1068 return Type::Error;
1069 }
1070
1071 match un.op {
1072 ast::UnaryOp::Neg => {
1073 if expr_ty.is_integer() {
1074 expr_ty
1075 } else {
1076 self.error(TypeError::InvalidUnaryOp {
1077 op: "-".to_string(),
1078 ty: expr_ty.to_string(),
1079 span: miette::SourceSpan::new(
1080 un.span.start.into(),
1081 un.span.end - un.span.start,
1082 ),
1083 src: self.source.clone(),
1084 });
1085 Type::Error
1086 }
1087 }
1088 ast::UnaryOp::Not => {
1089 if expr_ty.is_bool() {
1090 Type::Primitive(PrimitiveType::Bool)
1091 } else {
1092 self.error(TypeError::InvalidUnaryOp {
1093 op: "!".to_string(),
1094 ty: expr_ty.to_string(),
1095 span: miette::SourceSpan::new(
1096 un.span.start.into(),
1097 un.span.end - un.span.start,
1098 ),
1099 src: self.source.clone(),
1100 });
1101 Type::Error
1102 }
1103 }
1104 ast::UnaryOp::BitNot => {
1105 if expr_ty.is_integer() {
1106 expr_ty
1107 } else {
1108 self.error(TypeError::InvalidUnaryOp {
1109 op: "~".to_string(),
1110 ty: expr_ty.to_string(),
1111 span: miette::SourceSpan::new(
1112 un.span.start.into(),
1113 un.span.end - un.span.start,
1114 ),
1115 src: self.source.clone(),
1116 });
1117 Type::Error
1118 }
1119 }
1120 ast::UnaryOp::PreInc
1121 | ast::UnaryOp::PreDec
1122 | ast::UnaryOp::PostInc
1123 | ast::UnaryOp::PostDec => {
1124 if expr_ty.is_integer() {
1125 expr_ty
1126 } else {
1127 self.error(TypeError::InvalidUnaryOp {
1128 op: "++/--".to_string(),
1129 ty: expr_ty.to_string(),
1130 span: miette::SourceSpan::new(
1131 un.span.start.into(),
1132 un.span.end - un.span.start,
1133 ),
1134 src: self.source.clone(),
1135 });
1136 Type::Error
1137 }
1138 }
1139 }
1140 }
1141
1142 fn check_call_expr(&mut self, call: &ast::CallExpr) -> Type {
1143 if let ast::Expr::Ident(ident) = &call.callee {
1145 let name = ident.name.as_str();
1146
1147 match name {
1149 "assert" => {
1150 if call.args.is_empty() || call.args.len() > 2 {
1152 self.error(TypeError::wrong_arg_count(
1153 1,
1154 call.args.len(),
1155 self.span(call.span),
1156 &self.source,
1157 ));
1158 return Type::Unit;
1159 }
1160 let cond_ty = self.check_expr(&call.args[0].value);
1161 if !cond_ty.is_bool() && !matches!(cond_ty, Type::Error) {
1162 self.error(TypeError::type_mismatch(
1163 &Type::Primitive(PrimitiveType::Bool),
1164 &cond_ty,
1165 self.span(call.args[0].value.span()),
1166 &self.source,
1167 ));
1168 }
1169 if call.args.len() == 2 {
1171 let msg_ty = self.check_expr(&call.args[1].value);
1172 if !matches!(msg_ty, Type::Primitive(PrimitiveType::String))
1173 && !matches!(msg_ty, Type::Error)
1174 {
1175 self.error(TypeError::type_mismatch(
1176 &Type::Primitive(PrimitiveType::String),
1177 &msg_ty,
1178 self.span(call.args[1].value.span()),
1179 &self.source,
1180 ));
1181 }
1182 }
1183 return Type::Unit;
1184 }
1185 "assertEq" => {
1186 if call.args.len() < 2 || call.args.len() > 3 {
1188 self.error(TypeError::wrong_arg_count(
1189 2,
1190 call.args.len(),
1191 self.span(call.span),
1192 &self.source,
1193 ));
1194 return Type::Unit;
1195 }
1196 let left_ty = self.check_expr(&call.args[0].value);
1197 let right_ty = self.check_expr(&call.args[1].value);
1198 if !self.types_compatible(&left_ty, &right_ty) {
1199 self.error(TypeError::type_mismatch(
1200 &left_ty,
1201 &right_ty,
1202 self.span(call.args[1].value.span()),
1203 &self.source,
1204 ));
1205 }
1206 if call.args.len() == 3 {
1208 let msg_ty = self.check_expr(&call.args[2].value);
1209 if !matches!(msg_ty, Type::Primitive(PrimitiveType::String))
1210 && !matches!(msg_ty, Type::Error)
1211 {
1212 self.error(TypeError::type_mismatch(
1213 &Type::Primitive(PrimitiveType::String),
1214 &msg_ty,
1215 self.span(call.args[2].value.span()),
1216 &self.source,
1217 ));
1218 }
1219 }
1220 return Type::Unit;
1221 }
1222 "assertNe" => {
1223 if call.args.len() < 2 || call.args.len() > 3 {
1225 self.error(TypeError::wrong_arg_count(
1226 2,
1227 call.args.len(),
1228 self.span(call.span),
1229 &self.source,
1230 ));
1231 return Type::Unit;
1232 }
1233 let left_ty = self.check_expr(&call.args[0].value);
1234 let right_ty = self.check_expr(&call.args[1].value);
1235 if !self.types_compatible(&left_ty, &right_ty) {
1236 self.error(TypeError::type_mismatch(
1237 &left_ty,
1238 &right_ty,
1239 self.span(call.args[1].value.span()),
1240 &self.source,
1241 ));
1242 }
1243 if call.args.len() == 3 {
1245 let msg_ty = self.check_expr(&call.args[2].value);
1246 if !matches!(msg_ty, Type::Primitive(PrimitiveType::String))
1247 && !matches!(msg_ty, Type::Error)
1248 {
1249 self.error(TypeError::type_mismatch(
1250 &Type::Primitive(PrimitiveType::String),
1251 &msg_ty,
1252 self.span(call.args[2].value.span()),
1253 &self.source,
1254 ));
1255 }
1256 }
1257 return Type::Unit;
1258 }
1259 "assertGt" | "assertGe" | "assertLt" | "assertLe" => {
1260 if call.args.len() < 2 || call.args.len() > 3 {
1262 self.error(TypeError::wrong_arg_count(
1263 2,
1264 call.args.len(),
1265 self.span(call.span),
1266 &self.source,
1267 ));
1268 return Type::Unit;
1269 }
1270 let left_ty = self.check_expr(&call.args[0].value);
1271 let right_ty = self.check_expr(&call.args[1].value);
1272 if !left_ty.is_integer() && !matches!(left_ty, Type::Error) {
1274 self.error(TypeError::type_mismatch(
1275 &Type::Primitive(PrimitiveType::Uint256),
1276 &left_ty,
1277 self.span(call.args[0].value.span()),
1278 &self.source,
1279 ));
1280 }
1281 if !right_ty.is_integer() && !matches!(right_ty, Type::Error) {
1282 self.error(TypeError::type_mismatch(
1283 &Type::Primitive(PrimitiveType::Uint256),
1284 &right_ty,
1285 self.span(call.args[1].value.span()),
1286 &self.source,
1287 ));
1288 }
1289 return Type::Unit;
1290 }
1291 "transfer" => {
1292 if call.args.len() != 2 {
1294 self.error(TypeError::wrong_arg_count(
1295 2,
1296 call.args.len(),
1297 self.span(call.span),
1298 &self.source,
1299 ));
1300 return Type::Unit;
1301 }
1302 let to_ty = self.check_expr(&call.args[0].value);
1303 let amount_ty = self.check_expr(&call.args[1].value);
1304 if !matches!(to_ty, Type::Primitive(PrimitiveType::Address))
1306 && !matches!(to_ty, Type::Error)
1307 {
1308 self.error(TypeError::type_mismatch(
1309 &Type::Primitive(PrimitiveType::Address),
1310 &to_ty,
1311 self.span(call.args[0].value.span()),
1312 &self.source,
1313 ));
1314 }
1315 if !amount_ty.is_integer() && !matches!(amount_ty, Type::Error) {
1317 self.error(TypeError::type_mismatch(
1318 &Type::Primitive(PrimitiveType::Uint64),
1319 &amount_ty,
1320 self.span(call.args[1].value.span()),
1321 &self.source,
1322 ));
1323 }
1324 return Type::Unit;
1325 }
1326 _ => {}
1327 }
1328
1329 if name == "address" {
1331 if call.args.len() != 1 {
1332 self.error(TypeError::wrong_arg_count(
1333 1,
1334 call.args.len(),
1335 self.span(call.span),
1336 &self.source,
1337 ));
1338 return Type::Error;
1339 }
1340 self.check_expr(&call.args[0].value);
1342 return Type::Primitive(PrimitiveType::Address);
1343 }
1344
1345 if let Some(prim) = PrimitiveType::parse(name) {
1348 if prim.is_integer() || prim.is_fixed_bytes() {
1349 if call.args.len() != 1 {
1350 self.error(TypeError::wrong_arg_count(
1351 1,
1352 call.args.len(),
1353 self.span(call.span),
1354 &self.source,
1355 ));
1356 return Type::Error;
1357 }
1358 self.check_expr(&call.args[0].value);
1359 return Type::Primitive(prim);
1360 }
1361 }
1362
1363 if let Some(TypeDef::Interface(_)) = self.symbols.lookup_type(&SmolStr::from(name)) {
1366 if call.args.len() != 1 {
1367 self.error(TypeError::wrong_arg_count(
1368 1,
1369 call.args.len(),
1370 self.span(call.span),
1371 &self.source,
1372 ));
1373 return Type::Error;
1374 }
1375 let arg_ty = self.check_expr(&call.args[0].value);
1377 if !matches!(arg_ty, Type::Primitive(PrimitiveType::Address))
1378 && !matches!(arg_ty, Type::Error)
1379 {
1380 self.error(TypeError::type_mismatch(
1381 &Type::Primitive(PrimitiveType::Address),
1382 &arg_ty,
1383 self.span(call.args[0].span),
1384 &self.source,
1385 ));
1386 }
1387 return Type::Named(NamedType {
1389 name: SmolStr::from(name),
1390 type_args: Vec::new(),
1391 });
1392 }
1393 }
1394
1395 let callee_ty = self.check_expr(&call.callee);
1396
1397 if let Type::Function(fn_ty) = callee_ty {
1398 if call.args.len() != fn_ty.params.len() {
1400 self.error(TypeError::wrong_arg_count(
1401 fn_ty.params.len(),
1402 call.args.len(),
1403 self.span(call.span),
1404 &self.source,
1405 ));
1406 }
1407
1408 for (arg, expected_ty) in call.args.iter().zip(fn_ty.params.iter()) {
1410 let arg_ty = self.check_expr(&arg.value);
1411 if !self.types_compatible(expected_ty, &arg_ty) {
1412 self.error(TypeError::type_mismatch(
1413 expected_ty,
1414 &arg_ty,
1415 self.span(arg.span),
1416 &self.source,
1417 ));
1418 }
1419 }
1420
1421 *fn_ty.return_type
1422 } else if matches!(callee_ty, Type::Error) {
1423 Type::Error
1424 } else {
1425 self.error(TypeError::not_callable(
1426 &callee_ty,
1427 self.span(call.span),
1428 &self.source,
1429 ));
1430 Type::Error
1431 }
1432 }
1433
1434 fn check_method_call(&mut self, mc: &ast::MethodCallExpr) -> Type {
1435 let receiver_ty = self.check_expr(&mc.receiver);
1436 let method_name = mc.method.name.clone();
1437
1438 let arg_types: Vec<Type> = mc
1440 .args
1441 .iter()
1442 .map(|arg| self.check_expr(&arg.value))
1443 .collect();
1444
1445 if let Type::Named(named) = &receiver_ty {
1447 let type_name = named.name.as_str();
1448
1449 match type_name {
1451 "msg" => match method_name.as_str() {
1452 "sender" => return Type::Primitive(PrimitiveType::Address),
1453 "value" => return Type::Primitive(PrimitiveType::Uint256),
1454 "data" => return Type::Primitive(PrimitiveType::Bytes),
1455 _ => {}
1456 },
1457 "block" => match method_name.as_str() {
1458 "timestamp" => return Type::Primitive(PrimitiveType::Uint256),
1459 "number" => return Type::Primitive(PrimitiveType::Uint256),
1460 _ => {}
1461 },
1462 "tx" => match method_name.as_str() {
1463 "origin" => return Type::Primitive(PrimitiveType::Address),
1464 "gasprice" => return Type::Primitive(PrimitiveType::Uint256),
1465 _ => {}
1466 },
1467 "token" => {
1468 match method_name.as_str() {
1471 "transfer" | "mint" | "burn" => {
1472 if arg_types.len() != 4 {
1474 self.error(TypeError::wrong_arg_count(
1475 4,
1476 arg_types.len(),
1477 self.span(mc.span),
1478 &self.source,
1479 ));
1480 return Type::Error;
1481 }
1482 return Type::Unit;
1483 }
1484 "getATA" => {
1485 if arg_types.len() != 2 {
1487 self.error(TypeError::wrong_arg_count(
1488 2,
1489 arg_types.len(),
1490 self.span(mc.span),
1491 &self.source,
1492 ));
1493 return Type::Error;
1494 }
1495 return Type::Primitive(PrimitiveType::Address);
1496 }
1497 _ => {}
1498 }
1499 }
1500 "rent" => {
1502 match method_name.as_str() {
1503 "minimumBalance" => {
1504 if arg_types.len() != 1 {
1506 self.error(TypeError::wrong_arg_count(
1507 1,
1508 arg_types.len(),
1509 self.span(mc.span),
1510 &self.source,
1511 ));
1512 return Type::Error;
1513 }
1514 return Type::Primitive(PrimitiveType::Uint64);
1515 }
1516 "isExempt" => {
1517 if arg_types.len() != 2 {
1519 self.error(TypeError::wrong_arg_count(
1520 2,
1521 arg_types.len(),
1522 self.span(mc.span),
1523 &self.source,
1524 ));
1525 return Type::Error;
1526 }
1527 return Type::Primitive(PrimitiveType::Bool);
1528 }
1529 _ => {}
1530 }
1531 }
1532 "clock" => {
1534 if method_name.as_str() == "get" {
1535 return Type::Named(NamedType::new(SmolStr::from("clock")));
1537 }
1538 }
1539 _ => {}
1540 }
1541
1542 let method_info = self
1544 .symbols
1545 .lookup_type(&SmolStr::from(type_name))
1546 .and_then(|type_def| match type_def {
1547 TypeDef::Contract(c) => c.methods.get(&method_name).cloned(),
1548 TypeDef::Interface(i) => i.methods.get(&method_name).cloned(),
1549 _ => None,
1550 });
1551
1552 if let Some(fn_ty) = method_info {
1553 if arg_types.len() != fn_ty.params.len() {
1555 self.error(TypeError::wrong_arg_count(
1556 fn_ty.params.len(),
1557 arg_types.len(),
1558 self.span(mc.span),
1559 &self.source,
1560 ));
1561 return Type::Error;
1562 }
1563
1564 for (i, (arg_ty, param_ty)) in arg_types.iter().zip(fn_ty.params.iter()).enumerate()
1566 {
1567 if !self.types_compatible(param_ty, arg_ty) {
1568 self.error(TypeError::type_mismatch(
1569 param_ty,
1570 arg_ty,
1571 self.span(mc.args[i].value.span()),
1572 &self.source,
1573 ));
1574 }
1575 }
1576
1577 return (*fn_ty.return_type).clone();
1578 }
1579
1580 self.error(TypeError::undefined_method(
1582 &method_name,
1583 &receiver_ty,
1584 self.span(mc.span),
1585 &self.source,
1586 ));
1587 return Type::Error;
1588 }
1589
1590 if matches!(receiver_ty, Type::Error) {
1591 return Type::Error;
1592 }
1593
1594 if let Type::DynamicArray(elem_ty) = &receiver_ty {
1596 match method_name.as_str() {
1597 "push" => {
1598 if arg_types.len() != 1 {
1600 self.error(TypeError::wrong_arg_count(
1601 1,
1602 arg_types.len(),
1603 self.span(mc.span),
1604 &self.source,
1605 ));
1606 return Type::Error;
1607 }
1608 if !self.types_compatible(elem_ty, &arg_types[0]) {
1610 self.error(TypeError::type_mismatch(
1611 elem_ty,
1612 &arg_types[0],
1613 self.span(mc.span),
1614 &self.source,
1615 ));
1616 }
1617 return Type::Unit;
1618 }
1619 "pop" => {
1620 if !arg_types.is_empty() {
1621 self.error(TypeError::wrong_arg_count(
1622 0,
1623 arg_types.len(),
1624 self.span(mc.span),
1625 &self.source,
1626 ));
1627 return Type::Error;
1628 }
1629 return (**elem_ty).clone();
1630 }
1631 _ => {}
1632 }
1633 }
1634
1635 self.error(TypeError::undefined_method(
1637 &method_name,
1638 &receiver_ty,
1639 self.span(mc.span),
1640 &self.source,
1641 ));
1642 Type::Error
1643 }
1644
1645 fn check_field_access(&mut self, fa: &ast::FieldAccessExpr) -> Type {
1646 let expr_ty = self.check_expr(&fa.expr);
1647
1648 if matches!(expr_ty, Type::Error) {
1649 return Type::Error;
1650 }
1651
1652 if let Type::Named(named) = &expr_ty {
1654 let type_name = named.name.as_str();
1655 let field_name = fa.field.name.as_str();
1656
1657 match type_name {
1658 "msg" => match field_name {
1659 "sender" => return Type::Primitive(PrimitiveType::Address),
1660 "value" => return Type::Primitive(PrimitiveType::Uint256),
1661 "data" => return Type::Primitive(PrimitiveType::Bytes),
1662 _ => {}
1663 },
1664 "block" => match field_name {
1665 "timestamp" => return Type::Primitive(PrimitiveType::Uint256),
1666 "number" => return Type::Primitive(PrimitiveType::Uint256),
1667 _ => {}
1668 },
1669 "tx" => match field_name {
1670 "origin" => return Type::Primitive(PrimitiveType::Address),
1671 "gasprice" => return Type::Primitive(PrimitiveType::Uint256),
1672 _ => {}
1673 },
1674 "clock" => match field_name {
1676 "timestamp" => return Type::Primitive(PrimitiveType::Int64),
1677 "slot" => return Type::Primitive(PrimitiveType::Uint64),
1678 "epoch" => return Type::Primitive(PrimitiveType::Uint64),
1679 "unix_timestamp" => return Type::Primitive(PrimitiveType::Int64),
1680 _ => {}
1681 },
1682 _ => {
1683 let field_ty = self
1685 .symbols
1686 .lookup_type(&SmolStr::from(type_name))
1687 .and_then(|type_def| match type_def {
1688 TypeDef::Struct(s) => s.fields.get(field_name).cloned(),
1689 TypeDef::Contract(c) => c.state_fields.get(field_name).cloned(),
1690 _ => None,
1691 });
1692
1693 if let Some(ty) = field_ty {
1694 return ty;
1695 }
1696 }
1697 }
1698
1699 self.error(TypeError::undefined_field(
1701 &fa.field.name,
1702 &expr_ty,
1703 self.span(fa.span),
1704 &self.source,
1705 ));
1706 return Type::Error;
1707 }
1708
1709 let field_name = fa.field.name.as_str();
1711 if field_name == "length" {
1712 match &expr_ty {
1713 Type::Array(_, _) | Type::DynamicArray(_) => {
1714 return Type::Primitive(PrimitiveType::Uint256);
1715 }
1716 _ => {}
1717 }
1718 }
1719
1720 self.error(TypeError::undefined_field(
1722 &fa.field.name,
1723 &expr_ty,
1724 self.span(fa.span),
1725 &self.source,
1726 ));
1727 Type::Error
1728 }
1729
1730 fn check_index_expr(&mut self, idx: &ast::IndexExpr) -> Type {
1731 let expr_ty = self.check_expr(&idx.expr);
1732 let index_ty = self.check_expr(&idx.index);
1733
1734 match expr_ty {
1736 Type::Array(elem, _) | Type::DynamicArray(elem) => {
1737 if !index_ty.is_integer() && !matches!(index_ty, Type::Error) {
1739 self.error(TypeError::type_mismatch(
1740 &Type::Primitive(PrimitiveType::Uint256),
1741 &index_ty,
1742 self.span(idx.index.span()),
1743 &self.source,
1744 ));
1745 }
1746 *elem
1747 }
1748 Type::Mapping(key, value) => {
1749 if !self.types_compatible(&key, &index_ty) && !matches!(index_ty, Type::Error) {
1751 self.error(TypeError::type_mismatch(
1752 &key,
1753 &index_ty,
1754 self.span(idx.index.span()),
1755 &self.source,
1756 ));
1757 }
1758 *value
1759 }
1760 Type::Error => Type::Error,
1761 _ => {
1762 self.error(TypeError::NotIndexable {
1763 ty: expr_ty.to_string(),
1764 span: miette::SourceSpan::new(
1765 idx.span.start.into(),
1766 idx.span.end - idx.span.start,
1767 ),
1768 src: self.source.clone(),
1769 });
1770 Type::Error
1771 }
1772 }
1773 }
1774
1775 fn check_if_expr(&mut self, if_expr: &ast::IfExpr) -> Type {
1776 let cond_ty = self.check_expr(&if_expr.condition);
1777 if !cond_ty.is_bool() && !matches!(cond_ty, Type::Error) {
1778 self.error(TypeError::type_mismatch(
1779 &Type::Primitive(PrimitiveType::Bool),
1780 &cond_ty,
1781 self.span(if_expr.condition.span()),
1782 &self.source,
1783 ));
1784 }
1785
1786 self.check_block(&if_expr.then_block);
1787
1788 match &*if_expr.else_branch {
1789 ast::IfExprElse::Else(block) => self.check_block(block),
1790 ast::IfExprElse::ElseIf(elif) => {
1791 self.check_if_expr(elif);
1792 }
1793 }
1794
1795 Type::Unit
1797 }
1798
1799 fn check_array_expr(&mut self, arr: &ast::ArrayExpr) -> Type {
1800 if arr.elements.is_empty() {
1801 return Type::DynamicArray(Box::new(self.fresh_type_var()));
1802 }
1803
1804 let first_ty = self.check_expr(&arr.elements[0]);
1805
1806 for elem in arr.elements.iter().skip(1) {
1807 let elem_ty = self.check_expr(elem);
1808 if !self.types_compatible(&first_ty, &elem_ty) {
1809 self.error(TypeError::type_mismatch(
1810 &first_ty,
1811 &elem_ty,
1812 self.span(elem.span()),
1813 &self.source,
1814 ));
1815 }
1816 }
1817
1818 Type::Array(Box::new(first_ty), arr.elements.len() as u64)
1819 }
1820
1821 fn check_tuple_expr(&mut self, tuple: &ast::TupleExpr) -> Type {
1822 let elem_types: Vec<Type> = tuple.elements.iter().map(|e| self.check_expr(e)).collect();
1823 Type::Tuple(elem_types)
1824 }
1825
1826 fn check_assign_expr(&mut self, a: &ast::AssignExpr) -> Type {
1827 let target_ty = self.check_expr(&a.target);
1828 let value_ty = self.check_expr(&a.value);
1829
1830 match a.op {
1831 ast::AssignOp::Assign => {
1832 if !self.types_compatible(&target_ty, &value_ty) {
1834 self.error(TypeError::type_mismatch(
1835 &target_ty,
1836 &value_ty,
1837 self.span(a.span),
1838 &self.source,
1839 ));
1840 }
1841 }
1842 ast::AssignOp::AddAssign
1843 | ast::AssignOp::SubAssign
1844 | ast::AssignOp::MulAssign
1845 | ast::AssignOp::DivAssign
1846 | ast::AssignOp::RemAssign => {
1847 if (!target_ty.is_integer() || !self.types_compatible(&target_ty, &value_ty))
1849 && !matches!(target_ty, Type::Error)
1850 && !matches!(value_ty, Type::Error)
1851 {
1852 self.error(TypeError::invalid_binary_op(
1853 &format!("{:?}", a.op),
1854 &target_ty,
1855 &value_ty,
1856 self.span(a.span),
1857 &self.source,
1858 ));
1859 }
1860 }
1861 ast::AssignOp::BitAndAssign
1862 | ast::AssignOp::BitOrAssign
1863 | ast::AssignOp::BitXorAssign => {
1864 if (!target_ty.is_integer() || !value_ty.is_integer())
1866 && !matches!(target_ty, Type::Error)
1867 && !matches!(value_ty, Type::Error)
1868 {
1869 self.error(TypeError::invalid_binary_op(
1870 &format!("{:?}", a.op),
1871 &target_ty,
1872 &value_ty,
1873 self.span(a.span),
1874 &self.source,
1875 ));
1876 }
1877 }
1878 }
1879
1880 Type::Unit
1881 }
1882
1883 fn check_ternary_expr(&mut self, t: &ast::TernaryExpr) -> Type {
1884 let cond_ty = self.check_expr(&t.condition);
1885 if !cond_ty.is_bool() && !matches!(cond_ty, Type::Error) {
1886 self.error(TypeError::type_mismatch(
1887 &Type::Primitive(PrimitiveType::Bool),
1888 &cond_ty,
1889 self.span(t.condition.span()),
1890 &self.source,
1891 ));
1892 }
1893
1894 let then_ty = self.check_expr(&t.then_expr);
1895 let else_ty = self.check_expr(&t.else_expr);
1896
1897 if !self.types_compatible(&then_ty, &else_ty) {
1898 self.error(TypeError::type_mismatch(
1899 &then_ty,
1900 &else_ty,
1901 self.span(t.span),
1902 &self.source,
1903 ));
1904 return Type::Error;
1905 }
1906
1907 then_ty
1908 }
1909
1910 fn check_new_expr(&mut self, n: &ast::NewExpr) -> Type {
1911 let type_name = n.ty.name();
1912
1913 if self.symbols.lookup_type(type_name).is_none() {
1915 self.error(TypeError::undefined_type(
1916 type_name,
1917 self.span(n.span),
1918 &self.source,
1919 ));
1920 return Type::Error;
1921 }
1922
1923 for arg in &n.args {
1927 self.check_expr(&arg.value);
1928 }
1929
1930 Type::Named(NamedType::new(type_name.clone()))
1931 }
1932
1933 fn types_compatible(&self, expected: &Type, found: &Type) -> bool {
1938 match (expected, found) {
1939 (Type::Error, _) | (_, Type::Error) => true,
1940 (Type::Var(_), _) | (_, Type::Var(_)) => true, (Type::Primitive(a), Type::Primitive(b)) if a.is_integer() && b.is_integer() => true,
1943 (Type::Primitive(PrimitiveType::Address), Type::Primitive(PrimitiveType::Signer)) => {
1945 true
1946 }
1947 (Type::Primitive(PrimitiveType::Signer), Type::Primitive(PrimitiveType::Address)) => {
1948 true
1949 }
1950 (Type::Primitive(a), Type::Primitive(b)) => a == b,
1951 (Type::Unit, Type::Unit) => true,
1952 (Type::Never, _) => true, (Type::Named(a), Type::Named(b)) => {
1954 a.name == b.name
1955 && a.type_args.len() == b.type_args.len()
1956 && a.type_args
1957 .iter()
1958 .zip(b.type_args.iter())
1959 .all(|(x, y)| self.types_compatible(x, y))
1960 }
1961 (Type::Array(a, n1), Type::Array(b, n2)) => n1 == n2 && self.types_compatible(a, b),
1962 (Type::DynamicArray(a), Type::DynamicArray(b)) => self.types_compatible(a, b),
1963 (Type::Tuple(a), Type::Tuple(b)) => {
1964 a.len() == b.len()
1965 && a.iter()
1966 .zip(b.iter())
1967 .all(|(x, y)| self.types_compatible(x, y))
1968 }
1969 (Type::Mapping(k1, v1), Type::Mapping(k2, v2)) => {
1970 self.types_compatible(k1, k2) && self.types_compatible(v1, v2)
1971 }
1972 (Type::Function(a), Type::Function(b)) => {
1973 a.params.len() == b.params.len()
1974 && a.params
1975 .iter()
1976 .zip(b.params.iter())
1977 .all(|(x, y)| self.types_compatible(x, y))
1978 && self.types_compatible(&a.return_type, &b.return_type)
1979 }
1980 _ => false,
1981 }
1982 }
1983}