traverse_codegen/
deployer_stub.rs

1use crate::teststubs::ContractInfo;
2use anyhow::Result;
3use traverse_solidity::ast::*;
4use traverse_solidity::builder::*;
5
6pub fn generate_foundry_deployer_test_contract(contract_info: &ContractInfo) -> Result<SourceUnit> {
7    let mut builder = SolidityBuilder::new();
8    let test_name = format!("{}DeployerTest", contract_info.name);
9    let contract_import_path = format!("../src/{}.sol", contract_info.name);
10
11    builder
12        .pragma("solidity", "^0.8.0")
13        .import("forge-std/Test.sol")
14        .import(contract_import_path)
15        .contract(test_name, |contract| {
16            contract.inherits("Test").function("testDeploy", |func| {
17                func.visibility(Visibility::Public).body(|body| {
18                    // Generate deployment test logic
19                    if contract_info.has_constructor {
20                        // Add constructor args generation
21                        for (i, param) in contract_info.constructor_params.iter().enumerate() {
22                            let var_name = format!("arg{}", i);
23                            let type_name = param_type_to_ast(&param.param_type);
24                            let default_value = generate_default_value(&param.param_type);
25
26                            // Use variable_declaration_with_location for types that require data location
27                            if traverse_solidity::builder::requires_data_location(&type_name) {
28                                let data_location =
29                                    traverse_solidity::builder::get_default_data_location(&type_name);
30                                body.variable_declaration_with_location(
31                                    type_name,
32                                    var_name,
33                                    data_location,
34                                    Some(default_value),
35                                );
36                            } else {
37                                body.variable_declaration(type_name, var_name, Some(default_value));
38                            }
39                        }
40                    }
41
42                    // Add deployment statement
43                    let deploy_expr = if contract_info.has_constructor {
44                        let args: Vec<Expression> = (0..contract_info.constructor_params.len())
45                            .map(|i| identifier(format!("arg{}", i)))
46                            .collect();
47
48                        Expression::FunctionCall(FunctionCallExpression {
49                            function: Box::new(identifier(format!("new {}", contract_info.name))),
50                            arguments: args,
51                        })
52                    } else {
53                        Expression::FunctionCall(FunctionCallExpression {
54                            function: Box::new(identifier(format!("new {}", contract_info.name))),
55                            arguments: vec![],
56                        })
57                    };
58
59                    body.variable_declaration(
60                        user_type(&contract_info.name),
61                        "instance",
62                        Some(deploy_expr),
63                    );
64
65                    // Add assertion
66                    body.expression(Expression::FunctionCall(FunctionCallExpression {
67                        function: Box::new(identifier("assertTrue")),
68                        arguments: vec![Expression::Binary(BinaryExpression {
69                            left: Box::new(identifier("address(instance)")),
70                            operator: BinaryOperator::NotEqual,
71                            right: Box::new(identifier("address(0)")),
72                        })],
73                    }));
74                });
75            });
76        });
77
78    Ok(builder.build())
79}
80
81fn param_type_to_ast(param_type: &str) -> TypeName {
82    match param_type {
83        "address" => address(),
84        "bool" => bool(),
85        "string" => string(),
86        t if t.starts_with("uint") => {
87            if let Some(size_str) = t.strip_prefix("uint") {
88                if size_str.is_empty() {
89                    uint256()
90                } else if let Ok(size) = size_str.parse::<u16>() {
91                    uint(size)
92                } else {
93                    uint256()
94                }
95            } else {
96                uint256()
97            }
98        }
99        t if t.starts_with("int") => {
100            if let Some(size_str) = t.strip_prefix("int") {
101                if size_str.is_empty() {
102                    int256()
103                } else if let Ok(size) = size_str.parse::<u16>() {
104                    int(size)
105                } else {
106                    int256()
107                }
108            } else {
109                int256()
110            }
111        }
112        t if t.starts_with("bytes") => {
113            if t == "bytes" {
114                bytes()
115            } else if let Some(size_str) = t.strip_prefix("bytes") {
116                if let Ok(size) = size_str.parse::<u8>() {
117                    bytes_fixed(size)
118                } else {
119                    bytes()
120                }
121            } else {
122                bytes()
123            }
124        }
125        _ => user_type(param_type), // Custom type
126    }
127}
128
129fn generate_default_value(param_type: &str) -> Expression {
130    match param_type {
131        "address" => identifier("address(0x1)"),
132        "bool" => boolean(true),
133        "string" => string_literal("test"),
134        t if t.starts_with("uint") => number("1"),
135        t if t.starts_with("int") => number("1"),
136        t if t.starts_with("bytes") => string_literal("0x01"),
137        _ => number("0"),
138    }
139}