trident_template/
lib.rs

1use convert_case::Case;
2use convert_case::Casing;
3use serde_json::json;
4use sha2::Digest;
5use sha2::Sha256;
6use tera::Context;
7use tera::Tera;
8use trident_idl_spec::Idl;
9use trident_idl_spec::IdlInstruction;
10use trident_idl_spec::IdlType;
11use trident_idl_spec::IdlTypeDef;
12use trident_idl_spec::IdlTypeDefTy;
13
14use crate::error::TemplateError;
15
16pub mod error;
17
18/// Simple template engine for Trident code generation
19pub struct TridentTemplates {
20    tera: Tera,
21}
22
23impl TridentTemplates {
24    pub fn new() -> Result<Self, TemplateError> {
25        let mut tera = Tera::default();
26        tera.add_raw_templates(vec![
27            (
28                "instruction.rs",
29                include_str!("../templates/instruction.rs.tera"),
30            ),
31            (
32                "transaction.rs",
33                include_str!("../templates/transaction.rs.tera"),
34            ),
35            (
36                "test_fuzz.rs",
37                include_str!("../templates/test_fuzz.rs.tera"),
38            ),
39            (
40                "fuzz_accounts.rs",
41                include_str!("../templates/fuzz_accounts.rs.tera"),
42            ),
43            ("types.rs", include_str!("../templates/types.rs.tera")),
44            (
45                "Trident.toml",
46                include_str!("../templates/Trident.toml.tera"),
47            ),
48            (
49                "Cargo_fuzz.toml",
50                include_str!("../templates/Cargo_fuzz.toml.tera"),
51            ),
52        ])?;
53        Ok(Self { tera })
54    }
55
56    /// Generate all templates from IDLs
57    pub fn generate(
58        &self,
59        idls: &[Idl],
60        trident_version: &str,
61    ) -> Result<GeneratedFiles, TemplateError> {
62        let mut instructions = Vec::new();
63        let mut transactions = Vec::new();
64        let programs = self.build_programs_data(idls);
65
66        // Process instructions for each IDL
67        for idl in idls.iter() {
68            let program_id = if idl.address.is_empty() {
69                "fill corresponding program ID here"
70            } else {
71                &idl.address
72            };
73
74            for instruction in &idl.instructions {
75                let template_data = self.build_instruction_data(instruction, program_id)?;
76                let snake_name = &template_data["snake_name"].as_str().unwrap();
77
78                let context = Context::from_serialize(json!({"instruction": template_data}))?;
79
80                instructions.push((
81                    snake_name.to_string(),
82                    self.tera.render("instruction.rs", &context)?,
83                ));
84                transactions.push((
85                    snake_name.to_string(),
86                    self.tera.render("transaction.rs", &context)?,
87                ));
88            }
89        }
90
91        // Generate other files
92        let test_fuzz = self
93            .tera
94            .render("test_fuzz.rs", &Context::from_serialize(json!({}))?)?;
95        let fuzz_accounts = self.tera.render(
96            "fuzz_accounts.rs",
97            &Context::from_serialize(json!({"accounts": self.collect_all_accounts(idls)}))?,
98        )?;
99        let custom_types = self.tera.render(
100            "types.rs",
101            &Context::from_serialize(json!({"custom_types": self.collect_custom_types(idls)}))?,
102        )?;
103        let trident_toml = self.tera.render(
104            "Trident.toml",
105            &Context::from_serialize(json!({"programs": programs}))?,
106        )?;
107        let cargo_fuzz_toml = self.tera.render(
108            "Cargo_fuzz.toml",
109            &Context::from_serialize(json!({
110                "trident_version": trident_version,
111            }))?,
112        )?;
113
114        // Generate mod files (clone to avoid borrowing issues)
115        let instructions_mod = self.generate_mod_from_names(
116            &instructions
117                .iter()
118                .map(|(name, _)| name.clone())
119                .collect::<Vec<_>>(),
120        );
121        let transactions_mod = self.generate_mod_from_names(
122            &transactions
123                .iter()
124                .map(|(name, _)| name.clone())
125                .collect::<Vec<_>>(),
126        );
127
128        Ok(GeneratedFiles {
129            instructions,
130            transactions,
131            test_fuzz,
132            instructions_mod,
133            transactions_mod,
134            custom_types,
135            fuzz_accounts,
136            trident_toml,
137            cargo_fuzz_toml,
138        })
139    }
140
141    // Helper function to build program data
142    fn build_programs_data(&self, idls: &[Idl]) -> Vec<serde_json::Value> {
143        idls.iter()
144            .map(|idl| {
145                let program_id = if idl.address.is_empty() {
146                    "fill corresponding program ID here"
147                } else {
148                    &idl.address
149                };
150
151                let program_name = if idl.metadata.name.is_empty() {
152                    "fill corresponding program name here"
153                } else {
154                    &idl.metadata.name
155                };
156
157                json!({
158                    "name": program_name,
159                    "program_id": program_id,
160                })
161            })
162            .collect()
163    }
164
165    // Helper function to build instruction data
166    fn build_instruction_data(
167        &self,
168        instruction: &IdlInstruction,
169        program_id: &str,
170    ) -> Result<serde_json::Value, TemplateError> {
171        let name = &instruction.name;
172        let camel_name = name.to_case(Case::UpperCamel);
173        let snake_name = name.to_case(Case::Snake);
174
175        let discriminator = if instruction.discriminator.is_empty() {
176            self.generate_discriminator(name)
177        } else {
178            instruction.discriminator.clone()
179        };
180
181        let (accounts, composite_accounts) = self.process_accounts(&instruction.accounts);
182        let data_fields = self.process_data_fields(&instruction.args);
183
184        Ok(json!({
185            "name": name,
186            "camel_name": camel_name,
187            "snake_name": snake_name,
188            "program_id": program_id,
189            "discriminator": discriminator,
190            "accounts": accounts,
191            "composite_accounts": composite_accounts,
192            "data_fields": data_fields
193        }))
194    }
195
196    // Helper function to process data fields
197    fn process_data_fields(&self, args: &[trident_idl_spec::IdlField]) -> Vec<serde_json::Value> {
198        args.iter()
199            .map(|field| {
200                json!({
201                    "name": field.name,
202                    "rust_type": self.idl_type_to_rust(&field.ty)
203                })
204            })
205            .collect()
206    }
207
208    #[allow(clippy::only_used_in_recursion)]
209    /// Simplified account processing
210    fn process_accounts(
211        &self,
212        accounts: &[trident_idl_spec::IdlInstructionAccountItem],
213    ) -> (Vec<serde_json::Value>, Vec<serde_json::Value>) {
214        let mut main_accounts = Vec::new();
215        let mut composite_accounts = Vec::new();
216
217        for account in accounts {
218            match account {
219                trident_idl_spec::IdlInstructionAccountItem::Single(acc) => {
220                    main_accounts.push(json!({
221                        "name": acc.name,
222                        "is_signer": acc.signer,
223                        "is_writable": acc.writable,
224                        "address": acc.address,
225                        "is_composite": false,
226                        "composite_type_name": null
227                    }));
228                }
229                trident_idl_spec::IdlInstructionAccountItem::Composite(comp) => {
230                    let camel_name = comp.name.to_case(Case::UpperCamel);
231
232                    // Add to main accounts as composite reference
233                    main_accounts.push(json!({
234                        "name": comp.name,
235                        "is_signer": false,
236                        "is_writable": false,
237                        "address": null,
238                        "is_composite": true,
239                        "composite_type_name": camel_name
240                    }));
241
242                    // Process composite account itself
243                    let (comp_accounts, nested_composites) = self.process_accounts(&comp.accounts);
244                    composite_accounts.push(json!({
245                        "name": comp.name,
246                        "camel_name": camel_name,
247                        "accounts": comp_accounts,
248                        "nested_composites": nested_composites
249                    }));
250                    // Don't extend here - nested composites are already included in the nested_composites field
251                }
252            }
253        }
254
255        (main_accounts, composite_accounts)
256    }
257
258    #[allow(clippy::only_used_in_recursion)]
259    /// Simple type conversion
260    fn idl_type_to_rust(&self, idl_type: &IdlType) -> String {
261        match idl_type {
262            IdlType::Bool => "bool".to_string(),
263            IdlType::U8 => "u8".to_string(),
264            IdlType::I8 => "i8".to_string(),
265            IdlType::U16 => "u16".to_string(),
266            IdlType::I16 => "i16".to_string(),
267            IdlType::U32 => "u32".to_string(),
268            IdlType::I32 => "i32".to_string(),
269            IdlType::F32 => "f32".to_string(),
270            IdlType::U64 => "u64".to_string(),
271            IdlType::I64 => "i64".to_string(),
272            IdlType::F64 => "f64".to_string(),
273            IdlType::U128 => "u128".to_string(),
274            IdlType::I128 => "i128".to_string(),
275            IdlType::U256 => "u256".to_string(),
276            IdlType::I256 => "i256".to_string(),
277            IdlType::Bytes => "Vec<u8>".to_string(),
278            IdlType::String => "String".to_string(),
279            IdlType::Pubkey | IdlType::PublicKey => "TridentPubkey".to_string(),
280            IdlType::Option(inner) => format!("Option<{}>", self.idl_type_to_rust(inner)),
281            IdlType::Vec(inner) => format!("Vec<{}>", self.idl_type_to_rust(inner)),
282            IdlType::Array(inner, len) => {
283                let len_str = match len {
284                    trident_idl_spec::IdlArrayLen::Value(n) => n.to_string(),
285                    _ => "0".to_string(),
286                };
287                format!("[{}; {}]", self.idl_type_to_rust(inner), len_str)
288            }
289            IdlType::Defined(defined) => match defined {
290                trident_idl_spec::DefinedType::Simple(name) => name.clone(),
291                trident_idl_spec::DefinedType::Complex { name, .. } => name.clone(),
292            },
293            IdlType::Generic(name) => name.clone(),
294            _ => "UnknownType".to_string(),
295        }
296    }
297
298    /// Generate discriminator
299    fn generate_discriminator(&self, name: &str) -> Vec<u8> {
300        let preimage = format!("global:{}", name.to_case(Case::Snake));
301        let mut hasher = Sha256::new();
302        hasher.update(preimage);
303        hasher.finalize()[..8].to_vec()
304    }
305
306    /// Collect all accounts for fuzz_accounts
307    fn collect_all_accounts(&self, idls: &[Idl]) -> Vec<serde_json::Value> {
308        let mut accounts = std::collections::HashSet::new();
309        for idl in idls {
310            for instruction in &idl.instructions {
311                self.collect_accounts_recursive(&instruction.accounts, &mut accounts);
312            }
313        }
314        accounts
315            .into_iter()
316            .map(|name| json!({ "name": name }))
317            .collect()
318    }
319
320    #[allow(clippy::only_used_in_recursion)]
321    fn collect_accounts_recursive(
322        &self,
323        accounts: &[trident_idl_spec::IdlInstructionAccountItem],
324        acc: &mut std::collections::HashSet<String>,
325    ) {
326        for account in accounts {
327            match account {
328                trident_idl_spec::IdlInstructionAccountItem::Single(a) => {
329                    acc.insert(a.name.clone());
330                }
331                trident_idl_spec::IdlInstructionAccountItem::Composite(c) => {
332                    acc.insert(c.name.clone());
333                    self.collect_accounts_recursive(&c.accounts, acc);
334                }
335            }
336        }
337    }
338
339    /// Collect custom types
340    fn collect_custom_types(&self, idls: &[Idl]) -> Vec<serde_json::Value> {
341        idls.iter()
342            .flat_map(|idl| &idl.types)
343            .map(|type_def| self.convert_type_def_to_template_data(type_def))
344            .collect()
345    }
346
347    /// Convert IDL type definition to template data (simplified)
348    fn convert_type_def_to_template_data(&self, type_def: &IdlTypeDef) -> serde_json::Value {
349        match &type_def.ty {
350            IdlTypeDefTy::Struct { fields } => json!({
351                "type": "struct",
352                "name": type_def.name,
353                "fields": fields.as_ref().map(|f| self.convert_fields_to_template_data(f))
354            }),
355            IdlTypeDefTy::Enum { variants } => json!({
356                "type": "enum",
357                "name": type_def.name,
358                "variants": variants.iter().map(|v| json!({
359                    "name": v.name,
360                    "fields": v.fields.as_ref().map(|f| self.convert_fields_to_template_data(f))
361                })).collect::<Vec<_>>()
362            }),
363            IdlTypeDefTy::Type { .. } => json!({
364                "type": "type_alias",
365                "name": type_def.name
366            }),
367        }
368    }
369
370    /// Helper to convert fields to template data
371    fn convert_fields_to_template_data(
372        &self,
373        fields: &trident_idl_spec::IdlDefinedFields,
374    ) -> serde_json::Value {
375        match fields {
376            trident_idl_spec::IdlDefinedFields::Named(named) => json!({
377                "type": "named",
378                "fields": named.iter().map(|field| json!({
379                    "name": field.name,
380                    "rust_type": self.idl_type_to_rust(&field.ty)
381                })).collect::<Vec<_>>()
382            }),
383            trident_idl_spec::IdlDefinedFields::Tuple(tuple) => json!({
384                "type": "tuple",
385                "fields": tuple.iter().enumerate().map(|(i, field_type)| json!({
386                    "name": format!("field_{}", i),
387                    "rust_type": self.idl_type_to_rust(field_type)
388                })).collect::<Vec<_>>()
389            }),
390        }
391    }
392
393    fn generate_mod_from_names(&self, names: &[String]) -> String {
394        let mut content = String::new();
395        for name in names {
396            content.push_str(&format!("pub mod {};\n", name));
397        }
398        for name in names {
399            content.push_str(&format!("pub use {}::*;\n", name));
400        }
401        content
402    }
403}
404
405#[derive(Debug, Clone)]
406pub struct GeneratedFiles {
407    pub instructions: Vec<(String, String)>,
408    pub transactions: Vec<(String, String)>,
409    pub test_fuzz: String,
410    pub instructions_mod: String,
411    pub transactions_mod: String,
412    pub custom_types: String,
413    pub fuzz_accounts: String,
414    pub trident_toml: String,
415    pub cargo_fuzz_toml: String,
416}
417
418impl Default for TridentTemplates {
419    fn default() -> Self {
420        Self::new().expect("Failed to create template engine")
421    }
422}