solify_generator/
lib.rs

1use anyhow::{ Context, Result };
2use std::collections::HashMap;
3use std::fs::{ create_dir_all, File };
4use std::io::Write;
5use std::path::Path;
6
7use solify_common::{
8    IdlData,
9    SeedComponent,
10    SeedType,
11    SetupType,
12    TestMetadata,
13    TestValueType,
14};
15use solify_common::errors::SolifyError;
16use tera::{ Tera, Context as TeraContext };
17use serde::{Serialize, Deserialize};
18
19#[derive(Serialize, Deserialize)]
20struct AccountInfo {
21    original_name: String,
22    camel_name: String,
23}
24
25pub fn generate_with_tera(
26    meta: &TestMetadata,
27    idl: &IdlData,
28    out_dir: impl AsRef<Path>
29) -> Result<()> {
30    let out_dir = out_dir.as_ref();
31    create_dir_all(out_dir).with_context(|| format!("creating output dir {:?}", out_dir))?;
32
33    let mut tera = Tera::default();
34    tera
35        .add_raw_template("aggregated_tests.tera", AGGREGATED_TEMPLATE)
36        .context("add aggregated template")?;
37
38    let mut ctx = TeraContext::new();
39
40    let program_name = &idl.name;
41    let program_name_pascal = cut_program_name(program_name);
42    let program_capitalized = capitalize_first_letter(&program_name_pascal);
43    let program_name_camel = camel_case(program_name);
44    let program_name_pascal_case = to_pascal_case(program_name);
45    ctx.insert("program_name", program_name);
46    ctx.insert("program_name_pascal", &program_name_pascal);
47    ctx.insert("program_capitalized", &program_capitalized);
48    ctx.insert("program_name_camel", &program_name_camel);
49    ctx.insert("program_name_pascal_case", &program_name_pascal_case);
50    
51
52    // setup requirements
53    let setup_requirements = meta.setup_requirements.clone();
54    let mut map = HashMap::new();
55    let mut index = 0;
56
57    for setup_requirement in setup_requirements.iter().cloned() {
58        index += 1;
59
60        match setup_requirement.requirement_type {
61            SetupType::CreateKeypair => {
62                map.insert(index, "Keypair.generate()");
63            }
64            SetupType::FundAccount => {
65                map.insert(index, "FundAccount");
66            }
67            SetupType::InitializePda => {
68                map.insert(index, "PublicKey");
69            }
70            _ => {
71                return Err(SolifyError::InvalidSetupRequirement.into());
72            }
73        }
74    }
75    ctx.insert("setup_requirements", &map);
76
77    let mut pda_indices = Vec::new();
78    let mut index_1 = 0;
79
80    for r in setup_requirements.iter().cloned() {
81        index_1 += 1;
82
83        if r.requirement_type == SetupType::InitializePda {
84            pda_indices.push(index_1);
85        }
86    }
87
88    // pda initialization
89    let mut pda_map = HashMap::new();
90    let pda_init_sequence = meta.pda_init_sequence.clone();
91
92    for (i, pda_init) in pda_init_sequence.iter().enumerate() {
93        if let Some(index) = pda_indices.get(i) {
94        let seeds_expr = render_pda_seeds_expression(&pda_init.seeds);
95            pda_map.insert(*index, seeds_expr);
96        }
97    }
98
99    ctx.insert("pda_seeds", &pda_map);
100
101    let mut account_vars: HashMap<String, String> = HashMap::new();
102
103    for ad in meta.account_dependencies.iter() {
104        if ad.is_pda {
105            if
106                let Some((pos, _)) = meta.pda_init_sequence
107                    .iter()
108                    .enumerate()
109                    .find(|(_, p)| p.account_name == ad.account_name)
110            {
111                let setup_index = pda_indices[pos];
112                account_vars.insert(ad.account_name.clone(), format!("pda{}", setup_index));
113            } else {
114                account_vars.insert(
115                    ad.account_name.clone(),
116                    format!("/* missing pda for {} */ null", ad.account_name)
117                );
118            }
119        } else if ad.account_name == "authority" {
120            account_vars.insert(ad.account_name.clone(), "authorityPubkey".to_string());
121        } else if ad.account_name == "system_program" {
122            account_vars.insert(ad.account_name.clone(), "SystemProgram.programId".to_string());
123        } else {
124            account_vars.insert(ad.account_name.clone(), format!("{}", ad.account_name));
125        }
126    }
127
128    for instruction in &idl.instructions {
129        for acc in &instruction.accounts {
130            if !account_vars.contains_key(&acc.name) {
131                if acc.name == "system_program" || acc.name == "systemProgram" {
132                    account_vars.insert(acc.name.clone(), "SystemProgram.programId".to_string());
133                } else if acc.name == "authority" {
134                    account_vars.insert(acc.name.clone(), "authorityPubkey".to_string());
135                } else {
136                    if let Some((pos, _)) = meta.pda_init_sequence
137                        .iter()
138                        .enumerate()
139                        .find(|(_, p)| p.account_name == acc.name)
140                    {
141                        if let Some(setup_index) = pda_indices.get(pos) {
142                            account_vars.insert(acc.name.clone(), format!("pda{}", setup_index));
143                        }
144                    }
145                }
146            }
147        }
148    }
149
150    ctx.insert("account_vars", &account_vars);
151    let mut instruction_accounts: HashMap<String, Vec<AccountInfo>> = HashMap::new();
152    for instruction in &idl.instructions {
153        let account_infos: Vec<AccountInfo> = instruction.accounts.iter()
154            .map(|acc| {
155                AccountInfo {
156                    original_name: acc.name.clone(),
157                    camel_name: to_camel_case(&acc.name),
158                }
159            })
160            .collect();
161        instruction_accounts.insert(instruction.name.clone(), account_infos);
162    }
163    ctx.insert("instruction_accounts", &instruction_accounts);
164
165    let mut processed_test_cases = meta.test_cases.clone();
166    for test_case in &mut processed_test_cases {
167        for arg_value in &mut test_case.positive_cases {
168            for arg in &mut arg_value.argument_values {
169                arg.value_type = convert_to_typescript_value(arg.value_type.clone());
170            }
171        }
172        for arg_value in &mut test_case.negative_cases {
173            for arg in &mut arg_value.argument_values {
174                arg.value_type = convert_to_typescript_value(arg.value_type.clone());
175            }
176        }
177    }
178    ctx.insert("instruction_tests", &processed_test_cases);
179
180    let rendered = tera.render("aggregated_tests.tera", &ctx).context("render tera")?;
181
182    let out_path = out_dir.join(format!("{}.ts", program_name_pascal));
183    let mut f = File::create(&out_path).with_context(|| format!("create file {:?}", out_path))?;
184    f.write_all(rendered.as_bytes()).with_context(|| format!("write file {:?}", out_path))?;
185
186    println!("Wrote {}", out_path.display());
187    Ok(())
188}
189
190const AGGREGATED_TEMPLATE: &str =
191    r#"
192import * as anchor from "@coral-xyz/anchor";
193import { Program } from "@coral-xyz/anchor";
194import { {{ program_name_pascal_case }} } from "../target/types/{{ program_name }}";
195import { assert } from "chai";
196import { Keypair, SystemProgram, PublicKey, LAMPORTS_PER_SOL } from "@solana/web3.js";
197
198describe("{{ program_name | default(value='program') }}", () => {
199    // Configure the client
200    let provider = anchor.AnchorProvider.env();
201    anchor.setProvider(provider);
202    const connection = provider.connection;
203
204    const program = anchor.workspace.{{ program_name }} as Program<{{ program_name_pascal_case }}>;
205
206    // Setup Requirements
207    // keypair decelarations
208    {%- set keypair_found = false %}
209    {%- for id, code in setup_requirements %}
210    {%- if code == "Keypair.generate()" %}
211    {%- if not keypair_found %}
212    {%- set keypair_found = true %}
213    const authority = Keypair.generate();
214    const authorityPubkey = authority.publicKey;
215    {%- else %}
216    const user{{ id }} = Keypair.generate();
217    const user{{ id }}Pubkey = user{{ id }}.publicKey;
218    {%- endif %}
219    {%- endif %}
220    {%- endfor %}
221
222    // PDA Decelaration
223    {%- for id, code in setup_requirements %}
224    {%- if code == "PublicKey" %}
225    let pda{{ id }}: PublicKey;
226    let bump{{ id }}: number;
227    {%- endif %}
228    {%- endfor %}
229
230    before(async () => {
231        // ----- Airdrop for each user Keypair -----
232        {%- set keypair_found_airdrop = false %}
233        {%- for id, code in setup_requirements %}
234        {%- if code == "Keypair.generate()" %}
235        {%- if not keypair_found_airdrop %}
236        {%- set keypair_found_airdrop = true %}
237        const sig{{ id }} = await connection.requestAirdrop(authorityPubkey, 10 * LAMPORTS_PER_SOL);
238        await connection.confirmTransaction(sig{{ id }}, "confirmed");
239        {%- else %}
240        const sig{{ id }} = await connection.requestAirdrop(user{{ id }}Pubkey, 10 * LAMPORTS_PER_SOL);
241        await connection.confirmTransaction(sig{{ id }}, "confirmed");
242        {%- endif %}
243        {%- endif %}
244        {%- endfor %}
245
246        // ----- PDA Initialization -----
247        {%- for id, seeds in pda_seeds %}
248        [pda{{ id }}, bump{{ id }}] = PublicKey.findProgramAddressSync(
249            {{ seeds }},
250            program.programId
251        );
252        {%- endfor %}
253
254    });
255
256    {%- macro render_accounts(account_list) -%}
257    {%- for acc in account_list %}
258    {%- set var = account_vars[acc] | default(value='/* missing */ null') %}
259    {{ acc }}: {{ var }}{%- if not loop.last %},{%- endif %}
260    {%- endfor %}
261    {%- endmacro %}
262
263    {# ---------------- INSTRUCTION DESCRIBE BLOCKS ---------------- #}
264
265    {%- for instr in instruction_tests %}
266
267
268    {# ---------- POSITIVE TESTS ---------- #}
269    {%- for test in instr.positive_cases %}
270    it("{{ test.description }}", async () => {
271        // Prepare arguments
272        {%- for arg in test.argument_values %}
273        {%- if arg.value_type.variant == "Valid" %}
274        const {{ arg.argument_name }}Value = {{ arg.value_type.description }};
275        {%- elif arg.value_type.variant == "Invalid" %}
276        const {{ arg.argument_name }}Value = {{ arg.value_type.description }};
277        {%- else %}
278        const {{ arg.argument_name }}Value = null;
279        {%- endif %}
280        {%- endfor %}
281        // Execute instruction
282        try {
283            await program.methods
284                .{{ instr.instruction_name }}(
285                    {%- for arg in test.argument_values %}
286                    {{ arg.argument_name }}Value{%- if not loop.last %},{%- endif %}
287                    {%- endfor %}
288                )
289                .accountsStrict({
290                    {%- if instruction_accounts[instr.instruction_name] %}
291                    {%- for acc_info in instruction_accounts[instr.instruction_name] %}
292                    {%- set js_var = account_vars[acc_info.original_name] | default(value="null") %}
293                    {{ acc_info.camel_name }}: {{ js_var }}{%- if not loop.last %},{%- endif %}
294                    {%- endfor %}
295                    {%- endif %}
296                })
297                .signers([
298                    authority
299                ])
300                .rpc();
301            // Expect success
302            assert.ok(true);
303        } catch (err) {
304            assert.fail("Instruction should not have failed: " + err);
305        }
306    });
307    {%- endfor %}
308    {# ---------- NEGATIVE TESTS ---------- #}
309    {%- for test in instr.negative_cases %}
310    it("{{ test.description }}", async () => {
311        // Prepare arguments
312        {%- for arg in test.argument_values %}
313        {%- if arg.value_type.variant == "Valid" %}
314        const {{ arg.argument_name }}Value = {{ arg.value_type.description }};
315        {%- elif arg.value_type.variant == "Invalid" %}
316        const {{ arg.argument_name }}Value = {{ arg.value_type.description }};
317        {%- else %}
318        const {{ arg.argument_name }}Value = null;
319        {%- endif %}
320        {%- endfor %}
321        // Execute instruction expecting failure
322        try {
323            await program.methods
324                .{{ instr.instruction_name }}(
325                    {%- for arg in test.argument_values %}
326                    {{ arg.argument_name }}Value{%- if not loop.last %},{%- endif %}
327                    {%- endfor %}
328                )
329                .accountsStrict({
330                    {%- if instruction_accounts[instr.instruction_name] %}
331                    {%- for acc_info in instruction_accounts[instr.instruction_name] %}
332                    {%- set js_var = account_vars[acc_info.original_name] | default(value="null") %}
333                    {{ acc_info.camel_name }}: {{ js_var }}{%- if not loop.last %},{%- endif %}
334                    {%- endfor %}
335                    {%- endif %}
336                })
337                .signers([
338                    authority
339                ])
340                .rpc();
341        } catch (err) {
342            {%- if test.expected_outcome.variant == "Failure" %}
343            assert(err.message.includes("{{ test.expected_outcome.error_message }}"));
344            {%- endif %}
345        }
346    });
347    {%- endfor %}
348
349    {%- endfor %}
350
351})
352
353"#;
354
355// ------------------- Helper functions (rendering helpers) -------------------
356
357fn render_pda_seeds_expression(seeds: &[SeedComponent]) -> String {
358    let parts: Vec<String> = seeds
359        .iter()
360        .map(|seed| {
361            match seed.seed_type {
362                SeedType::Static => { format!("Buffer.from(\"{}\")", seed.value) }
363                SeedType::AccountKey => { format!("{}Pubkey.toBuffer()", seed.value) }
364                SeedType::Argument => { format!("Buffer.from(String({}))", seed.value) }
365            }
366        })
367        .collect();
368
369    format!("[{}]", parts.join(", "))
370}
371
372fn cut_program_name(s: &str) -> String {
373    s.split('_').next().unwrap_or(s).to_string()
374}
375
376fn capitalize_first_letter(s: &str) -> String {
377    s.chars().next().unwrap_or('A').to_uppercase().to_string() + &s[1..]
378}
379
380fn camel_case(s: &str) -> String {
381    let parts: Vec<&str> = s.split('_').collect();
382    if parts.is_empty() {
383        return String::new();
384    }
385    let first = parts[0].to_lowercase();
386    let rest: String = parts[1..].iter()
387        .map(|word| {
388            if word.is_empty() {
389                String::new()
390    } else {
391                word.chars().next().unwrap().to_uppercase().to_string() + &word[1..].to_lowercase()
392            }
393        })
394        .collect();
395    first + &rest
396}
397
398fn to_camel_case(s: &str) -> String {
399    let parts: Vec<&str> = s.split('_').collect();
400    if parts.is_empty() {
401        return String::new();
402    }
403    let first = parts[0].to_lowercase();
404    let rest: String = parts[1..].iter()
405        .map(|word| {
406            if word.is_empty() {
407                String::new()
408            } else {
409                let mut chars = word.chars();
410                if let Some(first_char) = chars.next() {
411                    first_char.to_uppercase().collect::<String>() + &chars.as_str().to_lowercase()
412                } else {
413                    String::new()
414                }
415            }
416        })
417        .collect();
418    first + &rest
419}
420
421fn to_pascal_case(s: &str) -> String {
422    let parts: Vec<&str> = s.split('_').collect();
423    if parts.is_empty() {
424        return String::new();
425    }
426    parts.iter()
427        .map(|word| {
428            if word.is_empty() {
429                String::new()
430            } else {
431                let mut chars = word.chars();
432                if let Some(first_char) = chars.next() {
433                    first_char.to_uppercase().collect::<String>() + &chars.as_str().to_lowercase()
434                } else {
435                    String::new()
436                }
437            }
438        })
439        .collect()
440}
441
442fn convert_to_typescript_value(value_type: TestValueType) -> TestValueType {
443    match value_type {
444        TestValueType::Valid { description } => {
445            TestValueType::Valid {
446                description: convert_rust_to_typescript(&description),
447            }
448        }
449        TestValueType::Invalid { description, reason } => {
450            TestValueType::Invalid {
451                description: convert_rust_to_typescript(&description),
452                reason,
453            }
454        }
455    }
456}
457
458fn convert_rust_to_typescript(value: &str) -> String {
459    let trimmed = value.trim();
460    
461    match trimmed {
462        "u64::MAX" => "new anchor.BN(\"18446744073709551615\")".to_string(),
463        "u64::MIN" => "new anchor.BN(\"0\")".to_string(),
464        "u32::MAX" => "new anchor.BN(\"4294967295\")".to_string(),
465        "u32::MIN" => "new anchor.BN(\"0\")".to_string(),
466        "u16::MAX" => "new anchor.BN(\"65535\")".to_string(),
467        "u16::MIN" => "new anchor.BN(\"0\")".to_string(),
468        "u8::MAX" => "new anchor.BN(\"255\")".to_string(),
469        "u8::MIN" => "new anchor.BN(\"0\")".to_string(),
470        "i64::MAX" => "new anchor.BN(\"9223372036854775807\")".to_string(),
471        "i64::MIN" => "new anchor.BN(\"-9223372036854775808\")".to_string(),
472        "i32::MAX" => "new anchor.BN(\"2147483647\")".to_string(),
473        "i32::MIN" => "new anchor.BN(\"-2147483648\")".to_string(),
474        "i16::MAX" => "new anchor.BN(\"32767\")".to_string(),
475        "i16::MIN" => "new anchor.BN(\"-32768\")".to_string(),
476        "i8::MAX" => "new anchor.BN(\"127\")".to_string(),
477        "i8::MIN" => "new anchor.BN(\"-128\")".to_string(),
478        _ => {
479            if let Ok(_) = trimmed.parse::<i128>() {
480                format!("new anchor.BN(\"{}\")", trimmed)
481            } else if let Ok(_) = trimmed.parse::<f64>() {
482                if trimmed.contains('.') {
483                    trimmed.to_string()
484                } else {
485                    format!("new anchor.BN(\"{}\")", trimmed)
486                }
487            } else if trimmed.starts_with('"') && trimmed.ends_with('"') {
488                trimmed.to_string()
489            } else if trimmed == "true" || trimmed == "false" {
490                trimmed.to_string()
491            } else if trimmed.starts_with("new ") || trimmed.starts_with("authority.") || trimmed.contains("Pubkey") {
492                trimmed.to_string()
493            } else {
494                if trimmed.starts_with('"') {
495                    trimmed.to_string()
496                } else {
497                    format!("\"{}\"", trimmed)
498                }
499            }
500        }
501    }
502}