1use anyhow::{ Context, Result };
2use std::collections::HashMap;
3use std::fs::{ create_dir_all, File };
4use std::io::Write;
5use std::path::Path;
6
7use solify_common::{
8 IdlData,
9 SeedComponent,
10 SeedType,
11 SetupType,
12 TestMetadata,
13 TestValueType,
14};
15use solify_common::errors::SolifyError;
16use tera::{ Tera, Context as TeraContext };
17use serde::{Serialize, Deserialize};
18
19#[derive(Serialize, Deserialize)]
20struct AccountInfo {
21 original_name: String,
22 camel_name: String,
23}
24
25pub fn generate_with_tera(
26 meta: &TestMetadata,
27 idl: &IdlData,
28 out_dir: impl AsRef<Path>
29) -> Result<()> {
30 let out_dir = out_dir.as_ref();
31 create_dir_all(out_dir).with_context(|| format!("creating output dir {:?}", out_dir))?;
32
33 let mut tera = Tera::default();
34 tera
35 .add_raw_template("aggregated_tests.tera", AGGREGATED_TEMPLATE)
36 .context("add aggregated template")?;
37
38 let mut ctx = TeraContext::new();
39
40 let program_name = &idl.name;
41 let program_name_pascal = cut_program_name(program_name);
42 let program_capitalized = capitalize_first_letter(&program_name_pascal);
43 let program_name_camel = camel_case(program_name);
44 let program_name_pascal_case = to_pascal_case(program_name);
45 ctx.insert("program_name", program_name);
46 ctx.insert("program_name_pascal", &program_name_pascal);
47 ctx.insert("program_capitalized", &program_capitalized);
48 ctx.insert("program_name_camel", &program_name_camel);
49 ctx.insert("program_name_pascal_case", &program_name_pascal_case);
50
51
52 let setup_requirements = meta.setup_requirements.clone();
54 let mut map = HashMap::new();
55 let mut index = 0;
56
57 for setup_requirement in setup_requirements.iter().cloned() {
58 index += 1;
59
60 match setup_requirement.requirement_type {
61 SetupType::CreateKeypair => {
62 map.insert(index, "Keypair.generate()");
63 }
64 SetupType::FundAccount => {
65 map.insert(index, "FundAccount");
66 }
67 SetupType::InitializePda => {
68 map.insert(index, "PublicKey");
69 }
70 _ => {
71 return Err(SolifyError::InvalidSetupRequirement.into());
72 }
73 }
74 }
75 ctx.insert("setup_requirements", &map);
76
77 let mut pda_indices = Vec::new();
78 let mut index_1 = 0;
79
80 for r in setup_requirements.iter().cloned() {
81 index_1 += 1;
82
83 if r.requirement_type == SetupType::InitializePda {
84 pda_indices.push(index_1);
85 }
86 }
87
88 let mut pda_map = HashMap::new();
90 let pda_init_sequence = meta.pda_init_sequence.clone();
91
92 for (i, pda_init) in pda_init_sequence.iter().enumerate() {
93 if let Some(index) = pda_indices.get(i) {
94 let seeds_expr = render_pda_seeds_expression(&pda_init.seeds);
95 pda_map.insert(*index, seeds_expr);
96 }
97 }
98
99 ctx.insert("pda_seeds", &pda_map);
100
101 let mut account_vars: HashMap<String, String> = HashMap::new();
102
103 for ad in meta.account_dependencies.iter() {
104 if ad.is_pda {
105 if
106 let Some((pos, _)) = meta.pda_init_sequence
107 .iter()
108 .enumerate()
109 .find(|(_, p)| p.account_name == ad.account_name)
110 {
111 let setup_index = pda_indices[pos];
112 account_vars.insert(ad.account_name.clone(), format!("pda{}", setup_index));
113 } else {
114 account_vars.insert(
115 ad.account_name.clone(),
116 format!("/* missing pda for {} */ null", ad.account_name)
117 );
118 }
119 } else if ad.account_name == "authority" {
120 account_vars.insert(ad.account_name.clone(), "authorityPubkey".to_string());
121 } else if ad.account_name == "system_program" {
122 account_vars.insert(ad.account_name.clone(), "SystemProgram.programId".to_string());
123 } else {
124 account_vars.insert(ad.account_name.clone(), format!("{}", ad.account_name));
125 }
126 }
127
128 for instruction in &idl.instructions {
129 for acc in &instruction.accounts {
130 if !account_vars.contains_key(&acc.name) {
131 if acc.name == "system_program" || acc.name == "systemProgram" {
132 account_vars.insert(acc.name.clone(), "SystemProgram.programId".to_string());
133 } else if acc.name == "authority" {
134 account_vars.insert(acc.name.clone(), "authorityPubkey".to_string());
135 } else {
136 if let Some((pos, _)) = meta.pda_init_sequence
137 .iter()
138 .enumerate()
139 .find(|(_, p)| p.account_name == acc.name)
140 {
141 if let Some(setup_index) = pda_indices.get(pos) {
142 account_vars.insert(acc.name.clone(), format!("pda{}", setup_index));
143 }
144 }
145 }
146 }
147 }
148 }
149
150 ctx.insert("account_vars", &account_vars);
151 let mut instruction_accounts: HashMap<String, Vec<AccountInfo>> = HashMap::new();
152 for instruction in &idl.instructions {
153 let account_infos: Vec<AccountInfo> = instruction.accounts.iter()
154 .map(|acc| {
155 AccountInfo {
156 original_name: acc.name.clone(),
157 camel_name: to_camel_case(&acc.name),
158 }
159 })
160 .collect();
161 instruction_accounts.insert(instruction.name.clone(), account_infos);
162 }
163 ctx.insert("instruction_accounts", &instruction_accounts);
164
165 let mut processed_test_cases = meta.test_cases.clone();
166 for test_case in &mut processed_test_cases {
167 for arg_value in &mut test_case.positive_cases {
168 for arg in &mut arg_value.argument_values {
169 arg.value_type = convert_to_typescript_value(arg.value_type.clone());
170 }
171 }
172 for arg_value in &mut test_case.negative_cases {
173 for arg in &mut arg_value.argument_values {
174 arg.value_type = convert_to_typescript_value(arg.value_type.clone());
175 }
176 }
177 }
178 ctx.insert("instruction_tests", &processed_test_cases);
179
180 let rendered = tera.render("aggregated_tests.tera", &ctx).context("render tera")?;
181
182 let out_path = out_dir.join(format!("{}.ts", program_name_pascal));
183 let mut f = File::create(&out_path).with_context(|| format!("create file {:?}", out_path))?;
184 f.write_all(rendered.as_bytes()).with_context(|| format!("write file {:?}", out_path))?;
185
186 println!("Wrote {}", out_path.display());
187 Ok(())
188}
189
190const AGGREGATED_TEMPLATE: &str =
191 r#"
192import * as anchor from "@coral-xyz/anchor";
193import { Program } from "@coral-xyz/anchor";
194import { {{ program_name_pascal_case }} } from "../target/types/{{ program_name }}";
195import { assert } from "chai";
196import { Keypair, SystemProgram, PublicKey, LAMPORTS_PER_SOL } from "@solana/web3.js";
197
198describe("{{ program_name | default(value='program') }}", () => {
199 // Configure the client
200 let provider = anchor.AnchorProvider.env();
201 anchor.setProvider(provider);
202 const connection = provider.connection;
203
204 const program = anchor.workspace.{{ program_name }} as Program<{{ program_name_pascal_case }}>;
205
206 // Setup Requirements
207 // keypair decelarations
208 {%- set keypair_found = false %}
209 {%- for id, code in setup_requirements %}
210 {%- if code == "Keypair.generate()" %}
211 {%- if not keypair_found %}
212 {%- set keypair_found = true %}
213 const authority = Keypair.generate();
214 const authorityPubkey = authority.publicKey;
215 {%- else %}
216 const user{{ id }} = Keypair.generate();
217 const user{{ id }}Pubkey = user{{ id }}.publicKey;
218 {%- endif %}
219 {%- endif %}
220 {%- endfor %}
221
222 // PDA Decelaration
223 {%- for id, code in setup_requirements %}
224 {%- if code == "PublicKey" %}
225 let pda{{ id }}: PublicKey;
226 let bump{{ id }}: number;
227 {%- endif %}
228 {%- endfor %}
229
230 before(async () => {
231 // ----- Airdrop for each user Keypair -----
232 {%- set keypair_found_airdrop = false %}
233 {%- for id, code in setup_requirements %}
234 {%- if code == "Keypair.generate()" %}
235 {%- if not keypair_found_airdrop %}
236 {%- set keypair_found_airdrop = true %}
237 const sig{{ id }} = await connection.requestAirdrop(authorityPubkey, 10 * LAMPORTS_PER_SOL);
238 await connection.confirmTransaction(sig{{ id }}, "confirmed");
239 {%- else %}
240 const sig{{ id }} = await connection.requestAirdrop(user{{ id }}Pubkey, 10 * LAMPORTS_PER_SOL);
241 await connection.confirmTransaction(sig{{ id }}, "confirmed");
242 {%- endif %}
243 {%- endif %}
244 {%- endfor %}
245
246 // ----- PDA Initialization -----
247 {%- for id, seeds in pda_seeds %}
248 [pda{{ id }}, bump{{ id }}] = PublicKey.findProgramAddressSync(
249 {{ seeds }},
250 program.programId
251 );
252 {%- endfor %}
253
254 });
255
256 {%- macro render_accounts(account_list) -%}
257 {%- for acc in account_list %}
258 {%- set var = account_vars[acc] | default(value='/* missing */ null') %}
259 {{ acc }}: {{ var }}{%- if not loop.last %},{%- endif %}
260 {%- endfor %}
261 {%- endmacro %}
262
263 {# ---------------- INSTRUCTION DESCRIBE BLOCKS ---------------- #}
264
265 {%- for instr in instruction_tests %}
266
267
268 {# ---------- POSITIVE TESTS ---------- #}
269 {%- for test in instr.positive_cases %}
270 it("{{ test.description }}", async () => {
271 // Prepare arguments
272 {%- for arg in test.argument_values %}
273 {%- if arg.value_type.variant == "Valid" %}
274 const {{ arg.argument_name }}Value = {{ arg.value_type.description }};
275 {%- elif arg.value_type.variant == "Invalid" %}
276 const {{ arg.argument_name }}Value = {{ arg.value_type.description }};
277 {%- else %}
278 const {{ arg.argument_name }}Value = null;
279 {%- endif %}
280 {%- endfor %}
281 // Execute instruction
282 try {
283 await program.methods
284 .{{ instr.instruction_name }}(
285 {%- for arg in test.argument_values %}
286 {{ arg.argument_name }}Value{%- if not loop.last %},{%- endif %}
287 {%- endfor %}
288 )
289 .accountsStrict({
290 {%- if instruction_accounts[instr.instruction_name] %}
291 {%- for acc_info in instruction_accounts[instr.instruction_name] %}
292 {%- set js_var = account_vars[acc_info.original_name] | default(value="null") %}
293 {{ acc_info.camel_name }}: {{ js_var }}{%- if not loop.last %},{%- endif %}
294 {%- endfor %}
295 {%- endif %}
296 })
297 .signers([
298 authority
299 ])
300 .rpc();
301 // Expect success
302 assert.ok(true);
303 } catch (err) {
304 assert.fail("Instruction should not have failed: " + err);
305 }
306 });
307 {%- endfor %}
308 {# ---------- NEGATIVE TESTS ---------- #}
309 {%- for test in instr.negative_cases %}
310 it("{{ test.description }}", async () => {
311 // Prepare arguments
312 {%- for arg in test.argument_values %}
313 {%- if arg.value_type.variant == "Valid" %}
314 const {{ arg.argument_name }}Value = {{ arg.value_type.description }};
315 {%- elif arg.value_type.variant == "Invalid" %}
316 const {{ arg.argument_name }}Value = {{ arg.value_type.description }};
317 {%- else %}
318 const {{ arg.argument_name }}Value = null;
319 {%- endif %}
320 {%- endfor %}
321 // Execute instruction expecting failure
322 try {
323 await program.methods
324 .{{ instr.instruction_name }}(
325 {%- for arg in test.argument_values %}
326 {{ arg.argument_name }}Value{%- if not loop.last %},{%- endif %}
327 {%- endfor %}
328 )
329 .accountsStrict({
330 {%- if instruction_accounts[instr.instruction_name] %}
331 {%- for acc_info in instruction_accounts[instr.instruction_name] %}
332 {%- set js_var = account_vars[acc_info.original_name] | default(value="null") %}
333 {{ acc_info.camel_name }}: {{ js_var }}{%- if not loop.last %},{%- endif %}
334 {%- endfor %}
335 {%- endif %}
336 })
337 .signers([
338 authority
339 ])
340 .rpc();
341 } catch (err) {
342 {%- if test.expected_outcome.variant == "Failure" %}
343 assert(err.message.includes("{{ test.expected_outcome.error_message }}"));
344 {%- endif %}
345 }
346 });
347 {%- endfor %}
348
349 {%- endfor %}
350
351})
352
353"#;
354
355fn render_pda_seeds_expression(seeds: &[SeedComponent]) -> String {
358 let parts: Vec<String> = seeds
359 .iter()
360 .map(|seed| {
361 match seed.seed_type {
362 SeedType::Static => { format!("Buffer.from(\"{}\")", seed.value) }
363 SeedType::AccountKey => { format!("{}Pubkey.toBuffer()", seed.value) }
364 SeedType::Argument => { format!("Buffer.from(String({}))", seed.value) }
365 }
366 })
367 .collect();
368
369 format!("[{}]", parts.join(", "))
370}
371
372fn cut_program_name(s: &str) -> String {
373 s.split('_').next().unwrap_or(s).to_string()
374}
375
376fn capitalize_first_letter(s: &str) -> String {
377 s.chars().next().unwrap_or('A').to_uppercase().to_string() + &s[1..]
378}
379
380fn camel_case(s: &str) -> String {
381 let parts: Vec<&str> = s.split('_').collect();
382 if parts.is_empty() {
383 return String::new();
384 }
385 let first = parts[0].to_lowercase();
386 let rest: String = parts[1..].iter()
387 .map(|word| {
388 if word.is_empty() {
389 String::new()
390 } else {
391 word.chars().next().unwrap().to_uppercase().to_string() + &word[1..].to_lowercase()
392 }
393 })
394 .collect();
395 first + &rest
396}
397
398fn to_camel_case(s: &str) -> String {
399 let parts: Vec<&str> = s.split('_').collect();
400 if parts.is_empty() {
401 return String::new();
402 }
403 let first = parts[0].to_lowercase();
404 let rest: String = parts[1..].iter()
405 .map(|word| {
406 if word.is_empty() {
407 String::new()
408 } else {
409 let mut chars = word.chars();
410 if let Some(first_char) = chars.next() {
411 first_char.to_uppercase().collect::<String>() + &chars.as_str().to_lowercase()
412 } else {
413 String::new()
414 }
415 }
416 })
417 .collect();
418 first + &rest
419}
420
421fn to_pascal_case(s: &str) -> String {
422 let parts: Vec<&str> = s.split('_').collect();
423 if parts.is_empty() {
424 return String::new();
425 }
426 parts.iter()
427 .map(|word| {
428 if word.is_empty() {
429 String::new()
430 } else {
431 let mut chars = word.chars();
432 if let Some(first_char) = chars.next() {
433 first_char.to_uppercase().collect::<String>() + &chars.as_str().to_lowercase()
434 } else {
435 String::new()
436 }
437 }
438 })
439 .collect()
440}
441
442fn convert_to_typescript_value(value_type: TestValueType) -> TestValueType {
443 match value_type {
444 TestValueType::Valid { description } => {
445 TestValueType::Valid {
446 description: convert_rust_to_typescript(&description),
447 }
448 }
449 TestValueType::Invalid { description, reason } => {
450 TestValueType::Invalid {
451 description: convert_rust_to_typescript(&description),
452 reason,
453 }
454 }
455 }
456}
457
458fn convert_rust_to_typescript(value: &str) -> String {
459 let trimmed = value.trim();
460
461 match trimmed {
462 "u64::MAX" => "new anchor.BN(\"18446744073709551615\")".to_string(),
463 "u64::MIN" => "new anchor.BN(\"0\")".to_string(),
464 "u32::MAX" => "new anchor.BN(\"4294967295\")".to_string(),
465 "u32::MIN" => "new anchor.BN(\"0\")".to_string(),
466 "u16::MAX" => "new anchor.BN(\"65535\")".to_string(),
467 "u16::MIN" => "new anchor.BN(\"0\")".to_string(),
468 "u8::MAX" => "new anchor.BN(\"255\")".to_string(),
469 "u8::MIN" => "new anchor.BN(\"0\")".to_string(),
470 "i64::MAX" => "new anchor.BN(\"9223372036854775807\")".to_string(),
471 "i64::MIN" => "new anchor.BN(\"-9223372036854775808\")".to_string(),
472 "i32::MAX" => "new anchor.BN(\"2147483647\")".to_string(),
473 "i32::MIN" => "new anchor.BN(\"-2147483648\")".to_string(),
474 "i16::MAX" => "new anchor.BN(\"32767\")".to_string(),
475 "i16::MIN" => "new anchor.BN(\"-32768\")".to_string(),
476 "i8::MAX" => "new anchor.BN(\"127\")".to_string(),
477 "i8::MIN" => "new anchor.BN(\"-128\")".to_string(),
478 _ => {
479 if let Ok(_) = trimmed.parse::<i128>() {
480 format!("new anchor.BN(\"{}\")", trimmed)
481 } else if let Ok(_) = trimmed.parse::<f64>() {
482 if trimmed.contains('.') {
483 trimmed.to_string()
484 } else {
485 format!("new anchor.BN(\"{}\")", trimmed)
486 }
487 } else if trimmed.starts_with('"') && trimmed.ends_with('"') {
488 trimmed.to_string()
489 } else if trimmed == "true" || trimmed == "false" {
490 trimmed.to_string()
491 } else if trimmed.starts_with("new ") || trimmed.starts_with("authority.") || trimmed.contains("Pubkey") {
492 trimmed.to_string()
493 } else {
494 if trimmed.starts_with('"') {
495 trimmed.to_string()
496 } else {
497 format!("\"{}\"", trimmed)
498 }
499 }
500 }
501 }
502}