solscript_codegen/
test_gen.rs

1//! Test File Generator
2//!
3//! Generates Anchor test files for the program.
4
5use crate::ir::*;
6use crate::CodegenError;
7
8/// Test file generator
9pub struct TestGenerator {
10    program_name: String,
11}
12
13impl Default for TestGenerator {
14    fn default() -> Self {
15        Self::new()
16    }
17}
18
19impl TestGenerator {
20    pub fn new() -> Self {
21        Self {
22            program_name: String::new(),
23        }
24    }
25
26    /// Generate the test file
27    pub fn generate(&mut self, ir: &SolanaProgram) -> Result<String, CodegenError> {
28        self.program_name = to_snake_case(&ir.name);
29        let class_name = to_camel_case(&ir.name);
30
31        let mut output = String::new();
32
33        // Header and imports
34        output.push_str(&self.generate_header(&class_name));
35        output.push('\n');
36
37        // Test setup
38        output.push_str(&self.generate_setup(&class_name));
39        output.push('\n');
40
41        // Generate test cases for each instruction
42        for instruction in &ir.instructions {
43            output.push_str(&self.generate_test_case(instruction, &class_name)?);
44            output.push('\n');
45        }
46
47        // Close the describe block
48        output.push_str("});\n");
49
50        Ok(output)
51    }
52
53    fn generate_header(&self, class_name: &str) -> String {
54        format!(
55            r#"/**
56 * Generated by SolScript compiler
57 * Anchor tests for {} program
58 */
59
60import * as anchor from "@coral-xyz/anchor";
61import {{ Program }} from "@coral-xyz/anchor";
62import {{ Keypair, PublicKey, SystemProgram }} from "@solana/web3.js";
63import {{ assert }} from "chai";
64
65// Import the generated types
66// import {{ {} }} from "../target/types/{}";
67
68describe("{}", () => {{
69  // Configure the client to use the local cluster
70  const provider = anchor.AnchorProvider.env();
71  anchor.setProvider(provider);
72
73  // Get the program
74  // const program = anchor.workspace.{} as Program<{}>;
75
76  // Test accounts
77  let stateAccount: Keypair;
78  let authority: Keypair;
79"#,
80            class_name, class_name, self.program_name, class_name, class_name, class_name
81        )
82    }
83
84    fn generate_setup(&self, _class_name: &str) -> String {
85        r#"
86  before(async () => {
87    stateAccount = Keypair.generate();
88    authority = provider.wallet.payer;
89  });
90
91  // Helper to get state PDA
92  const getStatePDA = (): [PublicKey, number] => {
93    return PublicKey.findProgramAddressSync(
94      [Buffer.from("state")],
95      // program.programId
96      SystemProgram.programId // placeholder
97    );
98  };
99"#
100        .to_string()
101    }
102
103    fn generate_test_case(
104        &self,
105        instruction: &Instruction,
106        _class_name: &str,
107    ) -> Result<String, CodegenError> {
108        let test_name = to_snake_case(&instruction.name).replace('_', " ");
109        let method_name = to_camel_case_lower(&instruction.name);
110
111        // Generate mock argument values based on parameter types
112        let mut args = Vec::new();
113        for param in &instruction.params {
114            let mock_value = self.generate_mock_value(&param.ty, &param.name);
115            args.push(format!(
116                "const {} = {};",
117                to_camel_case_lower(&param.name),
118                mock_value
119            ));
120        }
121        let args_setup = args.join("\n    ");
122
123        // Generate the instruction call arguments
124        let call_args: Vec<String> = instruction
125            .params
126            .iter()
127            .map(|p| to_camel_case_lower(&p.name))
128            .collect();
129        let call_args_str = call_args.join(", ");
130
131        // Determine if this is the initialize/constructor
132        let is_init = instruction.name.to_lowercase() == "initialize"
133            || instruction.name.to_lowercase() == "init"
134            || instruction.name.to_lowercase() == "constructor";
135
136        let accounts_str = if is_init {
137            r#"//       state: stateAccount.publicKey,
138    //       signer: authority.publicKey,
139    //       systemProgram: SystemProgram.programId,"#
140        } else {
141            r#"//       state: stateAccount.publicKey,
142    //       signer: authority.publicKey,"#
143        };
144
145        let signers_str = if is_init {
146            "[stateAccount, authority]"
147        } else {
148            "[authority]"
149        };
150
151        Ok(format!(
152            r#"
153  it("should {}", async () => {{
154    {}
155
156    // Uncomment when program is available:
157    // const tx = await program.methods
158    //   .{}({})
159    //   .accounts({{
160    {}
161    //   }})
162    //   .signers({})
163    //   .rpc();
164    // console.log("Transaction signature:", tx);
165
166    // Add assertions based on expected behavior
167    // const state = await program.account.{}State.fetch(stateAccount.publicKey);
168    // assert.ok(state, "State should exist");
169  }});
170"#,
171            test_name,
172            args_setup,
173            method_name,
174            call_args_str,
175            accounts_str,
176            signers_str,
177            to_camel_case_lower(&self.program_name)
178        ))
179    }
180
181    fn generate_mock_value(&self, ty: &SolanaType, name: &str) -> String {
182        match ty {
183            SolanaType::U8 => "42".to_string(),
184            SolanaType::U16 => "1000".to_string(),
185            SolanaType::U32 => "100000".to_string(),
186            SolanaType::U64 | SolanaType::U128 => "new anchor.BN(1000000)".to_string(),
187            SolanaType::I8 | SolanaType::I16 | SolanaType::I32 => "42".to_string(),
188            SolanaType::I64 | SolanaType::I128 => "new anchor.BN(1000000)".to_string(),
189            SolanaType::Bool => "true".to_string(),
190            SolanaType::String => format!("\"test_{}\"", name),
191            SolanaType::Pubkey | SolanaType::Signer => "Keypair.generate().publicKey".to_string(),
192            SolanaType::Bytes => "Buffer.from(\"test\")".to_string(),
193            SolanaType::FixedBytes(n) => format!("new Uint8Array({})", n),
194            SolanaType::Array(inner, size) => {
195                let inner_val = self.generate_mock_value(inner, name);
196                format!("Array({}).fill({})", size, inner_val)
197            }
198            SolanaType::Vec(inner) => {
199                let inner_val = self.generate_mock_value(inner, name);
200                format!("[{}]", inner_val)
201            }
202            SolanaType::Option(_) => "null".to_string(),
203            SolanaType::Mapping(_, _) => "new Map()".to_string(),
204            SolanaType::Custom(type_name) => {
205                // Generate a placeholder object for custom types
206                // In real usage, this should be filled with actual field values
207                format!("{{ /* {} instance - fill in fields */ }}", type_name)
208            }
209        }
210    }
211}
212
213// Helper functions for case conversion
214fn to_camel_case(s: &str) -> String {
215    let mut result = String::new();
216    let mut capitalize_next = true;
217
218    for c in s.chars() {
219        if c == '_' || c == '-' {
220            capitalize_next = true;
221        } else if capitalize_next {
222            result.push(c.to_ascii_uppercase());
223            capitalize_next = false;
224        } else {
225            result.push(c);
226        }
227    }
228
229    result
230}
231
232fn to_camel_case_lower(s: &str) -> String {
233    let camel = to_camel_case(s);
234    let mut chars = camel.chars();
235    match chars.next() {
236        None => String::new(),
237        Some(c) => c.to_ascii_lowercase().to_string() + chars.as_str(),
238    }
239}
240
241fn to_snake_case(s: &str) -> String {
242    let mut result = String::new();
243    for (i, c) in s.chars().enumerate() {
244        if c.is_uppercase() && i > 0 {
245            result.push('_');
246            result.push(c.to_ascii_lowercase());
247        } else {
248            result.push(c.to_ascii_lowercase());
249        }
250    }
251    result
252}