1use tracing::debug;
7
8use crate::teststubs::{
9 sanitize_identifier, to_pascal_case, ContractInfo, FunctionInfo,
10 SolidityTestContract, SolidityTestContractBuilder,
11};
12use crate::invariant_breaker::{break_invariant, InvariantBreakerValue};
13use anyhow::Result;
14use traverse_graph::cg::{CallGraph, EdgeType, NodeType, ParameterInfo};
15use std::collections::HashMap;
16
17use traverse_solidity::ast::*;
18use traverse_solidity::builder::*;
19
20pub fn generate_revert_tests_from_cfg(
21 graph: &CallGraph,
22 contract_name: &str,
23 function_name: &str,
24 function_params: &[ParameterInfo],
25) -> Result<Vec<SolidityTestContract>> {
26 debug!(
27 "[REVERT DEBUG] Starting enhanced revert test generation for {}.{}",
28 contract_name, function_name
29 );
30 debug!(
31 "[REVERT DEBUG] Function has {} parameters",
32 function_params.len()
33 );
34 for (i, param) in function_params.iter().enumerate() {
35 debug!(
36 "[REVERT DEBUG] Param {}: {} {}",
37 i, param.param_type, param.name
38 );
39 }
40
41 let mut test_contracts = Vec::new();
42
43 let func_node = graph.nodes.iter().find(|n| {
44 n.contract_name.as_deref() == Some(contract_name)
45 && n.name == function_name
46 && n.node_type == NodeType::Function
47 });
48
49 if let Some(func_node) = func_node {
50 debug!(
51 "[REVERT DEBUG] Found function node with ID: {}",
52 func_node.id
53 );
54
55 let require_edges: Vec<_> = graph
56 .edges
57 .iter()
58 .filter(|edge| {
59 edge.source_node_id == func_node.id && edge.edge_type == EdgeType::Require
60 })
61 .collect();
62
63 debug!(
64 "[REVERT DEBUG] Found {} require edges for function",
65 require_edges.len()
66 );
67
68 for (edge_idx, edge) in require_edges.iter().enumerate() {
69 debug!(
70 "[REVERT DEBUG] Processing require edge {} of {}",
71 edge_idx + 1,
72 require_edges.len()
73 );
74
75 let condition_node = &graph.nodes[edge.target_node_id];
76 let original_condition_name = condition_node.name.clone();
77
78 let descriptive_condition_text = condition_node
79 .condition_expression
80 .as_ref()
81 .filter(|s| !s.is_empty())
82 .cloned()
83 .unwrap_or_else(|| original_condition_name.clone());
84
85 let error_message = condition_node.revert_message.clone().unwrap_or_default();
86
87 debug!(
88 "[REVERT DEBUG] Processing condition: '{}'",
89 descriptive_condition_text
90 );
91
92 let rt = tokio::runtime::Runtime::new().unwrap();
93 let invariant_result = match rt.block_on(break_invariant(&descriptive_condition_text)) {
94 Ok(result) => {
95 debug!(
96 "[REVERT DEBUG] Invariant breaker for '{}': success={}, entries={}",
97 descriptive_condition_text, result.success, result.entries.len()
98 );
99 if result.success && !result.entries.is_empty() {
100 Some(result)
101 } else {
102 debug!(
103 "[REVERT DEBUG] Invariant breaker did not find a counterexample for '{}'",
104 descriptive_condition_text
105 );
106 None
107 }
108 }
109 Err(e) => {
110 debug!(
111 "[REVERT DEBUG] Error calling invariant breaker for '{}': {}",
112 descriptive_condition_text, e
113 );
114 None
115 }
116 };
117
118 if let Some(invariant_result) = invariant_result {
119 let test_contract = create_revert_test_contract(
120 graph,
121 contract_name,
122 function_name,
123 function_params,
124 &descriptive_condition_text,
125 &error_message,
126 &invariant_result,
127 edge_idx,
128 )?;
129
130 test_contracts.push(test_contract);
131 } else {
132 debug!(
133 "[REVERT DEBUG] Skipping test for condition '{}' - no counterexample found",
134 descriptive_condition_text
135 );
136 }
137 }
138 } else {
139 debug!(
140 "[REVERT DEBUG] Function node not found in graph for {}.{}",
141 contract_name, function_name
142 );
143 }
144
145 debug!(
146 "[REVERT DEBUG] Generated {} enhanced revert test contracts",
147 test_contracts.len()
148 );
149 Ok(test_contracts)
150}
151
152fn create_revert_test_contract(
153 graph: &CallGraph,
154 contract_name: &str,
155 function_name: &str,
156 function_params: &[ParameterInfo],
157 condition: &str,
158 error_message: &str,
159 invariant_result: &crate::invariant_breaker::InvariantBreakerResult,
160 edge_idx: usize,
161) -> Result<SolidityTestContract> {
162 let mut test_name_condition_identifier = sanitize_identifier(condition);
163 if test_name_condition_identifier.len() > 40 {
164 test_name_condition_identifier.truncate(40);
165 while test_name_condition_identifier.ends_with('_') {
166 test_name_condition_identifier.pop();
167 }
168 }
169 if test_name_condition_identifier.is_empty() {
170 test_name_condition_identifier = "condition".to_string();
171 }
172
173 let test_contract_name = format!(
174 "{}{}RevertTest{}",
175 contract_name,
176 to_pascal_case(function_name),
177 edge_idx
178 );
179
180 let test_function_name = format!(
181 "test_{}_reverts_{}",
182 function_name, test_name_condition_identifier
183 );
184
185 debug!(
186 "[REVERT DEBUG] Creating test contract '{}' with function '{}'",
187 test_contract_name, test_function_name
188 );
189
190 let first_entry = &invariant_result.entries[0];
192 let (needs_prank, prank_address) = extract_prank_info(&first_entry.variables, condition);
193
194 let test_contract = SolidityTestContractBuilder::new(test_contract_name.clone())
196 .add_import(format!("../src/{}.sol", contract_name))
197 .build_with_contract(|contract| {
198 contract.state_variable(
200 user_type(contract_name),
201 "contractInstance",
202 Some(Visibility::Private),
203 None,
204 );
205
206 contract.function("setUp", |func| {
208 func.visibility(Visibility::Public).body(|body| {
209 let constructor_args = if let Some(constructor_node) = graph.nodes.iter().find(|n| {
211 n.contract_name.as_deref() == Some(contract_name) &&
212 n.node_type == NodeType::Constructor
213 }) {
214 generate_constructor_args_as_expressions(&constructor_node.parameters)
215 } else {
216 vec![]
218 };
219
220 body.expression(Expression::Assignment(AssignmentExpression {
221 left: Box::new(identifier("contractInstance")),
222 operator: AssignmentOperator::Assign,
223 right: Box::new(Expression::FunctionCall(FunctionCallExpression {
224 function: Box::new(identifier(format!("new {}", contract_name))),
225 arguments: constructor_args,
226 })),
227 }));
228 });
229 });
230
231 contract.function(&test_function_name, |func| {
233 func.visibility(Visibility::Public).body(|body| {
234 if needs_prank {
236 if let Some(prank_addr) = prank_address {
237 body.expression(Expression::FunctionCall(FunctionCallExpression {
238 function: Box::new(Expression::MemberAccess(MemberAccessExpression {
239 object: Box::new(identifier("vm")),
240 member: "prank".to_string(),
241 })),
242 arguments: vec![string_literal(&prank_addr)],
243 }));
244 }
245 }
246
247 if !error_message.is_empty() {
249 body.expression(Expression::FunctionCall(FunctionCallExpression {
250 function: Box::new(Expression::MemberAccess(MemberAccessExpression {
251 object: Box::new(identifier("vm")),
252 member: "expectRevert".to_string(),
253 })),
254 arguments: vec![Expression::FunctionCall(FunctionCallExpression {
255 function: Box::new(identifier("bytes")),
256 arguments: vec![string_literal(error_message)],
257 })],
258 }));
259 } else {
260 body.expression(Expression::FunctionCall(FunctionCallExpression {
261 function: Box::new(Expression::MemberAccess(MemberAccessExpression {
262 object: Box::new(identifier("vm")),
263 member: "expectRevert".to_string(),
264 })),
265 arguments: vec![],
266 }));
267 }
268
269 match generate_function_args_from_invariant(
271 function_params,
272 &first_entry.variables,
273 ) {
274 Ok(function_args) => {
275 body.expression(Expression::FunctionCall(FunctionCallExpression {
276 function: Box::new(Expression::MemberAccess(MemberAccessExpression {
277 object: Box::new(identifier("contractInstance")),
278 member: function_name.to_string(),
279 })),
280 arguments: function_args,
281 }));
282 }
283 Err(e) => {
284 debug!("[REVERT DEBUG] Failed to generate function arguments: {}", e);
285 body.expression(Expression::FunctionCall(FunctionCallExpression {
287 function: Box::new(identifier("// Failed to generate arguments")),
288 arguments: vec![],
289 }));
290 }
291 }
292
293 if needs_prank {
295 body.expression(Expression::FunctionCall(FunctionCallExpression {
296 function: Box::new(Expression::MemberAccess(MemberAccessExpression {
297 object: Box::new(identifier("vm")),
298 member: "stopPrank".to_string(),
299 })),
300 arguments: vec![],
301 }));
302 }
303 });
304 });
305 });
306
307 Ok(test_contract)
308}
309
310fn extract_prank_info(
311 variables: &HashMap<String, InvariantBreakerValue>,
312 condition: &str,
313) -> (bool, Option<String>) {
314 let lower_condition = condition.to_lowercase();
315 let involves_sender = lower_condition.contains("msg.sender") || lower_condition.contains("caller");
316 let involves_access_control = lower_condition.contains("owner")
317 || lower_condition.contains("admin")
318 || lower_condition.contains("role");
319
320 if involves_sender && involves_access_control {
321 for (var_name, var_value) in variables {
323 if let InvariantBreakerValue::Address(addr) = var_value {
324 debug!(
325 "[REVERT DEBUG] Found address variable '{}' = '{}' for prank",
326 var_name, addr
327 );
328 return (true, Some(addr.clone()));
329 }
330 }
331
332 debug!(
333 "[REVERT DEBUG] No address variable found in invariant breaker results for condition: {}",
334 condition
335 );
336 (true, None) } else {
338 (false, None)
339 }
340}
341
342fn generate_constructor_args_as_expressions(params: &[ParameterInfo]) -> Vec<Expression> {
343 params
344 .iter()
345 .map(|param| {
346 match param.param_type.as_str() {
347 "string" => string_literal("test"),
348 "address" => Expression::FunctionCall(FunctionCallExpression {
349 function: Box::new(identifier("address")),
350 arguments: vec![number("1")],
351 }),
352 "bool" => boolean(true),
353 t if t.starts_with("uint") => number("42"),
354 t if t.starts_with("int") => number("42"),
355 _ => number("0"), }
357 })
358 .collect()
359}
360
361fn generate_function_args_from_invariant(
362 function_params: &[ParameterInfo],
363 variables: &HashMap<String, InvariantBreakerValue>,
364) -> Result<Vec<Expression>> {
365 let mut args = Vec::new();
366 let mut missing_params = Vec::new();
367
368 for param in function_params {
369 if let Some(var_value) = variables.get(¶m.name) {
370 debug!(
371 "[REVERT DEBUG] Using invariant value for param '{}': {:?}",
372 param.name, var_value
373 );
374 args.push(invariant_value_to_expression_with_type(var_value, ¶m.param_type));
376 } else {
377 debug!(
378 "[REVERT DEBUG] No invariant value found for param '{}'",
379 param.name
380 );
381 missing_params.push(param.name.clone());
382 }
383 }
384
385 if !missing_params.is_empty() {
386 return Err(anyhow::anyhow!(
387 "Missing invariant values for parameters: {}",
388 missing_params.join(", ")
389 ));
390 }
391
392 Ok(args)
393}
394
395fn invariant_value_to_expression_with_type(value: &InvariantBreakerValue, solidity_type: &str) -> Expression {
396 match value {
397 InvariantBreakerValue::Bool(b) => boolean(*b),
398 InvariantBreakerValue::UInt(n) => {
399 number(n.to_string())
402 },
403 InvariantBreakerValue::Int(n) => {
404 if solidity_type.starts_with("uint") {
406 number(n.to_string())
408 } else {
409 number(n.to_string())
410 }
411 },
412 InvariantBreakerValue::String(s) => string_literal(s),
413 InvariantBreakerValue::Address(addr) => {
414 if addr.starts_with("0x") {
416 Expression::FunctionCall(FunctionCallExpression {
417 function: Box::new(identifier("address")),
418 arguments: vec![Expression::Literal(Literal::HexString(HexStringLiteral {
419 value: addr.strip_prefix("0x").unwrap_or(addr).to_string(),
420 }))],
421 })
422 } else {
423 string_literal(addr)
424 }
425 }
426 InvariantBreakerValue::Bytes(b) => {
427 Expression::Literal(Literal::HexString(HexStringLiteral {
428 value: hex::encode(b),
429 }))
430 }
431 }
432}
433
434#[allow(dead_code)]
435fn invariant_value_to_expression(value: &InvariantBreakerValue) -> Expression {
436 match value {
437 InvariantBreakerValue::Bool(b) => boolean(*b),
438 InvariantBreakerValue::UInt(n) => {
439 number(n.to_string())
442 },
443 InvariantBreakerValue::Int(n) => number(n.to_string()),
444 InvariantBreakerValue::String(s) => string_literal(s),
445 InvariantBreakerValue::Address(addr) => {
446 if addr.starts_with("0x") {
448 Expression::FunctionCall(FunctionCallExpression {
449 function: Box::new(identifier("address")),
450 arguments: vec![Expression::Literal(Literal::HexString(HexStringLiteral {
451 value: addr.strip_prefix("0x").unwrap_or(addr).to_string(),
452 }))],
453 })
454 } else {
455 string_literal(addr)
456 }
457 }
458 InvariantBreakerValue::Bytes(b) => {
459 Expression::Literal(Literal::HexString(HexStringLiteral {
460 value: hex::encode(b),
461 }))
462 }
463 }
464}
465
466pub fn create_comprehensive_revert_test_contract(
467 contract_info: &ContractInfo,
468 function_info: &FunctionInfo,
469 graph: &CallGraph,
470) -> Result<SolidityTestContract> {
471 debug!(
472 "[REVERT DEBUG] Creating comprehensive revert test contract for {}.{}",
473 contract_info.name, function_info.name
474 );
475
476 let test_contracts = generate_revert_tests_from_cfg(
477 graph,
478 &contract_info.name,
479 &function_info.name,
480 &function_info.parameters,
481 )?;
482
483 if test_contracts.is_empty() {
484 return Err(anyhow::anyhow!(
485 "No revert tests could be generated for function {}",
486 function_info.name
487 ));
488 }
489
490 Ok(test_contracts.into_iter().next().unwrap())
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498
499 #[test]
500 fn test_invariant_value_to_expression() {
501 let bool_val = InvariantBreakerValue::Bool(true);
502 let expr = invariant_value_to_expression(&bool_val);
503 assert!(matches!(expr, Expression::Literal(Literal::Boolean(true))));
504
505 let uint_val = InvariantBreakerValue::UInt(42);
506 let expr = invariant_value_to_expression(&uint_val);
507 if let Expression::Literal(Literal::Number(num)) = expr {
508 assert_eq!(num.value, "42");
509 } else {
510 panic!("Expected number literal");
511 }
512
513 let string_val = InvariantBreakerValue::String("test".to_string());
514 let expr = invariant_value_to_expression(&string_val);
515 if let Expression::Literal(Literal::String(s)) = expr {
516 assert_eq!(s.value, "test");
517 } else {
518 panic!("Expected string literal");
519 }
520 }
521
522 #[test]
523 fn test_extract_prank_info() {
524 let mut variables = HashMap::new();
525 variables.insert(
526 "caller".to_string(),
527 InvariantBreakerValue::Address("0x1234567890123456789012345678901234567890".to_string()),
528 );
529
530 let condition = "msg.sender == owner";
531 let (needs_prank, prank_address) = extract_prank_info(&variables, condition);
532
533 assert!(needs_prank);
534 assert_eq!(
535 prank_address,
536 Some("0x1234567890123456789012345678901234567890".to_string())
537 );
538 }
539}