Skip to main content

swift_demangler/
function.rs

1//! Function symbol representation.
2//!
3//! This module provides the `Function` struct for representing Swift function symbols.
4
5use crate::context::{SymbolContext, extract_context};
6use crate::helpers::{
7    HasExtensionContext, HasFunctionSignature, HasGenericSignature, HasModule, NodeExt,
8};
9use crate::raw::Node;
10use crate::types::{FunctionType, GenericSignature, TypeRef};
11
12/// A Swift function symbol.
13#[derive(Clone, Copy)]
14pub struct Function<'ctx> {
15    raw: Node<'ctx>,
16    is_static: bool,
17}
18
19impl<'ctx> Function<'ctx> {
20    /// Create a Function from a raw node.
21    pub fn new(raw: Node<'ctx>) -> Self {
22        Self {
23            raw,
24            is_static: false,
25        }
26    }
27
28    /// Create a Function from a raw node, marking it as static.
29    pub fn new_static(raw: Node<'ctx>) -> Self {
30        Self {
31            raw,
32            is_static: true,
33        }
34    }
35
36    /// Get the underlying raw node.
37    pub fn raw(&self) -> Node<'ctx> {
38        self.raw
39    }
40
41    /// Get the context (location) where this function is defined.
42    pub fn context(&self) -> SymbolContext<'ctx> {
43        extract_context(self.raw)
44    }
45
46    /// Get the name of this function.
47    pub fn name(&self) -> Option<&'ctx str> {
48        self.raw.find_identifier_extended()
49    }
50
51    /// Get the argument labels for this function.
52    ///
53    /// Returns a vector where each element is `Some(label)` for labeled parameters
54    /// and `None` for unlabeled parameters (using `_`).
55    pub fn labels(&self) -> Vec<Option<&'ctx str>> {
56        self.raw.extract_labels()
57    }
58
59    /// Get the return type of this function.
60    pub fn return_type(&self) -> Option<TypeRef<'ctx>> {
61        self.signature().and_then(|s| s.return_type())
62    }
63
64    /// Check if this is a method (defined in a type context).
65    pub fn is_method(&self) -> bool {
66        self.raw.has_type_context()
67    }
68
69    /// Check if this is a static/class method.
70    pub fn is_static(&self) -> bool {
71        self.is_static
72    }
73
74    /// Get the containing type name if this is a method.
75    pub fn containing_type(&self) -> Option<&'ctx str> {
76        self.raw.find_containing_type()
77    }
78
79    /// Check if the containing type is a class (reference type).
80    pub fn containing_type_is_class(&self) -> bool {
81        self.raw.containing_type_is_class()
82    }
83
84    /// Check if the containing type is a protocol.
85    pub fn containing_type_is_protocol(&self) -> bool {
86        self.raw.containing_type_is_protocol()
87    }
88
89    /// Get the full name with labels (e.g., "foo(bar:baz:)").
90    pub fn full_name(&self) -> String {
91        let name = self.name().unwrap_or("");
92        let labels = self.labels();
93
94        if labels.is_empty() {
95            // Check if there are parameters from the signature
96            if let Some(sig) = self.signature() {
97                let params = sig.parameters();
98                if params.is_empty() {
99                    format!("{name}()")
100                } else {
101                    let param_labels: Vec<String> = params
102                        .iter()
103                        .map(|p| {
104                            p.label
105                                .map(|l| format!("{l}:"))
106                                .unwrap_or_else(|| "_:".to_string())
107                        })
108                        .collect();
109                    format!("{}({})", name, param_labels.join(""))
110                }
111            } else {
112                format!("{name}()")
113            }
114        } else {
115            let label_strs: Vec<String> = labels
116                .iter()
117                .map(|l| {
118                    l.map(|s| format!("{s}:"))
119                        .unwrap_or_else(|| "_:".to_string())
120                })
121                .collect();
122            format!("{}({})", name, label_strs.join(""))
123        }
124    }
125}
126
127impl<'ctx> HasGenericSignature<'ctx> for Function<'ctx> {
128    fn generic_signature(&self) -> Option<GenericSignature<'ctx>> {
129        self.raw.find_generic_signature()
130    }
131}
132
133impl<'ctx> HasFunctionSignature<'ctx> for Function<'ctx> {
134    fn signature(&self) -> Option<FunctionType<'ctx>> {
135        self.raw.find_function_type()
136    }
137}
138
139impl<'ctx> HasExtensionContext<'ctx> for Function<'ctx> {
140    fn raw(&self) -> Node<'ctx> {
141        self.raw
142    }
143}
144
145impl<'ctx> HasModule<'ctx> for Function<'ctx> {
146    fn module(&self) -> Option<&'ctx str> {
147        self.raw.find_module()
148    }
149}
150
151impl std::fmt::Debug for Function<'_> {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        let mut s = f.debug_struct("Function");
154        s.field("module", &self.module())
155            .field("containing_type", &self.containing_type())
156            .field("name", &self.name())
157            .field("labels", &self.labels())
158            .field("is_method", &self.is_method())
159            .field("is_static", &self.is_static())
160            .field("is_async", &self.is_async())
161            .field("is_throwing", &self.is_throwing())
162            .field("is_extension", &self.is_extension())
163            .field("is_generic", &self.is_generic());
164        if let Some(sig) = self.signature() {
165            s.field("signature", &sig);
166        }
167        if self.is_extension() {
168            s.field("extension_module", &self.extension_module());
169            let ext_requirements = self.extension_generic_requirements();
170            if !ext_requirements.is_empty() {
171                s.field("extension_generic_requirements", &ext_requirements);
172            }
173        }
174        let requirements = self.generic_requirements();
175        if !requirements.is_empty() {
176            s.field("generic_requirements", &requirements);
177        }
178        s.finish()
179    }
180}
181
182impl std::fmt::Display for Function<'_> {
183    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184        write!(f, "{}", self.raw)
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use crate::helpers::{HasFunctionSignature, HasModule};
191    use crate::raw::Context;
192    use crate::symbol::Symbol;
193
194    #[test]
195    fn test_simple_function() {
196        let ctx = Context::new();
197        let symbol = Symbol::parse(&ctx, "$s4main5helloSSyYaKF").unwrap();
198        if let Symbol::Function(func) = symbol {
199            assert_eq!(func.name(), Some("hello"));
200            assert_eq!(func.module(), Some("main"));
201            assert!(func.is_async());
202            assert!(func.is_throwing());
203            assert!(!func.is_method());
204        } else {
205            panic!("Expected function");
206        }
207    }
208
209    #[test]
210    fn test_method() {
211        let ctx = Context::new();
212        // foo.bar.bas(zim: foo.zim) -> ()
213        let symbol = Symbol::parse(&ctx, "_TFC3foo3bar3basfT3zimCS_3zim_T_").unwrap();
214        if let Symbol::Function(func) = symbol {
215            assert_eq!(func.name(), Some("bas"));
216            assert_eq!(func.module(), Some("foo"));
217            assert!(func.is_method());
218            assert_eq!(func.containing_type(), Some("bar"));
219        } else {
220            panic!("Expected function");
221        }
222    }
223
224    #[test]
225    fn test_function_signature() {
226        let ctx = Context::new();
227        let symbol = Symbol::parse(&ctx, "$s4main5helloSSyYaKF").unwrap();
228        if let Symbol::Function(func) = symbol {
229            let sig = func.signature().expect("Expected signature");
230            assert!(sig.is_async());
231            assert!(sig.is_throwing());
232            assert!(sig.return_type().is_some());
233        } else {
234            panic!("Expected function");
235        }
236    }
237
238    #[test]
239    fn test_function_display() {
240        let ctx = Context::new();
241        let symbol = Symbol::parse(&ctx, "$s4main5helloSSyYaKF").unwrap();
242        if let Symbol::Function(func) = symbol {
243            let display = func.to_string();
244            assert!(display.contains("hello"));
245        } else {
246            panic!("Expected function");
247        }
248    }
249}