traverse_codegen/
revert_stub.rs

1//! Enhanced revert test generation using proper Solidity AST
2//!
3//! This module generates revert tests using the new AST-based approach with
4//! SolidityTestContractBuilder and proper type-safe Solidity code generation.
5
6use tracing::debug;
7
8use crate::teststubs::{
9    sanitize_identifier, to_pascal_case, ContractInfo, FunctionInfo,
10    SolidityTestContract, SolidityTestContractBuilder,
11};
12use crate::invariant_breaker::{break_invariant, InvariantBreakerValue};
13use anyhow::Result;
14use traverse_graph::cg::{CallGraph, EdgeType, NodeType, ParameterInfo};
15use std::collections::HashMap;
16
17use traverse_solidity::ast::*;
18use traverse_solidity::builder::*;
19
20pub fn generate_revert_tests_from_cfg(
21    graph: &CallGraph,
22    contract_name: &str,
23    function_name: &str,
24    function_params: &[ParameterInfo],
25) -> Result<Vec<SolidityTestContract>> {
26    debug!(
27        "[REVERT DEBUG] Starting enhanced revert test generation for {}.{}",
28        contract_name, function_name
29    );
30    debug!(
31        "[REVERT DEBUG] Function has {} parameters",
32        function_params.len()
33    );
34    for (i, param) in function_params.iter().enumerate() {
35        debug!(
36            "[REVERT DEBUG] Param {}: {} {}",
37            i, param.param_type, param.name
38        );
39    }
40
41    let mut test_contracts = Vec::new();
42
43    let func_node = graph.nodes.iter().find(|n| {
44        n.contract_name.as_deref() == Some(contract_name)
45            && n.name == function_name
46            && n.node_type == NodeType::Function
47    });
48
49    if let Some(func_node) = func_node {
50        debug!(
51            "[REVERT DEBUG] Found function node with ID: {}",
52            func_node.id
53        );
54
55        let require_edges: Vec<_> = graph
56            .edges
57            .iter()
58            .filter(|edge| {
59                edge.source_node_id == func_node.id && edge.edge_type == EdgeType::Require
60            })
61            .collect();
62
63        debug!(
64            "[REVERT DEBUG] Found {} require edges for function",
65            require_edges.len()
66        );
67
68        for (edge_idx, edge) in require_edges.iter().enumerate() {
69            debug!(
70                "[REVERT DEBUG] Processing require edge {} of {}",
71                edge_idx + 1,
72                require_edges.len()
73            );
74
75            let condition_node = &graph.nodes[edge.target_node_id];
76            let original_condition_name = condition_node.name.clone();
77
78            let descriptive_condition_text = condition_node
79                .condition_expression
80                .as_ref()
81                .filter(|s| !s.is_empty())
82                .cloned()
83                .unwrap_or_else(|| original_condition_name.clone());
84
85            let error_message = condition_node.revert_message.clone().unwrap_or_default();
86
87            debug!(
88                "[REVERT DEBUG] Processing condition: '{}'",
89                descriptive_condition_text
90            );
91
92            let rt = tokio::runtime::Runtime::new().unwrap();
93            let invariant_result = match rt.block_on(break_invariant(&descriptive_condition_text)) {
94                Ok(result) => {
95                    debug!(
96                        "[REVERT DEBUG] Invariant breaker for '{}': success={}, entries={}",
97                        descriptive_condition_text, result.success, result.entries.len()
98                    );
99                    if result.success && !result.entries.is_empty() {
100                        Some(result)
101                    } else {
102                        debug!(
103                            "[REVERT DEBUG] Invariant breaker did not find a counterexample for '{}'",
104                            descriptive_condition_text
105                        );
106                        None
107                    }
108                }
109                Err(e) => {
110                    debug!(
111                        "[REVERT DEBUG] Error calling invariant breaker for '{}': {}",
112                        descriptive_condition_text, e
113                    );
114                    None
115                }
116            };
117
118            if let Some(invariant_result) = invariant_result {
119                let test_contract = create_revert_test_contract(
120                    graph,
121                    contract_name,
122                    function_name,
123                    function_params,
124                    &descriptive_condition_text,
125                    &error_message,
126                    &invariant_result,
127                    edge_idx,
128                )?;
129
130                test_contracts.push(test_contract);
131            } else {
132                debug!(
133                    "[REVERT DEBUG] Skipping test for condition '{}' - no counterexample found",
134                    descriptive_condition_text
135                );
136            }
137        }
138    } else {
139        debug!(
140            "[REVERT DEBUG] Function node not found in graph for {}.{}",
141            contract_name, function_name
142        );
143    }
144
145    debug!(
146        "[REVERT DEBUG] Generated {} enhanced revert test contracts",
147        test_contracts.len()
148    );
149    Ok(test_contracts)
150}
151
152fn create_revert_test_contract(
153    graph: &CallGraph,
154    contract_name: &str,
155    function_name: &str,
156    function_params: &[ParameterInfo],
157    condition: &str,
158    error_message: &str,
159    invariant_result: &crate::invariant_breaker::InvariantBreakerResult,
160    edge_idx: usize,
161) -> Result<SolidityTestContract> {
162    let mut test_name_condition_identifier = sanitize_identifier(condition);
163    if test_name_condition_identifier.len() > 40 {
164        test_name_condition_identifier.truncate(40);
165        while test_name_condition_identifier.ends_with('_') {
166            test_name_condition_identifier.pop();
167        }
168    }
169    if test_name_condition_identifier.is_empty() {
170        test_name_condition_identifier = "condition".to_string();
171    }
172
173    let test_contract_name = format!(
174        "{}{}RevertTest{}",
175        contract_name,
176        to_pascal_case(function_name),
177        edge_idx
178    );
179
180    let test_function_name = format!(
181        "test_{}_reverts_{}",
182        function_name, test_name_condition_identifier
183    );
184
185    debug!(
186        "[REVERT DEBUG] Creating test contract '{}' with function '{}'",
187        test_contract_name, test_function_name
188    );
189
190    // Extract counterexample variables
191    let first_entry = &invariant_result.entries[0];
192    let (needs_prank, prank_address) = extract_prank_info(&first_entry.variables, condition);
193
194    // Build the test contract using SolidityTestContractBuilder
195    let test_contract = SolidityTestContractBuilder::new(test_contract_name.clone())
196        .add_import(format!("../src/{}.sol", contract_name))
197        .build_with_contract(|contract| {
198            // Add state variable for contract instance
199            contract.state_variable(
200                user_type(contract_name),
201                "contractInstance",
202                Some(Visibility::Private),
203                None,
204            );
205
206            // Add setUp function
207            contract.function("setUp", |func| {
208                func.visibility(Visibility::Public).body(|body| {
209                    // Deploy contract instance - use constructor parameters from call graph
210                    let constructor_args = if let Some(constructor_node) = graph.nodes.iter().find(|n| {
211                        n.contract_name.as_deref() == Some(contract_name) && 
212                        n.node_type == NodeType::Constructor
213                    }) {
214                        generate_constructor_args_as_expressions(&constructor_node.parameters)
215                    } else {
216                        // Fallback to empty args if no constructor found
217                        vec![]
218                    };
219                    
220                    body.expression(Expression::Assignment(AssignmentExpression {
221                        left: Box::new(identifier("contractInstance")),
222                        operator: AssignmentOperator::Assign,
223                        right: Box::new(Expression::FunctionCall(FunctionCallExpression {
224                            function: Box::new(identifier(format!("new {}", contract_name))),
225                            arguments: constructor_args,
226                        })),
227                    }));
228                });
229            });
230
231            // Add the revert test function
232            contract.function(&test_function_name, |func| {
233                func.visibility(Visibility::Public).body(|body| {
234                    // Add prank if needed
235                    if needs_prank {
236                        if let Some(prank_addr) = prank_address {
237                            body.expression(Expression::FunctionCall(FunctionCallExpression {
238                                function: Box::new(Expression::MemberAccess(MemberAccessExpression {
239                                    object: Box::new(identifier("vm")),
240                                    member: "prank".to_string(),
241                                })),
242                                arguments: vec![string_literal(&prank_addr)],
243                            }));
244                        }
245                    }
246
247                    // Add expectRevert - convert statement to expression
248                    if !error_message.is_empty() {
249                        body.expression(Expression::FunctionCall(FunctionCallExpression {
250                            function: Box::new(Expression::MemberAccess(MemberAccessExpression {
251                                object: Box::new(identifier("vm")),
252                                member: "expectRevert".to_string(),
253                            })),
254                            arguments: vec![Expression::FunctionCall(FunctionCallExpression {
255                                function: Box::new(identifier("bytes")),
256                                arguments: vec![string_literal(error_message)],
257                            })],
258                        }));
259                    } else {
260                        body.expression(Expression::FunctionCall(FunctionCallExpression {
261                            function: Box::new(Expression::MemberAccess(MemberAccessExpression {
262                                object: Box::new(identifier("vm")),
263                                member: "expectRevert".to_string(),
264                            })),
265                            arguments: vec![],
266                        }));
267                    }
268
269                    // Generate function call with counterexample arguments
270                    match generate_function_args_from_invariant(
271                        function_params,
272                        &first_entry.variables,
273                    ) {
274                        Ok(function_args) => {
275                            body.expression(Expression::FunctionCall(FunctionCallExpression {
276                                function: Box::new(Expression::MemberAccess(MemberAccessExpression {
277                                    object: Box::new(identifier("contractInstance")),
278                                    member: function_name.to_string(),
279                                })),
280                                arguments: function_args,
281                            }));
282                        }
283                        Err(e) => {
284                            debug!("[REVERT DEBUG] Failed to generate function arguments: {}", e);
285                            // Add a comment about the error
286                            body.expression(Expression::FunctionCall(FunctionCallExpression {
287                                function: Box::new(identifier("// Failed to generate arguments")),
288                                arguments: vec![],
289                            }));
290                        }
291                    }
292
293                    // Stop prank if needed
294                    if needs_prank {
295                        body.expression(Expression::FunctionCall(FunctionCallExpression {
296                            function: Box::new(Expression::MemberAccess(MemberAccessExpression {
297                                object: Box::new(identifier("vm")),
298                                member: "stopPrank".to_string(),
299                            })),
300                            arguments: vec![],
301                        }));
302                    }
303                });
304            });
305        });
306
307    Ok(test_contract)
308}
309
310fn extract_prank_info(
311    variables: &HashMap<String, InvariantBreakerValue>,
312    condition: &str,
313) -> (bool, Option<String>) {
314    let lower_condition = condition.to_lowercase();
315    let involves_sender = lower_condition.contains("msg.sender") || lower_condition.contains("caller");
316    let involves_access_control = lower_condition.contains("owner")
317        || lower_condition.contains("admin")
318        || lower_condition.contains("role");
319
320    if involves_sender && involves_access_control {
321        // Look for address variables in the counterexample
322        for (var_name, var_value) in variables {
323            if let InvariantBreakerValue::Address(addr) = var_value {
324                debug!(
325                    "[REVERT DEBUG] Found address variable '{}' = '{}' for prank",
326                    var_name, addr
327                );
328                return (true, Some(addr.clone()));
329            }
330        }
331
332        debug!(
333            "[REVERT DEBUG] No address variable found in invariant breaker results for condition: {}",
334            condition
335        );
336        (true, None) // We need a prank but don't have an address
337    } else {
338        (false, None)
339    }
340}
341
342fn generate_constructor_args_as_expressions(params: &[ParameterInfo]) -> Vec<Expression> {
343    params
344        .iter()
345        .map(|param| {
346            match param.param_type.as_str() {
347                "string" => string_literal("test"),
348                "address" => Expression::FunctionCall(FunctionCallExpression {
349                    function: Box::new(identifier("address")),
350                    arguments: vec![number("1")],
351                }),
352                "bool" => boolean(true),
353                t if t.starts_with("uint") => number("42"),
354                t if t.starts_with("int") => number("42"),
355                _ => number("0"), // Default fallback
356            }
357        })
358        .collect()
359}
360
361fn generate_function_args_from_invariant(
362    function_params: &[ParameterInfo],
363    variables: &HashMap<String, InvariantBreakerValue>,
364) -> Result<Vec<Expression>> {
365    let mut args = Vec::new();
366    let mut missing_params = Vec::new();
367
368    for param in function_params {
369        if let Some(var_value) = variables.get(&param.name) {
370            debug!(
371                "[REVERT DEBUG] Using invariant value for param '{}': {:?}",
372                param.name, var_value
373            );
374            // Convert the invariant value to match the parameter's Solidity type
375            args.push(invariant_value_to_expression_with_type(var_value, &param.param_type));
376        } else {
377            debug!(
378                "[REVERT DEBUG] No invariant value found for param '{}'",
379                param.name
380            );
381            missing_params.push(param.name.clone());
382        }
383    }
384
385    if !missing_params.is_empty() {
386        return Err(anyhow::anyhow!(
387            "Missing invariant values for parameters: {}",
388            missing_params.join(", ")
389        ));
390    }
391
392    Ok(args)
393}
394
395fn invariant_value_to_expression_with_type(value: &InvariantBreakerValue, solidity_type: &str) -> Expression {
396    match value {
397        InvariantBreakerValue::Bool(b) => boolean(*b),
398        InvariantBreakerValue::UInt(n) => {
399            // For uint types, ensure the value is non-negative
400            // UInt values are already non-negative
401            number(n.to_string())
402        },
403        InvariantBreakerValue::Int(n) => {
404            // Check if the target type is uint - if so, convert to positive
405            if solidity_type.starts_with("uint") {
406                // UInt values are already non-negative
407            number(n.to_string())
408            } else {
409                number(n.to_string())
410            }
411        },
412        InvariantBreakerValue::String(s) => string_literal(s),
413        InvariantBreakerValue::Address(addr) => {
414            // For addresses, we might want to use address() cast or just the literal
415            if addr.starts_with("0x") {
416                Expression::FunctionCall(FunctionCallExpression {
417                    function: Box::new(identifier("address")),
418                    arguments: vec![Expression::Literal(Literal::HexString(HexStringLiteral {
419                        value: addr.strip_prefix("0x").unwrap_or(addr).to_string(),
420                    }))],
421                })
422            } else {
423                string_literal(addr)
424            }
425        }
426        InvariantBreakerValue::Bytes(b) => {
427            Expression::Literal(Literal::HexString(HexStringLiteral {
428                value: hex::encode(b),
429            }))
430        }
431    }
432}
433
434#[allow(dead_code)]
435fn invariant_value_to_expression(value: &InvariantBreakerValue) -> Expression {
436    match value {
437        InvariantBreakerValue::Bool(b) => boolean(*b),
438        InvariantBreakerValue::UInt(n) => {
439            // Ensure uint values are non-negative
440            // UInt values are already non-negative
441            number(n.to_string())
442        },
443        InvariantBreakerValue::Int(n) => number(n.to_string()),
444        InvariantBreakerValue::String(s) => string_literal(s),
445        InvariantBreakerValue::Address(addr) => {
446            // For addresses, we might want to use address() cast or just the literal
447            if addr.starts_with("0x") {
448                Expression::FunctionCall(FunctionCallExpression {
449                    function: Box::new(identifier("address")),
450                    arguments: vec![Expression::Literal(Literal::HexString(HexStringLiteral {
451                        value: addr.strip_prefix("0x").unwrap_or(addr).to_string(),
452                    }))],
453                })
454            } else {
455                string_literal(addr)
456            }
457        }
458        InvariantBreakerValue::Bytes(b) => {
459            Expression::Literal(Literal::HexString(HexStringLiteral {
460                value: hex::encode(b),
461            }))
462        }
463    }
464}
465
466pub fn create_comprehensive_revert_test_contract(
467    contract_info: &ContractInfo,
468    function_info: &FunctionInfo,
469    graph: &CallGraph,
470) -> Result<SolidityTestContract> {
471    debug!(
472        "[REVERT DEBUG] Creating comprehensive revert test contract for {}.{}",
473        contract_info.name, function_info.name
474    );
475
476    let test_contracts = generate_revert_tests_from_cfg(
477        graph,
478        &contract_info.name,
479        &function_info.name,
480        &function_info.parameters,
481    )?;
482
483    if test_contracts.is_empty() {
484        return Err(anyhow::anyhow!(
485            "No revert tests could be generated for function {}",
486            function_info.name
487        ));
488    }
489
490    // For now, return the first test contract
491    // In a more sophisticated implementation, we might combine multiple test contracts
492    Ok(test_contracts.into_iter().next().unwrap())
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498
499    #[test]
500    fn test_invariant_value_to_expression() {
501        let bool_val = InvariantBreakerValue::Bool(true);
502        let expr = invariant_value_to_expression(&bool_val);
503        assert!(matches!(expr, Expression::Literal(Literal::Boolean(true))));
504
505        let uint_val = InvariantBreakerValue::UInt(42);
506        let expr = invariant_value_to_expression(&uint_val);
507        if let Expression::Literal(Literal::Number(num)) = expr {
508            assert_eq!(num.value, "42");
509        } else {
510            panic!("Expected number literal");
511        }
512
513        let string_val = InvariantBreakerValue::String("test".to_string());
514        let expr = invariant_value_to_expression(&string_val);
515        if let Expression::Literal(Literal::String(s)) = expr {
516            assert_eq!(s.value, "test");
517        } else {
518            panic!("Expected string literal");
519        }
520    }
521
522    #[test]
523    fn test_extract_prank_info() {
524        let mut variables = HashMap::new();
525        variables.insert(
526            "caller".to_string(),
527            InvariantBreakerValue::Address("0x1234567890123456789012345678901234567890".to_string()),
528        );
529
530        let condition = "msg.sender == owner";
531        let (needs_prank, prank_address) = extract_prank_info(&variables, condition);
532
533        assert!(needs_prank);
534        assert_eq!(
535            prank_address,
536            Some("0x1234567890123456789012345678901234567890".to_string())
537        );
538    }
539}