1use convert_case::Case;
2use convert_case::Casing;
3use serde_json::json;
4use sha2::Digest;
5use sha2::Sha256;
6use tera::Context;
7use tera::Tera;
8use trident_idl_spec::Idl;
9use trident_idl_spec::IdlInstruction;
10use trident_idl_spec::IdlType;
11use trident_idl_spec::IdlTypeDef;
12use trident_idl_spec::IdlTypeDefTy;
13
14use crate::error::TemplateError;
15
16pub mod error;
17
18pub struct TridentTemplates {
20 tera: Tera,
21}
22
23impl TridentTemplates {
24 pub fn new() -> Result<Self, TemplateError> {
25 let mut tera = Tera::default();
26 tera.add_raw_templates(vec![
27 (
28 "test_fuzz.rs",
29 include_str!("../templates/test_fuzz.rs.tera"),
30 ),
31 (
32 "fuzz_accounts.rs",
33 include_str!("../templates/fuzz_accounts.rs.tera"),
34 ),
35 ("types.rs", include_str!("../templates/types.rs.tera")),
36 (
37 "Trident.toml",
38 include_str!("../templates/Trident.toml.tera"),
39 ),
40 (
41 "Cargo_fuzz.toml",
42 include_str!("../templates/Cargo_fuzz.toml.tera"),
43 ),
44 ])?;
45 Ok(Self { tera })
46 }
47
48 pub fn generate(
50 &self,
51 idls: &[Idl],
52 trident_version: &str,
53 ) -> Result<GeneratedFiles, TemplateError> {
54 let programs_data = self.build_programs_with_instructions_data(idls)?;
55
56 let test_fuzz = self
58 .tera
59 .render("test_fuzz.rs", &Context::from_serialize(json!({}))?)?;
60 let fuzz_accounts = self.tera.render(
61 "fuzz_accounts.rs",
62 &Context::from_serialize(json!({"accounts": self.collect_all_accounts(idls)}))?,
63 )?;
64 let types = self.tera.render(
65 "types.rs",
66 &Context::from_serialize(json!({
67 "programs": programs_data,
68 "custom_types": self.collect_custom_types(idls)
69 }))?,
70 )?;
71 let trident_toml = self.tera.render(
72 "Trident.toml",
73 &Context::from_serialize(json!({"programs": programs_data}))?,
74 )?;
75 let cargo_fuzz_toml = self.tera.render(
76 "Cargo_fuzz.toml",
77 &Context::from_serialize(json!({
78 "trident_version": trident_version,
79 }))?,
80 )?;
81
82 Ok(GeneratedFiles {
83 test_fuzz,
84 types,
85 fuzz_accounts,
86 trident_toml,
87 cargo_fuzz_toml,
88 })
89 }
90
91 fn build_programs_with_instructions_data(
93 &self,
94 idls: &[Idl],
95 ) -> Result<Vec<serde_json::Value>, TemplateError> {
96 let mut programs_data = Vec::new();
97
98 for idl in idls.iter() {
99 let program_id = if idl.address.is_empty() {
100 "fill corresponding program ID here"
101 } else {
102 &idl.address
103 };
104
105 let program_name = if idl.metadata.name.is_empty() {
106 "unknown_program"
107 } else {
108 &idl.metadata.name
109 };
110
111 let module_name = program_name.to_case(Case::Snake);
112
113 let mut instructions_data = Vec::new();
115 let mut composite_accounts = Vec::new();
116 let mut seen_composites = std::collections::HashSet::new();
117
118 for instruction in &idl.instructions {
119 let instruction_data = self.build_instruction_data_with_lifetimes(
120 instruction,
121 program_id,
122 &std::collections::HashMap::new(),
123 )?;
124
125 if let Some(composites) = instruction_data
127 .get("composite_accounts")
128 .and_then(|v| v.as_array())
129 {
130 for composite in composites {
131 if let Some(name) = composite.get("camel_name").and_then(|v| v.as_str()) {
132 if seen_composites.insert(name.to_string()) {
134 composite_accounts.push(composite.clone());
135 }
136 }
137 }
138 }
139
140 instructions_data.push(instruction_data);
141 }
142
143 programs_data.push(json!({
144 "name": program_name,
145 "module_name": module_name,
146 "program_id": program_id,
147 "instructions": instructions_data,
148 "composite_accounts": composite_accounts
149 }));
150 }
151
152 Ok(programs_data)
153 }
154
155 fn process_data_fields(&self, args: &[trident_idl_spec::IdlField]) -> Vec<serde_json::Value> {
157 args.iter()
158 .map(|field| {
159 json!({
160 "name": field.name,
161 "rust_type": self.idl_type_to_rust(&field.ty)
162 })
163 })
164 .collect()
165 }
166
167 #[allow(clippy::only_used_in_recursion)]
168 fn idl_type_to_rust(&self, idl_type: &IdlType) -> String {
170 match idl_type {
171 IdlType::Bool => "bool".to_string(),
172 IdlType::U8 => "u8".to_string(),
173 IdlType::I8 => "i8".to_string(),
174 IdlType::U16 => "u16".to_string(),
175 IdlType::I16 => "i16".to_string(),
176 IdlType::U32 => "u32".to_string(),
177 IdlType::I32 => "i32".to_string(),
178 IdlType::F32 => "f32".to_string(),
179 IdlType::U64 => "u64".to_string(),
180 IdlType::I64 => "i64".to_string(),
181 IdlType::F64 => "f64".to_string(),
182 IdlType::U128 => "u128".to_string(),
183 IdlType::I128 => "i128".to_string(),
184 IdlType::U256 => "u256".to_string(),
185 IdlType::I256 => "i256".to_string(),
186 IdlType::Bytes => "Vec<u8>".to_string(),
187 IdlType::String => "String".to_string(),
188 IdlType::Pubkey | IdlType::PublicKey => "Pubkey".to_string(),
189 IdlType::Option(inner) => format!("Option<{}>", self.idl_type_to_rust(inner)),
190 IdlType::Vec(inner) => format!("Vec<{}>", self.idl_type_to_rust(inner)),
191 IdlType::Array(inner, len) => {
192 let len_str = match len {
193 trident_idl_spec::IdlArrayLen::Value(n) => n.to_string(),
194 _ => "0".to_string(),
195 };
196 format!("[{}; {}]", self.idl_type_to_rust(inner), len_str)
197 }
198 IdlType::Defined(defined) => match defined {
199 trident_idl_spec::DefinedType::Simple(name) => name.clone(),
200 trident_idl_spec::DefinedType::Complex { name, .. } => name.clone(),
201 },
202 IdlType::Generic(name) => name.clone(),
203 _ => "UnknownType".to_string(),
204 }
205 }
206
207 fn generate_discriminator(&self, name: &str) -> Vec<u8> {
209 let preimage = format!("global:{}", name.to_case(Case::Snake));
210 let mut hasher = Sha256::new();
211 hasher.update(preimage);
212 hasher.finalize()[..8].to_vec()
213 }
214
215 fn collect_all_accounts(&self, idls: &[Idl]) -> Vec<serde_json::Value> {
217 let mut accounts = Vec::new();
218 for idl in idls {
219 for instruction in &idl.instructions {
220 self.collect_accounts_recursive(&instruction.accounts, &mut accounts);
221 }
222 }
223
224 let mut seen = std::collections::HashSet::new();
226 accounts.retain(|name| seen.insert(name.clone()));
227
228 accounts
229 .into_iter()
230 .map(|name| json!({ "name": name }))
231 .collect()
232 }
233
234 #[allow(clippy::only_used_in_recursion)]
235 fn collect_accounts_recursive(
236 &self,
237 accounts: &[trident_idl_spec::IdlInstructionAccountItem],
238 acc: &mut Vec<String>,
239 ) {
240 for account in accounts {
241 match account {
242 trident_idl_spec::IdlInstructionAccountItem::Single(a) => {
243 acc.push(a.name.clone());
244 }
245 trident_idl_spec::IdlInstructionAccountItem::Composite(c) => {
246 acc.push(c.name.clone());
247 self.collect_accounts_recursive(&c.accounts, acc);
248 }
249 }
250 }
251 }
252
253 fn collect_custom_types(&self, idls: &[Idl]) -> Vec<serde_json::Value> {
255 idls.iter()
256 .flat_map(|idl| &idl.types)
257 .map(|type_def| self.convert_type_def_to_template_data(type_def))
258 .collect()
259 }
260
261 fn convert_type_def_to_template_data(&self, type_def: &IdlTypeDef) -> serde_json::Value {
263 match &type_def.ty {
264 IdlTypeDefTy::Struct { fields } => json!({
265 "type": "struct",
266 "name": type_def.name,
267 "fields": fields.as_ref().map(|f| self.convert_fields_to_template_data(f))
268 }),
269 IdlTypeDefTy::Enum { variants } => json!({
270 "type": "enum",
271 "name": type_def.name,
272 "variants": variants.iter().map(|v| json!({
273 "name": v.name,
274 "fields": v.fields.as_ref().map(|f| self.convert_fields_to_template_data(f))
275 })).collect::<Vec<_>>()
276 }),
277 IdlTypeDefTy::Type { .. } => json!({
278 "type": "type_alias",
279 "name": type_def.name
280 }),
281 }
282 }
283
284 fn convert_fields_to_template_data(
286 &self,
287 fields: &trident_idl_spec::IdlDefinedFields,
288 ) -> serde_json::Value {
289 match fields {
290 trident_idl_spec::IdlDefinedFields::Named(named) => json!({
291 "type": "named",
292 "fields": named.iter().map(|field| json!({
293 "name": field.name,
294 "rust_type": self.idl_type_to_rust(&field.ty)
295 })).collect::<Vec<_>>()
296 }),
297 trident_idl_spec::IdlDefinedFields::Tuple(tuple) => json!({
298 "type": "tuple",
299 "fields": tuple.iter().enumerate().map(|(i, field_type)| json!({
300 "name": format!("field_{}", i),
301 "rust_type": self.idl_type_to_rust(field_type)
302 })).collect::<Vec<_>>()
303 }),
304 }
305 }
306
307 fn extract_seeds(
309 &self,
310 seeds: &[trident_idl_spec::IdlSeed],
311 ) -> (Vec<String>, Vec<String>, Vec<String>) {
312 let mut static_seeds = Vec::new();
313 let mut account_seeds = Vec::new();
314 let mut arg_seeds = Vec::new();
315
316 for seed in seeds {
317 match seed {
318 trident_idl_spec::IdlSeed::Const(const_seed) => {
319 let bytes_str = format!(
321 "[{}]",
322 const_seed
323 .value
324 .iter()
325 .map(|b| format!("{}u8", b))
326 .collect::<Vec<_>>()
327 .join(", ")
328 );
329 static_seeds.push(bytes_str);
330 }
331 trident_idl_spec::IdlSeed::Account(account_seed) => {
332 account_seeds.push(account_seed.path.clone());
334 }
335 trident_idl_spec::IdlSeed::Arg(arg_seed) => {
336 arg_seeds.push(arg_seed.path.clone());
338 }
339 }
340 }
341
342 (static_seeds, account_seeds, arg_seeds)
343 }
344
345 fn build_instruction_data_with_lifetimes(
347 &self,
348 instruction: &IdlInstruction,
349 program_id: &str,
350 _composite_lifetime_map: &std::collections::HashMap<String, bool>,
351 ) -> Result<serde_json::Value, TemplateError> {
352 let name = &instruction.name;
353 let camel_name = name.to_case(Case::UpperCamel);
354 let snake_name = name.to_case(Case::Snake);
355
356 let discriminator = if instruction.discriminator.is_empty() {
357 self.generate_discriminator(name)
358 } else {
359 instruction.discriminator.clone()
360 };
361
362 let (accounts, composite_accounts) =
363 self.process_accounts_with_lifetimes(&instruction.accounts);
364 let data_fields = self.process_data_fields(&instruction.args);
365
366 Ok(json!({
367 "name": name,
368 "camel_name": camel_name,
369 "snake_name": snake_name,
370 "program_id": program_id,
371 "discriminator": discriminator,
372 "accounts": accounts,
373 "composite_accounts": composite_accounts,
374 "data_fields": data_fields,
375 "needs_lifetime": false
376 }))
377 }
378
379 #[allow(clippy::only_used_in_recursion)]
380 fn process_accounts_with_lifetimes(
381 &self,
382 accounts: &[trident_idl_spec::IdlInstructionAccountItem],
383 ) -> (Vec<serde_json::Value>, Vec<serde_json::Value>) {
384 let mut main_accounts = Vec::new();
385 let mut composite_accounts = Vec::new();
386
387 for account in accounts {
388 match account {
389 trident_idl_spec::IdlInstructionAccountItem::Single(acc) => {
390 let has_pda_seeds = acc.pda.is_some();
391 let (static_seeds, account_seeds, arg_seeds) = if let Some(pda) = &acc.pda {
392 self.extract_seeds(&pda.seeds)
393 } else {
394 (Vec::new(), Vec::new(), Vec::new())
395 };
396
397 main_accounts.push(json!({
398 "name": acc.name,
399 "is_signer": acc.signer,
400 "is_writable": acc.writable,
401 "address": acc.address,
402 "is_composite": false,
403 "composite_type_name": null,
404 "has_pda_seeds": has_pda_seeds,
405 "composite_needs_lifetime": false,
406 "static_seeds": static_seeds,
407 "account_seeds": account_seeds,
408 "arg_seeds": arg_seeds
409 }));
410 }
411 trident_idl_spec::IdlInstructionAccountItem::Composite(comp) => {
412 let camel_name = comp.name.to_case(Case::UpperCamel);
413
414 main_accounts.push(json!({
416 "name": comp.name,
417 "is_signer": false,
418 "is_writable": false,
419 "address": null,
420 "is_composite": true,
421 "composite_type_name": camel_name,
422 "has_pda_seeds": false,
423 "composite_needs_lifetime": false
424 }));
425
426 let (comp_accounts, nested_composites) =
428 self.process_accounts_with_lifetimes(&comp.accounts);
429
430 composite_accounts.push(json!({
431 "name": comp.name,
432 "camel_name": camel_name,
433 "accounts": comp_accounts,
434 "nested_composites": nested_composites,
435 "needs_lifetime": false
436 }));
437 }
438 }
439 }
440
441 (main_accounts, composite_accounts)
443 }
444}
445
446#[derive(Debug, Clone)]
447pub struct GeneratedFiles {
448 pub test_fuzz: String,
449 pub types: String,
450 pub fuzz_accounts: String,
451 pub trident_toml: String,
452 pub cargo_fuzz_toml: String,
453}
454
455impl Default for TridentTemplates {
456 fn default() -> Self {
457 Self::new().expect("Failed to create template engine")
458 }
459}