traverse_codegen/
deployer_stub.rs1use crate::teststubs::ContractInfo;
2use anyhow::Result;
3use traverse_solidity::ast::*;
4use traverse_solidity::builder::*;
5
6pub fn generate_foundry_deployer_test_contract(contract_info: &ContractInfo) -> Result<SourceUnit> {
7 let mut builder = SolidityBuilder::new();
8 let test_name = format!("{}DeployerTest", contract_info.name);
9 let contract_import_path = format!("../src/{}.sol", contract_info.name);
10
11 builder
12 .pragma("solidity", "^0.8.0")
13 .import("forge-std/Test.sol")
14 .import(contract_import_path)
15 .contract(test_name, |contract| {
16 contract.inherits("Test").function("testDeploy", |func| {
17 func.visibility(Visibility::Public).body(|body| {
18 if contract_info.has_constructor {
20 for (i, param) in contract_info.constructor_params.iter().enumerate() {
22 let var_name = format!("arg{}", i);
23 let type_name = param_type_to_ast(¶m.param_type);
24 let default_value = generate_default_value(¶m.param_type);
25
26 if traverse_solidity::builder::requires_data_location(&type_name) {
28 let data_location =
29 traverse_solidity::builder::get_default_data_location(&type_name);
30 body.variable_declaration_with_location(
31 type_name,
32 var_name,
33 data_location,
34 Some(default_value),
35 );
36 } else {
37 body.variable_declaration(type_name, var_name, Some(default_value));
38 }
39 }
40 }
41
42 let deploy_expr = if contract_info.has_constructor {
44 let args: Vec<Expression> = (0..contract_info.constructor_params.len())
45 .map(|i| identifier(format!("arg{}", i)))
46 .collect();
47
48 Expression::FunctionCall(FunctionCallExpression {
49 function: Box::new(identifier(format!("new {}", contract_info.name))),
50 arguments: args,
51 })
52 } else {
53 Expression::FunctionCall(FunctionCallExpression {
54 function: Box::new(identifier(format!("new {}", contract_info.name))),
55 arguments: vec![],
56 })
57 };
58
59 body.variable_declaration(
60 user_type(&contract_info.name),
61 "instance",
62 Some(deploy_expr),
63 );
64
65 body.expression(Expression::FunctionCall(FunctionCallExpression {
67 function: Box::new(identifier("assertTrue")),
68 arguments: vec![Expression::Binary(BinaryExpression {
69 left: Box::new(identifier("address(instance)")),
70 operator: BinaryOperator::NotEqual,
71 right: Box::new(identifier("address(0)")),
72 })],
73 }));
74 });
75 });
76 });
77
78 Ok(builder.build())
79}
80
81fn param_type_to_ast(param_type: &str) -> TypeName {
82 match param_type {
83 "address" => address(),
84 "bool" => bool(),
85 "string" => string(),
86 t if t.starts_with("uint") => {
87 if let Some(size_str) = t.strip_prefix("uint") {
88 if size_str.is_empty() {
89 uint256()
90 } else if let Ok(size) = size_str.parse::<u16>() {
91 uint(size)
92 } else {
93 uint256()
94 }
95 } else {
96 uint256()
97 }
98 }
99 t if t.starts_with("int") => {
100 if let Some(size_str) = t.strip_prefix("int") {
101 if size_str.is_empty() {
102 int256()
103 } else if let Ok(size) = size_str.parse::<u16>() {
104 int(size)
105 } else {
106 int256()
107 }
108 } else {
109 int256()
110 }
111 }
112 t if t.starts_with("bytes") => {
113 if t == "bytes" {
114 bytes()
115 } else if let Some(size_str) = t.strip_prefix("bytes") {
116 if let Ok(size) = size_str.parse::<u8>() {
117 bytes_fixed(size)
118 } else {
119 bytes()
120 }
121 } else {
122 bytes()
123 }
124 }
125 _ => user_type(param_type), }
127}
128
129fn generate_default_value(param_type: &str) -> Expression {
130 match param_type {
131 "address" => identifier("address(0x1)"),
132 "bool" => boolean(true),
133 "string" => string_literal("test"),
134 t if t.starts_with("uint") => number("1"),
135 t if t.starts_with("int") => number("1"),
136 t if t.starts_with("bytes") => string_literal("0x01"),
137 _ => number("0"),
138 }
139}