Skip to main content

tauri_typegen/generators/
mod.rs

1pub mod base;
2pub mod ts;
3pub mod zod;
4
5use crate::analysis::CommandAnalyzer;
6use crate::models::{CommandInfo, EventInfo, StructInfo};
7use crate::GenerateConfig;
8use base::template_context::{CommandContext, EventContext, FieldContext, StructContext};
9use base::type_visitor::TypeVisitor;
10use std::collections::HashMap;
11
12pub use base::templates::GlobalContext;
13pub use base::BaseBindingsGenerator as BindingsGenerator;
14pub use ts::generator::TypeScriptBindingsGenerator;
15pub use zod::generator::ZodBindingsGenerator;
16
17/// Macro to reduce boilerplate for template registration
18#[macro_export]
19macro_rules! template {
20    ($tera:expr, $name:expr, $path:expr) => {
21        $tera
22            .add_raw_template($name, include_str!($path))
23            .map_err(|e| format!("Failed to register {}: {}", $name, e))?;
24    };
25}
26
27/// Factory function to create the appropriate bindings generator
28/// Returns a boxed trait object for polymorphism
29pub fn create_generator(validation_library: Option<String>) -> Box<dyn BindingsGenerator> {
30    match validation_library.as_deref().unwrap_or("none") {
31        "zod" => Box::new(ZodBindingsGenerator::new()),
32        _ => Box::new(TypeScriptBindingsGenerator::new()),
33    }
34}
35
36/// Utility for collecting and organizing types for bindings generation
37///
38/// This struct provides filtering and transformation utilities that sit between
39/// the analysis phase (which produces TypeStructure) and the generation phase
40/// (which consumes filtered types and contexts). It acts as a one-stop-shop for
41/// filtering unused code and collecting only the types needed for generation.
42pub struct TypeCollector {
43    pub known_structs: HashMap<String, StructInfo>,
44}
45
46impl TypeCollector {
47    pub fn new() -> Self {
48        Self {
49            known_structs: HashMap::new(),
50        }
51    }
52
53    /// Filter only the types used by commands
54    pub fn collect_used_types(
55        &self,
56        commands: &[CommandInfo],
57        events: &[EventInfo],
58        all_structs: &HashMap<String, StructInfo>,
59    ) -> HashMap<String, StructInfo> {
60        let mut used_types = std::collections::HashSet::new();
61
62        // Collect types from commands using structured TypeStructure
63        for command in commands {
64            // Add parameter types from type_structure
65            for param in &command.parameters {
66                Self::collect_referenced_types_from_structure(
67                    &param.type_structure,
68                    &mut used_types,
69                );
70            }
71            // Add return type from return_type_structure
72            Self::collect_referenced_types_from_structure(
73                &command.return_type_structure,
74                &mut used_types,
75            );
76            // Add channel message types from message_type_structure
77            for channel in &command.channels {
78                Self::collect_referenced_types_from_structure(
79                    &channel.message_type_structure,
80                    &mut used_types,
81                );
82            }
83        }
84
85        // Collect types from events
86        for event in events {
87            Self::collect_referenced_types_from_structure(
88                &event.payload_type_structure,
89                &mut used_types,
90            );
91        }
92
93        // Clone to avoid borrow checker issues
94        let initial_types = used_types.clone();
95
96        // Discover nested dependencies (types referenced by the collected types)
97        self.discover_nested_dependencies(&initial_types, all_structs, &mut used_types);
98
99        // Filter to only include used types
100        all_structs
101            .iter()
102            .filter(|(name, _)| used_types.contains(*name))
103            .map(|(k, v)| (k.clone(), v.clone()))
104            .collect()
105    }
106
107    /// Recursively discover nested dependencies
108    fn discover_nested_dependencies(
109        &self,
110        initial_types: &std::collections::HashSet<String>,
111        all_structs: &HashMap<String, StructInfo>,
112        all_types: &mut std::collections::HashSet<String>,
113    ) {
114        let mut to_process: Vec<String> = initial_types.iter().cloned().collect();
115        let mut processed: std::collections::HashSet<String> = std::collections::HashSet::new();
116
117        while let Some(type_name) = to_process.pop() {
118            if processed.contains(&type_name) {
119                continue;
120            }
121            processed.insert(type_name.clone());
122
123            if let Some(struct_info) = all_structs.get(&type_name) {
124                // Collect from fields (for structs and legacy enums)
125                for field in &struct_info.fields {
126                    let mut nested_types = std::collections::HashSet::new();
127                    Self::collect_referenced_types_from_structure(
128                        &field.type_structure,
129                        &mut nested_types,
130                    );
131
132                    for nested_type in nested_types {
133                        if !all_types.contains(&nested_type)
134                            && all_structs.contains_key(&nested_type)
135                        {
136                            all_types.insert(nested_type.clone());
137                            to_process.push(nested_type);
138                        }
139                    }
140                }
141
142                // Collect from enum variants (for richer enums)
143                if let Some(variants) = &struct_info.enum_variants {
144                    for variant in variants {
145                        match &variant.kind {
146                            crate::models::EnumVariantKind::Unit => {}
147                            crate::models::EnumVariantKind::Tuple(types) => {
148                                for type_struct in types {
149                                    let mut nested_types = std::collections::HashSet::new();
150                                    Self::collect_referenced_types_from_structure(
151                                        type_struct,
152                                        &mut nested_types,
153                                    );
154
155                                    for nested_type in nested_types {
156                                        if !all_types.contains(&nested_type)
157                                            && all_structs.contains_key(&nested_type)
158                                        {
159                                            all_types.insert(nested_type.clone());
160                                            to_process.push(nested_type);
161                                        }
162                                    }
163                                }
164                            }
165                            crate::models::EnumVariantKind::Struct(fields) => {
166                                for field in fields {
167                                    let mut nested_types = std::collections::HashSet::new();
168                                    Self::collect_referenced_types_from_structure(
169                                        &field.type_structure,
170                                        &mut nested_types,
171                                    );
172
173                                    for nested_type in nested_types {
174                                        if !all_types.contains(&nested_type)
175                                            && all_structs.contains_key(&nested_type)
176                                        {
177                                            all_types.insert(nested_type.clone());
178                                            to_process.push(nested_type);
179                                        }
180                                    }
181                                }
182                            }
183                        }
184                    }
185                }
186            }
187        }
188    }
189
190    /// Recursively collect custom type names from TypeStructure
191    /// Works directly with structured type information instead of string parsing
192    pub fn collect_referenced_types_from_structure(
193        type_structure: &crate::TypeStructure,
194        used_types: &mut std::collections::HashSet<String>,
195    ) {
196        use crate::TypeStructure;
197
198        match type_structure {
199            TypeStructure::Custom(name) => {
200                used_types.insert(name.clone());
201            }
202            TypeStructure::Array(inner)
203            | TypeStructure::Set(inner)
204            | TypeStructure::Optional(inner)
205            | TypeStructure::Result(inner) => {
206                Self::collect_referenced_types_from_structure(inner, used_types);
207            }
208            TypeStructure::Map { key, value } => {
209                Self::collect_referenced_types_from_structure(key, used_types);
210                Self::collect_referenced_types_from_structure(value, used_types);
211            }
212            TypeStructure::Tuple(types) => {
213                for t in types {
214                    Self::collect_referenced_types_from_structure(t, used_types);
215                }
216            }
217            TypeStructure::Primitive(_) => {
218                // Primitives are not custom types
219            }
220        }
221    }
222
223    /// Create CommandContext instances from CommandInfo using the provided visitor
224    pub fn create_command_contexts<V: TypeVisitor>(
225        &self,
226        commands: &[CommandInfo],
227        visitor: &V,
228        analyzer: &CommandAnalyzer,
229        config: &GenerateConfig,
230    ) -> Vec<CommandContext> {
231        let type_resolver = analyzer.get_type_resolver();
232        let mut sorted_commands: Vec<_> = commands.iter().collect();
233        sorted_commands.sort_by(|a, b| {
234            a.name
235                .cmp(&b.name)
236                .then_with(|| a.file_path.cmp(&b.file_path))
237                .then_with(|| a.line_number.cmp(&b.line_number))
238        });
239
240        // Deduplicate commands by name - first occurrence wins. The same
241        // command can be declared more than once under mutually-exclusive
242        // `#[cfg(...)]` gates (the standard cross-platform Tauri pattern);
243        // emitting both would produce duplicate TypeScript declarations.
244        let mut seen_commands: std::collections::HashSet<&str> = std::collections::HashSet::new();
245        sorted_commands.retain(|cmd| seen_commands.insert(cmd.name.as_str()));
246
247        sorted_commands
248            .into_iter()
249            .map(|cmd| {
250                CommandContext::new(config).from_command_info(cmd, visitor, &|rust_type: &str| {
251                    type_resolver.borrow_mut().parse_type_structure(rust_type)
252                })
253            })
254            .collect()
255    }
256
257    /// Create EventContext instances from EventInfo using the provided visitor
258    pub fn create_event_contexts<V: TypeVisitor>(
259        &self,
260        events: &[EventInfo],
261        visitor: &V,
262        analyzer: &CommandAnalyzer,
263        config: &GenerateConfig,
264    ) -> Vec<EventContext> {
265        let type_resolver = analyzer.get_type_resolver();
266
267        // Deduplicate events by name - first occurrence wins
268        let mut seen_events: std::collections::HashSet<&str> = std::collections::HashSet::new();
269        let mut sorted_events: Vec<&EventInfo> = Vec::new();
270        for event in events {
271            if seen_events.insert(event.event_name.as_str()) {
272                sorted_events.push(event);
273            }
274        }
275
276        sorted_events.sort_by(|a, b| {
277            a.event_name
278                .cmp(&b.event_name)
279                .then_with(|| a.file_path.cmp(&b.file_path))
280                .then_with(|| a.line_number.cmp(&b.line_number))
281                .then_with(|| a.payload_type.cmp(&b.payload_type))
282        });
283
284        sorted_events
285            .into_iter()
286            .map(|event| {
287                EventContext::new(config).from_event_info(event, visitor, &|rust_type: &str| {
288                    type_resolver.borrow_mut().parse_type_structure(rust_type)
289                })
290            })
291            .collect()
292    }
293
294    /// Create StructContext instances from StructInfo using the provided visitor
295    pub fn create_struct_contexts<V: TypeVisitor>(
296        &self,
297        used_structs: &HashMap<String, StructInfo>,
298        visitor: &V,
299        config: &GenerateConfig,
300    ) -> Vec<StructContext> {
301        let mut sorted_structs: Vec<_> = used_structs.iter().collect();
302        sorted_structs.sort_by(|(name_a, struct_a), (name_b, struct_b)| {
303            name_a
304                .cmp(name_b)
305                .then_with(|| struct_a.file_path.cmp(&struct_b.file_path))
306        });
307
308        sorted_structs
309            .into_iter()
310            .map(|(name, struct_info)| {
311                StructContext::new(config).from_struct_info(name, struct_info, visitor)
312            })
313            .collect()
314    }
315
316    /// Create FieldContext instances from StructInfo using the provided visitor
317    pub fn create_field_contexts<V: TypeVisitor>(
318        &self,
319        struct_info: &StructInfo,
320        visitor: &V,
321        config: &GenerateConfig,
322    ) -> Vec<FieldContext> {
323        struct_info
324            .fields
325            .iter()
326            .map(|field| {
327                FieldContext::new(config).from_field_info(
328                    field,
329                    &struct_info.serde_rename_all,
330                    visitor,
331                )
332            })
333            .collect()
334    }
335}
336
337impl Default for TypeCollector {
338    fn default() -> Self {
339        Self::new()
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use crate::TypeStructure;
347    use std::collections::HashSet;
348
349    mod factory {
350        use super::*;
351
352        #[test]
353        fn test_create_generator_zod() {
354            let gen = create_generator(Some("zod".to_string()));
355            // Just verify it creates without panic - we can't easily inspect trait objects
356            assert!(std::any::type_name_of_val(&gen).contains("Box"));
357        }
358
359        #[test]
360        fn test_create_generator_none() {
361            let gen = create_generator(Some("none".to_string()));
362            assert!(std::any::type_name_of_val(&gen).contains("Box"));
363        }
364
365        #[test]
366        fn test_create_generator_default() {
367            let gen = create_generator(None);
368            assert!(std::any::type_name_of_val(&gen).contains("Box"));
369        }
370
371        #[test]
372        fn test_create_generator_unknown_fallback() {
373            let gen = create_generator(Some("unknown".to_string()));
374            assert!(std::any::type_name_of_val(&gen).contains("Box"));
375        }
376    }
377
378    mod type_collector {
379        use super::*;
380
381        #[test]
382        fn test_new_creates_empty_collector() {
383            let collector = TypeCollector::new();
384            assert!(collector.known_structs.is_empty());
385        }
386
387        #[test]
388        fn test_default_creates_empty_collector() {
389            let collector = TypeCollector::default();
390            assert!(collector.known_structs.is_empty());
391        }
392    }
393
394    mod collect_referenced_types {
395        use super::*;
396
397        #[test]
398        fn test_collect_primitive() {
399            let mut used = HashSet::new();
400            let ts = TypeStructure::Primitive("string".to_string());
401            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
402            assert!(used.is_empty());
403        }
404
405        #[test]
406        fn test_collect_custom() {
407            let mut used = HashSet::new();
408            let ts = TypeStructure::Custom("User".to_string());
409            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
410            assert_eq!(used.len(), 1);
411            assert!(used.contains("User"));
412        }
413
414        #[test]
415        fn test_collect_array() {
416            let mut used = HashSet::new();
417            let ts = TypeStructure::Array(Box::new(TypeStructure::Custom("User".to_string())));
418            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
419            assert_eq!(used.len(), 1);
420            assert!(used.contains("User"));
421        }
422
423        #[test]
424        fn test_collect_optional() {
425            let mut used = HashSet::new();
426            let ts = TypeStructure::Optional(Box::new(TypeStructure::Custom("User".to_string())));
427            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
428            assert_eq!(used.len(), 1);
429            assert!(used.contains("User"));
430        }
431
432        #[test]
433        fn test_collect_result() {
434            let mut used = HashSet::new();
435            let ts = TypeStructure::Result(Box::new(TypeStructure::Custom("User".to_string())));
436            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
437            assert_eq!(used.len(), 1);
438            assert!(used.contains("User"));
439        }
440
441        #[test]
442        fn test_collect_set() {
443            let mut used = HashSet::new();
444            let ts = TypeStructure::Set(Box::new(TypeStructure::Custom("User".to_string())));
445            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
446            assert_eq!(used.len(), 1);
447            assert!(used.contains("User"));
448        }
449
450        #[test]
451        fn test_collect_map() {
452            let mut used = HashSet::new();
453            let ts = TypeStructure::Map {
454                key: Box::new(TypeStructure::Primitive("string".to_string())),
455                value: Box::new(TypeStructure::Custom("User".to_string())),
456            };
457            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
458            assert_eq!(used.len(), 1);
459            assert!(used.contains("User"));
460        }
461
462        #[test]
463        fn test_collect_map_both_custom() {
464            let mut used = HashSet::new();
465            let ts = TypeStructure::Map {
466                key: Box::new(TypeStructure::Custom("UserId".to_string())),
467                value: Box::new(TypeStructure::Custom("User".to_string())),
468            };
469            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
470            assert_eq!(used.len(), 2);
471            assert!(used.contains("User"));
472            assert!(used.contains("UserId"));
473        }
474
475        #[test]
476        fn test_collect_tuple() {
477            let mut used = HashSet::new();
478            let ts = TypeStructure::Tuple(vec![
479                TypeStructure::Custom("User".to_string()),
480                TypeStructure::Custom("Product".to_string()),
481            ]);
482            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
483            assert_eq!(used.len(), 2);
484            assert!(used.contains("User"));
485            assert!(used.contains("Product"));
486        }
487
488        #[test]
489        fn test_collect_nested() {
490            let mut used = HashSet::new();
491            let ts = TypeStructure::Array(Box::new(TypeStructure::Optional(Box::new(
492                TypeStructure::Custom("User".to_string()),
493            ))));
494            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
495            assert_eq!(used.len(), 1);
496            assert!(used.contains("User"));
497        }
498
499        #[test]
500        fn test_collect_multiple_calls_accumulate() {
501            let mut used = HashSet::new();
502            let ts1 = TypeStructure::Custom("User".to_string());
503            let ts2 = TypeStructure::Custom("Product".to_string());
504            TypeCollector::collect_referenced_types_from_structure(&ts1, &mut used);
505            TypeCollector::collect_referenced_types_from_structure(&ts2, &mut used);
506            assert_eq!(used.len(), 2);
507        }
508
509        #[test]
510        fn test_collect_duplicates_deduped() {
511            let mut used = HashSet::new();
512            let ts = TypeStructure::Custom("User".to_string());
513            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
514            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
515            assert_eq!(used.len(), 1);
516        }
517    }
518
519    mod collect_used_types {
520        use super::*;
521        use crate::models::{CommandInfo, ParameterInfo, StructInfo};
522
523        fn create_struct(name: &str) -> StructInfo {
524            StructInfo {
525                name: name.to_string(),
526                fields: vec![],
527                file_path: "test.rs".to_string(),
528                is_enum: false,
529                serde_rename_all: None,
530                serde_tag: None,
531                enum_variants: None,
532            }
533        }
534
535        fn create_param(
536            name: &str,
537            rust_type: &str,
538            type_structure: TypeStructure,
539        ) -> ParameterInfo {
540            ParameterInfo {
541                name: name.to_string(),
542                rust_type: rust_type.to_string(),
543                is_optional: false,
544                type_structure,
545                serde_rename: None,
546            }
547        }
548
549        #[test]
550        fn test_collect_from_empty_commands() {
551            let collector = TypeCollector::new();
552            let commands = vec![];
553            let all_structs = HashMap::new();
554            let used = collector.collect_used_types(&commands, &[], &all_structs);
555            assert!(used.is_empty());
556        }
557
558        #[test]
559        fn test_collect_from_command_parameters() {
560            let collector = TypeCollector::new();
561            let mut all_structs = HashMap::new();
562            let user_struct = create_struct("User");
563            all_structs.insert("User".to_string(), user_struct.clone());
564
565            let param = create_param("user", "User", TypeStructure::Custom("User".to_string()));
566            let command = CommandInfo::new_for_test(
567                "greet",
568                "test.rs",
569                1,
570                vec![param],
571                "string",
572                false,
573                vec![],
574            );
575
576            let used = collector.collect_used_types(&[command], &[], &all_structs);
577            assert_eq!(used.len(), 1);
578            assert!(used.contains_key("User"));
579        }
580
581        #[test]
582        fn test_collect_from_command_return_type() {
583            let collector = TypeCollector::new();
584            let mut all_structs = HashMap::new();
585            let result_struct = create_struct("ApiResult");
586            all_structs.insert("ApiResult".to_string(), result_struct.clone());
587
588            // Create command that returns ApiResult
589            let mut command = CommandInfo::new_for_test(
590                "fetch_data",
591                "test.rs",
592                1,
593                vec![],
594                "ApiResult",
595                false,
596                vec![],
597            );
598            // Set the return_type_structure
599            command.return_type_structure = TypeStructure::Custom("ApiResult".to_string());
600
601            let used = collector.collect_used_types(&[command], &[], &all_structs);
602            assert_eq!(used.len(), 1);
603            assert!(used.contains_key("ApiResult"));
604        }
605
606        #[test]
607        fn test_filters_unused_types() {
608            let collector = TypeCollector::new();
609            let mut all_structs = HashMap::new();
610
611            // Add two structs but only use one
612            let user_struct = create_struct("User");
613            let product_struct = create_struct("Product");
614            all_structs.insert("User".to_string(), user_struct);
615            all_structs.insert("Product".to_string(), product_struct);
616
617            let param = create_param("user", "User", TypeStructure::Custom("User".to_string()));
618            let command = CommandInfo::new_for_test(
619                "greet",
620                "test.rs",
621                1,
622                vec![param],
623                "string",
624                false,
625                vec![],
626            );
627
628            let used = collector.collect_used_types(&[command], &[], &all_structs);
629            assert_eq!(used.len(), 1);
630            assert!(used.contains_key("User"));
631            assert!(!used.contains_key("Product"));
632        }
633    }
634
635    mod nested_dependencies {
636        use super::*;
637        use crate::models::{CommandInfo, FieldInfo, ParameterInfo, StructInfo};
638
639        fn create_field(name: &str, rust_type: &str, type_structure: TypeStructure) -> FieldInfo {
640            FieldInfo {
641                name: name.to_string(),
642                rust_type: rust_type.to_string(),
643                is_optional: false,
644                is_public: true,
645                validator_attributes: None,
646                serde_rename: None,
647                type_structure,
648            }
649        }
650
651        fn create_struct_with_fields(name: &str, fields: Vec<FieldInfo>) -> StructInfo {
652            StructInfo {
653                name: name.to_string(),
654                fields,
655                file_path: "test.rs".to_string(),
656                is_enum: false,
657                serde_rename_all: None,
658                serde_tag: None,
659                enum_variants: None,
660            }
661        }
662
663        fn create_param(
664            name: &str,
665            rust_type: &str,
666            type_structure: TypeStructure,
667        ) -> ParameterInfo {
668            ParameterInfo {
669                name: name.to_string(),
670                rust_type: rust_type.to_string(),
671                is_optional: false,
672                type_structure,
673                serde_rename: None,
674            }
675        }
676
677        #[test]
678        fn test_discovers_nested_dependencies() {
679            let collector = TypeCollector::new();
680            let mut all_structs = HashMap::new();
681
682            // User has a field of type Address
683            let address_field = create_field(
684                "address",
685                "Address",
686                TypeStructure::Custom("Address".to_string()),
687            );
688            let user_struct = create_struct_with_fields("User", vec![address_field]);
689            let address_struct = create_struct_with_fields("Address", vec![]);
690
691            all_structs.insert("User".to_string(), user_struct);
692            all_structs.insert("Address".to_string(), address_struct);
693
694            // Command only uses User directly
695            let param = create_param("user", "User", TypeStructure::Custom("User".to_string()));
696            let command = CommandInfo::new_for_test(
697                "greet",
698                "test.rs",
699                1,
700                vec![param],
701                "string",
702                false,
703                vec![],
704            );
705
706            let used = collector.collect_used_types(&[command], &[], &all_structs);
707
708            // Should include both User and Address
709            assert_eq!(used.len(), 2);
710            assert!(used.contains_key("User"));
711            assert!(used.contains_key("Address"));
712        }
713
714        #[test]
715        fn test_handles_deep_nesting() {
716            let collector = TypeCollector::new();
717            let mut all_structs = HashMap::new();
718
719            // A -> B -> C chain
720            let c_struct = create_struct_with_fields("C", vec![]);
721            let b_field = create_field("c", "C", TypeStructure::Custom("C".to_string()));
722            let b_struct = create_struct_with_fields("B", vec![b_field]);
723            let a_field = create_field("b", "B", TypeStructure::Custom("B".to_string()));
724            let a_struct = create_struct_with_fields("A", vec![a_field]);
725
726            all_structs.insert("A".to_string(), a_struct);
727            all_structs.insert("B".to_string(), b_struct);
728            all_structs.insert("C".to_string(), c_struct);
729
730            let param = create_param("data", "A", TypeStructure::Custom("A".to_string()));
731            let command = CommandInfo::new_for_test(
732                "process",
733                "test.rs",
734                1,
735                vec![param],
736                "void",
737                false,
738                vec![],
739            );
740
741            let used = collector.collect_used_types(&[command], &[], &all_structs);
742
743            // Should include A, B, and C
744            assert_eq!(used.len(), 3);
745            assert!(used.contains_key("A"));
746            assert!(used.contains_key("B"));
747            assert!(used.contains_key("C"));
748        }
749    }
750}