tauri_typegen/analysis/
command_parser.rs

1use crate::analysis::serde_parser::SerdeParser;
2use crate::analysis::type_resolver::TypeResolver;
3use crate::models::{CommandInfo, ParameterInfo};
4use std::path::Path;
5use syn::{File as SynFile, FnArg, ItemFn, PatType, ReturnType, Type};
6
7/// Parser for Tauri command functions
8#[derive(Debug)]
9pub struct CommandParser {
10    serde_parser: SerdeParser,
11}
12
13impl CommandParser {
14    pub fn new() -> Self {
15        Self {
16            serde_parser: SerdeParser::new(),
17        }
18    }
19
20    /// Extract commands from a cached AST
21    pub fn extract_commands_from_ast(
22        &self,
23        ast: &SynFile,
24        file_path: &Path,
25        type_resolver: &mut TypeResolver,
26    ) -> Result<Vec<CommandInfo>, Box<dyn std::error::Error>> {
27        let commands = ast
28            .items
29            .iter()
30            .filter_map(|item| {
31                if let syn::Item::Fn(func) = item {
32                    if self.is_tauri_command(func) {
33                        return self.extract_command_info(func, file_path, type_resolver);
34                    }
35                }
36                None
37            })
38            .collect();
39
40        Ok(commands)
41    }
42
43    /// Check if a function is a Tauri command
44    fn is_tauri_command(&self, func: &ItemFn) -> bool {
45        func.attrs.iter().any(|attr| {
46            attr.path().segments.len() == 2
47                && attr.path().segments[0].ident == "tauri"
48                && attr.path().segments[1].ident == "command"
49                || attr.path().is_ident("command")
50        })
51    }
52
53    /// Extract command information from a function
54    fn extract_command_info(
55        &self,
56        func: &ItemFn,
57        file_path: &Path,
58        type_resolver: &mut TypeResolver,
59    ) -> Option<CommandInfo> {
60        let name = func.sig.ident.to_string();
61
62        let parameters = self.extract_parameters(&func.sig.inputs, type_resolver);
63        let return_type = self.extract_return_type(&func.sig.output);
64        let return_type_structure = type_resolver.parse_type_structure(&return_type);
65        let is_async = func.sig.asyncness.is_some();
66
67        // Get line number from the function's span
68        let line_number = func.sig.ident.span().start().line;
69
70        // Parse serde rename_all attribute from function attributes
71        let serde_rename_all = self
72            .serde_parser
73            .parse_struct_serde_attrs(&func.attrs)
74            .rename_all;
75
76        Some(CommandInfo {
77            name,
78            parameters,
79            return_type,
80            return_type_structure,
81            file_path: file_path.to_string_lossy().to_string(),
82            line_number,
83            is_async,
84            channels: Vec::new(), // Will be populated by channel_parser
85            serde_rename_all,
86        })
87    }
88
89    /// Extract parameters from function signature
90    fn extract_parameters(
91        &self,
92        inputs: &syn::punctuated::Punctuated<FnArg, syn::token::Comma>,
93        type_resolver: &mut TypeResolver,
94    ) -> Vec<ParameterInfo> {
95        inputs
96            .iter()
97            .filter_map(|input| {
98                if let FnArg::Typed(PatType { pat, ty, attrs, .. }) = input {
99                    if let syn::Pat::Ident(pat_ident) = pat.as_ref() {
100                        let name = pat_ident.ident.to_string();
101
102                        // Skip Tauri-specific parameters
103                        if self.is_tauri_parameter_type(ty) {
104                            return None;
105                        }
106
107                        let rust_type = Self::type_to_string(ty);
108                        let type_structure = type_resolver.parse_type_structure(&rust_type);
109                        let is_optional = self.is_optional_type(ty);
110
111                        // Parse serde rename attribute from parameter attributes
112                        let serde_rename = self.serde_parser.parse_field_serde_attrs(attrs).rename;
113
114                        return Some(ParameterInfo {
115                            name,
116                            rust_type,
117                            is_optional,
118                            type_structure,
119                            serde_rename,
120                        });
121                    }
122                }
123                None
124            })
125            .collect()
126    }
127
128    /// Check if a parameter type is a Tauri-specific type that should be skipped
129    /// This checks the actual syn::Type to properly handle both imported and fully-qualified types
130    fn is_tauri_parameter_type(&self, ty: &Type) -> bool {
131        if let Type::Path(type_path) = ty {
132            let segments = &type_path.path.segments;
133
134            // Check various patterns:
135            // 1. Fully qualified: tauri::AppHandle, tauri::State<T>, tauri::ipc::Request
136            // 2. Imported: AppHandle, State<T>, Window<T>
137            if segments.len() >= 2 {
138                // Check for tauri::* or tauri::ipc::*
139                if segments[0].ident == "tauri" {
140                    if segments.len() == 2 {
141                        // tauri::AppHandle, tauri::Window, etc.
142                        let second = &segments[1].ident;
143                        return second == "AppHandle"
144                            || second == "Window"
145                            || second == "WebviewWindow"
146                            || second == "State"
147                            || second == "Manager";
148                    } else if segments.len() == 3 && segments[1].ident == "ipc" {
149                        // tauri::ipc::Request, tauri::ipc::Channel
150                        let third = &segments[2].ident;
151                        return third == "Request" || third == "Channel";
152                    }
153                }
154            }
155
156            // Check for imported types (single segment)
157            if let Some(last_segment) = segments.last() {
158                let type_ident = &last_segment.ident;
159
160                // Only match specific Tauri types that are commonly imported
161                // Be careful not to match user types with similar names
162                if type_ident == "AppHandle" || type_ident == "WebviewWindow" {
163                    return true;
164                }
165
166                // Channel should be filtered if it has generic parameters (indicating it's the Tauri IPC channel)
167                if type_ident == "Channel"
168                    && matches!(
169                        last_segment.arguments,
170                        syn::PathArguments::AngleBracketed(_)
171                    )
172                {
173                    return true;
174                }
175
176                // State and Window are common names, only match if they have generic params
177                // (Tauri's State and Window types always have generics like State<T>, Window<R>)
178                if (type_ident == "State" || type_ident == "Window")
179                    && !last_segment.arguments.is_empty()
180                {
181                    return true;
182                }
183            }
184        }
185
186        false
187    }
188
189    /// Extract return type from function signature - returns rust_type only
190    fn extract_return_type(&self, output: &ReturnType) -> String {
191        match output {
192            ReturnType::Default => "()".to_string(),
193            ReturnType::Type(_, ty) => Self::type_to_string(ty),
194        }
195    }
196
197    /// Convert a Type to its string representation
198    fn type_to_string(ty: &Type) -> String {
199        match ty {
200            Type::Path(type_path) => {
201                let segments: Vec<String> = type_path
202                    .path
203                    .segments
204                    .iter()
205                    .map(|segment| {
206                        if segment.arguments.is_empty() {
207                            segment.ident.to_string()
208                        } else {
209                            match &segment.arguments {
210                                syn::PathArguments::AngleBracketed(args) => {
211                                    let inner_types: Vec<String> = args
212                                        .args
213                                        .iter()
214                                        .filter_map(|arg| {
215                                            if let syn::GenericArgument::Type(inner_ty) = arg {
216                                                Some(Self::type_to_string(inner_ty))
217                                            } else {
218                                                None
219                                            }
220                                        })
221                                        .collect();
222                                    format!("{}<{}>", segment.ident, inner_types.join(", "))
223                                }
224                                _ => segment.ident.to_string(),
225                            }
226                        }
227                    })
228                    .collect();
229                segments.join("::")
230            }
231            Type::Reference(type_ref) => {
232                format!("&{}", Self::type_to_string(&type_ref.elem))
233            }
234            Type::Tuple(type_tuple) => {
235                if type_tuple.elems.is_empty() {
236                    "()".to_string()
237                } else {
238                    let types: Vec<String> =
239                        type_tuple.elems.iter().map(Self::type_to_string).collect();
240                    format!("({})", types.join(", "))
241                }
242            }
243            _ => "unknown".to_string(),
244        }
245    }
246
247    /// Check if a type is Option<T>
248    fn is_optional_type(&self, ty: &Type) -> bool {
249        if let Type::Path(type_path) = ty {
250            if let Some(segment) = type_path.path.segments.last() {
251                return segment.ident == "Option";
252            }
253        }
254        false
255    }
256}
257
258impl Default for CommandParser {
259    fn default() -> Self {
260        Self::new()
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use syn::parse_quote;
268
269    #[test]
270    fn test_new_command_parser() {
271        let parser = CommandParser::new();
272        // Just verify it constructs without panicking
273        let _ = parser;
274    }
275
276    #[test]
277    fn test_default_impl() {
278        let parser = CommandParser::default();
279        // Just verify default works
280        let _ = parser;
281    }
282
283    // is_tauri_command tests
284    mod is_tauri_command {
285        use super::*;
286
287        #[test]
288        fn test_recognizes_tauri_command_attribute() {
289            let parser = CommandParser::new();
290            let func: ItemFn = parse_quote! {
291                #[tauri::command]
292                fn greet(name: String) -> String {
293                    format!("Hello, {}!", name)
294                }
295            };
296
297            assert!(parser.is_tauri_command(&func));
298        }
299
300        #[test]
301        fn test_recognizes_command_attribute() {
302            let parser = CommandParser::new();
303            let func: ItemFn = parse_quote! {
304                #[command]
305                fn greet(name: String) -> String {
306                    format!("Hello, {}!", name)
307                }
308            };
309
310            assert!(parser.is_tauri_command(&func));
311        }
312
313        #[test]
314        fn test_rejects_non_command_function() {
315            let parser = CommandParser::new();
316            let func: ItemFn = parse_quote! {
317                fn greet(name: String) -> String {
318                    format!("Hello, {}!", name)
319                }
320            };
321
322            assert!(!parser.is_tauri_command(&func));
323        }
324
325        #[test]
326        fn test_rejects_other_attributes() {
327            let parser = CommandParser::new();
328            let func: ItemFn = parse_quote! {
329                #[derive(Debug)]
330                fn greet(name: String) -> String {
331                    format!("Hello, {}!", name)
332                }
333            };
334
335            assert!(!parser.is_tauri_command(&func));
336        }
337    }
338
339    // type_to_string tests
340    mod type_to_string {
341        use super::*;
342
343        #[test]
344        fn test_simple_type() {
345            let ty: Type = parse_quote!(String);
346            assert_eq!(CommandParser::type_to_string(&ty), "String");
347        }
348
349        #[test]
350        fn test_generic_type() {
351            let ty: Type = parse_quote!(Vec<String>);
352            assert_eq!(CommandParser::type_to_string(&ty), "Vec<String>");
353        }
354
355        #[test]
356        fn test_nested_generic() {
357            let ty: Type = parse_quote!(Vec<Option<String>>);
358            assert_eq!(CommandParser::type_to_string(&ty), "Vec<Option<String>>");
359        }
360
361        #[test]
362        fn test_multiple_generics() {
363            let ty: Type = parse_quote!(HashMap<String, i32>);
364            assert_eq!(CommandParser::type_to_string(&ty), "HashMap<String, i32>");
365        }
366
367        #[test]
368        fn test_reference_type() {
369            let ty: Type = parse_quote!(&str);
370            assert_eq!(CommandParser::type_to_string(&ty), "&str");
371        }
372
373        #[test]
374        fn test_empty_tuple() {
375            let ty: Type = parse_quote!(());
376            assert_eq!(CommandParser::type_to_string(&ty), "()");
377        }
378
379        #[test]
380        fn test_tuple_with_elements() {
381            let ty: Type = parse_quote!((String, i32));
382            assert_eq!(CommandParser::type_to_string(&ty), "(String, i32)");
383        }
384
385        #[test]
386        fn test_qualified_path() {
387            let ty: Type = parse_quote!(std::collections::HashMap<String, i32>);
388            assert_eq!(
389                CommandParser::type_to_string(&ty),
390                "std::collections::HashMap<String, i32>"
391            );
392        }
393    }
394
395    // is_optional_type tests
396    mod is_optional_type {
397        use super::*;
398
399        #[test]
400        fn test_recognizes_option() {
401            let parser = CommandParser::new();
402            let ty: Type = parse_quote!(Option<String>);
403            assert!(parser.is_optional_type(&ty));
404        }
405
406        #[test]
407        fn test_recognizes_nested_option() {
408            let parser = CommandParser::new();
409            let ty: Type = parse_quote!(Option<Vec<String>>);
410            assert!(parser.is_optional_type(&ty));
411        }
412
413        #[test]
414        fn test_rejects_non_option() {
415            let parser = CommandParser::new();
416            let ty: Type = parse_quote!(String);
417            assert!(!parser.is_optional_type(&ty));
418        }
419
420        #[test]
421        fn test_rejects_vec() {
422            let parser = CommandParser::new();
423            let ty: Type = parse_quote!(Vec<String>);
424            assert!(!parser.is_optional_type(&ty));
425        }
426    }
427
428    // is_tauri_parameter_type tests
429    mod is_tauri_parameter_type {
430        use super::*;
431
432        #[test]
433        fn test_recognizes_app_handle() {
434            let parser = CommandParser::new();
435            let ty: Type = parse_quote!(tauri::AppHandle);
436            assert!(parser.is_tauri_parameter_type(&ty));
437        }
438
439        #[test]
440        fn test_recognizes_imported_app_handle() {
441            let parser = CommandParser::new();
442            let ty: Type = parse_quote!(AppHandle);
443            assert!(parser.is_tauri_parameter_type(&ty));
444        }
445
446        #[test]
447        fn test_recognizes_window_with_generics() {
448            let parser = CommandParser::new();
449            let ty: Type = parse_quote!(Window<R>);
450            assert!(parser.is_tauri_parameter_type(&ty));
451        }
452
453        #[test]
454        fn test_recognizes_state_with_generics() {
455            let parser = CommandParser::new();
456            let ty: Type = parse_quote!(State<AppState>);
457            assert!(parser.is_tauri_parameter_type(&ty));
458        }
459
460        #[test]
461        fn test_recognizes_webview_window() {
462            let parser = CommandParser::new();
463            let ty: Type = parse_quote!(tauri::WebviewWindow);
464            assert!(parser.is_tauri_parameter_type(&ty));
465        }
466
467        #[test]
468        fn test_recognizes_imported_webview_window() {
469            let parser = CommandParser::new();
470            let ty: Type = parse_quote!(WebviewWindow);
471            assert!(parser.is_tauri_parameter_type(&ty));
472        }
473
474        #[test]
475        fn test_recognizes_ipc_request() {
476            let parser = CommandParser::new();
477            let ty: Type = parse_quote!(tauri::ipc::Request);
478            assert!(parser.is_tauri_parameter_type(&ty));
479        }
480
481        #[test]
482        fn test_recognizes_ipc_channel() {
483            let parser = CommandParser::new();
484            let ty: Type = parse_quote!(tauri::ipc::Channel<String>);
485            assert!(parser.is_tauri_parameter_type(&ty));
486        }
487
488        #[test]
489        fn test_recognizes_channel_with_generics() {
490            let parser = CommandParser::new();
491            let ty: Type = parse_quote!(Channel<ProgressUpdate>);
492            assert!(parser.is_tauri_parameter_type(&ty));
493        }
494
495        #[test]
496        fn test_rejects_user_string_type() {
497            let parser = CommandParser::new();
498            let ty: Type = parse_quote!(String);
499            assert!(!parser.is_tauri_parameter_type(&ty));
500        }
501
502        #[test]
503        fn test_rejects_user_custom_type() {
504            let parser = CommandParser::new();
505            let ty: Type = parse_quote!(User);
506            assert!(!parser.is_tauri_parameter_type(&ty));
507        }
508
509        #[test]
510        fn test_rejects_state_without_generics() {
511            let parser = CommandParser::new();
512            // User might have their own State type without generics
513            let ty: Type = parse_quote!(State);
514            assert!(!parser.is_tauri_parameter_type(&ty));
515        }
516
517        #[test]
518        fn test_rejects_window_without_generics() {
519            let parser = CommandParser::new();
520            // User might have their own Window type without generics
521            let ty: Type = parse_quote!(Window);
522            assert!(!parser.is_tauri_parameter_type(&ty));
523        }
524    }
525
526    // extract_return_type tests
527    mod extract_return_type {
528        use super::*;
529
530        #[test]
531        fn test_extract_simple_return() {
532            let parser = CommandParser::new();
533            let output: ReturnType = parse_quote!(-> String);
534            assert_eq!(parser.extract_return_type(&output), "String");
535        }
536
537        #[test]
538        fn test_extract_generic_return() {
539            let parser = CommandParser::new();
540            let output: ReturnType = parse_quote!(-> Vec<String>);
541            assert_eq!(parser.extract_return_type(&output), "Vec<String>");
542        }
543
544        #[test]
545        fn test_extract_result_return() {
546            let parser = CommandParser::new();
547            let output: ReturnType = parse_quote!(-> Result<String, Error>);
548            assert_eq!(parser.extract_return_type(&output), "Result<String, Error>");
549        }
550
551        #[test]
552        fn test_extract_default_return() {
553            let parser = CommandParser::new();
554            let output: ReturnType = parse_quote!();
555            assert_eq!(parser.extract_return_type(&output), "()");
556        }
557    }
558
559    // extract_parameters tests
560    mod extract_parameters {
561        use super::*;
562
563        #[test]
564        fn test_extract_simple_parameter() {
565            let parser = CommandParser::new();
566            let mut type_resolver = TypeResolver::new();
567            let inputs = parse_quote!(name: String);
568
569            let params = parser.extract_parameters(&inputs, &mut type_resolver);
570
571            assert_eq!(params.len(), 1);
572            assert_eq!(params[0].name, "name");
573            assert_eq!(params[0].rust_type, "String");
574            assert!(!params[0].is_optional);
575        }
576
577        #[test]
578        fn test_extract_optional_parameter() {
579            let parser = CommandParser::new();
580            let mut type_resolver = TypeResolver::new();
581            let inputs = parse_quote!(email: Option<String>);
582
583            let params = parser.extract_parameters(&inputs, &mut type_resolver);
584
585            assert_eq!(params.len(), 1);
586            assert_eq!(params[0].name, "email");
587            assert!(params[0].is_optional);
588        }
589
590        #[test]
591        fn test_extract_multiple_parameters() {
592            let parser = CommandParser::new();
593            let mut type_resolver = TypeResolver::new();
594            let inputs = parse_quote!(name: String, age: i32);
595
596            let params = parser.extract_parameters(&inputs, &mut type_resolver);
597
598            assert_eq!(params.len(), 2);
599            assert_eq!(params[0].name, "name");
600            assert_eq!(params[1].name, "age");
601        }
602
603        #[test]
604        fn test_filters_app_handle() {
605            let parser = CommandParser::new();
606            let mut type_resolver = TypeResolver::new();
607            let inputs = parse_quote!(app: AppHandle, name: String);
608
609            let params = parser.extract_parameters(&inputs, &mut type_resolver);
610
611            // AppHandle should be filtered out
612            assert_eq!(params.len(), 1);
613            assert_eq!(params[0].name, "name");
614        }
615
616        #[test]
617        fn test_filters_state() {
618            let parser = CommandParser::new();
619            let mut type_resolver = TypeResolver::new();
620            let inputs = parse_quote!(state: State<AppState>, name: String);
621
622            let params = parser.extract_parameters(&inputs, &mut type_resolver);
623
624            // State should be filtered out
625            assert_eq!(params.len(), 1);
626            assert_eq!(params[0].name, "name");
627        }
628
629        #[test]
630        fn test_filters_channel() {
631            let parser = CommandParser::new();
632            let mut type_resolver = TypeResolver::new();
633            let inputs = parse_quote!(progress: Channel<u32>, name: String);
634
635            let params = parser.extract_parameters(&inputs, &mut type_resolver);
636
637            // Channel should be filtered out
638            assert_eq!(params.len(), 1);
639            assert_eq!(params[0].name, "name");
640        }
641
642        #[test]
643        fn test_empty_parameters() {
644            let parser = CommandParser::new();
645            let mut type_resolver = TypeResolver::new();
646            let inputs = parse_quote!();
647
648            let params = parser.extract_parameters(&inputs, &mut type_resolver);
649
650            assert_eq!(params.len(), 0);
651        }
652    }
653
654    // extract_command_info tests
655    mod extract_command_info {
656        use super::*;
657        use std::path::PathBuf;
658
659        #[test]
660        fn test_extract_simple_command() {
661            let parser = CommandParser::new();
662            let mut type_resolver = TypeResolver::new();
663            let func: ItemFn = parse_quote! {
664                #[tauri::command]
665                fn greet(name: String) -> String {
666                    format!("Hello, {}!", name)
667                }
668            };
669            let path = PathBuf::from("test.rs");
670
671            let info = parser.extract_command_info(&func, &path, &mut type_resolver);
672
673            assert!(info.is_some());
674            let info = info.unwrap();
675            assert_eq!(info.name, "greet");
676            assert_eq!(info.parameters.len(), 1);
677            assert_eq!(info.return_type, "String");
678            assert!(!info.is_async);
679        }
680
681        #[test]
682        fn test_extract_async_command() {
683            let parser = CommandParser::new();
684            let mut type_resolver = TypeResolver::new();
685            let func: ItemFn = parse_quote! {
686                #[tauri::command]
687                async fn fetch_data() -> Result<String, Error> {
688                    Ok("data".to_string())
689                }
690            };
691            let path = PathBuf::from("test.rs");
692
693            let info = parser.extract_command_info(&func, &path, &mut type_resolver);
694
695            assert!(info.is_some());
696            let info = info.unwrap();
697            assert!(info.is_async);
698            assert_eq!(info.return_type, "Result<String, Error>");
699        }
700
701        #[test]
702        fn test_extract_command_with_no_return() {
703            let parser = CommandParser::new();
704            let mut type_resolver = TypeResolver::new();
705            let func: ItemFn = parse_quote! {
706                #[tauri::command]
707                fn log_message(msg: String) {
708                    println!("{}", msg);
709                }
710            };
711            let path = PathBuf::from("test.rs");
712
713            let info = parser.extract_command_info(&func, &path, &mut type_resolver);
714
715            assert!(info.is_some());
716            let info = info.unwrap();
717            assert_eq!(info.return_type, "()");
718        }
719    }
720}