1use tracing::debug;
7
8use crate::teststubs::{
9 generate_valid_args_for_function, to_pascal_case, ContractInfo, FunctionInfo,
10 SolidityTestContract, SolidityTestContractBuilder,
11};
12use anyhow::Result;
13use traverse_graph::cg::{CallGraph, EdgeType, NodeType, ParameterInfo};
14use traverse_solidity::ast::*;
15use traverse_solidity::builder::*;
16
17pub fn generate_access_control_tests_from_cfg(
18 graph: &CallGraph,
19 contract_name: &str,
20 function_name: &str,
21 function_params: &[ParameterInfo],
22 constructor_params: &[ParameterInfo],
23) -> Result<Vec<SolidityTestContract>> {
24 let mut test_contracts = Vec::new();
25
26 let func_node = graph.nodes.iter().find(|n| {
27 n.contract_name.as_deref() == Some(contract_name)
28 && n.name == function_name
29 && n.node_type == NodeType::Function
30 });
31
32 if let Some(func_node) = func_node {
33 let require_edges: Vec<_> = graph
34 .edges
35 .iter()
36 .filter(|edge| {
37 edge.source_node_id == func_node.id && edge.edge_type == EdgeType::Require
38 })
39 .collect();
40
41 for (edge_idx, edge) in require_edges.iter().enumerate() {
42 let condition_node = &graph.nodes[edge.target_node_id];
43 let condition_text = condition_node
44 .condition_expression
45 .as_ref()
46 .filter(|s| !s.is_empty())
47 .cloned()
48 .unwrap_or_else(|| condition_node.name.clone());
49
50 if is_access_control_condition(&condition_text) {
51 let test_contract = create_access_control_test_contract(
52 contract_name,
53 function_name,
54 function_params,
55 constructor_params,
56 &condition_text,
57 &condition_node.revert_message.clone().unwrap_or_default(),
58 edge_idx,
59 )?;
60
61 test_contracts.push(test_contract);
62 }
63 }
64 }
65
66 Ok(test_contracts)
67}
68
69fn create_access_control_test_contract(
70 contract_name: &str,
71 function_name: &str,
72 function_params: &[ParameterInfo],
73 constructor_params: &[ParameterInfo],
74 _condition: &str,
75 error_message: &str,
76 edge_idx: usize,
77) -> Result<SolidityTestContract> {
78 let test_contract_name = format!(
79 "{}{}AccessControlTest{}",
80 contract_name,
81 to_pascal_case(function_name),
82 edge_idx
83 );
84
85 let test_function_name = format!("test_{}_access_control_{}", function_name, edge_idx);
86
87 let test_contract = SolidityTestContractBuilder::new(test_contract_name.clone())
88 .add_import(format!("../src/{}.sol", contract_name))
89 .build_with_contract(|contract| {
90 contract.state_variable(
92 user_type(contract_name),
93 "contractInstance",
94 Some(Visibility::Private),
95 None,
96 );
97
98 contract.function("setUp", |func| {
99 func.visibility(Visibility::Public).body(|body| {
100 let constructor_args =
102 generate_constructor_args_as_expressions(constructor_params);
103
104 body.expression(Expression::Assignment(AssignmentExpression {
105 left: Box::new(identifier("contractInstance")),
106 operator: AssignmentOperator::Assign,
107 right: Box::new(Expression::FunctionCall(FunctionCallExpression {
108 function: Box::new(identifier(format!("new {}", contract_name))),
109 arguments: constructor_args,
110 })),
111 }));
112 });
113 });
114
115 contract.function(&test_function_name, |func| {
116 func.visibility(Visibility::Public).body(|body| {
117 body.expression(Expression::FunctionCall(FunctionCallExpression {
118 function: Box::new(Expression::MemberAccess(MemberAccessExpression {
119 object: Box::new(identifier("vm")),
120 member: "prank".to_string(),
121 })),
122 arguments: vec![Expression::FunctionCall(FunctionCallExpression {
123 function: Box::new(identifier("address")),
124 arguments: vec![number("0x1")],
125 })],
126 }));
127
128 if !error_message.is_empty() {
130 body.expression(Expression::FunctionCall(FunctionCallExpression {
131 function: Box::new(Expression::MemberAccess(MemberAccessExpression {
132 object: Box::new(identifier("vm")),
133 member: "expectRevert".to_string(),
134 })),
135 arguments: vec![Expression::FunctionCall(FunctionCallExpression {
136 function: Box::new(identifier("bytes")),
137 arguments: vec![string_literal(error_message)],
138 })],
139 }));
140 } else {
141 body.expression(Expression::FunctionCall(FunctionCallExpression {
142 function: Box::new(Expression::MemberAccess(MemberAccessExpression {
143 object: Box::new(identifier("vm")),
144 member: "expectRevert".to_string(),
145 })),
146 arguments: vec![],
147 }));
148 }
149
150 match generate_valid_args_for_function(function_params, None) {
151 Ok(function_args) => {
152 body.expression(Expression::FunctionCall(FunctionCallExpression {
153 function: Box::new(Expression::MemberAccess(
154 MemberAccessExpression {
155 object: Box::new(identifier("contractInstance")),
156 member: function_name.to_string(),
157 },
158 )),
159 arguments: function_args,
160 }));
161 }
162 Err(e) => {
163 debug!("Failed to generate function arguments: {}", e);
164 body.expression(Expression::FunctionCall(FunctionCallExpression {
166 function: Box::new(identifier("// Failed to generate arguments")),
167 arguments: vec![],
168 }));
169 }
170 }
171
172 body.expression(Expression::FunctionCall(FunctionCallExpression {
173 function: Box::new(Expression::MemberAccess(MemberAccessExpression {
174 object: Box::new(identifier("vm")),
175 member: "stopPrank".to_string(),
176 })),
177 arguments: vec![],
178 }));
179 });
180 });
181 });
182
183 Ok(test_contract)
184}
185
186fn is_access_control_condition(condition: &str) -> bool {
187 let lower_condition = condition.to_lowercase();
188 lower_condition.contains("msg.sender")
189 || lower_condition.contains("owner")
190 || lower_condition.contains("admin")
191 || lower_condition.contains("role")
192 || lower_condition.contains("authorized")
193 || lower_condition.contains("permission")
194}
195
196fn generate_constructor_args_as_expressions(params: &[ParameterInfo]) -> Vec<Expression> {
197 params
198 .iter()
199 .map(|param| match param.param_type.as_str() {
200 "string" => string_literal("test"),
201 "address" => Expression::FunctionCall(FunctionCallExpression {
202 function: Box::new(identifier("address")),
203 arguments: vec![number("1")],
204 }),
205 "bool" => boolean(true),
206 t if t.starts_with("uint") => number("42"),
207 t if t.starts_with("int") => number("42"),
208 _ => number("0"), })
210 .collect()
211}
212
213pub fn create_comprehensive_access_control_test_contract(
214 contract_info: &ContractInfo,
215 function_info: &FunctionInfo,
216 graph: &CallGraph,
217) -> Result<SolidityTestContract> {
218 let test_contracts = generate_access_control_tests_from_cfg(
219 graph,
220 &contract_info.name,
221 &function_info.name,
222 &function_info.parameters,
223 &contract_info.constructor_params,
224 )?;
225
226 if test_contracts.is_empty() {
227 return Err(anyhow::anyhow!(
228 "No access control tests could be generated for function {}",
229 function_info.name
230 ));
231 }
232
233 Ok(test_contracts.into_iter().next().unwrap())
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 #[test]
243 fn test_is_access_control_condition() {
244 assert!(is_access_control_condition("msg.sender == owner"));
245 assert!(is_access_control_condition("onlyOwner"));
246 assert!(is_access_control_condition(
247 "hasRole(ADMIN_ROLE, msg.sender)"
248 ));
249 assert!(!is_access_control_condition("balance > 0"));
250 assert!(!is_access_control_condition("amount < 1000"));
251 }
252}