Skip to main content

swift_demangler/
helpers.rs

1//! Extension traits for [`Node`] and [`NodeKind`] with shared helper methods.
2//!
3//! This module provides:
4//! - [`NodeKindExt`] trait that extends [`NodeKind`] with classification predicates
5//! - [`NodeExt`] trait that extends [`Node`] with utility methods
6
7use crate::raw::{Node, NodeKind};
8use crate::types::{FunctionType, GenericRequirement, GenericSignature, TypeRef};
9
10/// Extension trait adding classification methods to [`NodeKind`].
11pub(crate) trait NodeKindExt {
12    /// Check if this kind represents a function type.
13    fn is_function_type(&self) -> bool;
14
15    /// Check if this kind represents a type context (class, struct, enum, protocol, extension).
16    fn is_type_context(&self) -> bool;
17}
18
19impl NodeKindExt for NodeKind {
20    #[inline]
21    fn is_function_type(&self) -> bool {
22        matches!(
23            self,
24            NodeKind::FunctionType
25                | NodeKind::NoEscapeFunctionType
26                | NodeKind::CFunctionPointer
27                | NodeKind::ThinFunctionType
28                | NodeKind::ImplFunctionType
29                | NodeKind::UncurriedFunctionType
30        )
31    }
32
33    #[inline]
34    fn is_type_context(&self) -> bool {
35        matches!(
36            self,
37            NodeKind::Class
38                | NodeKind::Structure
39                | NodeKind::Enum
40                | NodeKind::Protocol
41                | NodeKind::Extension
42        )
43    }
44}
45
46/// Extension trait adding helper methods to [`Node`].
47pub(crate) trait NodeExt<'ctx> {
48    /// Find the module name from this node's children and descendants.
49    ///
50    /// Searches for a [`NodeKind::Module`] node in the following order:
51    /// 1. Direct [`NodeKind::Module`] child
52    /// 2. [`NodeKind::Module`] inside type context children ([`NodeKind::Class`]/[`NodeKind::Structure`]/[`NodeKind::Enum`]/[`NodeKind::Protocol`]/[`NodeKind::Extension`]/[`NodeKind::TypeAlias`])
53    fn find_module(&self) -> Option<&'ctx str>;
54
55    /// Find module by searching all descendants.
56    ///
57    /// Use this as a fallback when [`find_module`](NodeExt::find_module) doesn't find a result.
58    fn find_module_in_descendants(&self) -> Option<&'ctx str>;
59
60    /// Find the generic signature from this node's children.
61    ///
62    /// Searches for [`NodeKind::DependentGenericSignature`] in:
63    /// 1. [`NodeKind::Type`] -> [`NodeKind::DependentGenericType`] -> [`NodeKind::DependentGenericSignature`]
64    /// 2. [`NodeKind::Extension`] -> [`NodeKind::DependentGenericSignature`] (for constrained extensions)
65    fn find_generic_signature(&self) -> Option<GenericSignature<'ctx>>;
66
67    /// Find a function type from this node's children.
68    ///
69    /// Searches for function types ([`NodeKind::FunctionType`], [`NodeKind::NoEscapeFunctionType`], etc.) in:
70    /// 1. [`NodeKind::Type`] -> [`NodeKind::FunctionType`] (direct)
71    /// 2. [`NodeKind::Type`] -> [`NodeKind::DependentGenericType`] -> [`NodeKind::Type`] -> [`NodeKind::FunctionType`] (generic functions)
72    fn find_function_type(&self) -> Option<FunctionType<'ctx>>;
73
74    /// Extract argument labels from this node's [`NodeKind::LabelList`] child.
75    ///
76    /// Returns a vector where each element is `Some(label)` for labeled parameters
77    /// and `None` for unlabeled parameters (using `_`).
78    fn extract_labels(&self) -> Vec<Option<&'ctx str>>;
79
80    /// Find the first [`NodeKind::Identifier`] child and return its text.
81    fn find_identifier(&self) -> Option<&'ctx str>;
82
83    /// Find identifier, also checking [`NodeKind::LocalDeclName`] and [`NodeKind::PrivateDeclName`] wrappers.
84    ///
85    /// This is useful for function names which may be wrapped in these nodes.
86    fn find_identifier_extended(&self) -> Option<&'ctx str>;
87
88    /// Find the containing type name from this node's children.
89    ///
90    /// Searches for [`NodeKind::Class`]/[`NodeKind::Structure`]/[`NodeKind::Enum`]/[`NodeKind::Protocol`] children and returns
91    /// the [`NodeKind::Identifier`] inside. For [`NodeKind::Extension`], looks inside for the extended type.
92    fn find_containing_type(&self) -> Option<&'ctx str>;
93
94    /// Check if the containing type is a class (reference type).
95    fn containing_type_is_class(&self) -> bool;
96
97    /// Check if the containing type is a protocol.
98    fn containing_type_is_protocol(&self) -> bool;
99
100    /// Check if this node has a type context child ([`NodeKind::Class`]/[`NodeKind::Structure`]/[`NodeKind::Enum`]/[`NodeKind::Protocol`]/[`NodeKind::Extension`]).
101    fn has_type_context(&self) -> bool;
102
103    /// Find a [`NodeKind::Type`] child and extract its inner type as a [`TypeRef`].
104    ///
105    /// This handles the common pattern of [`NodeKind::Type`] -> inner type.
106    fn extract_type_ref(&self) -> Option<TypeRef<'ctx>>;
107
108    /// Find the first child with the given kind.
109    fn child_of_kind(&self, kind: NodeKind) -> Option<Node<'ctx>>;
110
111    /// If this node's kind matches, return its first child.
112    ///
113    /// Useful for unwrapping single-child wrapper nodes like [`NodeKind::Type`].
114    /// Returns `None` if the kind doesn't match or the node has no children.
115    fn unwrap_if_kind(&self, kind: NodeKind) -> Option<Node<'ctx>>;
116}
117
118impl<'ctx> NodeExt<'ctx> for Node<'ctx> {
119    fn find_module(&self) -> Option<&'ctx str> {
120        // Direct module child
121        if let Some(module) = self.child_of_kind(NodeKind::Module) {
122            return module.text();
123        }
124        // Module inside a type context
125        for child in self.children() {
126            match child.kind() {
127                NodeKind::Class
128                | NodeKind::Structure
129                | NodeKind::Enum
130                | NodeKind::Protocol
131                | NodeKind::Extension
132                | NodeKind::TypeAlias => {
133                    for inner in child.descendants() {
134                        if inner.kind() == NodeKind::Module {
135                            return inner.text();
136                        }
137                    }
138                }
139                _ => {}
140            }
141        }
142        None
143    }
144
145    fn find_module_in_descendants(&self) -> Option<&'ctx str> {
146        for desc in self.descendants() {
147            if desc.kind() == NodeKind::Module {
148                return desc.text();
149            }
150        }
151        None
152    }
153
154    fn find_generic_signature(&self) -> Option<GenericSignature<'ctx>> {
155        // First, look for the symbol's own generic signature in Type -> DependentGenericType
156        for child in self.children() {
157            if let Some(inner) = child.unwrap_if_kind(NodeKind::Type)
158                && inner.kind() == NodeKind::DependentGenericType
159                && let Some(sig) = inner.child_of_kind(NodeKind::DependentGenericSignature)
160            {
161                return Some(GenericSignature::new(sig));
162            }
163        }
164        // Fall back to constrained extension: Extension -> DependentGenericSignature
165        // (only if the symbol itself doesn't have its own generic signature)
166        if let Some(ext) = self.child_of_kind(NodeKind::Extension)
167            && let Some(sig) = ext.child_of_kind(NodeKind::DependentGenericSignature)
168        {
169            return Some(GenericSignature::new(sig));
170        }
171        None
172    }
173
174    fn find_function_type(&self) -> Option<FunctionType<'ctx>> {
175        for child in self.children() {
176            if let Some(inner) = child.unwrap_if_kind(NodeKind::Type) {
177                // Check for direct function type
178                if inner.kind().is_function_type() {
179                    return Some(FunctionType::new(inner));
180                }
181                // For generic functions, the function type is wrapped in DependentGenericType
182                if inner.kind() == NodeKind::DependentGenericType {
183                    for dep_child in inner.children() {
184                        if let Some(func_type) = dep_child.unwrap_if_kind(NodeKind::Type)
185                            && func_type.kind().is_function_type()
186                        {
187                            return Some(FunctionType::new(func_type));
188                        }
189                    }
190                }
191            }
192        }
193        None
194    }
195
196    fn extract_labels(&self) -> Vec<Option<&'ctx str>> {
197        for child in self.children() {
198            if child.kind() == NodeKind::LabelList {
199                return child
200                    .children()
201                    .map(|label_node| {
202                        if label_node.kind() == NodeKind::Identifier {
203                            label_node.text()
204                        } else if label_node.kind() == NodeKind::FirstElementMarker {
205                            None // Represents `_`
206                        } else {
207                            label_node.text()
208                        }
209                    })
210                    .collect();
211            }
212        }
213        Vec::new()
214    }
215
216    fn find_identifier(&self) -> Option<&'ctx str> {
217        for child in self.children() {
218            if child.kind() == NodeKind::Identifier {
219                return child.text();
220            }
221        }
222        None
223    }
224
225    fn find_identifier_extended(&self) -> Option<&'ctx str> {
226        // If this node itself is an Identifier, return its text
227        if self.kind() == NodeKind::Identifier {
228            return self.text();
229        }
230        for child in self.children() {
231            if child.kind() == NodeKind::Identifier {
232                return child.text();
233            }
234            // Check for LocalDeclName (for local functions)
235            if child.kind() == NodeKind::LocalDeclName {
236                for inner in child.children() {
237                    if inner.kind() == NodeKind::Identifier {
238                        return inner.text();
239                    }
240                }
241            }
242            // Check for PrivateDeclName
243            if child.kind() == NodeKind::PrivateDeclName {
244                for inner in child.children() {
245                    if inner.kind() == NodeKind::Identifier {
246                        return inner.text();
247                    }
248                }
249            }
250        }
251        None
252    }
253
254    fn find_containing_type(&self) -> Option<&'ctx str> {
255        for child in self.children() {
256            match child.kind() {
257                NodeKind::Class | NodeKind::Structure | NodeKind::Enum | NodeKind::Protocol => {
258                    for inner in child.children() {
259                        if inner.kind() == NodeKind::Identifier {
260                            return inner.text();
261                        }
262                    }
263                }
264                NodeKind::Extension => {
265                    // Extension wraps the extended type (e.g., Extension -> Protocol -> Identifier)
266                    for inner in child.children() {
267                        match inner.kind() {
268                            NodeKind::Class
269                            | NodeKind::Structure
270                            | NodeKind::Enum
271                            | NodeKind::Protocol => {
272                                for id in inner.children() {
273                                    if id.kind() == NodeKind::Identifier {
274                                        return id.text();
275                                    }
276                                }
277                            }
278                            _ => {}
279                        }
280                    }
281                }
282                _ => {}
283            }
284        }
285        None
286    }
287
288    fn containing_type_is_class(&self) -> bool {
289        for child in self.children() {
290            match child.kind() {
291                NodeKind::Class => return true,
292                NodeKind::Structure | NodeKind::Enum | NodeKind::Protocol => return false,
293                NodeKind::Extension => {
294                    for inner in child.children() {
295                        match inner.kind() {
296                            NodeKind::Class => return true,
297                            NodeKind::Structure | NodeKind::Enum | NodeKind::Protocol => {
298                                return false;
299                            }
300                            _ => {}
301                        }
302                    }
303                }
304                _ => {}
305            }
306        }
307        false
308    }
309
310    fn containing_type_is_protocol(&self) -> bool {
311        for child in self.children() {
312            match child.kind() {
313                NodeKind::Protocol => return true,
314                NodeKind::Class | NodeKind::Structure | NodeKind::Enum => return false,
315                NodeKind::Extension => {
316                    for inner in child.children() {
317                        match inner.kind() {
318                            NodeKind::Protocol => return true,
319                            NodeKind::Class | NodeKind::Structure | NodeKind::Enum => {
320                                return false;
321                            }
322                            _ => {}
323                        }
324                    }
325                }
326                _ => {}
327            }
328        }
329        false
330    }
331
332    fn has_type_context(&self) -> bool {
333        self.children().any(|c| c.kind().is_type_context())
334    }
335
336    fn extract_type_ref(&self) -> Option<TypeRef<'ctx>> {
337        let type_node = self.child_of_kind(NodeKind::Type)?;
338        Some(TypeRef::new(type_node.child(0).unwrap_or(type_node)))
339    }
340
341    fn child_of_kind(&self, kind: NodeKind) -> Option<Node<'ctx>> {
342        self.children().find(|c| c.kind() == kind)
343    }
344
345    fn unwrap_if_kind(&self, kind: NodeKind) -> Option<Node<'ctx>> {
346        if self.kind() == kind {
347            self.child(0)
348        } else {
349            None
350        }
351    }
352}
353
354/// Trait for types that may have a generic signature.
355///
356/// This provides a uniform interface for accessing generic constraints
357/// on functions, constructors, closures, accessors, and enum cases.
358pub trait HasGenericSignature<'ctx> {
359    /// Get the generic signature if this symbol has generic constraints.
360    fn generic_signature(&self) -> Option<GenericSignature<'ctx>>;
361
362    /// Get the generic requirements (constraints) for this symbol.
363    ///
364    /// This is a convenience method that extracts just the requirements.
365    fn generic_requirements(&self) -> Vec<GenericRequirement<'ctx>> {
366        self.generic_signature()
367            .map(|sig| sig.requirements())
368            .unwrap_or_default()
369    }
370
371    /// Check if this symbol is generic.
372    fn is_generic(&self) -> bool {
373        self.generic_signature().is_some()
374    }
375}
376
377/// Trait for types that have a function signature.
378///
379/// This provides a uniform interface for accessing function type information
380/// on [`Function`](crate::Function)s, [`Constructor`](crate::Constructor)s, and [`Closure`](crate::Closure)s.
381pub trait HasFunctionSignature<'ctx> {
382    /// Get the function signature (type).
383    fn signature(&self) -> Option<FunctionType<'ctx>>;
384
385
386    /// Check if this function is async.
387    fn is_async(&self) -> bool {
388        self.signature().map(|s| s.is_async()).unwrap_or(false)
389    }
390
391    /// Check if this function throws.
392    fn is_throwing(&self) -> bool {
393        self.signature().map(|s| s.is_throwing()).unwrap_or(false)
394    }
395}
396
397/// Trait for types that can be defined in an extension.
398///
399/// This provides a uniform interface for accessing extension context information
400/// on functions, constructors, accessors, and closures.
401pub trait HasExtensionContext<'ctx> {
402    /// Get the raw node for this symbol.
403    fn raw(&self) -> Node<'ctx>;
404
405    /// Check if this symbol is defined in an extension.
406    fn is_extension(&self) -> bool {
407        self.raw()
408            .children()
409            .any(|c| c.kind() == NodeKind::Extension)
410    }
411
412    /// Get the module where the extension is defined, if this is an extension member.
413    fn extension_module(&self) -> Option<&'ctx str> {
414        for child in self.raw().children() {
415            if child.kind() == NodeKind::Extension
416                && let Some(module) = child.child_of_kind(NodeKind::Module)
417            {
418                return module.text();
419            }
420        }
421        None
422    }
423
424    /// Get the generic signature from the extension context, if any.
425    ///
426    /// This is separate from [`HasGenericSignature::generic_signature`] which returns the
427    /// symbol's own generic constraints. Extension constraints define
428    /// when the extension applies (e.g., `extension Array where Element: Comparable`).
429    fn extension_generic_signature(&self) -> Option<GenericSignature<'ctx>> {
430        for child in self.raw().children() {
431            if child.kind() == NodeKind::Extension
432                && let Some(sig) = child.child_of_kind(NodeKind::DependentGenericSignature)
433            {
434                return Some(GenericSignature::new(sig));
435            }
436        }
437        None
438    }
439
440    /// Get the generic requirements from the extension context, if any.
441    fn extension_generic_requirements(&self) -> Vec<GenericRequirement<'ctx>> {
442        self.extension_generic_signature()
443            .map(|sig| sig.requirements())
444            .unwrap_or_default()
445    }
446}
447
448/// Trait for types that can provide their containing module.
449pub trait HasModule<'ctx> {
450    /// Get the module name where this symbol is defined.
451    fn module(&self) -> Option<&'ctx str>;
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457
458    #[test]
459    fn test_is_function_type() {
460        assert!(NodeKind::FunctionType.is_function_type());
461        assert!(NodeKind::NoEscapeFunctionType.is_function_type());
462        assert!(NodeKind::ThinFunctionType.is_function_type());
463        assert!(!NodeKind::Module.is_function_type());
464        assert!(!NodeKind::Class.is_function_type());
465    }
466
467    #[test]
468    fn test_is_type_context() {
469        assert!(NodeKind::Class.is_type_context());
470        assert!(NodeKind::Structure.is_type_context());
471        assert!(NodeKind::Enum.is_type_context());
472        assert!(NodeKind::Protocol.is_type_context());
473        assert!(NodeKind::Extension.is_type_context());
474        assert!(!NodeKind::Module.is_type_context());
475        assert!(!NodeKind::FunctionType.is_type_context());
476    }
477}