swift_demangler/helpers.rs
1//! Extension traits for [`Node`] and [`NodeKind`] with shared helper methods.
2//!
3//! This module provides:
4//! - [`NodeKindExt`] trait that extends [`NodeKind`] with classification predicates
5//! - [`NodeExt`] trait that extends [`Node`] with utility methods
6
7use crate::raw::{Node, NodeKind};
8use crate::types::{FunctionType, GenericRequirement, GenericSignature, TypeRef};
9
10/// Extension trait adding classification methods to [`NodeKind`].
11pub(crate) trait NodeKindExt {
12 /// Check if this kind represents a function type.
13 fn is_function_type(&self) -> bool;
14
15 /// Check if this kind represents a type context (class, struct, enum, protocol, extension).
16 fn is_type_context(&self) -> bool;
17}
18
19impl NodeKindExt for NodeKind {
20 #[inline]
21 fn is_function_type(&self) -> bool {
22 matches!(
23 self,
24 NodeKind::FunctionType
25 | NodeKind::NoEscapeFunctionType
26 | NodeKind::CFunctionPointer
27 | NodeKind::ThinFunctionType
28 | NodeKind::ImplFunctionType
29 | NodeKind::UncurriedFunctionType
30 )
31 }
32
33 #[inline]
34 fn is_type_context(&self) -> bool {
35 matches!(
36 self,
37 NodeKind::Class
38 | NodeKind::Structure
39 | NodeKind::Enum
40 | NodeKind::Protocol
41 | NodeKind::Extension
42 )
43 }
44}
45
46/// Extension trait adding helper methods to [`Node`].
47pub(crate) trait NodeExt<'ctx> {
48 /// Find the module name from this node's children and descendants.
49 ///
50 /// Searches for a [`NodeKind::Module`] node in the following order:
51 /// 1. Direct [`NodeKind::Module`] child
52 /// 2. [`NodeKind::Module`] inside type context children ([`NodeKind::Class`]/[`NodeKind::Structure`]/[`NodeKind::Enum`]/[`NodeKind::Protocol`]/[`NodeKind::Extension`]/[`NodeKind::TypeAlias`])
53 fn find_module(&self) -> Option<&'ctx str>;
54
55 /// Find module by searching all descendants.
56 ///
57 /// Use this as a fallback when [`find_module`](NodeExt::find_module) doesn't find a result.
58 fn find_module_in_descendants(&self) -> Option<&'ctx str>;
59
60 /// Find the generic signature from this node's children.
61 ///
62 /// Searches for [`NodeKind::DependentGenericSignature`] in:
63 /// 1. [`NodeKind::Type`] -> [`NodeKind::DependentGenericType`] -> [`NodeKind::DependentGenericSignature`]
64 /// 2. [`NodeKind::Extension`] -> [`NodeKind::DependentGenericSignature`] (for constrained extensions)
65 fn find_generic_signature(&self) -> Option<GenericSignature<'ctx>>;
66
67 /// Find a function type from this node's children.
68 ///
69 /// Searches for function types ([`NodeKind::FunctionType`], [`NodeKind::NoEscapeFunctionType`], etc.) in:
70 /// 1. [`NodeKind::Type`] -> [`NodeKind::FunctionType`] (direct)
71 /// 2. [`NodeKind::Type`] -> [`NodeKind::DependentGenericType`] -> [`NodeKind::Type`] -> [`NodeKind::FunctionType`] (generic functions)
72 fn find_function_type(&self) -> Option<FunctionType<'ctx>>;
73
74 /// Extract argument labels from this node's [`NodeKind::LabelList`] child.
75 ///
76 /// Returns a vector where each element is `Some(label)` for labeled parameters
77 /// and `None` for unlabeled parameters (using `_`).
78 fn extract_labels(&self) -> Vec<Option<&'ctx str>>;
79
80 /// Find the first [`NodeKind::Identifier`] child and return its text.
81 fn find_identifier(&self) -> Option<&'ctx str>;
82
83 /// Find identifier, also checking [`NodeKind::LocalDeclName`] and [`NodeKind::PrivateDeclName`] wrappers.
84 ///
85 /// This is useful for function names which may be wrapped in these nodes.
86 fn find_identifier_extended(&self) -> Option<&'ctx str>;
87
88 /// Find the containing type name from this node's children.
89 ///
90 /// Searches for [`NodeKind::Class`]/[`NodeKind::Structure`]/[`NodeKind::Enum`]/[`NodeKind::Protocol`] children and returns
91 /// the [`NodeKind::Identifier`] inside. For [`NodeKind::Extension`], looks inside for the extended type.
92 fn find_containing_type(&self) -> Option<&'ctx str>;
93
94 /// Check if the containing type is a class (reference type).
95 fn containing_type_is_class(&self) -> bool;
96
97 /// Check if the containing type is a protocol.
98 fn containing_type_is_protocol(&self) -> bool;
99
100 /// Check if this node has a type context child ([`NodeKind::Class`]/[`NodeKind::Structure`]/[`NodeKind::Enum`]/[`NodeKind::Protocol`]/[`NodeKind::Extension`]).
101 fn has_type_context(&self) -> bool;
102
103 /// Find a [`NodeKind::Type`] child and extract its inner type as a [`TypeRef`].
104 ///
105 /// This handles the common pattern of [`NodeKind::Type`] -> inner type.
106 fn extract_type_ref(&self) -> Option<TypeRef<'ctx>>;
107
108 /// Find the first child with the given kind.
109 fn child_of_kind(&self, kind: NodeKind) -> Option<Node<'ctx>>;
110
111 /// If this node's kind matches, return its first child.
112 ///
113 /// Useful for unwrapping single-child wrapper nodes like [`NodeKind::Type`].
114 /// Returns `None` if the kind doesn't match or the node has no children.
115 fn unwrap_if_kind(&self, kind: NodeKind) -> Option<Node<'ctx>>;
116}
117
118impl<'ctx> NodeExt<'ctx> for Node<'ctx> {
119 fn find_module(&self) -> Option<&'ctx str> {
120 // Direct module child
121 if let Some(module) = self.child_of_kind(NodeKind::Module) {
122 return module.text();
123 }
124 // Module inside a type context
125 for child in self.children() {
126 match child.kind() {
127 NodeKind::Class
128 | NodeKind::Structure
129 | NodeKind::Enum
130 | NodeKind::Protocol
131 | NodeKind::Extension
132 | NodeKind::TypeAlias => {
133 for inner in child.descendants() {
134 if inner.kind() == NodeKind::Module {
135 return inner.text();
136 }
137 }
138 }
139 _ => {}
140 }
141 }
142 None
143 }
144
145 fn find_module_in_descendants(&self) -> Option<&'ctx str> {
146 for desc in self.descendants() {
147 if desc.kind() == NodeKind::Module {
148 return desc.text();
149 }
150 }
151 None
152 }
153
154 fn find_generic_signature(&self) -> Option<GenericSignature<'ctx>> {
155 // First, look for the symbol's own generic signature in Type -> DependentGenericType
156 for child in self.children() {
157 if let Some(inner) = child.unwrap_if_kind(NodeKind::Type)
158 && inner.kind() == NodeKind::DependentGenericType
159 && let Some(sig) = inner.child_of_kind(NodeKind::DependentGenericSignature)
160 {
161 return Some(GenericSignature::new(sig));
162 }
163 }
164 // Fall back to constrained extension: Extension -> DependentGenericSignature
165 // (only if the symbol itself doesn't have its own generic signature)
166 if let Some(ext) = self.child_of_kind(NodeKind::Extension)
167 && let Some(sig) = ext.child_of_kind(NodeKind::DependentGenericSignature)
168 {
169 return Some(GenericSignature::new(sig));
170 }
171 None
172 }
173
174 fn find_function_type(&self) -> Option<FunctionType<'ctx>> {
175 for child in self.children() {
176 if let Some(inner) = child.unwrap_if_kind(NodeKind::Type) {
177 // Check for direct function type
178 if inner.kind().is_function_type() {
179 return Some(FunctionType::new(inner));
180 }
181 // For generic functions, the function type is wrapped in DependentGenericType
182 if inner.kind() == NodeKind::DependentGenericType {
183 for dep_child in inner.children() {
184 if let Some(func_type) = dep_child.unwrap_if_kind(NodeKind::Type)
185 && func_type.kind().is_function_type()
186 {
187 return Some(FunctionType::new(func_type));
188 }
189 }
190 }
191 }
192 }
193 None
194 }
195
196 fn extract_labels(&self) -> Vec<Option<&'ctx str>> {
197 for child in self.children() {
198 if child.kind() == NodeKind::LabelList {
199 return child
200 .children()
201 .map(|label_node| {
202 if label_node.kind() == NodeKind::Identifier {
203 label_node.text()
204 } else if label_node.kind() == NodeKind::FirstElementMarker {
205 None // Represents `_`
206 } else {
207 label_node.text()
208 }
209 })
210 .collect();
211 }
212 }
213 Vec::new()
214 }
215
216 fn find_identifier(&self) -> Option<&'ctx str> {
217 for child in self.children() {
218 if child.kind() == NodeKind::Identifier {
219 return child.text();
220 }
221 }
222 None
223 }
224
225 fn find_identifier_extended(&self) -> Option<&'ctx str> {
226 // If this node itself is an Identifier, return its text
227 if self.kind() == NodeKind::Identifier {
228 return self.text();
229 }
230 for child in self.children() {
231 if child.kind() == NodeKind::Identifier {
232 return child.text();
233 }
234 // Check for LocalDeclName (for local functions)
235 if child.kind() == NodeKind::LocalDeclName {
236 for inner in child.children() {
237 if inner.kind() == NodeKind::Identifier {
238 return inner.text();
239 }
240 }
241 }
242 // Check for PrivateDeclName
243 if child.kind() == NodeKind::PrivateDeclName {
244 for inner in child.children() {
245 if inner.kind() == NodeKind::Identifier {
246 return inner.text();
247 }
248 }
249 }
250 }
251 None
252 }
253
254 fn find_containing_type(&self) -> Option<&'ctx str> {
255 for child in self.children() {
256 match child.kind() {
257 NodeKind::Class | NodeKind::Structure | NodeKind::Enum | NodeKind::Protocol => {
258 for inner in child.children() {
259 if inner.kind() == NodeKind::Identifier {
260 return inner.text();
261 }
262 }
263 }
264 NodeKind::Extension => {
265 // Extension wraps the extended type (e.g., Extension -> Protocol -> Identifier)
266 for inner in child.children() {
267 match inner.kind() {
268 NodeKind::Class
269 | NodeKind::Structure
270 | NodeKind::Enum
271 | NodeKind::Protocol => {
272 for id in inner.children() {
273 if id.kind() == NodeKind::Identifier {
274 return id.text();
275 }
276 }
277 }
278 _ => {}
279 }
280 }
281 }
282 _ => {}
283 }
284 }
285 None
286 }
287
288 fn containing_type_is_class(&self) -> bool {
289 for child in self.children() {
290 match child.kind() {
291 NodeKind::Class => return true,
292 NodeKind::Structure | NodeKind::Enum | NodeKind::Protocol => return false,
293 NodeKind::Extension => {
294 for inner in child.children() {
295 match inner.kind() {
296 NodeKind::Class => return true,
297 NodeKind::Structure | NodeKind::Enum | NodeKind::Protocol => {
298 return false;
299 }
300 _ => {}
301 }
302 }
303 }
304 _ => {}
305 }
306 }
307 false
308 }
309
310 fn containing_type_is_protocol(&self) -> bool {
311 for child in self.children() {
312 match child.kind() {
313 NodeKind::Protocol => return true,
314 NodeKind::Class | NodeKind::Structure | NodeKind::Enum => return false,
315 NodeKind::Extension => {
316 for inner in child.children() {
317 match inner.kind() {
318 NodeKind::Protocol => return true,
319 NodeKind::Class | NodeKind::Structure | NodeKind::Enum => {
320 return false;
321 }
322 _ => {}
323 }
324 }
325 }
326 _ => {}
327 }
328 }
329 false
330 }
331
332 fn has_type_context(&self) -> bool {
333 self.children().any(|c| c.kind().is_type_context())
334 }
335
336 fn extract_type_ref(&self) -> Option<TypeRef<'ctx>> {
337 let type_node = self.child_of_kind(NodeKind::Type)?;
338 Some(TypeRef::new(type_node.child(0).unwrap_or(type_node)))
339 }
340
341 fn child_of_kind(&self, kind: NodeKind) -> Option<Node<'ctx>> {
342 self.children().find(|c| c.kind() == kind)
343 }
344
345 fn unwrap_if_kind(&self, kind: NodeKind) -> Option<Node<'ctx>> {
346 if self.kind() == kind {
347 self.child(0)
348 } else {
349 None
350 }
351 }
352}
353
354/// Trait for types that may have a generic signature.
355///
356/// This provides a uniform interface for accessing generic constraints
357/// on functions, constructors, closures, accessors, and enum cases.
358pub trait HasGenericSignature<'ctx> {
359 /// Get the generic signature if this symbol has generic constraints.
360 fn generic_signature(&self) -> Option<GenericSignature<'ctx>>;
361
362 /// Get the generic requirements (constraints) for this symbol.
363 ///
364 /// This is a convenience method that extracts just the requirements.
365 fn generic_requirements(&self) -> Vec<GenericRequirement<'ctx>> {
366 self.generic_signature()
367 .map(|sig| sig.requirements())
368 .unwrap_or_default()
369 }
370
371 /// Check if this symbol is generic.
372 fn is_generic(&self) -> bool {
373 self.generic_signature().is_some()
374 }
375}
376
377/// Trait for types that have a function signature.
378///
379/// This provides a uniform interface for accessing function type information
380/// on [`Function`](crate::Function)s, [`Constructor`](crate::Constructor)s, and [`Closure`](crate::Closure)s.
381pub trait HasFunctionSignature<'ctx> {
382 /// Get the function signature (type).
383 fn signature(&self) -> Option<FunctionType<'ctx>>;
384
385
386 /// Check if this function is async.
387 fn is_async(&self) -> bool {
388 self.signature().map(|s| s.is_async()).unwrap_or(false)
389 }
390
391 /// Check if this function throws.
392 fn is_throwing(&self) -> bool {
393 self.signature().map(|s| s.is_throwing()).unwrap_or(false)
394 }
395}
396
397/// Trait for types that can be defined in an extension.
398///
399/// This provides a uniform interface for accessing extension context information
400/// on functions, constructors, accessors, and closures.
401pub trait HasExtensionContext<'ctx> {
402 /// Get the raw node for this symbol.
403 fn raw(&self) -> Node<'ctx>;
404
405 /// Check if this symbol is defined in an extension.
406 fn is_extension(&self) -> bool {
407 self.raw()
408 .children()
409 .any(|c| c.kind() == NodeKind::Extension)
410 }
411
412 /// Get the module where the extension is defined, if this is an extension member.
413 fn extension_module(&self) -> Option<&'ctx str> {
414 for child in self.raw().children() {
415 if child.kind() == NodeKind::Extension
416 && let Some(module) = child.child_of_kind(NodeKind::Module)
417 {
418 return module.text();
419 }
420 }
421 None
422 }
423
424 /// Get the generic signature from the extension context, if any.
425 ///
426 /// This is separate from [`HasGenericSignature::generic_signature`] which returns the
427 /// symbol's own generic constraints. Extension constraints define
428 /// when the extension applies (e.g., `extension Array where Element: Comparable`).
429 fn extension_generic_signature(&self) -> Option<GenericSignature<'ctx>> {
430 for child in self.raw().children() {
431 if child.kind() == NodeKind::Extension
432 && let Some(sig) = child.child_of_kind(NodeKind::DependentGenericSignature)
433 {
434 return Some(GenericSignature::new(sig));
435 }
436 }
437 None
438 }
439
440 /// Get the generic requirements from the extension context, if any.
441 fn extension_generic_requirements(&self) -> Vec<GenericRequirement<'ctx>> {
442 self.extension_generic_signature()
443 .map(|sig| sig.requirements())
444 .unwrap_or_default()
445 }
446}
447
448/// Trait for types that can provide their containing module.
449pub trait HasModule<'ctx> {
450 /// Get the module name where this symbol is defined.
451 fn module(&self) -> Option<&'ctx str>;
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457
458 #[test]
459 fn test_is_function_type() {
460 assert!(NodeKind::FunctionType.is_function_type());
461 assert!(NodeKind::NoEscapeFunctionType.is_function_type());
462 assert!(NodeKind::ThinFunctionType.is_function_type());
463 assert!(!NodeKind::Module.is_function_type());
464 assert!(!NodeKind::Class.is_function_type());
465 }
466
467 #[test]
468 fn test_is_type_context() {
469 assert!(NodeKind::Class.is_type_context());
470 assert!(NodeKind::Structure.is_type_context());
471 assert!(NodeKind::Enum.is_type_context());
472 assert!(NodeKind::Protocol.is_type_context());
473 assert!(NodeKind::Extension.is_type_context());
474 assert!(!NodeKind::Module.is_type_context());
475 assert!(!NodeKind::FunctionType.is_type_context());
476 }
477}