traverse_codegen/
access_control_stub.rs

1//! Enhanced access control test generation using proper Solidity AST
2//!
3//! This module generates access control tests using the new AST-based approach with
4//! SolidityTestContractBuilder and proper type-safe Solidity code generation.
5
6use tracing::debug;
7
8use crate::teststubs::{
9    generate_valid_args_for_function, to_pascal_case, ContractInfo, FunctionInfo,
10    SolidityTestContract, SolidityTestContractBuilder,
11};
12use anyhow::Result;
13use traverse_graph::cg::{CallGraph, EdgeType, NodeType, ParameterInfo};
14use traverse_solidity::ast::*;
15use traverse_solidity::builder::*;
16
17pub fn generate_access_control_tests_from_cfg(
18    graph: &CallGraph,
19    contract_name: &str,
20    function_name: &str,
21    function_params: &[ParameterInfo],
22    constructor_params: &[ParameterInfo],
23) -> Result<Vec<SolidityTestContract>> {
24    let mut test_contracts = Vec::new();
25
26    let func_node = graph.nodes.iter().find(|n| {
27        n.contract_name.as_deref() == Some(contract_name)
28            && n.name == function_name
29            && n.node_type == NodeType::Function
30    });
31
32    if let Some(func_node) = func_node {
33        let require_edges: Vec<_> = graph
34            .edges
35            .iter()
36            .filter(|edge| {
37                edge.source_node_id == func_node.id && edge.edge_type == EdgeType::Require
38            })
39            .collect();
40
41        for (edge_idx, edge) in require_edges.iter().enumerate() {
42            let condition_node = &graph.nodes[edge.target_node_id];
43            let condition_text = condition_node
44                .condition_expression
45                .as_ref()
46                .filter(|s| !s.is_empty())
47                .cloned()
48                .unwrap_or_else(|| condition_node.name.clone());
49
50            if is_access_control_condition(&condition_text) {
51                let test_contract = create_access_control_test_contract(
52                    contract_name,
53                    function_name,
54                    function_params,
55                    constructor_params,
56                    &condition_text,
57                    &condition_node.revert_message.clone().unwrap_or_default(),
58                    edge_idx,
59                )?;
60
61                test_contracts.push(test_contract);
62            }
63        }
64    }
65
66    Ok(test_contracts)
67}
68
69fn create_access_control_test_contract(
70    contract_name: &str,
71    function_name: &str,
72    function_params: &[ParameterInfo],
73    constructor_params: &[ParameterInfo],
74    _condition: &str,
75    error_message: &str,
76    edge_idx: usize,
77) -> Result<SolidityTestContract> {
78    let test_contract_name = format!(
79        "{}{}AccessControlTest{}",
80        contract_name,
81        to_pascal_case(function_name),
82        edge_idx
83    );
84
85    let test_function_name = format!("test_{}_access_control_{}", function_name, edge_idx);
86
87    let test_contract = SolidityTestContractBuilder::new(test_contract_name.clone())
88        .add_import(format!("../src/{}.sol", contract_name))
89        .build_with_contract(|contract| {
90            // Add state variable for contract instance
91            contract.state_variable(
92                user_type(contract_name),
93                "contractInstance",
94                Some(Visibility::Private),
95                None,
96            );
97
98            contract.function("setUp", |func| {
99                func.visibility(Visibility::Public).body(|body| {
100                    // Deploy contract instance
101                    let constructor_args =
102                        generate_constructor_args_as_expressions(constructor_params);
103
104                    body.expression(Expression::Assignment(AssignmentExpression {
105                        left: Box::new(identifier("contractInstance")),
106                        operator: AssignmentOperator::Assign,
107                        right: Box::new(Expression::FunctionCall(FunctionCallExpression {
108                            function: Box::new(identifier(format!("new {}", contract_name))),
109                            arguments: constructor_args,
110                        })),
111                    }));
112                });
113            });
114
115            contract.function(&test_function_name, |func| {
116                func.visibility(Visibility::Public).body(|body| {
117                    body.expression(Expression::FunctionCall(FunctionCallExpression {
118                        function: Box::new(Expression::MemberAccess(MemberAccessExpression {
119                            object: Box::new(identifier("vm")),
120                            member: "prank".to_string(),
121                        })),
122                        arguments: vec![Expression::FunctionCall(FunctionCallExpression {
123                            function: Box::new(identifier("address")),
124                            arguments: vec![number("0x1")],
125                        })],
126                    }));
127
128                    // Add expectRevert
129                    if !error_message.is_empty() {
130                        body.expression(Expression::FunctionCall(FunctionCallExpression {
131                            function: Box::new(Expression::MemberAccess(MemberAccessExpression {
132                                object: Box::new(identifier("vm")),
133                                member: "expectRevert".to_string(),
134                            })),
135                            arguments: vec![Expression::FunctionCall(FunctionCallExpression {
136                                function: Box::new(identifier("bytes")),
137                                arguments: vec![string_literal(error_message)],
138                            })],
139                        }));
140                    } else {
141                        body.expression(Expression::FunctionCall(FunctionCallExpression {
142                            function: Box::new(Expression::MemberAccess(MemberAccessExpression {
143                                object: Box::new(identifier("vm")),
144                                member: "expectRevert".to_string(),
145                            })),
146                            arguments: vec![],
147                        }));
148                    }
149
150                    match generate_valid_args_for_function(function_params, None) {
151                        Ok(function_args) => {
152                            body.expression(Expression::FunctionCall(FunctionCallExpression {
153                                function: Box::new(Expression::MemberAccess(
154                                    MemberAccessExpression {
155                                        object: Box::new(identifier("contractInstance")),
156                                        member: function_name.to_string(),
157                                    },
158                                )),
159                                arguments: function_args,
160                            }));
161                        }
162                        Err(e) => {
163                            debug!("Failed to generate function arguments: {}", e);
164                            // Add a comment about the error
165                            body.expression(Expression::FunctionCall(FunctionCallExpression {
166                                function: Box::new(identifier("// Failed to generate arguments")),
167                                arguments: vec![],
168                            }));
169                        }
170                    }
171
172                    body.expression(Expression::FunctionCall(FunctionCallExpression {
173                        function: Box::new(Expression::MemberAccess(MemberAccessExpression {
174                            object: Box::new(identifier("vm")),
175                            member: "stopPrank".to_string(),
176                        })),
177                        arguments: vec![],
178                    }));
179                });
180            });
181        });
182
183    Ok(test_contract)
184}
185
186fn is_access_control_condition(condition: &str) -> bool {
187    let lower_condition = condition.to_lowercase();
188    lower_condition.contains("msg.sender")
189        || lower_condition.contains("owner")
190        || lower_condition.contains("admin")
191        || lower_condition.contains("role")
192        || lower_condition.contains("authorized")
193        || lower_condition.contains("permission")
194}
195
196fn generate_constructor_args_as_expressions(params: &[ParameterInfo]) -> Vec<Expression> {
197    params
198        .iter()
199        .map(|param| match param.param_type.as_str() {
200            "string" => string_literal("test"),
201            "address" => Expression::FunctionCall(FunctionCallExpression {
202                function: Box::new(identifier("address")),
203                arguments: vec![number("1")],
204            }),
205            "bool" => boolean(true),
206            t if t.starts_with("uint") => number("42"),
207            t if t.starts_with("int") => number("42"),
208            _ => number("0"), // Default fallback
209        })
210        .collect()
211}
212
213pub fn create_comprehensive_access_control_test_contract(
214    contract_info: &ContractInfo,
215    function_info: &FunctionInfo,
216    graph: &CallGraph,
217) -> Result<SolidityTestContract> {
218    let test_contracts = generate_access_control_tests_from_cfg(
219        graph,
220        &contract_info.name,
221        &function_info.name,
222        &function_info.parameters,
223        &contract_info.constructor_params,
224    )?;
225
226    if test_contracts.is_empty() {
227        return Err(anyhow::anyhow!(
228            "No access control tests could be generated for function {}",
229            function_info.name
230        ));
231    }
232
233    // For now, return the first test contract
234    // In a more sophisticated implementation, we might combine multiple test contracts
235    Ok(test_contracts.into_iter().next().unwrap())
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    #[test]
243    fn test_is_access_control_condition() {
244        assert!(is_access_control_condition("msg.sender == owner"));
245        assert!(is_access_control_condition("onlyOwner"));
246        assert!(is_access_control_condition(
247            "hasRole(ADMIN_ROLE, msg.sender)"
248        ));
249        assert!(!is_access_control_condition("balance > 0"));
250        assert!(!is_access_control_condition("amount < 1000"));
251    }
252}