1use anyhow::{ Context, Result };
2use std::collections::HashMap;
3use std::fs::{ create_dir_all, File };
4use std::io::Write;
5use std::path::Path;
6
7use solify_common::{
8 IdlData,
9 SeedComponent,
10 SeedType,
11 SetupType,
12 TestMetadata,
13 TestValueType,
14};
15use solify_common::errors::SolifyError;
16use tera::{ Tera, Context as TeraContext };
17use serde::{Serialize, Deserialize};
18
19#[derive(Serialize, Deserialize)]
20struct AccountInfo {
21 original_name: String,
22 camel_name: String,
23}
24
25#[derive(Serialize, Deserialize)]
26struct InstructionTestCaseWrapper {
27 instruction_name: String,
28 instruction_name_camel: String,
29 arguments: Vec<solify_common::ArgumentInfo>,
30 positive_cases: Vec<solify_common::TestCase>,
31 negative_cases: Vec<solify_common::TestCase>,
32}
33
34pub fn generate_with_tera(
35 meta: &TestMetadata,
36 idl: &IdlData,
37 out_dir: impl AsRef<Path>
38) -> Result<()> {
39 let out_dir = out_dir.as_ref();
40 create_dir_all(out_dir).with_context(|| format!("creating output dir {:?}", out_dir))?;
41
42 let mut tera = Tera::default();
43 tera
44 .add_raw_template("aggregated_tests.tera", AGGREGATED_TEMPLATE)
45 .context("add aggregated template")?;
46
47 let mut ctx = TeraContext::new();
48
49 let program_name = &idl.name;
50 let program_name_pascal = cut_program_name(program_name);
51 let program_capitalized = capitalize_first_letter(&program_name_pascal);
52 let program_name_camel = camel_case(program_name);
53 let program_name_pascal_case = to_pascal_case(program_name);
54 ctx.insert("program_name", program_name);
55 ctx.insert("program_name_pascal", &program_name_pascal);
56 ctx.insert("program_capitalized", &program_capitalized);
57 ctx.insert("program_name_camel", &program_name_camel);
58 ctx.insert("program_name_pascal_case", &program_name_pascal_case);
59
60
61 let setup_requirements = meta.setup_requirements.clone();
63 let mut map = HashMap::new();
64 let mut index = 0;
65
66 for setup_requirement in setup_requirements.iter().cloned() {
67 index += 1;
68
69 match setup_requirement.requirement_type {
70 SetupType::CreateKeypair => {
71 map.insert(index, "Keypair.generate()");
72 }
73 SetupType::FundAccount => {
74 map.insert(index, "FundAccount");
75 }
76 SetupType::InitializePda => {
77 map.insert(index, "PublicKey");
78 }
79 _ => {
80 return Err(SolifyError::InvalidSetupRequirement.into());
81 }
82 }
83 }
84 ctx.insert("setup_requirements", &map);
85
86 let mut pda_indices = Vec::new();
87 let mut index_1 = 0;
88
89 for r in setup_requirements.iter().cloned() {
90 index_1 += 1;
91
92 if r.requirement_type == SetupType::InitializePda {
93 pda_indices.push(index_1);
94 }
95 }
96
97 let mut pda_map = HashMap::new();
99 let pda_init_sequence = meta.pda_init_sequence.clone();
100
101 for (i, pda_init) in pda_init_sequence.iter().enumerate() {
102 if let Some(index) = pda_indices.get(i) {
103 let seeds_expr = render_pda_seeds_expression(&pda_init.seeds);
104 pda_map.insert(*index, seeds_expr);
105 }
106 }
107
108 ctx.insert("pda_seeds", &pda_map);
109
110 let mut account_vars: HashMap<String, String> = HashMap::new();
111
112 for ad in meta.account_dependencies.iter() {
113 if ad.is_pda {
114 if
115 let Some((pos, _)) = meta.pda_init_sequence
116 .iter()
117 .enumerate()
118 .find(|(_, p)| p.account_name == ad.account_name)
119 {
120 let setup_index = pda_indices[pos];
121 account_vars.insert(ad.account_name.clone(), format!("pda{}", setup_index));
122 } else {
123 account_vars.insert(
124 ad.account_name.clone(),
125 format!("/* missing pda for {} */ null", ad.account_name)
126 );
127 }
128 } else if ad.account_name == "authority" {
129 account_vars.insert(ad.account_name.clone(), "authorityPubkey".to_string());
130 } else if ad.account_name == "system_program" {
131 account_vars.insert(ad.account_name.clone(), "SystemProgram.programId".to_string());
132 } else {
133 account_vars.insert(ad.account_name.clone(), format!("{}", ad.account_name));
134 }
135 }
136
137 for instruction in &idl.instructions {
138 for acc in &instruction.accounts {
139 if !account_vars.contains_key(&acc.name) {
140 if acc.name == "system_program" || acc.name == "systemProgram" {
141 account_vars.insert(acc.name.clone(), "SystemProgram.programId".to_string());
142 } else if acc.name == "authority" {
143 account_vars.insert(acc.name.clone(), "authorityPubkey".to_string());
144 } else {
145 if let Some((pos, _)) = meta.pda_init_sequence
146 .iter()
147 .enumerate()
148 .find(|(_, p)| p.account_name == acc.name)
149 {
150 if let Some(setup_index) = pda_indices.get(pos) {
151 account_vars.insert(acc.name.clone(), format!("pda{}", setup_index));
152 }
153 }
154 }
155 }
156 }
157 }
158
159 ctx.insert("account_vars", &account_vars);
160 let mut instruction_accounts: HashMap<String, Vec<AccountInfo>> = HashMap::new();
161 for instruction in &idl.instructions {
162 let account_infos: Vec<AccountInfo> = instruction.accounts.iter()
163 .map(|acc| {
164 AccountInfo {
165 original_name: acc.name.clone(),
166 camel_name: to_camel_case(&acc.name),
167 }
168 })
169 .collect();
170 instruction_accounts.insert(instruction.name.clone(), account_infos);
171 }
172 ctx.insert("instruction_accounts", &instruction_accounts);
173
174 let processed_test_cases: Vec<InstructionTestCaseWrapper> = meta.test_cases.iter()
175 .map(|test_case| {
176 let mut positive_cases = test_case.positive_cases.clone();
177 for arg_value in &mut positive_cases {
178 for arg in &mut arg_value.argument_values {
179 arg.value_type = convert_to_typescript_value(arg.value_type.clone());
180 }
181 }
182
183 let mut negative_cases = test_case.negative_cases.clone();
184 for arg_value in &mut negative_cases {
185 for arg in &mut arg_value.argument_values {
186 arg.value_type = convert_to_typescript_value(arg.value_type.clone());
187 }
188 }
189
190 InstructionTestCaseWrapper {
191 instruction_name: test_case.instruction_name.clone(),
192 instruction_name_camel: to_camel_case(&test_case.instruction_name),
193 arguments: test_case.arguments.clone(),
194 positive_cases,
195 negative_cases,
196 }
197 })
198 .collect();
199 ctx.insert("instruction_tests", &processed_test_cases);
200
201 let rendered = tera.render("aggregated_tests.tera", &ctx).context("render tera")?;
202
203 let out_path = out_dir.join(format!("{}.ts", program_name_pascal));
204 let mut f = File::create(&out_path).with_context(|| format!("create file {:?}", out_path))?;
205 f.write_all(rendered.as_bytes()).with_context(|| format!("write file {:?}", out_path))?;
206
207 println!("Wrote {}", out_path.display());
208 Ok(())
209}
210
211const AGGREGATED_TEMPLATE: &str =
212 r#"
213import * as anchor from "@coral-xyz/anchor";
214import { Program } from "@coral-xyz/anchor";
215import { {{ program_name_pascal_case }} } from "../target/types/{{ program_name }}";
216import { assert } from "chai";
217import { Keypair, SystemProgram, PublicKey, LAMPORTS_PER_SOL } from "@solana/web3.js";
218
219// This file is generated by solify. You can edit it manually
220
221describe("{{ program_name | default(value='program') }}", () => {
222 // Configure the client
223 let provider = anchor.AnchorProvider.env();
224 anchor.setProvider(provider);
225 const connection = provider.connection;
226
227 const program = anchor.workspace.{{ program_name }} as Program<{{ program_name_pascal_case }}>;
228
229 // Setup Requirements
230 // keypair decelarations
231 {%- set keypair_found = false %}
232 {%- for id, code in setup_requirements %}
233 {%- if code == "Keypair.generate()" %}
234 {%- if not keypair_found %}
235 {%- set keypair_found = true %}
236 const authority = Keypair.generate();
237 const authorityPubkey = authority.publicKey;
238 {%- else %}
239 const user{{ id }} = Keypair.generate();
240 const user{{ id }}Pubkey = user{{ id }}.publicKey;
241 {%- endif %}
242 {%- endif %}
243 {%- endfor %}
244
245 // PDA Decelaration
246 {%- for id, code in setup_requirements %}
247 {%- if code == "PublicKey" %}
248 let pda{{ id }}: PublicKey;
249 let bump{{ id }}: number;
250 {%- endif %}
251 {%- endfor %}
252
253 before(async () => {
254 // ----- Airdrop for each user Keypair -----
255 {%- set keypair_found_airdrop = false %}
256 {%- for id, code in setup_requirements %}
257 {%- if code == "Keypair.generate()" %}
258 {%- if not keypair_found_airdrop %}
259 {%- set keypair_found_airdrop = true %}
260 const sig{{ id }} = await connection.requestAirdrop(authorityPubkey, 10 * LAMPORTS_PER_SOL);
261 await connection.confirmTransaction(sig{{ id }}, "confirmed");
262 {%- else %}
263 const sig{{ id }} = await connection.requestAirdrop(user{{ id }}Pubkey, 10 * LAMPORTS_PER_SOL);
264 await connection.confirmTransaction(sig{{ id }}, "confirmed");
265 {%- endif %}
266 {%- endif %}
267 {%- endfor %}
268
269 // ----- PDA Initialization -----
270 {%- for id, seeds in pda_seeds %}
271 [pda{{ id }}, bump{{ id }}] = PublicKey.findProgramAddressSync(
272 {{ seeds }},
273 program.programId
274 );
275 {%- endfor %}
276
277 });
278
279 {%- macro render_accounts(account_list) -%}
280 {%- for acc in account_list %}
281 {%- set var = account_vars[acc] | default(value='/* missing */ null') %}
282 {{ acc }}: {{ var }}{%- if not loop.last %},{%- endif %}
283 {%- endfor %}
284 {%- endmacro %}
285
286 {# ---------------- INSTRUCTION DESCRIBE BLOCKS ---------------- #}
287
288 {%- for instr in instruction_tests %}
289
290
291 {# ---------- POSITIVE TESTS ---------- #}
292 {%- for test in instr.positive_cases %}
293 it("{{ test.description }}", async () => {
294 // Prepare arguments
295 {%- for arg in test.argument_values %}
296 {%- if arg.value_type.variant == "Valid" %}
297 const {{ arg.argument_name }}Value = {{ arg.value_type.description }};
298 {%- elif arg.value_type.variant == "Invalid" %}
299 const {{ arg.argument_name }}Value = {{ arg.value_type.description }};
300 {%- else %}
301 const {{ arg.argument_name }}Value = null;
302 {%- endif %}
303 {%- endfor %}
304 // Execute instruction
305 try {
306 await program.methods
307 .{{ instr.instruction_name_camel }}(
308 {%- for arg in test.argument_values %}
309 {{ arg.argument_name }}Value{%- if not loop.last %},{%- endif %}
310 {%- endfor %}
311 )
312 .accountsStrict({
313 {%- if instruction_accounts[instr.instruction_name] %}
314 {%- for acc_info in instruction_accounts[instr.instruction_name] %}
315 {%- set js_var = account_vars[acc_info.original_name] | default(value="null") %}
316 {{ acc_info.camel_name }}: {{ js_var }}{%- if not loop.last %},{%- endif %}
317 {%- endfor %}
318 {%- endif %}
319 })
320 .signers([
321 authority
322 ])
323 .rpc();
324 // Expect success
325 assert.ok(true);
326 } catch (err) {
327 assert.fail("Instruction should not have failed: " + err);
328 }
329 });
330 {%- endfor %}
331 {# ---------- NEGATIVE TESTS ---------- #}
332 {%- for test in instr.negative_cases %}
333 it("{{ test.description }}", async () => {
334 // Prepare arguments
335 {%- for arg in test.argument_values %}
336 {%- if arg.value_type.variant == "Valid" %}
337 const {{ arg.argument_name }}Value = {{ arg.value_type.description }};
338 {%- elif arg.value_type.variant == "Invalid" %}
339 const {{ arg.argument_name }}Value = {{ arg.value_type.description }};
340 {%- else %}
341 const {{ arg.argument_name }}Value = null;
342 {%- endif %}
343 {%- endfor %}
344 // Execute instruction expecting failure
345 try {
346 await program.methods
347 .{{ instr.instruction_name_camel }}(
348 {%- for arg in test.argument_values %}
349 {{ arg.argument_name }}Value{%- if not loop.last %},{%- endif %}
350 {%- endfor %}
351 )
352 .accountsStrict({
353 {%- if instruction_accounts[instr.instruction_name] %}
354 {%- for acc_info in instruction_accounts[instr.instruction_name] %}
355 {%- set js_var = account_vars[acc_info.original_name] | default(value="null") %}
356 {{ acc_info.camel_name }}: {{ js_var }}{%- if not loop.last %},{%- endif %}
357 {%- endfor %}
358 {%- endif %}
359 })
360 .signers([
361 authority
362 ])
363 .rpc();
364 } catch (err) {
365 {%- if test.expected_outcome.variant == "Failure" %}
366 assert(err.message.includes("{{ test.expected_outcome.error_message }}"));
367 {%- endif %}
368 }
369 });
370 {%- endfor %}
371
372 {%- endfor %}
373
374})
375
376"#;
377
378fn render_pda_seeds_expression(seeds: &[SeedComponent]) -> String {
381 let parts: Vec<String> = seeds
382 .iter()
383 .map(|seed| {
384 match seed.seed_type {
385 SeedType::Static => { format!("Buffer.from(\"{}\")", seed.value) }
386 SeedType::AccountKey => { format!("{}Pubkey.toBuffer()", seed.value) }
387 SeedType::Argument => { format!("Buffer.from(String({}))", seed.value) }
388 }
389 })
390 .collect();
391
392 format!("[{}]", parts.join(", "))
393}
394
395fn cut_program_name(s: &str) -> String {
396 s.split('_').next().unwrap_or(s).to_string()
397}
398
399fn capitalize_first_letter(s: &str) -> String {
400 s.chars().next().unwrap_or('A').to_uppercase().to_string() + &s[1..]
401}
402
403fn camel_case(s: &str) -> String {
404 let parts: Vec<&str> = s.split('_').collect();
405 if parts.is_empty() {
406 return String::new();
407 }
408 let first = parts[0].to_lowercase();
409 let rest: String = parts[1..].iter()
410 .map(|word| {
411 if word.is_empty() {
412 String::new()
413 } else {
414 word.chars().next().unwrap().to_uppercase().to_string() + &word[1..].to_lowercase()
415 }
416 })
417 .collect();
418 first + &rest
419}
420
421fn to_camel_case(s: &str) -> String {
422 let parts: Vec<&str> = s.split('_').collect();
423 if parts.is_empty() {
424 return String::new();
425 }
426 let first = parts[0].to_lowercase();
427 let rest: String = parts[1..].iter()
428 .map(|word| {
429 if word.is_empty() {
430 String::new()
431 } else {
432 let mut chars = word.chars();
433 if let Some(first_char) = chars.next() {
434 first_char.to_uppercase().collect::<String>() + &chars.as_str().to_lowercase()
435 } else {
436 String::new()
437 }
438 }
439 })
440 .collect();
441 first + &rest
442}
443
444fn to_pascal_case(s: &str) -> String {
445 let parts: Vec<&str> = s.split('_').collect();
446 if parts.is_empty() {
447 return String::new();
448 }
449 parts.iter()
450 .map(|word| {
451 if word.is_empty() {
452 String::new()
453 } else {
454 let mut chars = word.chars();
455 if let Some(first_char) = chars.next() {
456 first_char.to_uppercase().collect::<String>() + &chars.as_str().to_lowercase()
457 } else {
458 String::new()
459 }
460 }
461 })
462 .collect()
463}
464
465fn convert_to_typescript_value(value_type: TestValueType) -> TestValueType {
466 match value_type {
467 TestValueType::Valid { description } => {
468 TestValueType::Valid {
469 description: convert_rust_to_typescript(&description),
470 }
471 }
472 TestValueType::Invalid { description, reason } => {
473 TestValueType::Invalid {
474 description: convert_rust_to_typescript(&description),
475 reason,
476 }
477 }
478 }
479}
480
481fn convert_rust_to_typescript(value: &str) -> String {
482 let trimmed = value.trim();
483
484 match trimmed {
485 "u64::MAX" => "new anchor.BN(\"18446744073709551615\")".to_string(),
486 "u64::MIN" => "new anchor.BN(\"0\")".to_string(),
487 "u32::MAX" => "new anchor.BN(\"4294967295\")".to_string(),
488 "u32::MIN" => "new anchor.BN(\"0\")".to_string(),
489 "u16::MAX" => "new anchor.BN(\"65535\")".to_string(),
490 "u16::MIN" => "new anchor.BN(\"0\")".to_string(),
491 "u8::MAX" => "new anchor.BN(\"255\")".to_string(),
492 "u8::MIN" => "new anchor.BN(\"0\")".to_string(),
493 "i64::MAX" => "new anchor.BN(\"9223372036854775807\")".to_string(),
494 "i64::MIN" => "new anchor.BN(\"-9223372036854775808\")".to_string(),
495 "i32::MAX" => "new anchor.BN(\"2147483647\")".to_string(),
496 "i32::MIN" => "new anchor.BN(\"-2147483648\")".to_string(),
497 "i16::MAX" => "new anchor.BN(\"32767\")".to_string(),
498 "i16::MIN" => "new anchor.BN(\"-32768\")".to_string(),
499 "i8::MAX" => "new anchor.BN(\"127\")".to_string(),
500 "i8::MIN" => "new anchor.BN(\"-128\")".to_string(),
501 _ => {
502 if let Ok(_) = trimmed.parse::<i128>() {
503 format!("new anchor.BN(\"{}\")", trimmed)
504 } else if let Ok(_) = trimmed.parse::<f64>() {
505 if trimmed.contains('.') {
506 trimmed.to_string()
507 } else {
508 format!("new anchor.BN(\"{}\")", trimmed)
509 }
510 } else if trimmed.starts_with('"') && trimmed.ends_with('"') {
511 trimmed.to_string()
512 } else if trimmed == "true" || trimmed == "false" {
513 trimmed.to_string()
514 } else if trimmed.starts_with("new ") || trimmed.starts_with("authority.") || trimmed.contains("Pubkey") {
515 trimmed.to_string()
516 } else {
517 if trimmed.starts_with('"') {
518 trimmed.to_string()
519 } else {
520 format!("\"{}\"", trimmed)
521 }
522 }
523 }
524 }
525}