solscript_parser/
lib.rs

1//! SolScript Parser
2//!
3//! This crate parses SolScript source code into an AST using pest.
4
5#![allow(unused_assignments)] // Suppress false positives from derive macros
6
7mod error;
8mod parser;
9
10pub use error::*;
11pub use parser::*;
12
13use pest_derive::Parser;
14
15#[derive(Parser)]
16#[grammar = "src/solscript.pest"]
17pub struct SolScriptParser;
18
19/// Parse SolScript source code into an AST
20pub fn parse(source: &str) -> Result<solscript_ast::Program, ParseError> {
21    parser::parse_program(source)
22}
23
24#[cfg(test)]
25mod tests {
26    use super::*;
27
28    #[test]
29    fn test_parse_empty_contract() {
30        let source = r#"
31            contract Empty {
32            }
33        "#;
34        let result = parse(source);
35        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
36    }
37
38    #[test]
39    fn test_parse_contract_with_state() {
40        let source = r#"
41            contract Counter {
42                uint256 public count;
43            }
44        "#;
45        let result = parse(source);
46        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
47    }
48
49    #[test]
50    fn test_parse_contract_with_function() {
51        let source = r#"
52            contract Counter {
53                uint256 public count;
54
55                function increment() public {
56                    count += 1;
57                }
58            }
59        "#;
60        let result = parse(source);
61        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
62    }
63
64    #[test]
65    fn test_parse_struct() {
66        let source = r#"
67            struct Point {
68                uint256 x;
69                uint256 y;
70            }
71        "#;
72        let result = parse(source);
73        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
74    }
75
76    #[test]
77    fn test_parse_enum() {
78        let source = r#"
79            enum Status {
80                Pending,
81                Active,
82                Completed
83            }
84        "#;
85        let result = parse(source);
86        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
87    }
88
89    #[test]
90    fn test_parse_import() {
91        let source = r#"
92            import { Token, PDA } from "@solana/token";
93        "#;
94        let result = parse(source);
95        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
96    }
97
98    #[test]
99    fn test_parse_event_and_error() {
100        let source = r#"
101            event Transfer(address indexed from, address indexed to, uint256 amount);
102            error InsufficientBalance(uint256 available, uint256 required);
103        "#;
104        let result = parse(source);
105        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
106    }
107
108    #[test]
109    fn test_parse_mapping_types() {
110        let source = r#"
111            contract Storage {
112                mapping(address => uint256) public balances;
113                mapping(address => mapping(address => uint256)) public allowances;
114            }
115        "#;
116        let result = parse(source);
117        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
118    }
119
120    #[test]
121    fn test_parse_function_with_params() {
122        let source = r#"
123            contract Math {
124                function add(uint256 a, uint256 b) public pure returns (uint256) {
125                    return a + b;
126                }
127            }
128        "#;
129        let result = parse(source);
130        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
131    }
132
133    #[test]
134    fn test_parse_if_statement() {
135        let source = r#"
136            contract Logic {
137                function check(uint256 x) public pure returns (bool) {
138                    if (x > 10) {
139                        return true;
140                    } else {
141                        return false;
142                    }
143                }
144            }
145        "#;
146        let result = parse(source);
147        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
148    }
149
150    #[test]
151    fn test_parse_var_declaration() {
152        let source = r#"
153            contract Vars {
154                function compute() public pure returns (uint256) {
155                    uint256 x = 10;
156                    uint256 y = 20;
157                    uint256 result = x + y;
158                    return result;
159                }
160            }
161        "#;
162        let result = parse(source);
163        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
164    }
165
166    #[test]
167    fn test_parse_method_chaining() {
168        let source = r#"
169            contract Chain {
170                mapping(bytes32 => uint256) public data;
171
172                function process(bytes32 key) public view returns (uint256) {
173                    uint256 result = data[key];
174                    return result;
175                }
176            }
177        "#;
178        let result = parse(source);
179        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
180    }
181
182    #[test]
183    fn test_parse_interface_definition() {
184        let source = r#"
185            interface IERC20 {
186                function transfer(address to, uint256 amount) external returns (bool);
187                function balanceOf(address account) external view returns (uint256);
188            }
189        "#;
190        let result = parse(source);
191        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
192    }
193
194    #[test]
195    fn test_parse_constructor() {
196        let source = r#"
197            contract Token {
198                address public owner;
199                uint256 public totalSupply;
200
201                constructor(uint256 initialSupply) {
202                    owner = msg.sender;
203                    totalSupply = initialSupply;
204                }
205            }
206        "#;
207        let result = parse(source);
208        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
209    }
210
211    #[test]
212    fn test_parse_modifier() {
213        let source = r#"
214            contract Owned {
215                address public owner;
216
217                modifier onlyOwner() {
218                    require(msg.sender == owner, "Not owner");
219                    _;
220                }
221
222                function withdraw() public onlyOwner {
223                    return;
224                }
225            }
226        "#;
227        let result = parse(source);
228        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
229    }
230
231    #[test]
232    fn test_parse_for_loop() {
233        let source = r#"
234            contract Loops {
235                function sum(uint256 n) public pure returns (uint256) {
236                    uint256 total = 0;
237                    for (uint256 i = 0; i < n; i++) {
238                        total += i;
239                    }
240                    return total;
241                }
242            }
243        "#;
244        let result = parse(source);
245        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
246    }
247
248    #[test]
249    fn test_parse_while_loop() {
250        let source = r#"
251            contract Loops {
252                function countdown(uint256 n) public pure returns (uint256) {
253                    while (n > 0) {
254                        n -= 1;
255                    }
256                    return n;
257                }
258            }
259        "#;
260        let result = parse(source);
261        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
262    }
263
264    #[test]
265    fn test_parse_emit_statement() {
266        let source = r#"
267            event Transfer(address from, address to, uint256 amount);
268
269            contract Events {
270                function doTransfer(address to, uint256 amount) public {
271                    emit Transfer(msg.sender, to, amount);
272                }
273            }
274        "#;
275        let result = parse(source);
276        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
277    }
278
279    #[test]
280    fn test_parse_require_and_revert() {
281        let source = r#"
282            contract Validation {
283                function transfer(address to, uint256 amount) public {
284                    require(amount > 0, "Amount must be positive");
285                    require(to != msg.sender, "Cannot transfer to self");
286                    if (amount > 1000) {
287                        revert("Amount too large");
288                    }
289                }
290            }
291        "#;
292        let result = parse(source);
293        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
294    }
295
296    #[test]
297    fn test_parse_array_types() {
298        let source = r#"
299            contract Arrays {
300                uint256[] public dynamicArray;
301                uint256[10] public fixedArray;
302
303                function getElement(uint256 i) public view returns (uint256) {
304                    return dynamicArray[i];
305                }
306            }
307        "#;
308        let result = parse(source);
309        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
310    }
311
312    #[test]
313    fn test_parse_ternary_expression() {
314        let source = r#"
315            contract Ternary {
316                function max(uint256 a, uint256 b) public pure returns (uint256) {
317                    return a > b ? a : b;
318                }
319            }
320        "#;
321        let result = parse(source);
322        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
323    }
324
325    #[test]
326    fn test_parse_new_expression() {
327        let source = r#"
328            contract Factory {
329                function createToken() public returns (Token) {
330                    Token token = new Token(1000);
331                    return token;
332                }
333            }
334        "#;
335        let result = parse(source);
336        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
337    }
338
339    #[test]
340    fn test_parse_inheritance() {
341        let source = r#"
342            interface IERC20 {
343                function transfer(address to, uint256 amount) external returns (bool);
344            }
345
346            contract Token is IERC20 {
347                mapping(address => uint256) public balances;
348
349                function transfer(address to, uint256 amount) external returns (bool) {
350                    balances[msg.sender] -= amount;
351                    balances[to] += amount;
352                    return true;
353                }
354            }
355        "#;
356        let result = parse(source);
357        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
358    }
359
360    #[test]
361    fn test_parse_multiple_modifiers() {
362        let source = r#"
363            contract MultiMod {
364                function restricted() public view onlyOwner whenNotPaused returns (uint256) {
365                    return 42;
366                }
367            }
368        "#;
369        let result = parse(source);
370        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
371    }
372
373    #[test]
374    fn test_parse_address_literal() {
375        let source = r#"
376            contract Addresses {
377                address public ZERO = 0x0000000000000000000000000000000000000000;
378
379                function isZero(address addr) public pure returns (bool) {
380                    return addr == 0x0000000000000000000000000000000000000000;
381                }
382            }
383        "#;
384        let result = parse(source);
385        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
386    }
387
388    #[test]
389    fn test_parse_bitwise_operations() {
390        let source = r#"
391            contract Bitwise {
392                function operations(uint256 a, uint256 b) public pure returns (uint256) {
393                    uint256 and_result = a & b;
394                    uint256 or_result = a | b;
395                    uint256 xor_result = a ^ b;
396                    uint256 not_result = ~a;
397                    uint256 shift_left = a << 2;
398                    uint256 shift_right = a >> 2;
399                    return and_result + or_result + xor_result;
400                }
401            }
402        "#;
403        let result = parse(source);
404        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
405    }
406
407    #[test]
408    fn test_parse_visibility_modifiers() {
409        let source = r#"
410            contract Visibility {
411                uint256 public publicVar;
412                uint256 private privateVar;
413                uint256 internal internalVar;
414
415                function publicFunc() public {}
416                function privateFunc() private {}
417                function internalFunc() internal {}
418                function externalFunc() external {}
419            }
420        "#;
421        let result = parse(source);
422        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
423    }
424
425    #[test]
426    fn test_parse_events_errors_in_contract() {
427        let source = r#"
428            contract Token {
429                uint256 public totalSupply;
430
431                event Transfer(address from, address to, uint256 amount);
432                event Approval(address owner, address spender, uint256 amount);
433                error InsufficientBalance(uint256 available, uint256 required);
434                error Unauthorized(address caller);
435
436                function transfer(address to, uint256 amount) public {
437                    totalSupply -= amount;
438                }
439            }
440        "#;
441        let result = parse(source);
442        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
443
444        let program = result.unwrap();
445        let contract = match &program.items[0] {
446            solscript_ast::Item::Contract(c) => c,
447            _ => panic!("Expected contract"),
448        };
449
450        // Verify we have 2 events and 2 errors in the contract
451        let events: Vec<_> = contract
452            .members
453            .iter()
454            .filter(|m| matches!(m, solscript_ast::ContractMember::Event(_)))
455            .collect();
456        let errors: Vec<_> = contract
457            .members
458            .iter()
459            .filter(|m| matches!(m, solscript_ast::ContractMember::Error(_)))
460            .collect();
461        assert_eq!(events.len(), 2, "Expected 2 events");
462        assert_eq!(errors.len(), 2, "Expected 2 errors");
463    }
464
465    #[test]
466    fn test_parse_abstract_contract() {
467        let source = r#"
468            abstract contract Base {
469                uint256 public value;
470
471                // Abstract function (no body)
472                function getValue() public view returns (uint256);
473
474                // Implemented function
475                function setValue(uint256 newValue) public {
476                    value = newValue;
477                }
478            }
479
480            contract Derived is Base {
481                constructor() {
482                    value = 0;
483                }
484
485                function getValue() public view returns (uint256) {
486                    return value;
487                }
488            }
489        "#;
490        let result = parse(source);
491        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
492
493        let program = result.unwrap();
494
495        // First contract should be abstract
496        let base = match &program.items[0] {
497            solscript_ast::Item::Contract(c) => c,
498            _ => panic!("Expected contract"),
499        };
500        assert!(base.is_abstract, "Base should be abstract");
501        assert_eq!(base.name.name.as_str(), "Base");
502
503        // Check that abstract function has no body
504        let get_value_fn = base
505            .members
506            .iter()
507            .find_map(|m| {
508                if let solscript_ast::ContractMember::Function(f) = m {
509                    if f.name.name.as_str() == "getValue" {
510                        return Some(f);
511                    }
512                }
513                None
514            })
515            .expect("Should have getValue function");
516        assert!(
517            get_value_fn.body.is_none(),
518            "Abstract function should have no body"
519        );
520
521        // Check that implemented function has a body
522        let set_value_fn = base
523            .members
524            .iter()
525            .find_map(|m| {
526                if let solscript_ast::ContractMember::Function(f) = m {
527                    if f.name.name.as_str() == "setValue" {
528                        return Some(f);
529                    }
530                }
531                None
532            })
533            .expect("Should have setValue function");
534        assert!(
535            set_value_fn.body.is_some(),
536            "Implemented function should have body"
537        );
538
539        // Second contract should not be abstract
540        let derived = match &program.items[1] {
541            solscript_ast::Item::Contract(c) => c,
542            _ => panic!("Expected contract"),
543        };
544        assert!(!derived.is_abstract, "Derived should not be abstract");
545        assert_eq!(derived.name.name.as_str(), "Derived");
546    }
547
548    #[test]
549    fn test_parse_selfdestruct() {
550        let source = r#"
551            contract Closeable {
552                address public owner;
553
554                function destroy() public {
555                    selfdestruct(msg.sender);
556                }
557            }
558        "#;
559        let result = parse(source);
560        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
561
562        let program = result.unwrap();
563        let contract = match &program.items[0] {
564            solscript_ast::Item::Contract(c) => c,
565            _ => panic!("Expected contract"),
566        };
567
568        // Find the destroy function
569        let destroy_fn = contract
570            .members
571            .iter()
572            .find_map(|m| {
573                if let solscript_ast::ContractMember::Function(f) = m {
574                    if f.name.name.as_str() == "destroy" {
575                        return Some(f);
576                    }
577                }
578                None
579            })
580            .expect("Should have destroy function");
581
582        // Check that body contains a selfdestruct statement
583        let body = destroy_fn
584            .body
585            .as_ref()
586            .expect("destroy should have a body");
587        assert_eq!(body.stmts.len(), 1);
588        assert!(matches!(
589            body.stmts[0],
590            solscript_ast::Stmt::Selfdestruct(_)
591        ));
592    }
593
594    #[test]
595    fn test_parse_interface_cpi() {
596        let source = r#"
597            interface IERC20 {
598                function transfer(address to, uint256 amount) external returns (bool);
599                function balanceOf(address account) external view returns (uint256);
600            }
601
602            contract TokenUser {
603                address public tokenProgram;
604
605                function doTransfer(address to, uint256 amount) public {
606                    IERC20(tokenProgram).transfer(to, amount);
607                }
608
609                function checkBalance(address account) public view returns (uint256) {
610                    return IERC20(tokenProgram).balanceOf(account);
611                }
612            }
613        "#;
614        let result = parse(source);
615        assert!(result.is_ok(), "Failed to parse: {:?}", result.err());
616
617        let program = result.unwrap();
618
619        // First item should be the interface
620        let interface = match &program.items[0] {
621            solscript_ast::Item::Interface(i) => i,
622            _ => panic!("Expected interface"),
623        };
624        assert_eq!(interface.name.name.as_str(), "IERC20");
625        assert_eq!(interface.members.len(), 2);
626
627        // Second item should be the contract
628        let contract = match &program.items[1] {
629            solscript_ast::Item::Contract(c) => c,
630            _ => panic!("Expected contract"),
631        };
632        assert_eq!(contract.name.name.as_str(), "TokenUser");
633
634        // Find the doTransfer function and verify it has a method call expression
635        let do_transfer_fn = contract
636            .members
637            .iter()
638            .find_map(|m| {
639                if let solscript_ast::ContractMember::Function(f) = m {
640                    if f.name.name.as_str() == "doTransfer" {
641                        return Some(f);
642                    }
643                }
644                None
645            })
646            .expect("Should have doTransfer function");
647
648        let body = do_transfer_fn
649            .body
650            .as_ref()
651            .expect("doTransfer should have a body");
652        assert_eq!(body.stmts.len(), 1);
653
654        // The statement should be an expression statement with a method call
655        if let solscript_ast::Stmt::Expr(expr_stmt) = &body.stmts[0] {
656            if let solscript_ast::Expr::MethodCall(mc) = &expr_stmt.expr {
657                assert_eq!(mc.method.name.as_str(), "transfer");
658                // The receiver should be a Call expression: IERC20(tokenProgram)
659                if let solscript_ast::Expr::Call(call) = &mc.receiver {
660                    if let solscript_ast::Expr::Ident(ident) = &call.callee {
661                        assert_eq!(ident.name.as_str(), "IERC20");
662                    } else {
663                        panic!("Expected identifier callee");
664                    }
665                } else {
666                    panic!("Expected call expression as receiver");
667                }
668            } else {
669                panic!("Expected method call expression");
670            }
671        } else {
672            panic!("Expected expression statement");
673        }
674    }
675}