Skip to main content

tauri_typegen/generators/
mod.rs

1pub mod base;
2pub mod ts;
3pub mod zod;
4
5use crate::analysis::CommandAnalyzer;
6use crate::models::{CommandInfo, EventInfo, StructInfo};
7use crate::GenerateConfig;
8use base::template_context::{CommandContext, EventContext, FieldContext, StructContext};
9use base::type_visitor::TypeVisitor;
10use std::collections::HashMap;
11
12pub use base::templates::GlobalContext;
13pub use base::BaseBindingsGenerator as BindingsGenerator;
14pub use ts::generator::TypeScriptBindingsGenerator;
15pub use zod::generator::ZodBindingsGenerator;
16
17/// Macro to reduce boilerplate for template registration
18#[macro_export]
19macro_rules! template {
20    ($tera:expr, $name:expr, $path:expr) => {
21        $tera
22            .add_raw_template($name, include_str!($path))
23            .map_err(|e| format!("Failed to register {}: {}", $name, e))?;
24    };
25}
26
27/// Factory function to create the appropriate bindings generator
28/// Returns a boxed trait object for polymorphism
29pub fn create_generator(validation_library: Option<String>) -> Box<dyn BindingsGenerator> {
30    match validation_library.as_deref().unwrap_or("none") {
31        "zod" => Box::new(ZodBindingsGenerator::new()),
32        _ => Box::new(TypeScriptBindingsGenerator::new()),
33    }
34}
35
36/// Utility for collecting and organizing types for bindings generation
37///
38/// This struct provides filtering and transformation utilities that sit between
39/// the analysis phase (which produces TypeStructure) and the generation phase
40/// (which consumes filtered types and contexts). It acts as a one-stop-shop for
41/// filtering unused code and collecting only the types needed for generation.
42pub struct TypeCollector {
43    pub known_structs: HashMap<String, StructInfo>,
44}
45
46impl TypeCollector {
47    pub fn new() -> Self {
48        Self {
49            known_structs: HashMap::new(),
50        }
51    }
52
53    /// Filter only the types used by commands
54    pub fn collect_used_types(
55        &self,
56        commands: &[CommandInfo],
57        events: &[EventInfo],
58        all_structs: &HashMap<String, StructInfo>,
59    ) -> HashMap<String, StructInfo> {
60        let mut used_types = std::collections::HashSet::new();
61
62        // Collect types from commands using structured TypeStructure
63        for command in commands {
64            // Add parameter types from type_structure
65            for param in &command.parameters {
66                Self::collect_referenced_types_from_structure(
67                    &param.type_structure,
68                    &mut used_types,
69                );
70            }
71            // Add return type from return_type_structure
72            Self::collect_referenced_types_from_structure(
73                &command.return_type_structure,
74                &mut used_types,
75            );
76            // Add channel message types from message_type_structure
77            for channel in &command.channels {
78                Self::collect_referenced_types_from_structure(
79                    &channel.message_type_structure,
80                    &mut used_types,
81                );
82            }
83        }
84
85        // Collect types from events
86        for event in events {
87            Self::collect_referenced_types_from_structure(
88                &event.payload_type_structure,
89                &mut used_types,
90            );
91        }
92
93        // Clone to avoid borrow checker issues
94        let initial_types = used_types.clone();
95
96        // Discover nested dependencies (types referenced by the collected types)
97        self.discover_nested_dependencies(&initial_types, all_structs, &mut used_types);
98
99        // Filter to only include used types
100        all_structs
101            .iter()
102            .filter(|(name, _)| used_types.contains(*name))
103            .map(|(k, v)| (k.clone(), v.clone()))
104            .collect()
105    }
106
107    /// Recursively discover nested dependencies
108    fn discover_nested_dependencies(
109        &self,
110        initial_types: &std::collections::HashSet<String>,
111        all_structs: &HashMap<String, StructInfo>,
112        all_types: &mut std::collections::HashSet<String>,
113    ) {
114        let mut to_process: Vec<String> = initial_types.iter().cloned().collect();
115        let mut processed: std::collections::HashSet<String> = std::collections::HashSet::new();
116
117        while let Some(type_name) = to_process.pop() {
118            if processed.contains(&type_name) {
119                continue;
120            }
121            processed.insert(type_name.clone());
122
123            if let Some(struct_info) = all_structs.get(&type_name) {
124                // Collect from fields (for structs and legacy enums)
125                for field in &struct_info.fields {
126                    let mut nested_types = std::collections::HashSet::new();
127                    Self::collect_referenced_types_from_structure(
128                        &field.type_structure,
129                        &mut nested_types,
130                    );
131
132                    for nested_type in nested_types {
133                        if !all_types.contains(&nested_type)
134                            && all_structs.contains_key(&nested_type)
135                        {
136                            all_types.insert(nested_type.clone());
137                            to_process.push(nested_type);
138                        }
139                    }
140                }
141
142                // Collect from enum variants (for richer enums)
143                if let Some(variants) = &struct_info.enum_variants {
144                    for variant in variants {
145                        match &variant.kind {
146                            crate::models::EnumVariantKind::Unit => {}
147                            crate::models::EnumVariantKind::Tuple(types) => {
148                                for type_struct in types {
149                                    let mut nested_types = std::collections::HashSet::new();
150                                    Self::collect_referenced_types_from_structure(
151                                        type_struct,
152                                        &mut nested_types,
153                                    );
154
155                                    for nested_type in nested_types {
156                                        if !all_types.contains(&nested_type)
157                                            && all_structs.contains_key(&nested_type)
158                                        {
159                                            all_types.insert(nested_type.clone());
160                                            to_process.push(nested_type);
161                                        }
162                                    }
163                                }
164                            }
165                            crate::models::EnumVariantKind::Struct(fields) => {
166                                for field in fields {
167                                    let mut nested_types = std::collections::HashSet::new();
168                                    Self::collect_referenced_types_from_structure(
169                                        &field.type_structure,
170                                        &mut nested_types,
171                                    );
172
173                                    for nested_type in nested_types {
174                                        if !all_types.contains(&nested_type)
175                                            && all_structs.contains_key(&nested_type)
176                                        {
177                                            all_types.insert(nested_type.clone());
178                                            to_process.push(nested_type);
179                                        }
180                                    }
181                                }
182                            }
183                        }
184                    }
185                }
186            }
187        }
188    }
189
190    /// Recursively collect custom type names from TypeStructure
191    /// Works directly with structured type information instead of string parsing
192    pub fn collect_referenced_types_from_structure(
193        type_structure: &crate::TypeStructure,
194        used_types: &mut std::collections::HashSet<String>,
195    ) {
196        use crate::TypeStructure;
197
198        match type_structure {
199            TypeStructure::Custom(name) => {
200                used_types.insert(name.clone());
201            }
202            TypeStructure::Array(inner)
203            | TypeStructure::Set(inner)
204            | TypeStructure::Optional(inner)
205            | TypeStructure::Result(inner) => {
206                Self::collect_referenced_types_from_structure(inner, used_types);
207            }
208            TypeStructure::Map { key, value } => {
209                Self::collect_referenced_types_from_structure(key, used_types);
210                Self::collect_referenced_types_from_structure(value, used_types);
211            }
212            TypeStructure::Tuple(types) => {
213                for t in types {
214                    Self::collect_referenced_types_from_structure(t, used_types);
215                }
216            }
217            TypeStructure::Primitive(_) => {
218                // Primitives are not custom types
219            }
220        }
221    }
222
223    /// Create CommandContext instances from CommandInfo using the provided visitor
224    pub fn create_command_contexts<V: TypeVisitor>(
225        &self,
226        commands: &[CommandInfo],
227        visitor: &V,
228        analyzer: &CommandAnalyzer,
229        config: &GenerateConfig,
230    ) -> Vec<CommandContext> {
231        let type_resolver = analyzer.get_type_resolver();
232        let mut sorted_commands: Vec<_> = commands.iter().collect();
233        sorted_commands.sort_by(|a, b| {
234            a.name
235                .cmp(&b.name)
236                .then_with(|| a.file_path.cmp(&b.file_path))
237                .then_with(|| a.line_number.cmp(&b.line_number))
238        });
239
240        sorted_commands
241            .into_iter()
242            .map(|cmd| {
243                CommandContext::new(config).from_command_info(cmd, visitor, &|rust_type: &str| {
244                    type_resolver.borrow_mut().parse_type_structure(rust_type)
245                })
246            })
247            .collect()
248    }
249
250    /// Create EventContext instances from EventInfo using the provided visitor
251    pub fn create_event_contexts<V: TypeVisitor>(
252        &self,
253        events: &[EventInfo],
254        visitor: &V,
255        analyzer: &CommandAnalyzer,
256        config: &GenerateConfig,
257    ) -> Vec<EventContext> {
258        let type_resolver = analyzer.get_type_resolver();
259
260        // Deduplicate events by name - first occurrence wins
261        let mut seen_events: std::collections::HashSet<&str> = std::collections::HashSet::new();
262        let mut sorted_events: Vec<&EventInfo> = Vec::new();
263        for event in events {
264            if seen_events.insert(event.event_name.as_str()) {
265                sorted_events.push(event);
266            }
267        }
268
269        sorted_events.sort_by(|a, b| {
270            a.event_name
271                .cmp(&b.event_name)
272                .then_with(|| a.file_path.cmp(&b.file_path))
273                .then_with(|| a.line_number.cmp(&b.line_number))
274                .then_with(|| a.payload_type.cmp(&b.payload_type))
275        });
276
277        sorted_events
278            .into_iter()
279            .map(|event| {
280                EventContext::new(config).from_event_info(event, visitor, &|rust_type: &str| {
281                    type_resolver.borrow_mut().parse_type_structure(rust_type)
282                })
283            })
284            .collect()
285    }
286
287    /// Create StructContext instances from StructInfo using the provided visitor
288    pub fn create_struct_contexts<V: TypeVisitor>(
289        &self,
290        used_structs: &HashMap<String, StructInfo>,
291        visitor: &V,
292        config: &GenerateConfig,
293    ) -> Vec<StructContext> {
294        let mut sorted_structs: Vec<_> = used_structs.iter().collect();
295        sorted_structs.sort_by(|(name_a, struct_a), (name_b, struct_b)| {
296            name_a
297                .cmp(name_b)
298                .then_with(|| struct_a.file_path.cmp(&struct_b.file_path))
299        });
300
301        sorted_structs
302            .into_iter()
303            .map(|(name, struct_info)| {
304                StructContext::new(config).from_struct_info(name, struct_info, visitor)
305            })
306            .collect()
307    }
308
309    /// Create FieldContext instances from StructInfo using the provided visitor
310    pub fn create_field_contexts<V: TypeVisitor>(
311        &self,
312        struct_info: &StructInfo,
313        visitor: &V,
314        config: &GenerateConfig,
315    ) -> Vec<FieldContext> {
316        struct_info
317            .fields
318            .iter()
319            .map(|field| {
320                FieldContext::new(config).from_field_info(
321                    field,
322                    &struct_info.serde_rename_all,
323                    visitor,
324                )
325            })
326            .collect()
327    }
328}
329
330impl Default for TypeCollector {
331    fn default() -> Self {
332        Self::new()
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use crate::TypeStructure;
340    use std::collections::HashSet;
341
342    mod factory {
343        use super::*;
344
345        #[test]
346        fn test_create_generator_zod() {
347            let gen = create_generator(Some("zod".to_string()));
348            // Just verify it creates without panic - we can't easily inspect trait objects
349            assert!(std::any::type_name_of_val(&gen).contains("Box"));
350        }
351
352        #[test]
353        fn test_create_generator_none() {
354            let gen = create_generator(Some("none".to_string()));
355            assert!(std::any::type_name_of_val(&gen).contains("Box"));
356        }
357
358        #[test]
359        fn test_create_generator_default() {
360            let gen = create_generator(None);
361            assert!(std::any::type_name_of_val(&gen).contains("Box"));
362        }
363
364        #[test]
365        fn test_create_generator_unknown_fallback() {
366            let gen = create_generator(Some("unknown".to_string()));
367            assert!(std::any::type_name_of_val(&gen).contains("Box"));
368        }
369    }
370
371    mod type_collector {
372        use super::*;
373
374        #[test]
375        fn test_new_creates_empty_collector() {
376            let collector = TypeCollector::new();
377            assert!(collector.known_structs.is_empty());
378        }
379
380        #[test]
381        fn test_default_creates_empty_collector() {
382            let collector = TypeCollector::default();
383            assert!(collector.known_structs.is_empty());
384        }
385    }
386
387    mod collect_referenced_types {
388        use super::*;
389
390        #[test]
391        fn test_collect_primitive() {
392            let mut used = HashSet::new();
393            let ts = TypeStructure::Primitive("string".to_string());
394            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
395            assert!(used.is_empty());
396        }
397
398        #[test]
399        fn test_collect_custom() {
400            let mut used = HashSet::new();
401            let ts = TypeStructure::Custom("User".to_string());
402            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
403            assert_eq!(used.len(), 1);
404            assert!(used.contains("User"));
405        }
406
407        #[test]
408        fn test_collect_array() {
409            let mut used = HashSet::new();
410            let ts = TypeStructure::Array(Box::new(TypeStructure::Custom("User".to_string())));
411            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
412            assert_eq!(used.len(), 1);
413            assert!(used.contains("User"));
414        }
415
416        #[test]
417        fn test_collect_optional() {
418            let mut used = HashSet::new();
419            let ts = TypeStructure::Optional(Box::new(TypeStructure::Custom("User".to_string())));
420            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
421            assert_eq!(used.len(), 1);
422            assert!(used.contains("User"));
423        }
424
425        #[test]
426        fn test_collect_result() {
427            let mut used = HashSet::new();
428            let ts = TypeStructure::Result(Box::new(TypeStructure::Custom("User".to_string())));
429            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
430            assert_eq!(used.len(), 1);
431            assert!(used.contains("User"));
432        }
433
434        #[test]
435        fn test_collect_set() {
436            let mut used = HashSet::new();
437            let ts = TypeStructure::Set(Box::new(TypeStructure::Custom("User".to_string())));
438            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
439            assert_eq!(used.len(), 1);
440            assert!(used.contains("User"));
441        }
442
443        #[test]
444        fn test_collect_map() {
445            let mut used = HashSet::new();
446            let ts = TypeStructure::Map {
447                key: Box::new(TypeStructure::Primitive("string".to_string())),
448                value: Box::new(TypeStructure::Custom("User".to_string())),
449            };
450            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
451            assert_eq!(used.len(), 1);
452            assert!(used.contains("User"));
453        }
454
455        #[test]
456        fn test_collect_map_both_custom() {
457            let mut used = HashSet::new();
458            let ts = TypeStructure::Map {
459                key: Box::new(TypeStructure::Custom("UserId".to_string())),
460                value: Box::new(TypeStructure::Custom("User".to_string())),
461            };
462            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
463            assert_eq!(used.len(), 2);
464            assert!(used.contains("User"));
465            assert!(used.contains("UserId"));
466        }
467
468        #[test]
469        fn test_collect_tuple() {
470            let mut used = HashSet::new();
471            let ts = TypeStructure::Tuple(vec![
472                TypeStructure::Custom("User".to_string()),
473                TypeStructure::Custom("Product".to_string()),
474            ]);
475            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
476            assert_eq!(used.len(), 2);
477            assert!(used.contains("User"));
478            assert!(used.contains("Product"));
479        }
480
481        #[test]
482        fn test_collect_nested() {
483            let mut used = HashSet::new();
484            let ts = TypeStructure::Array(Box::new(TypeStructure::Optional(Box::new(
485                TypeStructure::Custom("User".to_string()),
486            ))));
487            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
488            assert_eq!(used.len(), 1);
489            assert!(used.contains("User"));
490        }
491
492        #[test]
493        fn test_collect_multiple_calls_accumulate() {
494            let mut used = HashSet::new();
495            let ts1 = TypeStructure::Custom("User".to_string());
496            let ts2 = TypeStructure::Custom("Product".to_string());
497            TypeCollector::collect_referenced_types_from_structure(&ts1, &mut used);
498            TypeCollector::collect_referenced_types_from_structure(&ts2, &mut used);
499            assert_eq!(used.len(), 2);
500        }
501
502        #[test]
503        fn test_collect_duplicates_deduped() {
504            let mut used = HashSet::new();
505            let ts = TypeStructure::Custom("User".to_string());
506            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
507            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
508            assert_eq!(used.len(), 1);
509        }
510    }
511
512    mod collect_used_types {
513        use super::*;
514        use crate::models::{CommandInfo, ParameterInfo, StructInfo};
515
516        fn create_struct(name: &str) -> StructInfo {
517            StructInfo {
518                name: name.to_string(),
519                fields: vec![],
520                file_path: "test.rs".to_string(),
521                is_enum: false,
522                serde_rename_all: None,
523                serde_tag: None,
524                enum_variants: None,
525            }
526        }
527
528        fn create_param(
529            name: &str,
530            rust_type: &str,
531            type_structure: TypeStructure,
532        ) -> ParameterInfo {
533            ParameterInfo {
534                name: name.to_string(),
535                rust_type: rust_type.to_string(),
536                is_optional: false,
537                type_structure,
538                serde_rename: None,
539            }
540        }
541
542        #[test]
543        fn test_collect_from_empty_commands() {
544            let collector = TypeCollector::new();
545            let commands = vec![];
546            let all_structs = HashMap::new();
547            let used = collector.collect_used_types(&commands, &[], &all_structs);
548            assert!(used.is_empty());
549        }
550
551        #[test]
552        fn test_collect_from_command_parameters() {
553            let collector = TypeCollector::new();
554            let mut all_structs = HashMap::new();
555            let user_struct = create_struct("User");
556            all_structs.insert("User".to_string(), user_struct.clone());
557
558            let param = create_param("user", "User", TypeStructure::Custom("User".to_string()));
559            let command = CommandInfo::new_for_test(
560                "greet",
561                "test.rs",
562                1,
563                vec![param],
564                "string",
565                false,
566                vec![],
567            );
568
569            let used = collector.collect_used_types(&[command], &[], &all_structs);
570            assert_eq!(used.len(), 1);
571            assert!(used.contains_key("User"));
572        }
573
574        #[test]
575        fn test_collect_from_command_return_type() {
576            let collector = TypeCollector::new();
577            let mut all_structs = HashMap::new();
578            let result_struct = create_struct("ApiResult");
579            all_structs.insert("ApiResult".to_string(), result_struct.clone());
580
581            // Create command that returns ApiResult
582            let mut command = CommandInfo::new_for_test(
583                "fetch_data",
584                "test.rs",
585                1,
586                vec![],
587                "ApiResult",
588                false,
589                vec![],
590            );
591            // Set the return_type_structure
592            command.return_type_structure = TypeStructure::Custom("ApiResult".to_string());
593
594            let used = collector.collect_used_types(&[command], &[], &all_structs);
595            assert_eq!(used.len(), 1);
596            assert!(used.contains_key("ApiResult"));
597        }
598
599        #[test]
600        fn test_filters_unused_types() {
601            let collector = TypeCollector::new();
602            let mut all_structs = HashMap::new();
603
604            // Add two structs but only use one
605            let user_struct = create_struct("User");
606            let product_struct = create_struct("Product");
607            all_structs.insert("User".to_string(), user_struct);
608            all_structs.insert("Product".to_string(), product_struct);
609
610            let param = create_param("user", "User", TypeStructure::Custom("User".to_string()));
611            let command = CommandInfo::new_for_test(
612                "greet",
613                "test.rs",
614                1,
615                vec![param],
616                "string",
617                false,
618                vec![],
619            );
620
621            let used = collector.collect_used_types(&[command], &[], &all_structs);
622            assert_eq!(used.len(), 1);
623            assert!(used.contains_key("User"));
624            assert!(!used.contains_key("Product"));
625        }
626    }
627
628    mod nested_dependencies {
629        use super::*;
630        use crate::models::{CommandInfo, FieldInfo, ParameterInfo, StructInfo};
631
632        fn create_field(name: &str, rust_type: &str, type_structure: TypeStructure) -> FieldInfo {
633            FieldInfo {
634                name: name.to_string(),
635                rust_type: rust_type.to_string(),
636                is_optional: false,
637                is_public: true,
638                validator_attributes: None,
639                serde_rename: None,
640                type_structure,
641            }
642        }
643
644        fn create_struct_with_fields(name: &str, fields: Vec<FieldInfo>) -> StructInfo {
645            StructInfo {
646                name: name.to_string(),
647                fields,
648                file_path: "test.rs".to_string(),
649                is_enum: false,
650                serde_rename_all: None,
651                serde_tag: None,
652                enum_variants: None,
653            }
654        }
655
656        fn create_param(
657            name: &str,
658            rust_type: &str,
659            type_structure: TypeStructure,
660        ) -> ParameterInfo {
661            ParameterInfo {
662                name: name.to_string(),
663                rust_type: rust_type.to_string(),
664                is_optional: false,
665                type_structure,
666                serde_rename: None,
667            }
668        }
669
670        #[test]
671        fn test_discovers_nested_dependencies() {
672            let collector = TypeCollector::new();
673            let mut all_structs = HashMap::new();
674
675            // User has a field of type Address
676            let address_field = create_field(
677                "address",
678                "Address",
679                TypeStructure::Custom("Address".to_string()),
680            );
681            let user_struct = create_struct_with_fields("User", vec![address_field]);
682            let address_struct = create_struct_with_fields("Address", vec![]);
683
684            all_structs.insert("User".to_string(), user_struct);
685            all_structs.insert("Address".to_string(), address_struct);
686
687            // Command only uses User directly
688            let param = create_param("user", "User", TypeStructure::Custom("User".to_string()));
689            let command = CommandInfo::new_for_test(
690                "greet",
691                "test.rs",
692                1,
693                vec![param],
694                "string",
695                false,
696                vec![],
697            );
698
699            let used = collector.collect_used_types(&[command], &[], &all_structs);
700
701            // Should include both User and Address
702            assert_eq!(used.len(), 2);
703            assert!(used.contains_key("User"));
704            assert!(used.contains_key("Address"));
705        }
706
707        #[test]
708        fn test_handles_deep_nesting() {
709            let collector = TypeCollector::new();
710            let mut all_structs = HashMap::new();
711
712            // A -> B -> C chain
713            let c_struct = create_struct_with_fields("C", vec![]);
714            let b_field = create_field("c", "C", TypeStructure::Custom("C".to_string()));
715            let b_struct = create_struct_with_fields("B", vec![b_field]);
716            let a_field = create_field("b", "B", TypeStructure::Custom("B".to_string()));
717            let a_struct = create_struct_with_fields("A", vec![a_field]);
718
719            all_structs.insert("A".to_string(), a_struct);
720            all_structs.insert("B".to_string(), b_struct);
721            all_structs.insert("C".to_string(), c_struct);
722
723            let param = create_param("data", "A", TypeStructure::Custom("A".to_string()));
724            let command = CommandInfo::new_for_test(
725                "process",
726                "test.rs",
727                1,
728                vec![param],
729                "void",
730                false,
731                vec![],
732            );
733
734            let used = collector.collect_used_types(&[command], &[], &all_structs);
735
736            // Should include A, B, and C
737            assert_eq!(used.len(), 3);
738            assert!(used.contains_key("A"));
739            assert!(used.contains_key("B"));
740            assert!(used.contains_key("C"));
741        }
742    }
743}