tauri_typegen/generators/
mod.rs

1pub mod base;
2pub mod ts;
3pub mod zod;
4
5use crate::analysis::CommandAnalyzer;
6use crate::models::{CommandInfo, EventInfo, StructInfo};
7use crate::GenerateConfig;
8use base::template_context::{CommandContext, EventContext, FieldContext, StructContext};
9use base::type_visitor::TypeVisitor;
10use std::collections::HashMap;
11
12pub use base::templates::GlobalContext;
13pub use base::BaseBindingsGenerator as BindingsGenerator;
14pub use ts::generator::TypeScriptBindingsGenerator;
15pub use zod::generator::ZodBindingsGenerator;
16
17/// Macro to reduce boilerplate for template registration
18#[macro_export]
19macro_rules! template {
20    ($tera:expr, $name:expr, $path:expr) => {
21        $tera
22            .add_raw_template($name, include_str!($path))
23            .map_err(|e| format!("Failed to register {}: {}", $name, e))?;
24    };
25}
26
27/// Factory function to create the appropriate bindings generator
28/// Returns a boxed trait object for polymorphism
29pub fn create_generator(validation_library: Option<String>) -> Box<dyn BindingsGenerator> {
30    match validation_library.as_deref().unwrap_or("none") {
31        "zod" => Box::new(ZodBindingsGenerator::new()),
32        _ => Box::new(TypeScriptBindingsGenerator::new()),
33    }
34}
35
36/// Utility for collecting and organizing types for bindings generation
37///
38/// This struct provides filtering and transformation utilities that sit between
39/// the analysis phase (which produces TypeStructure) and the generation phase
40/// (which consumes filtered types and contexts). It acts as a one-stop-shop for
41/// filtering unused code and collecting only the types needed for generation.
42pub struct TypeCollector {
43    pub known_structs: HashMap<String, StructInfo>,
44}
45
46impl TypeCollector {
47    pub fn new() -> Self {
48        Self {
49            known_structs: HashMap::new(),
50        }
51    }
52
53    /// Filter only the types used by commands
54    pub fn collect_used_types(
55        &self,
56        commands: &[CommandInfo],
57        all_structs: &HashMap<String, StructInfo>,
58    ) -> HashMap<String, StructInfo> {
59        let mut used_types = std::collections::HashSet::new();
60
61        // Collect types from commands using structured TypeStructure
62        for command in commands {
63            // Add parameter types from type_structure
64            for param in &command.parameters {
65                Self::collect_referenced_types_from_structure(
66                    &param.type_structure,
67                    &mut used_types,
68                );
69            }
70            // Add return type from return_type_structure
71            Self::collect_referenced_types_from_structure(
72                &command.return_type_structure,
73                &mut used_types,
74            );
75            // Add channel message types from message_type_structure
76            for channel in &command.channels {
77                Self::collect_referenced_types_from_structure(
78                    &channel.message_type_structure,
79                    &mut used_types,
80                );
81            }
82        }
83
84        // Clone to avoid borrow checker issues
85        let initial_types = used_types.clone();
86
87        // Discover nested dependencies (types referenced by the collected types)
88        self.discover_nested_dependencies(&initial_types, all_structs, &mut used_types);
89
90        // Filter to only include used types
91        all_structs
92            .iter()
93            .filter(|(name, _)| used_types.contains(*name))
94            .map(|(k, v)| (k.clone(), v.clone()))
95            .collect()
96    }
97
98    /// Recursively discover nested dependencies
99    fn discover_nested_dependencies(
100        &self,
101        initial_types: &std::collections::HashSet<String>,
102        all_structs: &HashMap<String, StructInfo>,
103        all_types: &mut std::collections::HashSet<String>,
104    ) {
105        let mut to_process: Vec<String> = initial_types.iter().cloned().collect();
106        let mut processed: std::collections::HashSet<String> = std::collections::HashSet::new();
107
108        while let Some(type_name) = to_process.pop() {
109            if processed.contains(&type_name) {
110                continue;
111            }
112            processed.insert(type_name.clone());
113
114            if let Some(struct_info) = all_structs.get(&type_name) {
115                for field in &struct_info.fields {
116                    let mut nested_types = std::collections::HashSet::new();
117                    // Use type_structure to collect referenced types
118                    Self::collect_referenced_types_from_structure(
119                        &field.type_structure,
120                        &mut nested_types,
121                    );
122
123                    for nested_type in nested_types {
124                        if !all_types.contains(&nested_type)
125                            && all_structs.contains_key(&nested_type)
126                        {
127                            all_types.insert(nested_type.clone());
128                            to_process.push(nested_type);
129                        }
130                    }
131                }
132            }
133        }
134    }
135
136    /// Recursively collect custom type names from TypeStructure
137    /// Works directly with structured type information instead of string parsing
138    pub fn collect_referenced_types_from_structure(
139        type_structure: &crate::TypeStructure,
140        used_types: &mut std::collections::HashSet<String>,
141    ) {
142        use crate::TypeStructure;
143
144        match type_structure {
145            TypeStructure::Custom(name) => {
146                used_types.insert(name.clone());
147            }
148            TypeStructure::Array(inner)
149            | TypeStructure::Set(inner)
150            | TypeStructure::Optional(inner)
151            | TypeStructure::Result(inner) => {
152                Self::collect_referenced_types_from_structure(inner, used_types);
153            }
154            TypeStructure::Map { key, value } => {
155                Self::collect_referenced_types_from_structure(key, used_types);
156                Self::collect_referenced_types_from_structure(value, used_types);
157            }
158            TypeStructure::Tuple(types) => {
159                for t in types {
160                    Self::collect_referenced_types_from_structure(t, used_types);
161                }
162            }
163            TypeStructure::Primitive(_) => {
164                // Primitives are not custom types
165            }
166        }
167    }
168
169    /// Create CommandContext instances from CommandInfo using the provided visitor
170    pub fn create_command_contexts<V: TypeVisitor>(
171        &self,
172        commands: &[CommandInfo],
173        visitor: &V,
174        analyzer: &CommandAnalyzer,
175        config: &GenerateConfig,
176    ) -> Vec<CommandContext> {
177        let type_resolver = analyzer.get_type_resolver();
178
179        commands
180            .iter()
181            .map(|cmd| {
182                CommandContext::new(config).from_command_info(cmd, visitor, &|rust_type: &str| {
183                    type_resolver.borrow_mut().parse_type_structure(rust_type)
184                })
185            })
186            .collect()
187    }
188
189    /// Create EventContext instances from EventInfo using the provided visitor
190    pub fn create_event_contexts<V: TypeVisitor>(
191        &self,
192        events: &[EventInfo],
193        visitor: &V,
194        analyzer: &CommandAnalyzer,
195        config: &GenerateConfig,
196    ) -> Vec<EventContext> {
197        let type_resolver = analyzer.get_type_resolver();
198
199        events
200            .iter()
201            .map(|event| {
202                EventContext::new(config).from_event_info(event, visitor, &|rust_type: &str| {
203                    type_resolver.borrow_mut().parse_type_structure(rust_type)
204                })
205            })
206            .collect()
207    }
208
209    /// Create StructContext instances from StructInfo using the provided visitor
210    pub fn create_struct_contexts<V: TypeVisitor>(
211        &self,
212        used_structs: &HashMap<String, StructInfo>,
213        visitor: &V,
214        config: &GenerateConfig,
215    ) -> Vec<StructContext> {
216        used_structs
217            .iter()
218            .map(|(name, struct_info)| {
219                StructContext::new(config).from_struct_info(name, struct_info, visitor)
220            })
221            .collect()
222    }
223
224    /// Create FieldContext instances from StructInfo using the provided visitor
225    pub fn create_field_contexts<V: TypeVisitor>(
226        &self,
227        struct_info: &StructInfo,
228        visitor: &V,
229        config: &GenerateConfig,
230    ) -> Vec<FieldContext> {
231        struct_info
232            .fields
233            .iter()
234            .map(|field| {
235                FieldContext::new(config).from_field_info(
236                    field,
237                    &struct_info.serde_rename_all,
238                    visitor,
239                )
240            })
241            .collect()
242    }
243}
244
245impl Default for TypeCollector {
246    fn default() -> Self {
247        Self::new()
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use crate::TypeStructure;
255    use std::collections::HashSet;
256
257    mod factory {
258        use super::*;
259
260        #[test]
261        fn test_create_generator_zod() {
262            let gen = create_generator(Some("zod".to_string()));
263            // Just verify it creates without panic - we can't easily inspect trait objects
264            assert!(std::any::type_name_of_val(&gen).contains("Box"));
265        }
266
267        #[test]
268        fn test_create_generator_none() {
269            let gen = create_generator(Some("none".to_string()));
270            assert!(std::any::type_name_of_val(&gen).contains("Box"));
271        }
272
273        #[test]
274        fn test_create_generator_default() {
275            let gen = create_generator(None);
276            assert!(std::any::type_name_of_val(&gen).contains("Box"));
277        }
278
279        #[test]
280        fn test_create_generator_unknown_fallback() {
281            let gen = create_generator(Some("unknown".to_string()));
282            assert!(std::any::type_name_of_val(&gen).contains("Box"));
283        }
284    }
285
286    mod type_collector {
287        use super::*;
288
289        #[test]
290        fn test_new_creates_empty_collector() {
291            let collector = TypeCollector::new();
292            assert!(collector.known_structs.is_empty());
293        }
294
295        #[test]
296        fn test_default_creates_empty_collector() {
297            let collector = TypeCollector::default();
298            assert!(collector.known_structs.is_empty());
299        }
300    }
301
302    mod collect_referenced_types {
303        use super::*;
304
305        #[test]
306        fn test_collect_primitive() {
307            let mut used = HashSet::new();
308            let ts = TypeStructure::Primitive("string".to_string());
309            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
310            assert!(used.is_empty());
311        }
312
313        #[test]
314        fn test_collect_custom() {
315            let mut used = HashSet::new();
316            let ts = TypeStructure::Custom("User".to_string());
317            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
318            assert_eq!(used.len(), 1);
319            assert!(used.contains("User"));
320        }
321
322        #[test]
323        fn test_collect_array() {
324            let mut used = HashSet::new();
325            let ts = TypeStructure::Array(Box::new(TypeStructure::Custom("User".to_string())));
326            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
327            assert_eq!(used.len(), 1);
328            assert!(used.contains("User"));
329        }
330
331        #[test]
332        fn test_collect_optional() {
333            let mut used = HashSet::new();
334            let ts = TypeStructure::Optional(Box::new(TypeStructure::Custom("User".to_string())));
335            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
336            assert_eq!(used.len(), 1);
337            assert!(used.contains("User"));
338        }
339
340        #[test]
341        fn test_collect_result() {
342            let mut used = HashSet::new();
343            let ts = TypeStructure::Result(Box::new(TypeStructure::Custom("User".to_string())));
344            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
345            assert_eq!(used.len(), 1);
346            assert!(used.contains("User"));
347        }
348
349        #[test]
350        fn test_collect_set() {
351            let mut used = HashSet::new();
352            let ts = TypeStructure::Set(Box::new(TypeStructure::Custom("User".to_string())));
353            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
354            assert_eq!(used.len(), 1);
355            assert!(used.contains("User"));
356        }
357
358        #[test]
359        fn test_collect_map() {
360            let mut used = HashSet::new();
361            let ts = TypeStructure::Map {
362                key: Box::new(TypeStructure::Primitive("string".to_string())),
363                value: Box::new(TypeStructure::Custom("User".to_string())),
364            };
365            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
366            assert_eq!(used.len(), 1);
367            assert!(used.contains("User"));
368        }
369
370        #[test]
371        fn test_collect_map_both_custom() {
372            let mut used = HashSet::new();
373            let ts = TypeStructure::Map {
374                key: Box::new(TypeStructure::Custom("UserId".to_string())),
375                value: Box::new(TypeStructure::Custom("User".to_string())),
376            };
377            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
378            assert_eq!(used.len(), 2);
379            assert!(used.contains("User"));
380            assert!(used.contains("UserId"));
381        }
382
383        #[test]
384        fn test_collect_tuple() {
385            let mut used = HashSet::new();
386            let ts = TypeStructure::Tuple(vec![
387                TypeStructure::Custom("User".to_string()),
388                TypeStructure::Custom("Product".to_string()),
389            ]);
390            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
391            assert_eq!(used.len(), 2);
392            assert!(used.contains("User"));
393            assert!(used.contains("Product"));
394        }
395
396        #[test]
397        fn test_collect_nested() {
398            let mut used = HashSet::new();
399            let ts = TypeStructure::Array(Box::new(TypeStructure::Optional(Box::new(
400                TypeStructure::Custom("User".to_string()),
401            ))));
402            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
403            assert_eq!(used.len(), 1);
404            assert!(used.contains("User"));
405        }
406
407        #[test]
408        fn test_collect_multiple_calls_accumulate() {
409            let mut used = HashSet::new();
410            let ts1 = TypeStructure::Custom("User".to_string());
411            let ts2 = TypeStructure::Custom("Product".to_string());
412            TypeCollector::collect_referenced_types_from_structure(&ts1, &mut used);
413            TypeCollector::collect_referenced_types_from_structure(&ts2, &mut used);
414            assert_eq!(used.len(), 2);
415        }
416
417        #[test]
418        fn test_collect_duplicates_deduped() {
419            let mut used = HashSet::new();
420            let ts = TypeStructure::Custom("User".to_string());
421            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
422            TypeCollector::collect_referenced_types_from_structure(&ts, &mut used);
423            assert_eq!(used.len(), 1);
424        }
425    }
426
427    mod collect_used_types {
428        use super::*;
429        use crate::models::{CommandInfo, ParameterInfo, StructInfo};
430
431        fn create_struct(name: &str) -> StructInfo {
432            StructInfo {
433                name: name.to_string(),
434                fields: vec![],
435                file_path: "test.rs".to_string(),
436                is_enum: false,
437                serde_rename_all: None,
438            }
439        }
440
441        fn create_param(
442            name: &str,
443            rust_type: &str,
444            type_structure: TypeStructure,
445        ) -> ParameterInfo {
446            ParameterInfo {
447                name: name.to_string(),
448                rust_type: rust_type.to_string(),
449                is_optional: false,
450                type_structure,
451                serde_rename: None,
452            }
453        }
454
455        #[test]
456        fn test_collect_from_empty_commands() {
457            let collector = TypeCollector::new();
458            let commands = vec![];
459            let all_structs = HashMap::new();
460            let used = collector.collect_used_types(&commands, &all_structs);
461            assert!(used.is_empty());
462        }
463
464        #[test]
465        fn test_collect_from_command_parameters() {
466            let collector = TypeCollector::new();
467            let mut all_structs = HashMap::new();
468            let user_struct = create_struct("User");
469            all_structs.insert("User".to_string(), user_struct.clone());
470
471            let param = create_param("user", "User", TypeStructure::Custom("User".to_string()));
472            let command = CommandInfo::new_for_test(
473                "greet",
474                "test.rs",
475                1,
476                vec![param],
477                "string",
478                false,
479                vec![],
480            );
481
482            let used = collector.collect_used_types(&[command], &all_structs);
483            assert_eq!(used.len(), 1);
484            assert!(used.contains_key("User"));
485        }
486
487        #[test]
488        fn test_collect_from_command_return_type() {
489            let collector = TypeCollector::new();
490            let mut all_structs = HashMap::new();
491            let result_struct = create_struct("ApiResult");
492            all_structs.insert("ApiResult".to_string(), result_struct.clone());
493
494            // Create command that returns ApiResult
495            let mut command = CommandInfo::new_for_test(
496                "fetch_data",
497                "test.rs",
498                1,
499                vec![],
500                "ApiResult",
501                false,
502                vec![],
503            );
504            // Set the return_type_structure
505            command.return_type_structure = TypeStructure::Custom("ApiResult".to_string());
506
507            let used = collector.collect_used_types(&[command], &all_structs);
508            assert_eq!(used.len(), 1);
509            assert!(used.contains_key("ApiResult"));
510        }
511
512        #[test]
513        fn test_filters_unused_types() {
514            let collector = TypeCollector::new();
515            let mut all_structs = HashMap::new();
516
517            // Add two structs but only use one
518            let user_struct = create_struct("User");
519            let product_struct = create_struct("Product");
520            all_structs.insert("User".to_string(), user_struct);
521            all_structs.insert("Product".to_string(), product_struct);
522
523            let param = create_param("user", "User", TypeStructure::Custom("User".to_string()));
524            let command = CommandInfo::new_for_test(
525                "greet",
526                "test.rs",
527                1,
528                vec![param],
529                "string",
530                false,
531                vec![],
532            );
533
534            let used = collector.collect_used_types(&[command], &all_structs);
535            assert_eq!(used.len(), 1);
536            assert!(used.contains_key("User"));
537            assert!(!used.contains_key("Product"));
538        }
539    }
540
541    mod nested_dependencies {
542        use super::*;
543        use crate::models::{CommandInfo, FieldInfo, ParameterInfo, StructInfo};
544
545        fn create_field(name: &str, rust_type: &str, type_structure: TypeStructure) -> FieldInfo {
546            FieldInfo {
547                name: name.to_string(),
548                rust_type: rust_type.to_string(),
549                is_optional: false,
550                is_public: true,
551                validator_attributes: None,
552                serde_rename: None,
553                type_structure,
554            }
555        }
556
557        fn create_struct_with_fields(name: &str, fields: Vec<FieldInfo>) -> StructInfo {
558            StructInfo {
559                name: name.to_string(),
560                fields,
561                file_path: "test.rs".to_string(),
562                is_enum: false,
563                serde_rename_all: None,
564            }
565        }
566
567        fn create_param(
568            name: &str,
569            rust_type: &str,
570            type_structure: TypeStructure,
571        ) -> ParameterInfo {
572            ParameterInfo {
573                name: name.to_string(),
574                rust_type: rust_type.to_string(),
575                is_optional: false,
576                type_structure,
577                serde_rename: None,
578            }
579        }
580
581        #[test]
582        fn test_discovers_nested_dependencies() {
583            let collector = TypeCollector::new();
584            let mut all_structs = HashMap::new();
585
586            // User has a field of type Address
587            let address_field = create_field(
588                "address",
589                "Address",
590                TypeStructure::Custom("Address".to_string()),
591            );
592            let user_struct = create_struct_with_fields("User", vec![address_field]);
593            let address_struct = create_struct_with_fields("Address", vec![]);
594
595            all_structs.insert("User".to_string(), user_struct);
596            all_structs.insert("Address".to_string(), address_struct);
597
598            // Command only uses User directly
599            let param = create_param("user", "User", TypeStructure::Custom("User".to_string()));
600            let command = CommandInfo::new_for_test(
601                "greet",
602                "test.rs",
603                1,
604                vec![param],
605                "string",
606                false,
607                vec![],
608            );
609
610            let used = collector.collect_used_types(&[command], &all_structs);
611
612            // Should include both User and Address
613            assert_eq!(used.len(), 2);
614            assert!(used.contains_key("User"));
615            assert!(used.contains_key("Address"));
616        }
617
618        #[test]
619        fn test_handles_deep_nesting() {
620            let collector = TypeCollector::new();
621            let mut all_structs = HashMap::new();
622
623            // A -> B -> C chain
624            let c_struct = create_struct_with_fields("C", vec![]);
625            let b_field = create_field("c", "C", TypeStructure::Custom("C".to_string()));
626            let b_struct = create_struct_with_fields("B", vec![b_field]);
627            let a_field = create_field("b", "B", TypeStructure::Custom("B".to_string()));
628            let a_struct = create_struct_with_fields("A", vec![a_field]);
629
630            all_structs.insert("A".to_string(), a_struct);
631            all_structs.insert("B".to_string(), b_struct);
632            all_structs.insert("C".to_string(), c_struct);
633
634            let param = create_param("data", "A", TypeStructure::Custom("A".to_string()));
635            let command = CommandInfo::new_for_test(
636                "process",
637                "test.rs",
638                1,
639                vec![param],
640                "void",
641                false,
642                vec![],
643            );
644
645            let used = collector.collect_used_types(&[command], &all_structs);
646
647            // Should include A, B, and C
648            assert_eq!(used.len(), 3);
649            assert!(used.contains_key("A"));
650            assert!(used.contains_key("B"));
651            assert!(used.contains_key("C"));
652        }
653    }
654}