1use convert_case::Case;
2use convert_case::Casing;
3use serde_json::json;
4use sha2::Digest;
5use sha2::Sha256;
6use tera::Context;
7use tera::Tera;
8use trident_idl_spec::Idl;
9use trident_idl_spec::IdlInstruction;
10use trident_idl_spec::IdlType;
11use trident_idl_spec::IdlTypeDef;
12use trident_idl_spec::IdlTypeDefTy;
13
14use crate::error::TemplateError;
15
16pub mod error;
17
18pub struct TridentTemplates {
20 tera: Tera,
21}
22
23impl TridentTemplates {
24 pub fn new() -> Result<Self, TemplateError> {
25 let mut tera = Tera::default();
26 tera.add_raw_templates(vec![
27 (
28 "instruction.rs",
29 include_str!("../templates/instruction.rs.tera"),
30 ),
31 (
32 "transaction.rs",
33 include_str!("../templates/transaction.rs.tera"),
34 ),
35 (
36 "test_fuzz.rs",
37 include_str!("../templates/test_fuzz.rs.tera"),
38 ),
39 (
40 "fuzz_accounts.rs",
41 include_str!("../templates/fuzz_accounts.rs.tera"),
42 ),
43 ("types.rs", include_str!("../templates/types.rs.tera")),
44 (
45 "Trident.toml",
46 include_str!("../templates/Trident.toml.tera"),
47 ),
48 (
49 "Cargo_fuzz.toml",
50 include_str!("../templates/Cargo_fuzz.toml.tera"),
51 ),
52 ])?;
53 Ok(Self { tera })
54 }
55
56 pub fn generate(
58 &self,
59 idls: &[Idl],
60 trident_version: &str,
61 ) -> Result<GeneratedFiles, TemplateError> {
62 let mut instructions = Vec::new();
63 let mut transactions = Vec::new();
64 let programs = self.build_programs_data(idls);
65
66 for idl in idls.iter() {
68 let program_id = if idl.address.is_empty() {
69 "fill corresponding program ID here"
70 } else {
71 &idl.address
72 };
73
74 for instruction in &idl.instructions {
75 let template_data = self.build_instruction_data(instruction, program_id)?;
76 let snake_name = &template_data["snake_name"].as_str().unwrap();
77
78 let context = Context::from_serialize(json!({"instruction": template_data}))?;
79
80 instructions.push((
81 snake_name.to_string(),
82 self.tera.render("instruction.rs", &context)?,
83 ));
84 transactions.push((
85 snake_name.to_string(),
86 self.tera.render("transaction.rs", &context)?,
87 ));
88 }
89 }
90
91 let test_fuzz = self
93 .tera
94 .render("test_fuzz.rs", &Context::from_serialize(json!({}))?)?;
95 let fuzz_accounts = self.tera.render(
96 "fuzz_accounts.rs",
97 &Context::from_serialize(json!({"accounts": self.collect_all_accounts(idls)}))?,
98 )?;
99 let custom_types = self.tera.render(
100 "types.rs",
101 &Context::from_serialize(json!({"custom_types": self.collect_custom_types(idls)}))?,
102 )?;
103 let trident_toml = self.tera.render(
104 "Trident.toml",
105 &Context::from_serialize(json!({"programs": programs}))?,
106 )?;
107 let cargo_fuzz_toml = self.tera.render(
108 "Cargo_fuzz.toml",
109 &Context::from_serialize(json!({
110 "trident_version": trident_version,
111 }))?,
112 )?;
113
114 let instructions_mod = self.generate_mod_from_names(
116 &instructions
117 .iter()
118 .map(|(name, _)| name.clone())
119 .collect::<Vec<_>>(),
120 );
121 let transactions_mod = self.generate_mod_from_names(
122 &transactions
123 .iter()
124 .map(|(name, _)| name.clone())
125 .collect::<Vec<_>>(),
126 );
127
128 Ok(GeneratedFiles {
129 instructions,
130 transactions,
131 test_fuzz,
132 instructions_mod,
133 transactions_mod,
134 custom_types,
135 fuzz_accounts,
136 trident_toml,
137 cargo_fuzz_toml,
138 })
139 }
140
141 fn build_programs_data(&self, idls: &[Idl]) -> Vec<serde_json::Value> {
143 idls.iter()
144 .map(|idl| {
145 let program_id = if idl.address.is_empty() {
146 "fill corresponding program ID here"
147 } else {
148 &idl.address
149 };
150
151 let program_name = if idl.metadata.name.is_empty() {
152 "fill corresponding program name here"
153 } else {
154 &idl.metadata.name
155 };
156
157 json!({
158 "name": program_name,
159 "program_id": program_id,
160 })
161 })
162 .collect()
163 }
164
165 fn build_instruction_data(
167 &self,
168 instruction: &IdlInstruction,
169 program_id: &str,
170 ) -> Result<serde_json::Value, TemplateError> {
171 let name = &instruction.name;
172 let camel_name = name.to_case(Case::UpperCamel);
173 let snake_name = name.to_case(Case::Snake);
174
175 let discriminator = if instruction.discriminator.is_empty() {
176 self.generate_discriminator(name)
177 } else {
178 instruction.discriminator.clone()
179 };
180
181 let (accounts, composite_accounts) = self.process_accounts(&instruction.accounts);
182 let data_fields = self.process_data_fields(&instruction.args);
183
184 Ok(json!({
185 "name": name,
186 "camel_name": camel_name,
187 "snake_name": snake_name,
188 "program_id": program_id,
189 "discriminator": discriminator,
190 "accounts": accounts,
191 "composite_accounts": composite_accounts,
192 "data_fields": data_fields
193 }))
194 }
195
196 fn process_data_fields(&self, args: &[trident_idl_spec::IdlField]) -> Vec<serde_json::Value> {
198 args.iter()
199 .map(|field| {
200 json!({
201 "name": field.name,
202 "rust_type": self.idl_type_to_rust(&field.ty)
203 })
204 })
205 .collect()
206 }
207
208 #[allow(clippy::only_used_in_recursion)]
209 fn process_accounts(
211 &self,
212 accounts: &[trident_idl_spec::IdlInstructionAccountItem],
213 ) -> (Vec<serde_json::Value>, Vec<serde_json::Value>) {
214 let mut main_accounts = Vec::new();
215 let mut composite_accounts = Vec::new();
216
217 for account in accounts {
218 match account {
219 trident_idl_spec::IdlInstructionAccountItem::Single(acc) => {
220 main_accounts.push(json!({
221 "name": acc.name,
222 "is_signer": acc.signer,
223 "is_writable": acc.writable,
224 "address": acc.address,
225 "is_composite": false,
226 "composite_type_name": null
227 }));
228 }
229 trident_idl_spec::IdlInstructionAccountItem::Composite(comp) => {
230 let camel_name = comp.name.to_case(Case::UpperCamel);
231
232 main_accounts.push(json!({
234 "name": comp.name,
235 "is_signer": false,
236 "is_writable": false,
237 "address": null,
238 "is_composite": true,
239 "composite_type_name": camel_name
240 }));
241
242 let (comp_accounts, nested_composites) = self.process_accounts(&comp.accounts);
244 composite_accounts.push(json!({
245 "name": comp.name,
246 "camel_name": camel_name,
247 "accounts": comp_accounts,
248 "nested_composites": nested_composites
249 }));
250 }
252 }
253 }
254
255 (main_accounts, composite_accounts)
256 }
257
258 #[allow(clippy::only_used_in_recursion)]
259 fn idl_type_to_rust(&self, idl_type: &IdlType) -> String {
261 match idl_type {
262 IdlType::Bool => "bool".to_string(),
263 IdlType::U8 => "u8".to_string(),
264 IdlType::I8 => "i8".to_string(),
265 IdlType::U16 => "u16".to_string(),
266 IdlType::I16 => "i16".to_string(),
267 IdlType::U32 => "u32".to_string(),
268 IdlType::I32 => "i32".to_string(),
269 IdlType::F32 => "f32".to_string(),
270 IdlType::U64 => "u64".to_string(),
271 IdlType::I64 => "i64".to_string(),
272 IdlType::F64 => "f64".to_string(),
273 IdlType::U128 => "u128".to_string(),
274 IdlType::I128 => "i128".to_string(),
275 IdlType::U256 => "u256".to_string(),
276 IdlType::I256 => "i256".to_string(),
277 IdlType::Bytes => "Vec<u8>".to_string(),
278 IdlType::String => "String".to_string(),
279 IdlType::Pubkey | IdlType::PublicKey => "TridentPubkey".to_string(),
280 IdlType::Option(inner) => format!("Option<{}>", self.idl_type_to_rust(inner)),
281 IdlType::Vec(inner) => format!("Vec<{}>", self.idl_type_to_rust(inner)),
282 IdlType::Array(inner, len) => {
283 let len_str = match len {
284 trident_idl_spec::IdlArrayLen::Value(n) => n.to_string(),
285 _ => "0".to_string(),
286 };
287 format!("[{}; {}]", self.idl_type_to_rust(inner), len_str)
288 }
289 IdlType::Defined(defined) => match defined {
290 trident_idl_spec::DefinedType::Simple(name) => name.clone(),
291 trident_idl_spec::DefinedType::Complex { name, .. } => name.clone(),
292 },
293 IdlType::Generic(name) => name.clone(),
294 _ => "UnknownType".to_string(),
295 }
296 }
297
298 fn generate_discriminator(&self, name: &str) -> Vec<u8> {
300 let preimage = format!("global:{}", name.to_case(Case::Snake));
301 let mut hasher = Sha256::new();
302 hasher.update(preimage);
303 hasher.finalize()[..8].to_vec()
304 }
305
306 fn collect_all_accounts(&self, idls: &[Idl]) -> Vec<serde_json::Value> {
308 let mut accounts = std::collections::HashSet::new();
309 for idl in idls {
310 for instruction in &idl.instructions {
311 self.collect_accounts_recursive(&instruction.accounts, &mut accounts);
312 }
313 }
314 accounts
315 .into_iter()
316 .map(|name| json!({ "name": name }))
317 .collect()
318 }
319
320 #[allow(clippy::only_used_in_recursion)]
321 fn collect_accounts_recursive(
322 &self,
323 accounts: &[trident_idl_spec::IdlInstructionAccountItem],
324 acc: &mut std::collections::HashSet<String>,
325 ) {
326 for account in accounts {
327 match account {
328 trident_idl_spec::IdlInstructionAccountItem::Single(a) => {
329 acc.insert(a.name.clone());
330 }
331 trident_idl_spec::IdlInstructionAccountItem::Composite(c) => {
332 acc.insert(c.name.clone());
333 self.collect_accounts_recursive(&c.accounts, acc);
334 }
335 }
336 }
337 }
338
339 fn collect_custom_types(&self, idls: &[Idl]) -> Vec<serde_json::Value> {
341 idls.iter()
342 .flat_map(|idl| &idl.types)
343 .map(|type_def| self.convert_type_def_to_template_data(type_def))
344 .collect()
345 }
346
347 fn convert_type_def_to_template_data(&self, type_def: &IdlTypeDef) -> serde_json::Value {
349 match &type_def.ty {
350 IdlTypeDefTy::Struct { fields } => json!({
351 "type": "struct",
352 "name": type_def.name,
353 "fields": fields.as_ref().map(|f| self.convert_fields_to_template_data(f))
354 }),
355 IdlTypeDefTy::Enum { variants } => json!({
356 "type": "enum",
357 "name": type_def.name,
358 "variants": variants.iter().map(|v| json!({
359 "name": v.name,
360 "fields": v.fields.as_ref().map(|f| self.convert_fields_to_template_data(f))
361 })).collect::<Vec<_>>()
362 }),
363 IdlTypeDefTy::Type { .. } => json!({
364 "type": "type_alias",
365 "name": type_def.name
366 }),
367 }
368 }
369
370 fn convert_fields_to_template_data(
372 &self,
373 fields: &trident_idl_spec::IdlDefinedFields,
374 ) -> serde_json::Value {
375 match fields {
376 trident_idl_spec::IdlDefinedFields::Named(named) => json!({
377 "type": "named",
378 "fields": named.iter().map(|field| json!({
379 "name": field.name,
380 "rust_type": self.idl_type_to_rust(&field.ty)
381 })).collect::<Vec<_>>()
382 }),
383 trident_idl_spec::IdlDefinedFields::Tuple(tuple) => json!({
384 "type": "tuple",
385 "fields": tuple.iter().enumerate().map(|(i, field_type)| json!({
386 "name": format!("field_{}", i),
387 "rust_type": self.idl_type_to_rust(field_type)
388 })).collect::<Vec<_>>()
389 }),
390 }
391 }
392
393 fn generate_mod_from_names(&self, names: &[String]) -> String {
394 let mut content = String::new();
395 for name in names {
396 content.push_str(&format!("pub mod {};\n", name));
397 }
398 for name in names {
399 content.push_str(&format!("pub use {}::*;\n", name));
400 }
401 content
402 }
403}
404
405#[derive(Debug, Clone)]
406pub struct GeneratedFiles {
407 pub instructions: Vec<(String, String)>,
408 pub transactions: Vec<(String, String)>,
409 pub test_fuzz: String,
410 pub instructions_mod: String,
411 pub transactions_mod: String,
412 pub custom_types: String,
413 pub fuzz_accounts: String,
414 pub trident_toml: String,
415 pub cargo_fuzz_toml: String,
416}
417
418impl Default for TridentTemplates {
419 fn default() -> Self {
420 Self::new().expect("Failed to create template engine")
421 }
422}