solify_generator/
lib.rs

1use anyhow::{ Context, Result };
2use std::collections::HashMap;
3use std::fs::{ create_dir_all, File };
4use std::io::Write;
5use std::path::Path;
6
7use solify_common::{
8    IdlData,
9    SeedComponent,
10    SeedType,
11    SetupType,
12    TestMetadata,
13    TestValueType,
14};
15use solify_common::errors::SolifyError;
16use tera::{ Tera, Context as TeraContext };
17use serde::{Serialize, Deserialize};
18
19#[derive(Serialize, Deserialize)]
20struct AccountInfo {
21    original_name: String,
22    camel_name: String,
23}
24
25#[derive(Serialize, Deserialize)]
26struct InstructionTestCaseWrapper {
27    instruction_name: String,
28    instruction_name_camel: String,
29    arguments: Vec<solify_common::ArgumentInfo>,
30    positive_cases: Vec<solify_common::TestCase>,
31    negative_cases: Vec<solify_common::TestCase>,
32}
33
34pub fn generate_with_tera(
35    meta: &TestMetadata,
36    idl: &IdlData,
37    out_dir: impl AsRef<Path>
38) -> Result<()> {
39    let out_dir = out_dir.as_ref();
40    create_dir_all(out_dir).with_context(|| format!("creating output dir {:?}", out_dir))?;
41
42    let mut tera = Tera::default();
43    tera
44        .add_raw_template("aggregated_tests.tera", AGGREGATED_TEMPLATE)
45        .context("add aggregated template")?;
46
47    let mut ctx = TeraContext::new();
48
49    let program_name = &idl.name;
50    let program_name_pascal = cut_program_name(program_name);
51    let program_capitalized = capitalize_first_letter(&program_name_pascal);
52    let program_name_camel = camel_case(program_name);
53    let program_name_pascal_case = to_pascal_case(program_name);
54    ctx.insert("program_name", program_name);
55    ctx.insert("program_name_pascal", &program_name_pascal);
56    ctx.insert("program_capitalized", &program_capitalized);
57    ctx.insert("program_name_camel", &program_name_camel);
58    ctx.insert("program_name_pascal_case", &program_name_pascal_case);
59    
60
61    // setup requirements
62    let setup_requirements = meta.setup_requirements.clone();
63    let mut map = HashMap::new();
64    let mut index = 0;
65
66    for setup_requirement in setup_requirements.iter().cloned() {
67        index += 1;
68
69        match setup_requirement.requirement_type {
70            SetupType::CreateKeypair => {
71                map.insert(index, "Keypair.generate()");
72            }
73            SetupType::FundAccount => {
74                map.insert(index, "FundAccount");
75            }
76            SetupType::InitializePda => {
77                map.insert(index, "PublicKey");
78            }
79            _ => {
80                return Err(SolifyError::InvalidSetupRequirement.into());
81            }
82        }
83    }
84    ctx.insert("setup_requirements", &map);
85
86    let mut pda_indices = Vec::new();
87    let mut index_1 = 0;
88
89    for r in setup_requirements.iter().cloned() {
90        index_1 += 1;
91
92        if r.requirement_type == SetupType::InitializePda {
93            pda_indices.push(index_1);
94        }
95    }
96
97    // pda initialization
98    let mut pda_map = HashMap::new();
99    let pda_init_sequence = meta.pda_init_sequence.clone();
100
101    for (i, pda_init) in pda_init_sequence.iter().enumerate() {
102        if let Some(index) = pda_indices.get(i) {
103        let seeds_expr = render_pda_seeds_expression(&pda_init.seeds);
104            pda_map.insert(*index, seeds_expr);
105        }
106    }
107
108    ctx.insert("pda_seeds", &pda_map);
109
110    let mut account_vars: HashMap<String, String> = HashMap::new();
111
112    for ad in meta.account_dependencies.iter() {
113        if ad.is_pda {
114            if
115                let Some((pos, _)) = meta.pda_init_sequence
116                    .iter()
117                    .enumerate()
118                    .find(|(_, p)| p.account_name == ad.account_name)
119            {
120                let setup_index = pda_indices[pos];
121                account_vars.insert(ad.account_name.clone(), format!("pda{}", setup_index));
122            } else {
123                account_vars.insert(
124                    ad.account_name.clone(),
125                    format!("/* missing pda for {} */ null", ad.account_name)
126                );
127            }
128        } else if ad.account_name == "authority" {
129            account_vars.insert(ad.account_name.clone(), "authorityPubkey".to_string());
130        } else if ad.account_name == "system_program" {
131            account_vars.insert(ad.account_name.clone(), "SystemProgram.programId".to_string());
132        } else {
133            account_vars.insert(ad.account_name.clone(), format!("{}", ad.account_name));
134        }
135    }
136
137    for instruction in &idl.instructions {
138        for acc in &instruction.accounts {
139            if !account_vars.contains_key(&acc.name) {
140                if acc.name == "system_program" || acc.name == "systemProgram" {
141                    account_vars.insert(acc.name.clone(), "SystemProgram.programId".to_string());
142                } else if acc.name == "authority" {
143                    account_vars.insert(acc.name.clone(), "authorityPubkey".to_string());
144                } else {
145                    if let Some((pos, _)) = meta.pda_init_sequence
146                        .iter()
147                        .enumerate()
148                        .find(|(_, p)| p.account_name == acc.name)
149                    {
150                        if let Some(setup_index) = pda_indices.get(pos) {
151                            account_vars.insert(acc.name.clone(), format!("pda{}", setup_index));
152                        }
153                    }
154                }
155            }
156        }
157    }
158
159    ctx.insert("account_vars", &account_vars);
160    let mut instruction_accounts: HashMap<String, Vec<AccountInfo>> = HashMap::new();
161    for instruction in &idl.instructions {
162        let account_infos: Vec<AccountInfo> = instruction.accounts.iter()
163            .map(|acc| {
164                AccountInfo {
165                    original_name: acc.name.clone(),
166                    camel_name: to_camel_case(&acc.name),
167                }
168            })
169            .collect();
170        instruction_accounts.insert(instruction.name.clone(), account_infos);
171    }
172    ctx.insert("instruction_accounts", &instruction_accounts);
173
174    let processed_test_cases: Vec<InstructionTestCaseWrapper> = meta.test_cases.iter()
175        .map(|test_case| {
176            let mut positive_cases = test_case.positive_cases.clone();
177            for arg_value in &mut positive_cases {
178                for arg in &mut arg_value.argument_values {
179                    arg.value_type = convert_to_typescript_value(arg.value_type.clone());
180                }
181            }
182            
183            let mut negative_cases = test_case.negative_cases.clone();
184            for arg_value in &mut negative_cases {
185                for arg in &mut arg_value.argument_values {
186                    arg.value_type = convert_to_typescript_value(arg.value_type.clone());
187                }
188            }
189            
190            InstructionTestCaseWrapper {
191                instruction_name: test_case.instruction_name.clone(),
192                instruction_name_camel: to_camel_case(&test_case.instruction_name),
193                arguments: test_case.arguments.clone(),
194                positive_cases,
195                negative_cases,
196            }
197        })
198        .collect();
199    ctx.insert("instruction_tests", &processed_test_cases);
200
201    let rendered = tera.render("aggregated_tests.tera", &ctx).context("render tera")?;
202
203    let out_path = out_dir.join(format!("{}.ts", program_name_pascal));
204    let mut f = File::create(&out_path).with_context(|| format!("create file {:?}", out_path))?;
205    f.write_all(rendered.as_bytes()).with_context(|| format!("write file {:?}", out_path))?;
206
207    println!("Wrote {}", out_path.display());
208    Ok(())
209}
210
211const AGGREGATED_TEMPLATE: &str =
212    r#"
213import * as anchor from "@coral-xyz/anchor";
214import { Program } from "@coral-xyz/anchor";
215import { {{ program_name_pascal_case }} } from "../target/types/{{ program_name }}";
216import { assert } from "chai";
217import { Keypair, SystemProgram, PublicKey, LAMPORTS_PER_SOL } from "@solana/web3.js";
218
219// This file is generated by solify. You can edit it manually
220
221describe("{{ program_name | default(value='program') }}", () => {
222    // Configure the client
223    let provider = anchor.AnchorProvider.env();
224    anchor.setProvider(provider);
225    const connection = provider.connection;
226
227    const program = anchor.workspace.{{ program_name }} as Program<{{ program_name_pascal_case }}>;
228
229    // Setup Requirements
230    // keypair decelarations
231    {%- set keypair_found = false %}
232    {%- for id, code in setup_requirements %}
233    {%- if code == "Keypair.generate()" %}
234    {%- if not keypair_found %}
235    {%- set keypair_found = true %}
236    const authority = Keypair.generate();
237    const authorityPubkey = authority.publicKey;
238    {%- else %}
239    const user{{ id }} = Keypair.generate();
240    const user{{ id }}Pubkey = user{{ id }}.publicKey;
241    {%- endif %}
242    {%- endif %}
243    {%- endfor %}
244
245    // PDA Decelaration
246    {%- for id, code in setup_requirements %}
247    {%- if code == "PublicKey" %}
248    let pda{{ id }}: PublicKey;
249    let bump{{ id }}: number;
250    {%- endif %}
251    {%- endfor %}
252
253    before(async () => {
254        // ----- Airdrop for each user Keypair -----
255        {%- set keypair_found_airdrop = false %}
256        {%- for id, code in setup_requirements %}
257        {%- if code == "Keypair.generate()" %}
258        {%- if not keypair_found_airdrop %}
259        {%- set keypair_found_airdrop = true %}
260        const sig{{ id }} = await connection.requestAirdrop(authorityPubkey, 10 * LAMPORTS_PER_SOL);
261        await connection.confirmTransaction(sig{{ id }}, "confirmed");
262        {%- else %}
263        const sig{{ id }} = await connection.requestAirdrop(user{{ id }}Pubkey, 10 * LAMPORTS_PER_SOL);
264        await connection.confirmTransaction(sig{{ id }}, "confirmed");
265        {%- endif %}
266        {%- endif %}
267        {%- endfor %}
268
269        // ----- PDA Initialization -----
270        {%- for id, seeds in pda_seeds %}
271        [pda{{ id }}, bump{{ id }}] = PublicKey.findProgramAddressSync(
272            {{ seeds }},
273            program.programId
274        );
275        {%- endfor %}
276
277    });
278
279    {%- macro render_accounts(account_list) -%}
280    {%- for acc in account_list %}
281    {%- set var = account_vars[acc] | default(value='/* missing */ null') %}
282    {{ acc }}: {{ var }}{%- if not loop.last %},{%- endif %}
283    {%- endfor %}
284    {%- endmacro %}
285
286    {# ---------------- INSTRUCTION DESCRIBE BLOCKS ---------------- #}
287
288    {%- for instr in instruction_tests %}
289
290
291    {# ---------- POSITIVE TESTS ---------- #}
292    {%- for test in instr.positive_cases %}
293    it("{{ test.description }}", async () => {
294        // Prepare arguments
295        {%- for arg in test.argument_values %}
296        {%- if arg.value_type.variant == "Valid" %}
297        const {{ arg.argument_name }}Value = {{ arg.value_type.description }};
298        {%- elif arg.value_type.variant == "Invalid" %}
299        const {{ arg.argument_name }}Value = {{ arg.value_type.description }};
300        {%- else %}
301        const {{ arg.argument_name }}Value = null;
302        {%- endif %}
303        {%- endfor %}
304        // Execute instruction
305        try {
306            await program.methods
307                .{{ instr.instruction_name_camel }}(
308                    {%- for arg in test.argument_values %}
309                    {{ arg.argument_name }}Value{%- if not loop.last %},{%- endif %}
310                    {%- endfor %}
311                )
312                .accountsStrict({
313                    {%- if instruction_accounts[instr.instruction_name] %}
314                    {%- for acc_info in instruction_accounts[instr.instruction_name] %}
315                    {%- set js_var = account_vars[acc_info.original_name] | default(value="null") %}
316                    {{ acc_info.camel_name }}: {{ js_var }}{%- if not loop.last %},{%- endif %}
317                    {%- endfor %}
318                    {%- endif %}
319                })
320                .signers([
321                    authority
322                ])
323                .rpc();
324            // Expect success
325            assert.ok(true);
326        } catch (err) {
327            assert.fail("Instruction should not have failed: " + err);
328        }
329    });
330    {%- endfor %}
331    {# ---------- NEGATIVE TESTS ---------- #}
332    {%- for test in instr.negative_cases %}
333    it("{{ test.description }}", async () => {
334        // Prepare arguments
335        {%- for arg in test.argument_values %}
336        {%- if arg.value_type.variant == "Valid" %}
337        const {{ arg.argument_name }}Value = {{ arg.value_type.description }};
338        {%- elif arg.value_type.variant == "Invalid" %}
339        const {{ arg.argument_name }}Value = {{ arg.value_type.description }};
340        {%- else %}
341        const {{ arg.argument_name }}Value = null;
342        {%- endif %}
343        {%- endfor %}
344        // Execute instruction expecting failure
345        try {
346            await program.methods
347                .{{ instr.instruction_name_camel }}(
348                    {%- for arg in test.argument_values %}
349                    {{ arg.argument_name }}Value{%- if not loop.last %},{%- endif %}
350                    {%- endfor %}
351                )
352                .accountsStrict({
353                    {%- if instruction_accounts[instr.instruction_name] %}
354                    {%- for acc_info in instruction_accounts[instr.instruction_name] %}
355                    {%- set js_var = account_vars[acc_info.original_name] | default(value="null") %}
356                    {{ acc_info.camel_name }}: {{ js_var }}{%- if not loop.last %},{%- endif %}
357                    {%- endfor %}
358                    {%- endif %}
359                })
360                .signers([
361                    authority
362                ])
363                .rpc();
364        } catch (err) {
365            {%- if test.expected_outcome.variant == "Failure" %}
366            assert(err.message.includes("{{ test.expected_outcome.error_message }}"));
367            {%- endif %}
368        }
369    });
370    {%- endfor %}
371
372    {%- endfor %}
373
374})
375
376"#;
377
378// ------------------- Helper functions (rendering helpers) -------------------
379
380fn render_pda_seeds_expression(seeds: &[SeedComponent]) -> String {
381    let parts: Vec<String> = seeds
382        .iter()
383        .map(|seed| {
384            match seed.seed_type {
385                SeedType::Static => { format!("Buffer.from(\"{}\")", seed.value) }
386                SeedType::AccountKey => { format!("{}Pubkey.toBuffer()", seed.value) }
387                SeedType::Argument => { format!("Buffer.from(String({}))", seed.value) }
388            }
389        })
390        .collect();
391
392    format!("[{}]", parts.join(", "))
393}
394
395fn cut_program_name(s: &str) -> String {
396    s.split('_').next().unwrap_or(s).to_string()
397}
398
399fn capitalize_first_letter(s: &str) -> String {
400    s.chars().next().unwrap_or('A').to_uppercase().to_string() + &s[1..]
401}
402
403fn camel_case(s: &str) -> String {
404    let parts: Vec<&str> = s.split('_').collect();
405    if parts.is_empty() {
406        return String::new();
407    }
408    let first = parts[0].to_lowercase();
409    let rest: String = parts[1..].iter()
410        .map(|word| {
411            if word.is_empty() {
412                String::new()
413    } else {
414                word.chars().next().unwrap().to_uppercase().to_string() + &word[1..].to_lowercase()
415            }
416        })
417        .collect();
418    first + &rest
419}
420
421fn to_camel_case(s: &str) -> String {
422    let parts: Vec<&str> = s.split('_').collect();
423    if parts.is_empty() {
424        return String::new();
425    }
426    let first = parts[0].to_lowercase();
427    let rest: String = parts[1..].iter()
428        .map(|word| {
429            if word.is_empty() {
430                String::new()
431            } else {
432                let mut chars = word.chars();
433                if let Some(first_char) = chars.next() {
434                    first_char.to_uppercase().collect::<String>() + &chars.as_str().to_lowercase()
435                } else {
436                    String::new()
437                }
438            }
439        })
440        .collect();
441    first + &rest
442}
443
444fn to_pascal_case(s: &str) -> String {
445    let parts: Vec<&str> = s.split('_').collect();
446    if parts.is_empty() {
447        return String::new();
448    }
449    parts.iter()
450        .map(|word| {
451            if word.is_empty() {
452                String::new()
453            } else {
454                let mut chars = word.chars();
455                if let Some(first_char) = chars.next() {
456                    first_char.to_uppercase().collect::<String>() + &chars.as_str().to_lowercase()
457                } else {
458                    String::new()
459                }
460            }
461        })
462        .collect()
463}
464
465fn convert_to_typescript_value(value_type: TestValueType) -> TestValueType {
466    match value_type {
467        TestValueType::Valid { description } => {
468            TestValueType::Valid {
469                description: convert_rust_to_typescript(&description),
470            }
471        }
472        TestValueType::Invalid { description, reason } => {
473            TestValueType::Invalid {
474                description: convert_rust_to_typescript(&description),
475                reason,
476            }
477        }
478    }
479}
480
481fn convert_rust_to_typescript(value: &str) -> String {
482    let trimmed = value.trim();
483    
484    match trimmed {
485        "u64::MAX" => "new anchor.BN(\"18446744073709551615\")".to_string(),
486        "u64::MIN" => "new anchor.BN(\"0\")".to_string(),
487        "u32::MAX" => "new anchor.BN(\"4294967295\")".to_string(),
488        "u32::MIN" => "new anchor.BN(\"0\")".to_string(),
489        "u16::MAX" => "new anchor.BN(\"65535\")".to_string(),
490        "u16::MIN" => "new anchor.BN(\"0\")".to_string(),
491        "u8::MAX" => "new anchor.BN(\"255\")".to_string(),
492        "u8::MIN" => "new anchor.BN(\"0\")".to_string(),
493        "i64::MAX" => "new anchor.BN(\"9223372036854775807\")".to_string(),
494        "i64::MIN" => "new anchor.BN(\"-9223372036854775808\")".to_string(),
495        "i32::MAX" => "new anchor.BN(\"2147483647\")".to_string(),
496        "i32::MIN" => "new anchor.BN(\"-2147483648\")".to_string(),
497        "i16::MAX" => "new anchor.BN(\"32767\")".to_string(),
498        "i16::MIN" => "new anchor.BN(\"-32768\")".to_string(),
499        "i8::MAX" => "new anchor.BN(\"127\")".to_string(),
500        "i8::MIN" => "new anchor.BN(\"-128\")".to_string(),
501        _ => {
502            if let Ok(_) = trimmed.parse::<i128>() {
503                format!("new anchor.BN(\"{}\")", trimmed)
504            } else if let Ok(_) = trimmed.parse::<f64>() {
505                if trimmed.contains('.') {
506                    trimmed.to_string()
507                } else {
508                    format!("new anchor.BN(\"{}\")", trimmed)
509                }
510            } else if trimmed.starts_with('"') && trimmed.ends_with('"') {
511                trimmed.to_string()
512            } else if trimmed == "true" || trimmed == "false" {
513                trimmed.to_string()
514            } else if trimmed.starts_with("new ") || trimmed.starts_with("authority.") || trimmed.contains("Pubkey") {
515                trimmed.to_string()
516            } else {
517                if trimmed.starts_with('"') {
518                    trimmed.to_string()
519                } else {
520                    format!("\"{}\"", trimmed)
521                }
522            }
523        }
524    }
525}