trident_template/
lib.rs

1use convert_case::Case;
2use convert_case::Casing;
3use serde_json::json;
4use sha2::Digest;
5use sha2::Sha256;
6use tera::Context;
7use tera::Tera;
8use trident_idl_spec::Idl;
9use trident_idl_spec::IdlInstruction;
10use trident_idl_spec::IdlType;
11use trident_idl_spec::IdlTypeDef;
12use trident_idl_spec::IdlTypeDefTy;
13
14use crate::error::TemplateError;
15
16pub mod error;
17
18/// Simple template engine for Trident code generation
19pub struct TridentTemplates {
20    tera: Tera,
21}
22
23impl TridentTemplates {
24    pub fn new() -> Result<Self, TemplateError> {
25        let mut tera = Tera::default();
26        tera.add_raw_templates(vec![
27            (
28                "test_fuzz.rs",
29                include_str!("../templates/test_fuzz.rs.tera"),
30            ),
31            (
32                "fuzz_accounts.rs",
33                include_str!("../templates/fuzz_accounts.rs.tera"),
34            ),
35            ("types.rs", include_str!("../templates/types.rs.tera")),
36            (
37                "Trident.toml",
38                include_str!("../templates/Trident.toml.tera"),
39            ),
40            (
41                "Cargo_fuzz.toml",
42                include_str!("../templates/Cargo_fuzz.toml.tera"),
43            ),
44        ])?;
45        Ok(Self { tera })
46    }
47
48    /// Generate all templates from IDLs
49    pub fn generate(
50        &self,
51        idls: &[Idl],
52        trident_version: &str,
53    ) -> Result<GeneratedFiles, TemplateError> {
54        let programs_data = self.build_programs_with_instructions_data(idls)?;
55
56        // Generate files
57        let test_fuzz = self
58            .tera
59            .render("test_fuzz.rs", &Context::from_serialize(json!({}))?)?;
60        let fuzz_accounts = self.tera.render(
61            "fuzz_accounts.rs",
62            &Context::from_serialize(json!({"accounts": self.collect_all_accounts(idls)}))?,
63        )?;
64        let types = self.tera.render(
65            "types.rs",
66            &Context::from_serialize(json!({
67                "programs": programs_data,
68                "custom_types": self.collect_custom_types(idls)
69            }))?,
70        )?;
71        let trident_toml = self.tera.render(
72            "Trident.toml",
73            &Context::from_serialize(json!({"programs": programs_data}))?,
74        )?;
75        let cargo_fuzz_toml = self.tera.render(
76            "Cargo_fuzz.toml",
77            &Context::from_serialize(json!({
78                "trident_version": trident_version,
79            }))?,
80        )?;
81
82        Ok(GeneratedFiles {
83            test_fuzz,
84            types,
85            fuzz_accounts,
86            trident_toml,
87            cargo_fuzz_toml,
88        })
89    }
90
91    // Helper function to build programs with instructions data
92    fn build_programs_with_instructions_data(
93        &self,
94        idls: &[Idl],
95    ) -> Result<Vec<serde_json::Value>, TemplateError> {
96        let mut programs_data = Vec::new();
97
98        for idl in idls.iter() {
99            let program_id = if idl.address.is_empty() {
100                "fill corresponding program ID here"
101            } else {
102                &idl.address
103            };
104
105            let program_name = if idl.metadata.name.is_empty() {
106                "unknown_program"
107            } else {
108                &idl.metadata.name
109            };
110
111            let module_name = program_name.to_case(Case::Snake);
112
113            // Process instructions and collect composite accounts (preserving IDL order)
114            let mut instructions_data = Vec::new();
115            let mut composite_accounts = Vec::new();
116            let mut seen_composites = std::collections::HashSet::new();
117
118            for instruction in &idl.instructions {
119                let instruction_data = self.build_instruction_data_with_lifetimes(
120                    instruction,
121                    program_id,
122                    &std::collections::HashMap::new(),
123                )?;
124
125                // Collect composite accounts for deduplication (preserving first occurrence order)
126                if let Some(composites) = instruction_data
127                    .get("composite_accounts")
128                    .and_then(|v| v.as_array())
129                {
130                    for composite in composites {
131                        if let Some(name) = composite.get("camel_name").and_then(|v| v.as_str()) {
132                            // Only add if not already seen (preserves first occurrence and IDL order)
133                            if seen_composites.insert(name.to_string()) {
134                                composite_accounts.push(composite.clone());
135                            }
136                        }
137                    }
138                }
139
140                instructions_data.push(instruction_data);
141            }
142
143            programs_data.push(json!({
144                "name": program_name,
145                "module_name": module_name,
146                "program_id": program_id,
147                "instructions": instructions_data,
148                "composite_accounts": composite_accounts
149            }));
150        }
151
152        Ok(programs_data)
153    }
154
155    // Helper function to process data fields
156    fn process_data_fields(&self, args: &[trident_idl_spec::IdlField]) -> Vec<serde_json::Value> {
157        args.iter()
158            .map(|field| {
159                json!({
160                    "name": field.name,
161                    "rust_type": self.idl_type_to_rust(&field.ty)
162                })
163            })
164            .collect()
165    }
166
167    #[allow(clippy::only_used_in_recursion)]
168    /// Simple type conversion
169    fn idl_type_to_rust(&self, idl_type: &IdlType) -> String {
170        match idl_type {
171            IdlType::Bool => "bool".to_string(),
172            IdlType::U8 => "u8".to_string(),
173            IdlType::I8 => "i8".to_string(),
174            IdlType::U16 => "u16".to_string(),
175            IdlType::I16 => "i16".to_string(),
176            IdlType::U32 => "u32".to_string(),
177            IdlType::I32 => "i32".to_string(),
178            IdlType::F32 => "f32".to_string(),
179            IdlType::U64 => "u64".to_string(),
180            IdlType::I64 => "i64".to_string(),
181            IdlType::F64 => "f64".to_string(),
182            IdlType::U128 => "u128".to_string(),
183            IdlType::I128 => "i128".to_string(),
184            IdlType::U256 => "u256".to_string(),
185            IdlType::I256 => "i256".to_string(),
186            IdlType::Bytes => "Vec<u8>".to_string(),
187            IdlType::String => "String".to_string(),
188            IdlType::Pubkey | IdlType::PublicKey => "Pubkey".to_string(),
189            IdlType::Option(inner) => format!("Option<{}>", self.idl_type_to_rust(inner)),
190            IdlType::Vec(inner) => format!("Vec<{}>", self.idl_type_to_rust(inner)),
191            IdlType::Array(inner, len) => {
192                let len_str = match len {
193                    trident_idl_spec::IdlArrayLen::Value(n) => n.to_string(),
194                    _ => "0".to_string(),
195                };
196                format!("[{}; {}]", self.idl_type_to_rust(inner), len_str)
197            }
198            IdlType::Defined(defined) => match defined {
199                trident_idl_spec::DefinedType::Simple(name) => name.clone(),
200                trident_idl_spec::DefinedType::Complex { name, .. } => name.clone(),
201            },
202            IdlType::Generic(name) => name.clone(),
203            _ => "UnknownType".to_string(),
204        }
205    }
206
207    /// Generate discriminator
208    fn generate_discriminator(&self, name: &str) -> Vec<u8> {
209        let preimage = format!("global:{}", name.to_case(Case::Snake));
210        let mut hasher = Sha256::new();
211        hasher.update(preimage);
212        hasher.finalize()[..8].to_vec()
213    }
214
215    /// Collect all accounts for fuzz_accounts (preserving IDL order and deterministic)
216    fn collect_all_accounts(&self, idls: &[Idl]) -> Vec<serde_json::Value> {
217        let mut accounts = Vec::new();
218        for idl in idls {
219            for instruction in &idl.instructions {
220                self.collect_accounts_recursive(&instruction.accounts, &mut accounts);
221            }
222        }
223
224        // Deduplicate while preserving order (keep first occurrence)
225        let mut seen = std::collections::HashSet::new();
226        accounts.retain(|name| seen.insert(name.clone()));
227
228        accounts
229            .into_iter()
230            .map(|name| json!({ "name": name }))
231            .collect()
232    }
233
234    #[allow(clippy::only_used_in_recursion)]
235    fn collect_accounts_recursive(
236        &self,
237        accounts: &[trident_idl_spec::IdlInstructionAccountItem],
238        acc: &mut Vec<String>,
239    ) {
240        for account in accounts {
241            match account {
242                trident_idl_spec::IdlInstructionAccountItem::Single(a) => {
243                    acc.push(a.name.clone());
244                }
245                trident_idl_spec::IdlInstructionAccountItem::Composite(c) => {
246                    acc.push(c.name.clone());
247                    self.collect_accounts_recursive(&c.accounts, acc);
248                }
249            }
250        }
251    }
252
253    /// Collect custom types
254    fn collect_custom_types(&self, idls: &[Idl]) -> Vec<serde_json::Value> {
255        idls.iter()
256            .flat_map(|idl| &idl.types)
257            .map(|type_def| self.convert_type_def_to_template_data(type_def))
258            .collect()
259    }
260
261    /// Convert IDL type definition to template data (simplified)
262    fn convert_type_def_to_template_data(&self, type_def: &IdlTypeDef) -> serde_json::Value {
263        match &type_def.ty {
264            IdlTypeDefTy::Struct { fields } => json!({
265                "type": "struct",
266                "name": type_def.name,
267                "fields": fields.as_ref().map(|f| self.convert_fields_to_template_data(f))
268            }),
269            IdlTypeDefTy::Enum { variants } => json!({
270                "type": "enum",
271                "name": type_def.name,
272                "variants": variants.iter().map(|v| json!({
273                    "name": v.name,
274                    "fields": v.fields.as_ref().map(|f| self.convert_fields_to_template_data(f))
275                })).collect::<Vec<_>>()
276            }),
277            IdlTypeDefTy::Type { .. } => json!({
278                "type": "type_alias",
279                "name": type_def.name
280            }),
281        }
282    }
283
284    /// Helper to convert fields to template data
285    fn convert_fields_to_template_data(
286        &self,
287        fields: &trident_idl_spec::IdlDefinedFields,
288    ) -> serde_json::Value {
289        match fields {
290            trident_idl_spec::IdlDefinedFields::Named(named) => json!({
291                "type": "named",
292                "fields": named.iter().map(|field| json!({
293                    "name": field.name,
294                    "rust_type": self.idl_type_to_rust(&field.ty)
295                })).collect::<Vec<_>>()
296            }),
297            trident_idl_spec::IdlDefinedFields::Tuple(tuple) => json!({
298                "type": "tuple",
299                "fields": tuple.iter().enumerate().map(|(i, field_type)| json!({
300                    "name": format!("field_{}", i),
301                    "rust_type": self.idl_type_to_rust(field_type)
302                })).collect::<Vec<_>>()
303            }),
304        }
305    }
306
307    /// Extract seeds from PDA seeds
308    fn extract_seeds(
309        &self,
310        seeds: &[trident_idl_spec::IdlSeed],
311    ) -> (Vec<String>, Vec<String>, Vec<String>) {
312        let mut static_seeds = Vec::new();
313        let mut account_seeds = Vec::new();
314        let mut arg_seeds = Vec::new();
315
316        for seed in seeds {
317            match seed {
318                trident_idl_spec::IdlSeed::Const(const_seed) => {
319                    // Convert byte array to Rust byte array literal
320                    let bytes_str = format!(
321                        "[{}]",
322                        const_seed
323                            .value
324                            .iter()
325                            .map(|b| format!("{}u8", b))
326                            .collect::<Vec<_>>()
327                            .join(", ")
328                    );
329                    static_seeds.push(bytes_str);
330                }
331                trident_idl_spec::IdlSeed::Account(account_seed) => {
332                    // Account reference for PDA seeds
333                    account_seeds.push(account_seed.path.clone());
334                }
335                trident_idl_spec::IdlSeed::Arg(arg_seed) => {
336                    // Argument reference for PDA seeds
337                    arg_seeds.push(arg_seed.path.clone());
338                }
339            }
340        }
341
342        (static_seeds, account_seeds, arg_seeds)
343    }
344
345    /// Build instruction data
346    fn build_instruction_data_with_lifetimes(
347        &self,
348        instruction: &IdlInstruction,
349        program_id: &str,
350        _composite_lifetime_map: &std::collections::HashMap<String, bool>,
351    ) -> Result<serde_json::Value, TemplateError> {
352        let name = &instruction.name;
353        let camel_name = name.to_case(Case::UpperCamel);
354        let snake_name = name.to_case(Case::Snake);
355
356        let discriminator = if instruction.discriminator.is_empty() {
357            self.generate_discriminator(name)
358        } else {
359            instruction.discriminator.clone()
360        };
361
362        let (accounts, composite_accounts) =
363            self.process_accounts_with_lifetimes(&instruction.accounts);
364        let data_fields = self.process_data_fields(&instruction.args);
365
366        Ok(json!({
367            "name": name,
368            "camel_name": camel_name,
369            "snake_name": snake_name,
370            "program_id": program_id,
371            "discriminator": discriminator,
372            "accounts": accounts,
373            "composite_accounts": composite_accounts,
374            "data_fields": data_fields,
375            "needs_lifetime": false
376        }))
377    }
378
379    #[allow(clippy::only_used_in_recursion)]
380    fn process_accounts_with_lifetimes(
381        &self,
382        accounts: &[trident_idl_spec::IdlInstructionAccountItem],
383    ) -> (Vec<serde_json::Value>, Vec<serde_json::Value>) {
384        let mut main_accounts = Vec::new();
385        let mut composite_accounts = Vec::new();
386
387        for account in accounts {
388            match account {
389                trident_idl_spec::IdlInstructionAccountItem::Single(acc) => {
390                    let has_pda_seeds = acc.pda.is_some();
391                    let (static_seeds, account_seeds, arg_seeds) = if let Some(pda) = &acc.pda {
392                        self.extract_seeds(&pda.seeds)
393                    } else {
394                        (Vec::new(), Vec::new(), Vec::new())
395                    };
396
397                    main_accounts.push(json!({
398                        "name": acc.name,
399                        "is_signer": acc.signer,
400                        "is_writable": acc.writable,
401                        "address": acc.address,
402                        "is_composite": false,
403                        "composite_type_name": null,
404                        "has_pda_seeds": has_pda_seeds,
405                        "composite_needs_lifetime": false,
406                        "static_seeds": static_seeds,
407                        "account_seeds": account_seeds,
408                        "arg_seeds": arg_seeds
409                    }));
410                }
411                trident_idl_spec::IdlInstructionAccountItem::Composite(comp) => {
412                    let camel_name = comp.name.to_case(Case::UpperCamel);
413
414                    // Add to main accounts as composite reference
415                    main_accounts.push(json!({
416                        "name": comp.name,
417                        "is_signer": false,
418                        "is_writable": false,
419                        "address": null,
420                        "is_composite": true,
421                        "composite_type_name": camel_name,
422                        "has_pda_seeds": false,
423                        "composite_needs_lifetime": false
424                    }));
425
426                    // Process composite account itself
427                    let (comp_accounts, nested_composites) =
428                        self.process_accounts_with_lifetimes(&comp.accounts);
429
430                    composite_accounts.push(json!({
431                        "name": comp.name,
432                        "camel_name": camel_name,
433                        "accounts": comp_accounts,
434                        "nested_composites": nested_composites,
435                        "needs_lifetime": false
436                    }));
437                }
438            }
439        }
440
441        // Preserve original IDL order - account order is critical for Solana programs
442        (main_accounts, composite_accounts)
443    }
444}
445
446#[derive(Debug, Clone)]
447pub struct GeneratedFiles {
448    pub test_fuzz: String,
449    pub types: String,
450    pub fuzz_accounts: String,
451    pub trident_toml: String,
452    pub cargo_fuzz_toml: String,
453}
454
455impl Default for TridentTemplates {
456    fn default() -> Self {
457        Self::new().expect("Failed to create template engine")
458    }
459}