traverse_codegen/
state_change_stub.rs

1//! Enhanced state change test generation using proper Solidity AST
2//!
3//! This module generates state change tests using the new AST-based approach with
4//! SolidityTestContractBuilder and proper type-safe Solidity code generation.
5
6use tracing::debug;
7
8use crate::teststubs::{
9    capitalize_first_letter, to_pascal_case, ContractInfo, FunctionInfo, SolidityTestContract,
10    SolidityTestContractBuilder,
11};
12use anyhow::Result;
13use traverse_graph::cg::{CallGraph, EdgeType, NodeType, ParameterInfo};
14use traverse_solidity::ast::*;
15use traverse_solidity::builder::*;
16
17pub fn generate_state_change_tests_from_cfg(
18    graph: &CallGraph,
19    ctx: &traverse_graph::cg::CallGraphGeneratorContext,
20    contract_name: &str,
21    function_name: &str,
22    function_params: &[ParameterInfo],
23) -> Result<Vec<SolidityTestContract>> {
24    let mut test_contracts = Vec::new();
25
26    let func_node = graph.nodes.iter().find(|n| {
27        n.contract_name.as_deref() == Some(contract_name)
28            && n.name == function_name
29            && n.node_type == NodeType::Function
30    });
31
32    if let Some(func_node) = func_node {
33        for edge in &graph.edges {
34            if edge.source_node_id == func_node.id && edge.edge_type == EdgeType::StorageWrite {
35                if let Some(var_node) = graph.nodes.get(edge.target_node_id) {
36                    if var_node.node_type == NodeType::StorageVariable {
37                        let test_contract = create_state_change_test_contract(
38                            contract_name,
39                            function_name,
40                            function_params,
41                            var_node,
42                            ctx,
43                            graph,
44                        )?;
45                        test_contracts.push(test_contract);
46                    }
47                }
48            }
49        }
50    }
51
52    Ok(test_contracts)
53}
54
55fn create_state_change_test_contract(
56    contract_name: &str,
57    function_name: &str,
58    function_params: &[ParameterInfo],
59    var_node: &traverse_graph::cg::Node,
60    ctx: &traverse_graph::cg::CallGraphGeneratorContext,
61    graph: &CallGraph,
62) -> Result<SolidityTestContract> {
63    let var_name = &var_node.name;
64    let test_contract_name = format!(
65        "{}{}StateChangeTest",
66        to_pascal_case(contract_name),
67        to_pascal_case(function_name)
68    );
69
70    let test_function_name = format!("test_{}_changes_{}", function_name, var_name);
71
72    // Get variable type information
73    let var_contract_scope = var_node
74        .contract_name
75        .as_ref()
76        .ok_or_else(|| anyhow::anyhow!("Storage variable {} missing contract scope", var_name))?;
77
78    let actual_var_type = ctx
79        .state_var_types
80        .get(&(var_contract_scope.clone(), var_name.clone()))
81        .cloned()
82        .unwrap_or_else(|| {
83            debug!(
84                "Warning: Type for state variable {}.{} not found in ctx.state_var_types. Defaulting to uint256.",
85                var_contract_scope, var_name
86            );
87            "uint256".to_string()
88        });
89
90    // Determine getter function name
91    let getter_name = if var_node.visibility == traverse_graph::cg::Visibility::Public {
92        var_name.clone()
93    } else {
94        format!("get{}", capitalize_first_letter(var_name))
95    };
96
97    let test_contract = SolidityTestContractBuilder::new(test_contract_name.clone())
98        .add_import(format!("../src/{}.sol", contract_name))
99        .build_with_contract(|contract| {
100            // Add state variable for contract instance
101            contract.state_variable(
102                user_type(contract_name),
103                "contractInstance",
104                Some(traverse_solidity::ast::Visibility::Private),
105                None,
106            );
107
108            contract.function("setUp", |func| {
109                func.visibility(traverse_solidity::ast::Visibility::Public)
110                    .body(|body| {
111                        // Deploy contract instance - use constructor parameters from context
112                        let constructor_args = if let Some(constructor_node) =
113                            graph.nodes.iter().find(|n| {
114                                n.contract_name.as_deref() == Some(contract_name)
115                                    && n.node_type == NodeType::Constructor
116                            }) {
117                            generate_constructor_args_as_expressions(&constructor_node.parameters)
118                        } else {
119                            // Fallback to empty args if no constructor found
120                            vec![]
121                        };
122
123                        body.expression(Expression::Assignment(AssignmentExpression {
124                            left: Box::new(identifier("contractInstance")),
125                            operator: AssignmentOperator::Assign,
126                            right: Box::new(Expression::FunctionCall(FunctionCallExpression {
127                                function: Box::new(identifier(format!("new {}", contract_name))),
128                                arguments: constructor_args,
129                            })),
130                        }));
131                    });
132            });
133
134            contract.function(&test_function_name, |func| {
135                func.visibility(traverse_solidity::ast::Visibility::Public)
136                    .body(|body| {
137                        // Get initial value
138                        let type_name = get_type_name_for_variable(&actual_var_type);
139                        let initial_value_expr = Expression::FunctionCall(FunctionCallExpression {
140                            function: Box::new(Expression::MemberAccess(MemberAccessExpression {
141                                object: Box::new(identifier("contractInstance")),
142                                member: getter_name.clone(),
143                            })),
144                            arguments: vec![],
145                        });
146
147                        // Use variable_declaration_with_location for types that require data location
148                        if traverse_solidity::builder::requires_data_location(&type_name) {
149                            let data_location =
150                                traverse_solidity::builder::get_default_data_location(&type_name);
151                            body.variable_declaration_with_location(
152                                type_name,
153                                "initialValue",
154                                data_location,
155                                Some(initial_value_expr),
156                            );
157                        } else {
158                            body.variable_declaration(
159                                type_name,
160                                "initialValue",
161                                Some(initial_value_expr),
162                            );
163                        }
164
165                        // Call the function that should change state
166                        // Use different values than constructor to ensure state actually changes
167                        match generate_different_args_for_function(function_params) {
168                            Ok(function_args) => {
169                                body.expression(Expression::FunctionCall(FunctionCallExpression {
170                                    function: Box::new(Expression::MemberAccess(
171                                        MemberAccessExpression {
172                                            object: Box::new(identifier("contractInstance")),
173                                            member: function_name.to_string(),
174                                        },
175                                    )),
176                                    arguments: function_args,
177                                }));
178                            }
179                            Err(e) => {
180                                debug!("Failed to generate function arguments: {}", e);
181                                // Add a comment about the error
182                                body.expression(Expression::FunctionCall(FunctionCallExpression {
183                                    function: Box::new(identifier(
184                                        "// Failed to generate arguments",
185                                    )),
186                                    arguments: vec![],
187                                }));
188                            }
189                        }
190
191                        // Assert that the value has changed
192                        let assert_condition = if actual_var_type == "string" {
193                            // For strings, compare keccak256 hashes
194                            Expression::Binary(BinaryExpression {
195                                left: Box::new(Expression::FunctionCall(FunctionCallExpression {
196                                    function: Box::new(identifier("keccak256")),
197                                    arguments: vec![Expression::FunctionCall(
198                                        FunctionCallExpression {
199                                            function: Box::new(identifier("abi.encodePacked")),
200                                            arguments: vec![Expression::FunctionCall(
201                                                FunctionCallExpression {
202                                                    function: Box::new(Expression::MemberAccess(
203                                                        MemberAccessExpression {
204                                                            object: Box::new(identifier(
205                                                                "contractInstance",
206                                                            )),
207                                                            member: getter_name.clone(),
208                                                        },
209                                                    )),
210                                                    arguments: vec![],
211                                                },
212                                            )],
213                                        },
214                                    )],
215                                })),
216                                operator: BinaryOperator::NotEqual,
217                                right: Box::new(Expression::FunctionCall(FunctionCallExpression {
218                                    function: Box::new(identifier("keccak256")),
219                                    arguments: vec![Expression::FunctionCall(
220                                        FunctionCallExpression {
221                                            function: Box::new(identifier("abi.encodePacked")),
222                                            arguments: vec![identifier("initialValue")],
223                                        },
224                                    )],
225                                })),
226                            })
227                        } else {
228                            // For other types, direct comparison
229                            Expression::Binary(BinaryExpression {
230                                left: Box::new(Expression::FunctionCall(FunctionCallExpression {
231                                    function: Box::new(Expression::MemberAccess(
232                                        MemberAccessExpression {
233                                            object: Box::new(identifier("contractInstance")),
234                                            member: getter_name.clone(),
235                                        },
236                                    )),
237                                    arguments: vec![],
238                                })),
239                                operator: BinaryOperator::NotEqual,
240                                right: Box::new(identifier("initialValue")),
241                            })
242                        };
243
244                        body.expression(Expression::FunctionCall(FunctionCallExpression {
245                            function: Box::new(identifier("assertTrue")),
246                            arguments: vec![assert_condition],
247                        }));
248                    });
249            });
250        });
251
252    Ok(test_contract)
253}
254
255fn generate_constructor_args_as_expressions(params: &[ParameterInfo]) -> Vec<Expression> {
256    params
257        .iter()
258        .map(|param| match param.param_type.as_str() {
259            "string" => string_literal("test"),
260            "address" => Expression::FunctionCall(FunctionCallExpression {
261                function: Box::new(identifier("address")),
262                arguments: vec![number("1")],
263            }),
264            "bool" => boolean(true),
265            t if t.starts_with("uint") => number("42"),
266            t if t.starts_with("int") => number("42"),
267            _ => number("0"), // Default fallback
268        })
269        .collect()
270}
271
272fn generate_different_args_for_function(params: &[ParameterInfo]) -> Result<Vec<Expression>> {
273    let args = params
274        .iter()
275        .map(|param| match param.param_type.as_str() {
276            "string" => string_literal("updated value"),
277            "address" => Expression::FunctionCall(FunctionCallExpression {
278                function: Box::new(identifier("address")),
279                arguments: vec![number("2")],
280            }),
281            "bool" => boolean(false), // Different from constructor's true
282            t if t.starts_with("uint") => number("100"), // Different from constructor's 42
283            t if t.starts_with("int") => number("100"), // Different from constructor's 42
284            _ => number("1"),         // Different from constructor's 0
285        })
286        .collect();
287
288    Ok(args)
289}
290
291fn get_type_name_for_variable(type_str: &str) -> TypeName {
292    let is_value_type = type_str == "bool"
293        || type_str == "address"
294        || type_str.starts_with("uint")
295        || type_str.starts_with("int")
296        || (type_str.starts_with("bytes")
297            && type_str.len() > 5
298            && type_str.chars().skip(5).all(|c| c.is_ascii_digit()));
299
300    if is_value_type {
301        match type_str {
302            "bool" => bool(),
303            "address" => address(),
304            t if t.starts_with("uint") => {
305                if let Some(size_str) = t.strip_prefix("uint") {
306                    if size_str.is_empty() {
307                        uint256()
308                    } else if let Ok(size) = size_str.parse::<u16>() {
309                        uint(size)
310                    } else {
311                        uint256()
312                    }
313                } else {
314                    uint256()
315                }
316            }
317            t if t.starts_with("int") => {
318                if let Some(size_str) = t.strip_prefix("int") {
319                    if size_str.is_empty() {
320                        int256()
321                    } else if let Ok(size) = size_str.parse::<u16>() {
322                        int(size)
323                    } else {
324                        int256()
325                    }
326                } else {
327                    int256()
328                }
329            }
330            _ => user_type(type_str),
331        }
332    } else {
333        // For reference types, we need to specify memory location
334        user_type(type_str)
335    }
336}
337
338pub fn create_comprehensive_state_change_test_contract(
339    contract_info: &ContractInfo,
340    function_info: &FunctionInfo,
341    graph: &CallGraph,
342    ctx: &traverse_graph::cg::CallGraphGeneratorContext,
343) -> Result<SolidityTestContract> {
344    let test_contracts = generate_state_change_tests_from_cfg(
345        graph,
346        ctx,
347        &contract_info.name,
348        &function_info.name,
349        &function_info.parameters,
350    )?;
351
352    if test_contracts.is_empty() {
353        return Err(anyhow::anyhow!(
354            "No state change tests could be generated for function {}",
355            function_info.name
356        ));
357    }
358
359    // For now, return the first test contract
360    // In a more sophisticated implementation, we might combine multiple test contracts
361    Ok(test_contracts.into_iter().next().unwrap())
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_get_type_name_for_variable() {
370        let uint_type = get_type_name_for_variable("uint256");
371        assert!(matches!(
372            uint_type,
373            TypeName::Elementary(ElementaryTypeName::UnsignedInteger(Some(256)))
374        ));
375
376        let bool_type = get_type_name_for_variable("bool");
377        assert!(matches!(
378            bool_type,
379            TypeName::Elementary(ElementaryTypeName::Bool)
380        ));
381
382        let string_type = get_type_name_for_variable("string");
383        assert!(matches!(string_type, TypeName::UserDefined(_)));
384    }
385}