tauri_typegen/analysis/
channel_parser.rs

1use crate::analysis::type_resolver::TypeResolver;
2use crate::models::ChannelInfo;
3use std::path::Path;
4use syn::spanned::Spanned;
5use syn::{
6    AngleBracketedGenericArguments, FnArg, GenericArgument, ItemFn, PathArguments, PathSegment,
7    Type,
8};
9
10/// Parser for Tauri Channel parameters in command signatures
11#[derive(Debug)]
12pub struct ChannelParser;
13
14impl ChannelParser {
15    pub fn new() -> Self {
16        Self
17    }
18
19    /// Extract channel parameters from a command function signature
20    /// Looks for parameters with type `Channel<T>`, `tauri::ipc::Channel<T>`, etc.
21    pub fn extract_channels_from_command(
22        &self,
23        func: &ItemFn,
24        command_name: &str,
25        file_path: &Path,
26        type_resolver: &mut TypeResolver,
27    ) -> Result<Vec<ChannelInfo>, Box<dyn std::error::Error>> {
28        let mut channels = Vec::new();
29
30        // Iterate through function parameters
31        for input in &func.sig.inputs {
32            if let FnArg::Typed(pat_type) = input {
33                // Extract parameter name
34                let param_name = if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
35                    pat_ident.ident.to_string()
36                } else {
37                    continue;
38                };
39
40                // Check if this parameter is a Channel type
41                if let Some(message_type) = self.extract_channel_message_type(&pat_type.ty) {
42                    // Get line number from parameter span
43                    let line_number = pat_type.ty.span().start().line;
44
45                    // Parse message type into TypeStructure
46                    let message_type_structure = type_resolver.parse_type_structure(&message_type);
47
48                    channels.push(ChannelInfo {
49                        parameter_name: param_name,
50                        message_type: message_type.clone(),
51                        command_name: command_name.to_string(),
52                        file_path: file_path.to_string_lossy().to_string(),
53                        line_number,
54                        serde_rename: None,
55                        message_type_structure,
56                    });
57                }
58            }
59        }
60
61        Ok(channels)
62    }
63
64    /// Extract the message type T from Channel<T>
65    /// Returns Some(T) if the type is a Channel, None otherwise
66    fn extract_channel_message_type(&self, ty: &Type) -> Option<String> {
67        match ty {
68            Type::Path(type_path) => {
69                // Get the last segment of the path (e.g., "Channel" from "tauri::ipc::Channel")
70                let last_segment = type_path.path.segments.last()?;
71
72                // Check if this is a Channel type
73                if self.is_channel_segment(last_segment, &type_path.path.segments) {
74                    // Extract the generic argument T from Channel<T>
75                    if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
76                        args,
77                        ..
78                    }) = &last_segment.arguments
79                    {
80                        if let Some(GenericArgument::Type(inner_type)) = args.first() {
81                            return Some(Self::type_to_string(inner_type));
82                        }
83                    }
84                }
85                None
86            }
87            _ => None,
88        }
89    }
90
91    /// Check if a path segment represents a Channel type
92    /// Handles: Channel, tauri::ipc::Channel, tauri::Channel
93    fn is_channel_segment(
94        &self,
95        segment: &PathSegment,
96        all_segments: &syn::punctuated::Punctuated<PathSegment, syn::Token![::]>,
97    ) -> bool {
98        let segment_name = segment.ident.to_string();
99
100        // Must be named "Channel"
101        if segment_name != "Channel" {
102            return false;
103        }
104
105        // Accept bare "Channel" or qualified paths like "tauri::ipc::Channel"
106        if all_segments.len() == 1 {
107            // Bare "Channel"
108            return true;
109        }
110
111        // Check for tauri::* namespace
112        if all_segments.len() >= 2 {
113            let first = all_segments.first().unwrap().ident.to_string();
114            if first == "tauri" {
115                return true;
116            }
117        }
118
119        false
120    }
121
122    /// Convert a syn::Type to a string representation
123    /// Simplified version - handles common cases
124    fn type_to_string(ty: &Type) -> String {
125        match ty {
126            Type::Path(type_path) => {
127                let segments: Vec<String> = type_path
128                    .path
129                    .segments
130                    .iter()
131                    .map(|seg| {
132                        let ident = seg.ident.to_string();
133                        // Handle generic arguments if present
134                        if let PathArguments::AngleBracketed(args) = &seg.arguments {
135                            let generic_args: Vec<String> = args
136                                .args
137                                .iter()
138                                .filter_map(|arg| {
139                                    if let GenericArgument::Type(t) = arg {
140                                        Some(Self::type_to_string(t))
141                                    } else {
142                                        None
143                                    }
144                                })
145                                .collect();
146                            if !generic_args.is_empty() {
147                                return format!("{}<{}>", ident, generic_args.join(", "));
148                            }
149                        }
150                        ident
151                    })
152                    .collect();
153                segments.join("::")
154            }
155            Type::Reference(type_ref) => {
156                format!("&{}", Self::type_to_string(&type_ref.elem))
157            }
158            Type::Tuple(tuple) => {
159                if tuple.elems.is_empty() {
160                    "()".to_string()
161                } else {
162                    let types: Vec<String> = tuple.elems.iter().map(Self::type_to_string).collect();
163                    format!("({})", types.join(", "))
164                }
165            }
166            Type::Array(arr) => {
167                format!("[{}]", Self::type_to_string(&arr.elem))
168            }
169            _ => "unknown".to_string(),
170        }
171    }
172}
173
174impl Default for ChannelParser {
175    fn default() -> Self {
176        Self::new()
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use std::path::Path;
184    use syn::{parse_quote, ItemFn};
185
186    fn parser() -> ChannelParser {
187        ChannelParser::new()
188    }
189
190    fn type_resolver() -> TypeResolver {
191        TypeResolver::new()
192    }
193
194    mod channel_detection {
195        use super::*;
196
197        #[test]
198        fn test_detect_simple_channel() {
199            let parser = parser();
200            let ty: Type = parse_quote!(Channel<ProgressUpdate>);
201            let result = parser.extract_channel_message_type(&ty);
202            assert_eq!(result, Some("ProgressUpdate".to_string()));
203        }
204
205        #[test]
206        fn test_detect_qualified_channel() {
207            let parser = parser();
208            let ty: Type = parse_quote!(tauri::ipc::Channel<DownloadEvent>);
209            let result = parser.extract_channel_message_type(&ty);
210            assert_eq!(result, Some("DownloadEvent".to_string()));
211        }
212
213        #[test]
214        fn test_detect_tauri_channel() {
215            let parser = parser();
216            let ty: Type = parse_quote!(tauri::Channel<Message>);
217            let result = parser.extract_channel_message_type(&ty);
218            assert_eq!(result, Some("Message".to_string()));
219        }
220
221        #[test]
222        fn test_detect_channel_with_primitive() {
223            let parser = parser();
224            let ty: Type = parse_quote!(Channel<i32>);
225            let result = parser.extract_channel_message_type(&ty);
226            assert_eq!(result, Some("i32".to_string()));
227        }
228
229        #[test]
230        fn test_non_channel_type() {
231            let parser = parser();
232            let ty: Type = parse_quote!(String);
233            let result = parser.extract_channel_message_type(&ty);
234            assert_eq!(result, None);
235        }
236
237        #[test]
238        fn test_channel_with_complex_type() {
239            let parser = parser();
240            let ty: Type = parse_quote!(Channel<Vec<String>>);
241            let result = parser.extract_channel_message_type(&ty);
242            assert_eq!(result, Some("Vec<String>".to_string()));
243        }
244
245        #[test]
246        fn test_non_tauri_qualified_channel() {
247            let parser = parser();
248            let ty: Type = parse_quote!(my_lib::Channel<String>);
249            let result = parser.extract_channel_message_type(&ty);
250            // Should not match - not tauri namespace
251            assert_eq!(result, None);
252        }
253
254        #[test]
255        fn test_channel_without_generic() {
256            let parser = parser();
257            let ty: Type = parse_quote!(Channel);
258            let result = parser.extract_channel_message_type(&ty);
259            assert_eq!(result, None);
260        }
261
262        #[test]
263        fn test_other_generic_type() {
264            let parser = parser();
265            let ty: Type = parse_quote!(Handler<String>);
266            let result = parser.extract_channel_message_type(&ty);
267            assert_eq!(result, None);
268        }
269    }
270
271    mod type_to_string_conversion {
272        use super::*;
273
274        #[test]
275        fn test_simple_type() {
276            let ty: Type = parse_quote!(String);
277            assert_eq!(ChannelParser::type_to_string(&ty), "String");
278        }
279
280        #[test]
281        fn test_generic_type() {
282            let ty: Type = parse_quote!(Vec<i32>);
283            assert_eq!(ChannelParser::type_to_string(&ty), "Vec<i32>");
284        }
285
286        #[test]
287        fn test_option_type() {
288            let ty: Type = parse_quote!(Option<User>);
289            assert_eq!(ChannelParser::type_to_string(&ty), "Option<User>");
290        }
291
292        #[test]
293        fn test_nested_generic() {
294            let ty: Type = parse_quote!(Vec<Option<String>>);
295            assert_eq!(ChannelParser::type_to_string(&ty), "Vec<Option<String>>");
296        }
297
298        #[test]
299        fn test_multiple_generics() {
300            let ty: Type = parse_quote!(HashMap<String, i32>);
301            assert_eq!(ChannelParser::type_to_string(&ty), "HashMap<String, i32>");
302        }
303
304        #[test]
305        fn test_reference_type() {
306            let ty: Type = parse_quote!(&String);
307            assert_eq!(ChannelParser::type_to_string(&ty), "&String");
308        }
309
310        #[test]
311        fn test_reference_with_generic() {
312            let ty: Type = parse_quote!(&Vec<i32>);
313            assert_eq!(ChannelParser::type_to_string(&ty), "&Vec<i32>");
314        }
315
316        #[test]
317        fn test_tuple_type() {
318            let ty: Type = parse_quote!((String, i32));
319            assert_eq!(ChannelParser::type_to_string(&ty), "(String, i32)");
320        }
321
322        #[test]
323        fn test_empty_tuple() {
324            let ty: Type = parse_quote!(());
325            assert_eq!(ChannelParser::type_to_string(&ty), "()");
326        }
327
328        #[test]
329        fn test_slice_type() {
330            let ty: Type = parse_quote!(&[i32]);
331            // Slice types are references to arrays
332            assert_eq!(ChannelParser::type_to_string(&ty), "&unknown");
333        }
334
335        #[test]
336        fn test_tuple_with_multiple_elements() {
337            let ty: Type = parse_quote!((String, i32, bool));
338            assert_eq!(ChannelParser::type_to_string(&ty), "(String, i32, bool)");
339        }
340
341        #[test]
342        fn test_qualified_path() {
343            let ty: Type = parse_quote!(std::string::String);
344            assert_eq!(ChannelParser::type_to_string(&ty), "std::string::String");
345        }
346    }
347
348    mod extract_channels_from_command {
349        use super::*;
350
351        #[test]
352        fn test_extract_no_channels() {
353            let parser = parser();
354            let mut resolver = type_resolver();
355            let func: ItemFn = parse_quote! {
356                #[tauri::command]
357                fn greet(name: String) -> String {
358                    format!("Hello {}", name)
359                }
360            };
361
362            let result = parser.extract_channels_from_command(
363                &func,
364                "greet",
365                Path::new("test.rs"),
366                &mut resolver,
367            );
368
369            assert!(result.is_ok());
370            assert_eq!(result.unwrap().len(), 0);
371        }
372
373        #[test]
374        fn test_extract_single_channel() {
375            let parser = parser();
376            let mut resolver = type_resolver();
377            let func: ItemFn = parse_quote! {
378                #[tauri::command]
379                fn download(progress: Channel<ProgressUpdate>) {
380                    // implementation
381                }
382            };
383
384            let result = parser.extract_channels_from_command(
385                &func,
386                "download",
387                Path::new("test.rs"),
388                &mut resolver,
389            );
390
391            assert!(result.is_ok());
392            let channels = result.unwrap();
393            assert_eq!(channels.len(), 1);
394            assert_eq!(channels[0].parameter_name, "progress");
395            assert_eq!(channels[0].message_type, "ProgressUpdate");
396            assert_eq!(channels[0].command_name, "download");
397        }
398
399        #[test]
400        fn test_extract_multiple_channels() {
401            let parser = parser();
402            let mut resolver = type_resolver();
403            let func: ItemFn = parse_quote! {
404                #[tauri::command]
405                fn process(
406                    progress: Channel<Progress>,
407                    logs: Channel<LogEntry>,
408                ) {
409                    // implementation
410                }
411            };
412
413            let result = parser.extract_channels_from_command(
414                &func,
415                "process",
416                Path::new("test.rs"),
417                &mut resolver,
418            );
419
420            assert!(result.is_ok());
421            let channels = result.unwrap();
422            assert_eq!(channels.len(), 2);
423            assert_eq!(channels[0].parameter_name, "progress");
424            assert_eq!(channels[0].message_type, "Progress");
425            assert_eq!(channels[1].parameter_name, "logs");
426            assert_eq!(channels[1].message_type, "LogEntry");
427        }
428
429        #[test]
430        fn test_extract_channel_with_qualified_path() {
431            let parser = parser();
432            let mut resolver = type_resolver();
433            let func: ItemFn = parse_quote! {
434                #[tauri::command]
435                fn monitor(events: tauri::ipc::Channel<Event>) {
436                    // implementation
437                }
438            };
439
440            let result = parser.extract_channels_from_command(
441                &func,
442                "monitor",
443                Path::new("test.rs"),
444                &mut resolver,
445            );
446
447            assert!(result.is_ok());
448            let channels = result.unwrap();
449            assert_eq!(channels.len(), 1);
450            assert_eq!(channels[0].message_type, "Event");
451        }
452
453        #[test]
454        fn test_extract_mixed_parameters() {
455            let parser = parser();
456            let mut resolver = type_resolver();
457            let func: ItemFn = parse_quote! {
458                #[tauri::command]
459                fn process(
460                    name: String,
461                    progress: Channel<Progress>,
462                    count: i32,
463                ) {
464                    // implementation
465                }
466            };
467
468            let result = parser.extract_channels_from_command(
469                &func,
470                "process",
471                Path::new("test.rs"),
472                &mut resolver,
473            );
474
475            assert!(result.is_ok());
476            let channels = result.unwrap();
477            // Should only extract the channel, not other parameters
478            assert_eq!(channels.len(), 1);
479            assert_eq!(channels[0].parameter_name, "progress");
480        }
481
482        #[test]
483        fn test_channel_with_complex_message_type() {
484            let parser = parser();
485            let mut resolver = type_resolver();
486            let func: ItemFn = parse_quote! {
487                #[tauri::command]
488                fn stream(data: Channel<Vec<Option<String>>>) {
489                    // implementation
490                }
491            };
492
493            let result = parser.extract_channels_from_command(
494                &func,
495                "stream",
496                Path::new("test.rs"),
497                &mut resolver,
498            );
499
500            assert!(result.is_ok());
501            let channels = result.unwrap();
502            assert_eq!(channels.len(), 1);
503            assert_eq!(channels[0].message_type, "Vec<Option<String>>");
504        }
505    }
506
507    mod edge_cases {
508        use super::*;
509
510        #[test]
511        fn test_function_with_no_parameters() {
512            let parser = parser();
513            let mut resolver = type_resolver();
514            let func: ItemFn = parse_quote! {
515                #[tauri::command]
516                fn simple() -> String {
517                    "test".to_string()
518                }
519            };
520
521            let result = parser.extract_channels_from_command(
522                &func,
523                "simple",
524                Path::new("test.rs"),
525                &mut resolver,
526            );
527
528            assert!(result.is_ok());
529            assert_eq!(result.unwrap().len(), 0);
530        }
531
532        #[test]
533        fn test_self_parameter_ignored() {
534            let parser = parser();
535            let mut resolver = type_resolver();
536            let func: ItemFn = parse_quote! {
537                #[tauri::command]
538                fn method(&self, ch: Channel<Event>) {
539                    // implementation
540                }
541            };
542
543            let result = parser.extract_channels_from_command(
544                &func,
545                "method",
546                Path::new("test.rs"),
547                &mut resolver,
548            );
549
550            assert!(result.is_ok());
551            let channels = result.unwrap();
552            assert_eq!(channels.len(), 1);
553            assert_eq!(channels[0].parameter_name, "ch");
554        }
555    }
556}