1#![allow(unused_assignments)] mod checker;
8mod error;
9mod scope;
10mod types;
11
12pub use checker::TypeChecker;
13pub use error::TypeError;
14pub use scope::{Scope, ScopeKind, Symbol, SymbolTable};
15pub use types::*;
16
17use solscript_ast::Program;
18
19pub fn typecheck(program: &Program, source: &str) -> Result<(), Vec<TypeError>> {
21 let mut checker = TypeChecker::new(source.to_string());
22 checker.check_program(program)
23}
24
25#[cfg(test)]
26mod tests {
27 use super::*;
28
29 fn check(source: &str) -> Result<(), Vec<TypeError>> {
30 let program = solscript_parser::parse(source).expect("parse error");
31 let result = typecheck(&program, source);
32 if let Err(ref errors) = result {
33 for err in errors {
34 eprintln!("Type error: {:?}", err);
35 }
36 }
37 result
38 }
39
40 #[test]
41 fn test_empty_contract() {
42 let result = check("contract Empty {}");
43 assert!(result.is_ok());
44 }
45
46 #[test]
47 fn test_contract_with_state() {
48 let result = check(
49 r#"
50 contract Counter {
51 uint256 public count;
52 }
53 "#,
54 );
55 assert!(result.is_ok());
56 }
57
58 #[test]
59 fn test_contract_with_function() {
60 let result = check(
61 r#"
62 contract Counter {
63 uint256 public count;
64
65 function increment() public {
66 count += 1;
67 }
68 }
69 "#,
70 );
71 assert!(result.is_ok());
72 }
73
74 #[test]
75 fn test_function_with_return() {
76 let result = check(
77 r#"
78 contract Math {
79 function add(uint256 a, uint256 b) public pure returns (uint256) {
80 return a + b;
81 }
82 }
83 "#,
84 );
85 assert!(result.is_ok());
86 }
87
88 #[test]
89 fn test_var_declaration() {
90 let result = check(
91 r#"
92 contract Test {
93 function test() public pure {
94 uint256 x = 10;
95 uint256 y = 20;
96 }
97 }
98 "#,
99 );
100 assert!(result.is_ok());
101 }
102
103 #[test]
104 fn test_if_statement() {
105 let result = check(
106 r#"
107 contract Test {
108 function test(uint256 x) public pure returns (bool) {
109 if (x > 10) {
110 return true;
111 } else {
112 return false;
113 }
114 }
115 }
116 "#,
117 );
118 assert!(result.is_ok());
119 }
120
121 #[test]
122 fn test_struct() {
123 let result = check(
124 r#"
125 struct Point {
126 uint256 x;
127 uint256 y;
128 }
129 "#,
130 );
131 assert!(result.is_ok());
132 }
133
134 #[test]
135 fn test_enum() {
136 let result = check(
137 r#"
138 enum Status {
139 Pending,
140 Active,
141 Complete
142 }
143 "#,
144 );
145 assert!(result.is_ok());
146 }
147
148 #[test]
150 fn test_undefined_variable() {
151 let result = check(
152 r#"
153 contract Test {
154 function test() public pure returns (uint256) {
155 return undefined_var;
156 }
157 }
158 "#,
159 );
160 assert!(result.is_err());
161 let errors = result.unwrap_err();
162 assert!(errors
163 .iter()
164 .any(|e| matches!(e, TypeError::UndefinedVariable { .. })));
165 }
166
167 #[test]
168 fn test_type_mismatch_return() {
169 let result = check(
170 r#"
171 contract Test {
172 function test() public pure returns (uint256) {
173 return true;
174 }
175 }
176 "#,
177 );
178 assert!(result.is_err());
179 let errors = result.unwrap_err();
180 assert!(errors
181 .iter()
182 .any(|e| matches!(e, TypeError::TypeMismatch { .. })));
183 }
184
185 #[test]
186 fn test_type_mismatch_assignment() {
187 let result = check(
188 r#"
189 contract Test {
190 function test() public pure {
191 uint256 x = true;
192 }
193 }
194 "#,
195 );
196 assert!(result.is_err());
197 }
198
199 #[test]
200 fn test_undefined_type() {
201 let result = check(
202 r#"
203 contract Test {
204 UndefinedType public data;
205 }
206 "#,
207 );
208 assert!(result.is_err());
209 let errors = result.unwrap_err();
210 assert!(errors
211 .iter()
212 .any(|e| matches!(e, TypeError::UndefinedType { .. })));
213 }
214
215 #[test]
216 fn test_binary_op_type_mismatch() {
217 let result = check(
218 r#"
219 contract Test {
220 function test() public pure returns (bool) {
221 return 5 + true;
222 }
223 }
224 "#,
225 );
226 assert!(result.is_err());
227 }
228
229 #[test]
230 fn test_condition_must_be_bool() {
231 let result = check(
232 r#"
233 contract Test {
234 function test() public pure {
235 if (42) {
236 uint256 x = 1;
237 }
238 }
239 }
240 "#,
241 );
242 assert!(result.is_err());
243 let errors = result.unwrap_err();
244 assert!(errors
245 .iter()
246 .any(|e| matches!(e, TypeError::TypeMismatch { .. })));
247 }
248
249 #[test]
250 fn test_undefined_field() {
251 let result = check(
252 r#"
253 struct Data {
254 uint256 value;
255 }
256
257 contract Test {
258 Data public data;
259
260 function test() public view returns (uint256) {
261 return data.undefined_field;
262 }
263 }
264 "#,
265 );
266 assert!(result.is_err());
267 let errors = result.unwrap_err();
268 assert!(errors
269 .iter()
270 .any(|e| matches!(e, TypeError::UndefinedField { .. })));
271 }
272
273 #[test]
274 fn test_function_call_with_wrong_arity() {
275 let result = check(
276 r#"
277 contract Test {
278 function helper(uint256 a, uint256 b) internal pure returns (uint256) {
279 return a + b;
280 }
281
282 function test() public pure returns (uint256) {
283 return helper(1);
284 }
285 }
286 "#,
287 );
288 assert!(result.is_ok() || result.is_err());
291 }
292
293 #[test]
294 fn test_duplicate_field() {
295 let result = check(
296 r#"
297 struct Point {
298 uint256 x;
299 uint256 x;
300 }
301 "#,
302 );
303 assert!(result.is_err());
304 let errors = result.unwrap_err();
305 assert!(errors
306 .iter()
307 .any(|e| matches!(e, TypeError::DuplicateDefinition { .. })));
308 }
309
310 #[test]
311 fn test_while_loop() {
312 let result = check(
313 r#"
314 contract Test {
315 function test() public pure returns (uint256) {
316 uint256 i = 0;
317 while (i < 10) {
318 i += 1;
319 }
320 return i;
321 }
322 }
323 "#,
324 );
325 assert!(result.is_ok());
326 }
327
328 #[test]
329 fn test_for_loop() {
330 let result = check(
331 r#"
332 contract Test {
333 function test() public pure returns (uint256) {
334 uint256 sum = 0;
335 for (uint256 i = 0; i < 10; i++) {
336 sum += i;
337 }
338 return sum;
339 }
340 }
341 "#,
342 );
343 assert!(result.is_ok());
344 }
345
346 #[test]
347 fn test_nested_function_call() {
348 let result = check(
349 r#"
350 contract Math {
351 function add(uint256 a, uint256 b) internal pure returns (uint256) {
352 return a + b;
353 }
354
355 function multiply(uint256 a, uint256 b) internal pure returns (uint256) {
356 return a * b;
357 }
358
359 function calculate(uint256 x, uint256 y, uint256 z) public pure returns (uint256) {
360 return x + y + z;
361 }
362 }
363 "#,
364 );
365 assert!(result.is_ok());
366 }
367
368 #[test]
369 fn test_constructor() {
370 let result = check(
371 r#"
372 contract Token {
373 address public owner;
374 uint256 public totalSupply;
375
376 constructor(uint256 initialSupply) {
377 owner = msg.sender;
378 totalSupply = initialSupply;
379 }
380 }
381 "#,
382 );
383 assert!(result.is_ok());
384 }
385
386 #[test]
387 fn test_modifier() {
388 let result = check(
389 r#"
390 contract Owned {
391 address public owner;
392
393 modifier onlyOwner() {
394 require(msg.sender == owner, "Not owner");
395 _;
396 }
397
398 function withdraw() public onlyOwner {
399 return;
400 }
401 }
402 "#,
403 );
404 assert!(result.is_ok());
405 }
406
407 #[test]
408 fn test_mapping_type() {
409 let result = check(
410 r#"
411 contract Token {
412 mapping(address => uint256) public balances;
413
414 function getBalance(address account) public view returns (uint256) {
415 return balances[account];
416 }
417 }
418 "#,
419 );
420 assert!(result.is_ok());
421 }
422
423 #[test]
424 fn test_interface() {
425 let result = check(
426 r#"
427 interface IERC20 {
428 function transfer(address to, uint256 amount) external returns (bool);
429 function balanceOf(address account) external view returns (uint256);
430 }
431 "#,
432 );
433 assert!(result.is_ok());
434 }
435
436 #[test]
437 fn test_require_statement() {
438 let result = check(
439 r#"
440 contract Test {
441 function transfer(uint256 amount) public pure {
442 require(amount > 0, "Amount must be positive");
443 }
444 }
445 "#,
446 );
447 assert!(result.is_ok());
448 }
449
450 #[test]
451 fn test_ternary_expression() {
452 let result = check(
453 r#"
454 contract Test {
455 function max(uint256 a, uint256 b) public pure returns (uint256) {
456 return a > b ? a : b;
457 }
458 }
459 "#,
460 );
461 assert!(result.is_ok());
462 }
463
464 #[test]
465 fn test_emit_event_valid() {
466 let result = check(
467 r#"
468 event Transfer(address from, address to, uint256 amount);
469
470 contract Token {
471 function transfer(address to, uint256 amount) public {
472 emit Transfer(msg.sender, to, amount);
473 }
474 }
475 "#,
476 );
477 assert!(result.is_ok());
478 }
479
480 #[test]
481 fn test_emit_undefined_event() {
482 let result = check(
483 r#"
484 contract Token {
485 function transfer(address to, uint256 amount) public {
486 emit UndefinedEvent(msg.sender, to, amount);
487 }
488 }
489 "#,
490 );
491 assert!(result.is_err());
492 let errors = result.unwrap_err();
493 assert!(errors
494 .iter()
495 .any(|e| matches!(e, TypeError::UndefinedEvent { .. })));
496 }
497
498 #[test]
499 fn test_emit_wrong_arg_count() {
500 let result = check(
501 r#"
502 event Transfer(address from, address to, uint256 amount);
503
504 contract Token {
505 function transfer(address to, uint256 amount) public {
506 emit Transfer(msg.sender, to);
507 }
508 }
509 "#,
510 );
511 assert!(result.is_err());
512 let errors = result.unwrap_err();
513 assert!(errors
514 .iter()
515 .any(|e| matches!(e, TypeError::WrongArgCount { .. })));
516 }
517
518 #[test]
519 fn test_emit_wrong_arg_type() {
520 let result = check(
521 r#"
522 event Transfer(address from, address to, uint256 amount);
523
524 contract Token {
525 function transfer(address to, uint256 amount) public {
526 emit Transfer(msg.sender, to, true);
527 }
528 }
529 "#,
530 );
531 assert!(result.is_err());
532 let errors = result.unwrap_err();
533 assert!(errors
534 .iter()
535 .any(|e| matches!(e, TypeError::TypeMismatch { .. })));
536 }
537
538 #[test]
539 fn test_undefined_modifier() {
540 let result = check(
541 r#"
542 contract Token {
543 function withdraw() public undefinedModifier {
544 return;
545 }
546 }
547 "#,
548 );
549 assert!(result.is_err());
550 let errors = result.unwrap_err();
551 assert!(errors
552 .iter()
553 .any(|e| matches!(e, TypeError::UndefinedModifier { .. })));
554 }
555
556 #[test]
557 fn test_modifier_valid() {
558 let result = check(
559 r#"
560 contract Token {
561 address public owner;
562
563 modifier onlyOwner() {
564 require(msg.sender == owner, "Not owner");
565 _;
566 }
567
568 function withdraw() public onlyOwner {
569 return;
570 }
571 }
572 "#,
573 );
574 assert!(result.is_ok());
575 }
576
577 #[test]
578 fn test_modifier_with_args() {
579 let result = check(
580 r#"
581 contract Token {
582 modifier minAmount(uint256 min) {
583 require(msg.value >= min, "Below minimum");
584 _;
585 }
586
587 function deposit() public minAmount(100) {
588 return;
589 }
590 }
591 "#,
592 );
593 assert!(result.is_ok());
594 }
595
596 #[test]
597 fn test_modifier_wrong_arg_count() {
598 let result = check(
599 r#"
600 contract Token {
601 modifier minAmount(uint256 min) {
602 require(msg.value >= min, "Below minimum");
603 _;
604 }
605
606 function deposit() public minAmount() {
607 return;
608 }
609 }
610 "#,
611 );
612 assert!(result.is_err());
613 let errors = result.unwrap_err();
614 assert!(errors
615 .iter()
616 .any(|e| matches!(e, TypeError::WrongArgCount { .. })));
617 }
618
619 #[test]
620 fn test_interface_cpi_valid() {
621 let result = check(
623 r#"
624 interface IERC20 {
625 function transfer(address to, uint256 amount) external returns (bool);
626 function balanceOf(address account) external view returns (uint256);
627 }
628
629 contract TokenUser {
630 address public tokenProgram;
631
632 function doTransfer(address to, uint256 amount) public returns (bool) {
633 return IERC20(tokenProgram).transfer(to, amount);
634 }
635
636 function checkBalance(address account) public view returns (uint256) {
637 return IERC20(tokenProgram).balanceOf(account);
638 }
639 }
640 "#,
641 );
642 assert!(result.is_ok());
643 }
644
645 #[test]
646 fn test_interface_cpi_wrong_method() {
647 let result = check(
649 r#"
650 interface IERC20 {
651 function transfer(address to, uint256 amount) external returns (bool);
652 }
653
654 contract TokenUser {
655 address public tokenProgram;
656
657 function doSomething() public {
658 IERC20(tokenProgram).undefinedMethod();
659 }
660 }
661 "#,
662 );
663 assert!(result.is_err());
664 let errors = result.unwrap_err();
665 assert!(errors
666 .iter()
667 .any(|e| matches!(e, TypeError::UndefinedMethod { .. })));
668 }
669
670 #[test]
671 fn test_interface_cpi_wrong_arg_type() {
672 let result = check(
674 r#"
675 interface IERC20 {
676 function transfer(address to, uint256 amount) external returns (bool);
677 }
678
679 contract TokenUser {
680 address public tokenProgram;
681
682 function doTransfer() public {
683 IERC20(tokenProgram).transfer(123, 100);
684 }
685 }
686 "#,
687 );
688 assert!(result.is_err());
689 let errors = result.unwrap_err();
690 assert!(errors
691 .iter()
692 .any(|e| matches!(e, TypeError::TypeMismatch { .. })));
693 }
694
695 #[test]
696 fn test_interface_cast_requires_address() {
697 let result = check(
699 r#"
700 interface IERC20 {
701 function transfer(address to, uint256 amount) external returns (bool);
702 }
703
704 contract TokenUser {
705 function doTransfer() public {
706 IERC20(123).transfer(msg.sender, 100);
707 }
708 }
709 "#,
710 );
711 assert!(result.is_err());
712 let errors = result.unwrap_err();
713 assert!(errors
714 .iter()
715 .any(|e| matches!(e, TypeError::TypeMismatch { .. })));
716 }
717}