1use tracing::debug;
7
8use crate::teststubs::{
9 capitalize_first_letter, to_pascal_case, ContractInfo, FunctionInfo, SolidityTestContract,
10 SolidityTestContractBuilder,
11};
12use anyhow::Result;
13use traverse_graph::cg::{CallGraph, EdgeType, NodeType, ParameterInfo};
14use traverse_solidity::ast::*;
15use traverse_solidity::builder::*;
16
17pub fn generate_state_change_tests_from_cfg(
18 graph: &CallGraph,
19 ctx: &traverse_graph::cg::CallGraphGeneratorContext,
20 contract_name: &str,
21 function_name: &str,
22 function_params: &[ParameterInfo],
23) -> Result<Vec<SolidityTestContract>> {
24 let mut test_contracts = Vec::new();
25
26 let func_node = graph.nodes.iter().find(|n| {
27 n.contract_name.as_deref() == Some(contract_name)
28 && n.name == function_name
29 && n.node_type == NodeType::Function
30 });
31
32 if let Some(func_node) = func_node {
33 for edge in &graph.edges {
34 if edge.source_node_id == func_node.id && edge.edge_type == EdgeType::StorageWrite {
35 if let Some(var_node) = graph.nodes.get(edge.target_node_id) {
36 if var_node.node_type == NodeType::StorageVariable {
37 let test_contract = create_state_change_test_contract(
38 contract_name,
39 function_name,
40 function_params,
41 var_node,
42 ctx,
43 graph,
44 )?;
45 test_contracts.push(test_contract);
46 }
47 }
48 }
49 }
50 }
51
52 Ok(test_contracts)
53}
54
55fn create_state_change_test_contract(
56 contract_name: &str,
57 function_name: &str,
58 function_params: &[ParameterInfo],
59 var_node: &traverse_graph::cg::Node,
60 ctx: &traverse_graph::cg::CallGraphGeneratorContext,
61 graph: &CallGraph,
62) -> Result<SolidityTestContract> {
63 let var_name = &var_node.name;
64 let test_contract_name = format!(
65 "{}{}StateChangeTest",
66 to_pascal_case(contract_name),
67 to_pascal_case(function_name)
68 );
69
70 let test_function_name = format!("test_{}_changes_{}", function_name, var_name);
71
72 let var_contract_scope = var_node
74 .contract_name
75 .as_ref()
76 .ok_or_else(|| anyhow::anyhow!("Storage variable {} missing contract scope", var_name))?;
77
78 let actual_var_type = ctx
79 .state_var_types
80 .get(&(var_contract_scope.clone(), var_name.clone()))
81 .cloned()
82 .unwrap_or_else(|| {
83 debug!(
84 "Warning: Type for state variable {}.{} not found in ctx.state_var_types. Defaulting to uint256.",
85 var_contract_scope, var_name
86 );
87 "uint256".to_string()
88 });
89
90 let getter_name = if var_node.visibility == traverse_graph::cg::Visibility::Public {
92 var_name.clone()
93 } else {
94 format!("get{}", capitalize_first_letter(var_name))
95 };
96
97 let test_contract = SolidityTestContractBuilder::new(test_contract_name.clone())
98 .add_import(format!("../src/{}.sol", contract_name))
99 .build_with_contract(|contract| {
100 contract.state_variable(
102 user_type(contract_name),
103 "contractInstance",
104 Some(traverse_solidity::ast::Visibility::Private),
105 None,
106 );
107
108 contract.function("setUp", |func| {
109 func.visibility(traverse_solidity::ast::Visibility::Public)
110 .body(|body| {
111 let constructor_args = if let Some(constructor_node) =
113 graph.nodes.iter().find(|n| {
114 n.contract_name.as_deref() == Some(contract_name)
115 && n.node_type == NodeType::Constructor
116 }) {
117 generate_constructor_args_as_expressions(&constructor_node.parameters)
118 } else {
119 vec![]
121 };
122
123 body.expression(Expression::Assignment(AssignmentExpression {
124 left: Box::new(identifier("contractInstance")),
125 operator: AssignmentOperator::Assign,
126 right: Box::new(Expression::FunctionCall(FunctionCallExpression {
127 function: Box::new(identifier(format!("new {}", contract_name))),
128 arguments: constructor_args,
129 })),
130 }));
131 });
132 });
133
134 contract.function(&test_function_name, |func| {
135 func.visibility(traverse_solidity::ast::Visibility::Public)
136 .body(|body| {
137 let type_name = get_type_name_for_variable(&actual_var_type);
139 let initial_value_expr = Expression::FunctionCall(FunctionCallExpression {
140 function: Box::new(Expression::MemberAccess(MemberAccessExpression {
141 object: Box::new(identifier("contractInstance")),
142 member: getter_name.clone(),
143 })),
144 arguments: vec![],
145 });
146
147 if traverse_solidity::builder::requires_data_location(&type_name) {
149 let data_location =
150 traverse_solidity::builder::get_default_data_location(&type_name);
151 body.variable_declaration_with_location(
152 type_name,
153 "initialValue",
154 data_location,
155 Some(initial_value_expr),
156 );
157 } else {
158 body.variable_declaration(
159 type_name,
160 "initialValue",
161 Some(initial_value_expr),
162 );
163 }
164
165 match generate_different_args_for_function(function_params) {
168 Ok(function_args) => {
169 body.expression(Expression::FunctionCall(FunctionCallExpression {
170 function: Box::new(Expression::MemberAccess(
171 MemberAccessExpression {
172 object: Box::new(identifier("contractInstance")),
173 member: function_name.to_string(),
174 },
175 )),
176 arguments: function_args,
177 }));
178 }
179 Err(e) => {
180 debug!("Failed to generate function arguments: {}", e);
181 body.expression(Expression::FunctionCall(FunctionCallExpression {
183 function: Box::new(identifier(
184 "// Failed to generate arguments",
185 )),
186 arguments: vec![],
187 }));
188 }
189 }
190
191 let assert_condition = if actual_var_type == "string" {
193 Expression::Binary(BinaryExpression {
195 left: Box::new(Expression::FunctionCall(FunctionCallExpression {
196 function: Box::new(identifier("keccak256")),
197 arguments: vec![Expression::FunctionCall(
198 FunctionCallExpression {
199 function: Box::new(identifier("abi.encodePacked")),
200 arguments: vec![Expression::FunctionCall(
201 FunctionCallExpression {
202 function: Box::new(Expression::MemberAccess(
203 MemberAccessExpression {
204 object: Box::new(identifier(
205 "contractInstance",
206 )),
207 member: getter_name.clone(),
208 },
209 )),
210 arguments: vec![],
211 },
212 )],
213 },
214 )],
215 })),
216 operator: BinaryOperator::NotEqual,
217 right: Box::new(Expression::FunctionCall(FunctionCallExpression {
218 function: Box::new(identifier("keccak256")),
219 arguments: vec![Expression::FunctionCall(
220 FunctionCallExpression {
221 function: Box::new(identifier("abi.encodePacked")),
222 arguments: vec![identifier("initialValue")],
223 },
224 )],
225 })),
226 })
227 } else {
228 Expression::Binary(BinaryExpression {
230 left: Box::new(Expression::FunctionCall(FunctionCallExpression {
231 function: Box::new(Expression::MemberAccess(
232 MemberAccessExpression {
233 object: Box::new(identifier("contractInstance")),
234 member: getter_name.clone(),
235 },
236 )),
237 arguments: vec![],
238 })),
239 operator: BinaryOperator::NotEqual,
240 right: Box::new(identifier("initialValue")),
241 })
242 };
243
244 body.expression(Expression::FunctionCall(FunctionCallExpression {
245 function: Box::new(identifier("assertTrue")),
246 arguments: vec![assert_condition],
247 }));
248 });
249 });
250 });
251
252 Ok(test_contract)
253}
254
255fn generate_constructor_args_as_expressions(params: &[ParameterInfo]) -> Vec<Expression> {
256 params
257 .iter()
258 .map(|param| match param.param_type.as_str() {
259 "string" => string_literal("test"),
260 "address" => Expression::FunctionCall(FunctionCallExpression {
261 function: Box::new(identifier("address")),
262 arguments: vec![number("1")],
263 }),
264 "bool" => boolean(true),
265 t if t.starts_with("uint") => number("42"),
266 t if t.starts_with("int") => number("42"),
267 _ => number("0"), })
269 .collect()
270}
271
272fn generate_different_args_for_function(params: &[ParameterInfo]) -> Result<Vec<Expression>> {
273 let args = params
274 .iter()
275 .map(|param| match param.param_type.as_str() {
276 "string" => string_literal("updated value"),
277 "address" => Expression::FunctionCall(FunctionCallExpression {
278 function: Box::new(identifier("address")),
279 arguments: vec![number("2")],
280 }),
281 "bool" => boolean(false), t if t.starts_with("uint") => number("100"), t if t.starts_with("int") => number("100"), _ => number("1"), })
286 .collect();
287
288 Ok(args)
289}
290
291fn get_type_name_for_variable(type_str: &str) -> TypeName {
292 let is_value_type = type_str == "bool"
293 || type_str == "address"
294 || type_str.starts_with("uint")
295 || type_str.starts_with("int")
296 || (type_str.starts_with("bytes")
297 && type_str.len() > 5
298 && type_str.chars().skip(5).all(|c| c.is_ascii_digit()));
299
300 if is_value_type {
301 match type_str {
302 "bool" => bool(),
303 "address" => address(),
304 t if t.starts_with("uint") => {
305 if let Some(size_str) = t.strip_prefix("uint") {
306 if size_str.is_empty() {
307 uint256()
308 } else if let Ok(size) = size_str.parse::<u16>() {
309 uint(size)
310 } else {
311 uint256()
312 }
313 } else {
314 uint256()
315 }
316 }
317 t if t.starts_with("int") => {
318 if let Some(size_str) = t.strip_prefix("int") {
319 if size_str.is_empty() {
320 int256()
321 } else if let Ok(size) = size_str.parse::<u16>() {
322 int(size)
323 } else {
324 int256()
325 }
326 } else {
327 int256()
328 }
329 }
330 _ => user_type(type_str),
331 }
332 } else {
333 user_type(type_str)
335 }
336}
337
338pub fn create_comprehensive_state_change_test_contract(
339 contract_info: &ContractInfo,
340 function_info: &FunctionInfo,
341 graph: &CallGraph,
342 ctx: &traverse_graph::cg::CallGraphGeneratorContext,
343) -> Result<SolidityTestContract> {
344 let test_contracts = generate_state_change_tests_from_cfg(
345 graph,
346 ctx,
347 &contract_info.name,
348 &function_info.name,
349 &function_info.parameters,
350 )?;
351
352 if test_contracts.is_empty() {
353 return Err(anyhow::anyhow!(
354 "No state change tests could be generated for function {}",
355 function_info.name
356 ));
357 }
358
359 Ok(test_contracts.into_iter().next().unwrap())
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_get_type_name_for_variable() {
370 let uint_type = get_type_name_for_variable("uint256");
371 assert!(matches!(
372 uint_type,
373 TypeName::Elementary(ElementaryTypeName::UnsignedInteger(Some(256)))
374 ));
375
376 let bool_type = get_type_name_for_variable("bool");
377 assert!(matches!(
378 bool_type,
379 TypeName::Elementary(ElementaryTypeName::Bool)
380 ));
381
382 let string_type = get_type_name_for_variable("string");
383 assert!(matches!(string_type, TypeName::UserDefined(_)));
384 }
385}