1use tracing::debug;
6
7use anyhow::{Context, Result};
8use traverse_graph::cg::{CallGraph, CallGraphGeneratorContext, ParameterInfo};
9use serde::{Deserialize, Serialize};
10use std::{collections::HashMap, path::PathBuf};
11use std::{fs, path::Path};
12
13use traverse_solidity::ast::*;
15use traverse_solidity::builder::*;
16use traverse_solidity::solidity_writer::write_source_unit;
17
18use crate::deployer_stub;
19use crate::revert_stub;
20use crate::state_change_stub;
21use crate::access_control_stub;
22use crate::CodeGenError;
23
24pub use traverse_solidity::ast::{Expression, Statement, TypeName, Visibility, StateMutability};
26pub use traverse_solidity::builder::{SolidityBuilder, ContractBuilder, FunctionBuilder, BlockBuilder};
27
28#[derive(Debug, serde::Serialize)]
29pub struct ContractInfo {
30 pub name: String,
31 pub has_constructor: bool,
32 pub constructor_params: Vec<ParameterInfo>,
33 pub functions: Vec<FunctionInfo>,
34}
35
36#[derive(Debug, serde::Serialize, Clone)]
37pub struct FunctionInfo {
38 pub name: String,
39 pub visibility: String,
40 pub return_type: Option<String>,
41 pub parameters: Vec<ParameterInfo>,
42}
43
44#[derive(Debug, Clone)]
46pub struct SolidityTestContract {
47 pub source_unit: SourceUnit,
48 pub contract_name: String,
49}
50
51impl SolidityTestContract {
52 pub fn new(contract_name: String, source_unit: SourceUnit) -> Self {
53 Self {
54 contract_name,
55 source_unit,
56 }
57 }
58
59 pub fn to_solidity_code(&self) -> String {
61 write_source_unit(&self.source_unit)
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub enum TestType {
68 RevertTest {
69 expected_error: String,
70 condition: String,
71 },
72 StateChangeTest {
73 variable_name: String,
74 expected_change: String,
75 },
76 AccessControlTest {
77 role_required: String,
78 unauthorized_caller: String,
79 },
80 DeployerTest {
81 contract_name: String,
82 constructor_args: Vec<String>,
83 },
84 FuzzTest {
85 property: String,
86 input_constraints: Vec<String>,
87 },
88}
89
90pub struct SolidityTestContractBuilder {
92 builder: SolidityBuilder,
93 contract_name: String,
94}
95
96impl SolidityTestContractBuilder {
97 pub fn new(contract_name: String) -> Self {
98 let mut builder = SolidityBuilder::new();
99
100 builder
102 .pragma("solidity", "^0.8.0")
103 .import("forge-std/Test.sol"); Self {
106 builder,
107 contract_name,
108 }
109 }
110
111 pub fn add_import(mut self, import_path: String) -> Self {
112 self.builder.import(import_path);
113 self
114 }
115
116 pub fn build_with_contract<F>(mut self, build_contract: F) -> SolidityTestContract
117 where
118 F: FnOnce(&mut ContractBuilder),
119 {
120 self.builder.contract(&self.contract_name, |contract| {
121 contract.inherits("Test");
123
124 build_contract(contract);
128 });
129
130 let source_unit = self.builder.build();
131 SolidityTestContract::new(self.contract_name.clone(), source_unit)
132 }
133}
134
135pub struct FoundryIntegration {
136 pub project_root: PathBuf,
137}
138
139impl FoundryIntegration {
140 pub fn new_with_project_setup(project_root: PathBuf) -> Result<Self> {
141 if !project_root.exists() {
142 fs::create_dir_all(&project_root).context("Failed to create project root directory")?;
143 }
144
145 let foundry = Self { project_root };
146
147 if !foundry.project_root.join("foundry.toml").exists() {
148 FoundryIntegration::init_foundry_project(&foundry.project_root)?;
149 }
150
151 Ok(foundry)
152 }
153
154 fn init_foundry_project(project_root: &Path) -> Result<()> {
155 use std::process::Command;
156
157 debug!("๐ง Initializing Foundry project at: {}", project_root.display());
158
159 let output = Command::new("forge")
160 .arg("init")
161 .arg("--force")
162 .current_dir(project_root)
163 .output()
164 .context("Failed to execute 'forge init' command")?;
165
166 if output.status.success() {
167 debug!("โ
Foundry project initialized successfully");
168 Ok(())
169 } else {
170 let stderr = String::from_utf8_lossy(&output.stderr);
171 if stderr.contains("already exists") || stderr.contains("already initialized") {
172 debug!("โ
Foundry project initialized (some files already existed)");
173 Ok(())
174 } else {
175 Err(anyhow::anyhow!("Failed to initialize Foundry project: {}", stderr))
176 }
177 }
178 }
179
180 pub fn copy_contract_to_src(&self, contract_path: &Path, contract_name: &str) -> Result<()> {
181 let src_dir = self.project_root.join("src");
182 fs::create_dir_all(&src_dir).context("Failed to create src directory")?;
183
184 let dest_path = src_dir.join(format!("{}.sol", contract_name));
185 if contract_path.exists() {
186 fs::copy(contract_path, &dest_path)
187 .context("Failed to copy contract to src directory")?;
188 debug!("๐ Copied {} to {}", contract_path.display(), dest_path.display());
189 }
190
191 Ok(())
192 }
193
194 pub fn write_test_contract(
195 &self,
196 contract: &SolidityTestContract,
197 test_file_path: &Path,
198 ) -> Result<()> {
199 let source_code = contract.to_solidity_code();
200
201 if let Some(parent_dir) = test_file_path.parent() {
202 fs::create_dir_all(parent_dir).context("Failed to create test directory")?;
203 }
204
205 fs::write(test_file_path, &source_code).context(format!(
206 "Failed to write test contract to {}",
207 test_file_path.display()
208 ))?;
209
210 Ok(())
211 }
212
213 pub fn run_project_build(&self) -> Result<bool> {
214 use std::process::Command;
215
216 let output = Command::new("forge")
217 .arg("build")
218 .current_dir(&self.project_root)
219 .output()
220 .context("Failed to execute 'forge build' command")?;
221
222 if output.status.success() {
223 Ok(true)
224 } else {
225 let stderr = String::from_utf8_lossy(&output.stderr);
226 debug!("Forge build failed: {}", stderr);
227 Ok(false)
228 }
229 }
230
231 pub fn run_tests(&self, test_pattern: Option<&str>) -> Result<bool> {
232 use std::process::Command;
233
234 let mut cmd = Command::new("forge");
235 cmd.arg("test").current_dir(&self.project_root);
236
237 if let Some(pattern) = test_pattern {
238 cmd.arg("--match-test").arg(pattern);
239 }
240
241 let output = cmd.output().context("Failed to execute 'forge test' command")?;
242
243 if output.status.success() {
244 debug!("โ
All tests passed!");
245 Ok(true)
246 } else {
247 let stderr = String::from_utf8_lossy(&output.stderr);
248 debug!("Some tests failed: {}", stderr);
249 Ok(false)
250 }
251 }
252}
253
254pub mod expression_helpers {
255 use super::*;
256
257
258 pub fn require_statement(condition: Expression, message: &str) -> Statement {
260 Statement::Expression(ExpressionStatement {
261 expression: Expression::FunctionCall(FunctionCallExpression {
262 function: Box::new(Expression::Identifier("require".to_string())),
263 arguments: vec![
264 condition,
265 Expression::Literal(Literal::String(StringLiteral {
266 value: message.to_string(),
267 })),
268 ],
269 }),
270 })
271 }
272
273 pub fn expect_revert_statement(error_message: &str) -> Statement {
275 Statement::Expression(ExpressionStatement {
276 expression: Expression::FunctionCall(FunctionCallExpression {
277 function: Box::new(Expression::MemberAccess(MemberAccessExpression {
278 object: Box::new(Expression::Identifier("vm".to_string())),
279 member: "expectRevert".to_string(),
280 })),
281 arguments: vec![Expression::FunctionCall(FunctionCallExpression {
282 function: Box::new(Expression::Identifier("bytes".to_string())),
283 arguments: vec![Expression::Literal(Literal::String(StringLiteral {
284 value: error_message.to_string(),
285 }))],
286 })],
287 }),
288 })
289 }
290
291 pub fn assert_statement(condition: Expression) -> Statement {
293 Statement::Expression(ExpressionStatement {
294 expression: Expression::FunctionCall(FunctionCallExpression {
295 function: Box::new(Expression::Identifier("assert".to_string())),
296 arguments: vec![condition],
297 }),
298 })
299 }
300
301 pub fn declare_and_assign(type_name: TypeName, name: &str, value: Expression) -> Statement {
303 Statement::Variable(VariableDeclarationStatement {
304 declaration: VariableDeclaration {
305 type_name,
306 data_location: None,
307 name: name.to_string(),
308 },
309 initial_value: Some(value),
310 })
311 }
312
313 pub fn function_call(target: Option<Expression>, function_name: &str, args: Vec<Expression>) -> Expression {
315 if let Some(target_expr) = target {
316 Expression::FunctionCall(FunctionCallExpression {
317 function: Box::new(Expression::MemberAccess(MemberAccessExpression {
318 object: Box::new(target_expr),
319 member: function_name.to_string(),
320 })),
321 arguments: args,
322 })
323 } else {
324 Expression::FunctionCall(FunctionCallExpression {
325 function: Box::new(Expression::Identifier(function_name.to_string())),
326 arguments: args,
327 })
328 }
329 }
330}
331
332pub fn strings_to_expressions(arg_strings: &[String]) -> Vec<Expression> {
333 arg_strings
334 .iter()
335 .map(|s| Expression::Literal(Literal::String(StringLiteral {
336 value: s.clone(),
337 })))
338 .collect()
339}
340
341pub fn generate_valid_args_for_function(
342 function_params: &[ParameterInfo],
343 actual_args_opt: Option<&Vec<String>>,
344) -> Result<Vec<Expression>> {
345 if let Some(actual_args) = actual_args_opt {
346 Ok(strings_to_expressions(actual_args))
347 } else {
348 let args = function_params
349 .iter()
350 .map(|param| match param.param_type.as_str() {
351 "string" => string_literal("updated test value"),
352 "address" => Expression::FunctionCall(FunctionCallExpression {
353 function: Box::new(Expression::Identifier("address".to_string())),
354 arguments: vec![number("1")],
355 }),
356 "bool" => boolean(true),
357 t if t.starts_with("uint") => number("42"),
358 t if t.starts_with("int") => number("42"),
359 _ => number("1"),
360 })
361 .collect();
362
363 Ok(args)
364 }
365}
366
367pub fn sanitize_identifier(input: &str) -> String {
368 input
369 .chars()
370 .map(|c| if c.is_alphanumeric() { c } else { '_' })
371 .collect::<String>()
372 .trim_matches('_')
373 .to_string()
374}
375
376pub fn capitalize_first_letter(s: &str) -> String {
377 let mut chars = s.chars();
378 match chars.next() {
379 None => String::new(),
380 Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
381 }
382}
383
384pub fn to_pascal_case(s: &str) -> String {
385 let mut result = String::new();
386 let mut capitalize_next = true;
387
388 for c in s.chars() {
389 if c == '_' || c == '-' || c == ' ' {
390 capitalize_next = true;
391 } else if capitalize_next {
392 result.push(c.to_uppercase().next().unwrap_or(c));
393 capitalize_next = false;
394 } else {
395 result.push(c);
396 }
397 }
398 result
399}
400
401pub(crate) fn extract_contracts_from_graph(
402 graph: &CallGraph,
403 ctx: &CallGraphGeneratorContext,
404) -> Vec<ContractInfo> {
405 let mut contracts_map: HashMap<String, ContractInfo> = HashMap::new();
407
408 for node in graph.nodes.iter() {
409 if let Some(contract_name_str) = &node.contract_name {
410 let is_interface_scope = ctx.all_interfaces.contains_key(contract_name_str);
411
412 if node.node_type == traverse_graph::cg::NodeType::Interface && &node.name == contract_name_str {
413 continue;
414 }
415
416 let contract_info = contracts_map
417 .entry(contract_name_str.clone())
418 .or_insert_with(|| ContractInfo {
419 name: contract_name_str.clone(),
420 has_constructor: false,
421 constructor_params: Vec::new(),
422 functions: Vec::new(),
423 });
424
425 match node.node_type {
426 traverse_graph::cg::NodeType::Constructor => {
427 contract_info.has_constructor = true;
428 contract_info.constructor_params = graph.nodes[node.id].parameters.clone();
429 }
430 traverse_graph::cg::NodeType::Function => {
431 if !is_interface_scope {
432 let params = graph.nodes[node.id].parameters.clone();
433 contract_info.functions.push(FunctionInfo {
434 name: node.name.clone(),
435 visibility: "public".to_string(),
436 return_type: None,
437 parameters: params,
438 });
439 }
440 }
441 _ => {}
442 }
443 }
444 }
445
446 contracts_map.into_values().collect()
447}
448
449pub(crate) fn generate_and_write_test_file(
450 foundry: &FoundryIntegration,
451 contract: &SolidityTestContract,
452 test_file_path: &Path,
453 verbose: bool,
454) -> Result<()> {
455 foundry
456 .write_test_contract(contract, test_file_path)
457 .map_err(|e| {
458 CodeGenError::FoundryError(format!(
459 "Failed to write test file {}: {}",
460 test_file_path.display(),
461 e
462 ))
463 })?;
464
465 if verbose {
466 debug!("๐ Generated test file: {}", test_file_path.display());
467 }
468
469 Ok(())
470}
471
472pub fn generate_tests_with_foundry(
473 graph: &CallGraph,
474 ctx: &CallGraphGeneratorContext,
475 verbose: bool,
476 output_dir: &Path,
477 foundry_root: Option<PathBuf>,
478 deployer_only: bool,
479 validate_compilation: bool,
480 original_contract_paths: &HashMap<String, PathBuf>,
481) -> Result<()> {
482 if verbose {
483 debug!("๐ Starting sol2test with enhanced Foundry integration");
484 }
485
486 let foundry_root = foundry_root
487 .clone()
488 .unwrap_or_else(|| output_dir.parent().unwrap_or(Path::new(".")).to_path_buf());
489
490 let foundry = FoundryIntegration::new_with_project_setup(foundry_root)
491 .map_err(|e| CodeGenError::FoundryError(format!("Failed to initialize Foundry: {}", e)))?;
492
493 let contracts = extract_contracts_from_graph(graph, ctx);
494
495 for contract_info in &contracts {
496 let contract_name = &contract_info.name;
497 if let Some(original_path) = original_contract_paths.get(contract_name) {
498 if verbose {
499 debug!(
500 "๐ Copying original contract '{}' from {} to Foundry src...",
501 contract_name,
502 original_path.display()
503 );
504 }
505 foundry
506 .copy_contract_to_src(original_path, contract_name)
507 .map_err(|e| {
508 CodeGenError::FoundryError(format!(
509 "Failed to copy contract {}: {}",
510 contract_name, e
511 ))
512 })?;
513 } else {
514 debug!(
515 "โ ๏ธ Warning: Original source path for contract '{}' not found. Skipping copy.",
516 contract_name
517 );
518 }
519 }
520
521 if verbose {
522 debug!("๐๏ธ Found {} contracts in CFG:", contracts.len());
523 for contract in &contracts {
524 debug!(
525 " - {} (functions: {})",
526 contract.name,
527 contract.functions.len()
528 );
529 }
530 }
531
532 let mut generated_count = 0;
533 let mut validated_count = 0;
534
535 for contract_info in &contracts {
536 if graph
537 .nodes
538 .iter()
539 .any(|n| n.node_type == traverse_graph::cg::NodeType::Interface && n.name == contract_info.name)
540 {
541 if verbose {
542 debug!(" โญ๏ธ Skipping interface: {}", contract_info.name);
543 }
544 continue;
545 }
546
547 let test_dir = foundry.project_root.join("test");
548 fs::create_dir_all(&test_dir).context("Failed to create test directory")?;
549
550 if !deployer_only {
551 match deployer_stub::generate_foundry_deployer_test_contract(contract_info) {
552 Ok(deployer_source_unit) => {
553 let deployer_test_contract = SolidityTestContract::new(
554 format!("{}DeployerTest", contract_info.name),
555 deployer_source_unit,
556 );
557 let deployer_test_filename = format!("{}.t.sol", deployer_test_contract.contract_name);
558 let deployer_test_path = test_dir.join(deployer_test_filename);
559
560 generate_and_write_test_file(
561 &foundry,
562 &deployer_test_contract,
563 &deployer_test_path,
564 verbose,
565 )?;
566 generated_count += 1;
567 }
568 Err(e) => {
569 debug!("Failed to generate deployer test for {}: {}", contract_info.name, e);
570 }
571 }
572 }
573
574 for function_info in &contract_info.functions {
575 if verbose {
576 debug!(
577 "Processing function: {}.{}",
578 contract_info.name, function_info.name
579 );
580 }
581
582 match revert_stub::generate_revert_tests_from_cfg(
584 graph,
585 &contract_info.name,
586 &function_info.name,
587 &function_info.parameters,
588 ) {
589 Ok(revert_test_contracts) => {
590 for (i, test_contract) in revert_test_contracts.iter().enumerate() {
591 let test_filename = format!("{}RevertTest{}.t.sol",
592 format!("{}{}", contract_info.name, function_info.name), i);
593 let test_path = test_dir.join(test_filename);
594
595 generate_and_write_test_file(&foundry, test_contract, &test_path, verbose)?;
596 generated_count += 1;
597 }
598 }
599 Err(e) => {
600 if verbose {
601 debug!(
602 "Error generating revert tests for {}.{}: {}",
603 contract_info.name, function_info.name, e
604 );
605 }
606 }
607 }
608
609 match state_change_stub::generate_state_change_tests_from_cfg(
610 graph,
611 ctx,
612 &contract_info.name,
613 &function_info.name,
614 &function_info.parameters,
615 ) {
616 Ok(state_test_contracts) => {
617 for (i, test_contract) in state_test_contracts.iter().enumerate() {
618 let test_filename = format!("{}StateTest{}.t.sol",
619 format!("{}{}", contract_info.name, function_info.name), i);
620 let test_path = test_dir.join(test_filename);
621
622 generate_and_write_test_file(&foundry, test_contract, &test_path, verbose)?;
623 generated_count += 1;
624 }
625 }
626 Err(e) => {
627 if verbose {
628 debug!(
629 "Error generating state change tests for {}.{}: {}",
630 contract_info.name, function_info.name, e
631 );
632 }
633 }
634 }
635
636 match access_control_stub::generate_access_control_tests_from_cfg(
637 graph,
638 &contract_info.name,
639 &function_info.name,
640 &function_info.parameters,
641 &contract_info.constructor_params,
642 ) {
643 Ok(access_test_contracts) => {
644 for (i, test_contract) in access_test_contracts.iter().enumerate() {
645 let test_filename = format!("{}AccessTest{}.t.sol",
646 format!("{}{}", contract_info.name, function_info.name), i);
647 let test_path = test_dir.join(test_filename);
648
649 generate_and_write_test_file(&foundry, test_contract, &test_path, verbose)?;
650 generated_count += 1;
651 }
652 }
653 Err(e) => {
654 if verbose {
655 debug!(
656 "Error generating access control tests for {}.{}: {}",
657 contract_info.name, function_info.name, e
658 );
659 }
660 }
661 }
662 }
663 }
664
665 if validate_compilation && generated_count > 0 {
666 if verbose {
667 debug!("\nโ๏ธ Attempting to compile the entire project with 'forge build'...");
668 }
669 match foundry.run_project_build() {
670 Ok(build_successful) => {
671 if build_successful {
672 validated_count = generated_count;
673 if verbose {
674 debug!("โ
Project build successful. All {} generated test contracts are valid.", generated_count);
675 }
676
677 if verbose {
679 debug!("\n๐งช Running generated tests with 'forge test'...");
680 }
681 match foundry.run_tests(None) {
682 Ok(tests_passed) => {
683 if tests_passed {
684 if verbose {
685 debug!("โ
All tests passed successfully!");
686 }
687 } else if verbose {
688 debug!("โ Some tests failed. Check 'forge test' output above for details.");
689 }
690 }
691 Err(e) => {
692 if verbose {
693 debug!("โ ๏ธ Error running tests: {}", e);
694 }
695 }
696 }
697 } else if verbose {
698 debug!("โ Project build failed. Some of the {} generated test contracts may have errors. Check 'forge build' output above.", generated_count);
699 }
700 }
701 Err(e) => {
702 if verbose {
703 debug!(
704 "โ ๏ธ Error during final project build: {}. Validation status uncertain.",
705 e
706 );
707 }
708 }
709 }
710 } else if generated_count == 0 && verbose && validate_compilation {
711 debug!("\n๐คท No test contracts were generated, skipping compilation validation.");
712 } else if !validate_compilation && verbose {
713 debug!("\nโน๏ธ Compilation validation was skipped via configuration.");
714 }
715
716 if verbose {
717 debug!("\n๐ Generation Summary:");
718 debug!(" - Generated: {} test contracts", generated_count);
719 if validate_compilation {
720 debug!(" - Validated: {} test contracts", validated_count);
721 debug!(
722 " - Validation rate: {:.1}%",
723 if generated_count > 0 {
724 (validated_count as f64 / generated_count as f64) * 100.0
725 } else {
726 0.0
727 }
728 );
729 }
730 }
731
732 Ok(())
733}
734
735#[cfg(test)]
736mod tests {
737 use super::*;
738
739 #[test]
740 fn test_enhanced_test_contract_builder() {
741 let contract = SolidityTestContractBuilder::new("TestContract".to_string())
742 .add_import("../src/MyContract.sol".to_string())
743 .build_with_contract(|contract| {
744 contract
745 .state_variable(uint256(), "testVar", Some(Visibility::Private), Some(number("42")))
746 .function("testSetValue", |func| {
747 func.parameter(uint256(), "_value")
748 .visibility(Visibility::Public)
749 .body(|body| {
750 body.expression(Expression::Assignment(AssignmentExpression {
751 left: Box::new(identifier("testVar")),
752 operator: AssignmentOperator::Assign,
753 right: Box::new(identifier("_value")),
754 }));
755 });
756 });
757 });
758
759 let solidity_code = contract.to_solidity_code();
760 assert!(solidity_code.contains("pragma solidity ^0.8.0;"));
761 assert!(solidity_code.contains("import \"forge-std/Test.sol\";"));
762 assert!(solidity_code.contains("import \"../src/MyContract.sol\";"));
763 assert!(solidity_code.contains("contract TestContract is Test"));
764 assert!(solidity_code.contains("uint256 private testVar = 42;"));
765 assert!(solidity_code.contains("function testSetValue(uint256 _value) public"));
766 }
767
768 #[test]
769 fn test_expression_helpers() {
770 use expression_helpers::*;
771
772 let require_stmt = require_statement(
773 binary(identifier("balance"), BinaryOperator::GreaterThanOrEqual, identifier("amount")),
774 "Insufficient balance"
775 );
776
777 if let Statement::Expression(ExpressionStatement { expression }) = require_stmt {
779 if let Expression::FunctionCall(call) = expression {
780 assert_eq!(call.arguments.len(), 2);
781 } else {
782 panic!("Expected function call expression");
783 }
784 } else {
785 panic!("Expected expression statement");
786 }
787 }
788
789 #[test]
790 fn test_type_safety_improvements() {
791 let type_name = uint256();
793 assert!(matches!(type_name, TypeName::Elementary(ElementaryTypeName::UnsignedInteger(Some(256)))));
794
795 let visibility = Visibility::Public;
796 assert_eq!(visibility.to_string(), "public");
797
798 let operator = BinaryOperator::GreaterThanOrEqual;
799 assert_eq!(operator.to_string(), ">=");
800 }
801}