Skip to main content

specta_swift/
swift.rs

1//! Swift language exporter configuration and main export functionality.
2
3use std::{borrow::Cow, fmt, path::Path};
4
5use specta::{
6    Format, Types,
7    datatype::{DataType, Fields, Reference},
8};
9
10use crate::Error;
11use crate::primitives::{export_type, is_duration_struct};
12
13/// Swift language exporter.
14#[derive(Clone)]
15pub struct Swift {
16    /// Header comment for generated files.
17    pub header: Cow<'static, str>,
18    /// Indentation style for generated code.
19    pub indent: IndentStyle,
20    /// Naming convention for identifiers.
21    pub naming: NamingConvention,
22    /// Generic type style.
23    pub generics: GenericStyle,
24    /// Optional type style.
25    pub optionals: OptionalStyle,
26    /// Additional protocols to conform to.
27    pub protocols: Vec<Cow<'static, str>>,
28}
29
30impl fmt::Debug for Swift {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        f.debug_struct("Swift")
33            .field("header", &self.header)
34            .field("indent", &self.indent)
35            .field("naming", &self.naming)
36            .field("generics", &self.generics)
37            .field("optionals", &self.optionals)
38            .field("protocols", &self.protocols)
39            .finish()
40    }
41}
42
43/// Indentation style for generated Swift code.
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum IndentStyle {
46    /// Use spaces for indentation.
47    Spaces(usize),
48    /// Use tabs for indentation.
49    Tabs,
50}
51
52impl Default for IndentStyle {
53    fn default() -> Self {
54        Self::Spaces(4)
55    }
56}
57
58/// Naming convention for Swift identifiers.
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
60pub enum NamingConvention {
61    /// PascalCase naming (default for Swift types).
62    #[default]
63    PascalCase,
64    /// camelCase naming.
65    CamelCase,
66    /// snake_case naming.
67    SnakeCase,
68}
69
70/// Generic type style for Swift.
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
72pub enum GenericStyle {
73    /// Use protocol constraints: `<T: Codable>`.
74    #[default]
75    Protocol,
76    /// Use where clauses: `<T> where T: Codable`.
77    Typealias,
78}
79
80/// Optional type style for Swift.
81#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
82pub enum OptionalStyle {
83    /// Use question mark syntax: `String?`.
84    #[default]
85    QuestionMark,
86    /// Use Optional type: `Optional<String>`.
87    Optional,
88}
89
90impl Default for Swift {
91    fn default() -> Self {
92        Self {
93            header: "// This file has been generated by Specta. DO NOT EDIT.".into(),
94            indent: IndentStyle::default(),
95            naming: NamingConvention::default(),
96            generics: GenericStyle::default(),
97            optionals: OptionalStyle::default(),
98            protocols: vec![],
99        }
100    }
101}
102
103impl Swift {
104    /// Create a new Swift exporter with default configuration.
105    pub fn new() -> Self {
106        Self::default()
107    }
108
109    /// Set the header comment for generated files.
110    pub fn header(mut self, header: impl Into<Cow<'static, str>>) -> Self {
111        self.header = header.into();
112        self
113    }
114
115    /// Set the indentation style.
116    pub fn indent(mut self, style: IndentStyle) -> Self {
117        self.indent = style;
118        self
119    }
120
121    /// Set the naming convention.
122    pub fn naming(mut self, convention: NamingConvention) -> Self {
123        self.naming = convention;
124        self
125    }
126
127    /// Set the generic type style.
128    pub fn generics(mut self, style: GenericStyle) -> Self {
129        self.generics = style;
130        self
131    }
132
133    /// Set the optional type style.
134    pub fn optionals(mut self, style: OptionalStyle) -> Self {
135        self.optionals = style;
136        self
137    }
138
139    /// Add a protocol that all types should conform to.
140    pub fn add_protocol(mut self, protocol: impl Into<Cow<'static, str>>) -> Self {
141        self.protocols.push(protocol.into());
142        self
143    }
144
145    /// Export types to a Swift string.
146    pub fn export(&self, types: &Types, format: impl Format) -> Result<String, Error> {
147        let exporter = self.clone();
148        let formatted_types = format_types(types, &format)?.into_owned();
149        let raw_types = &formatted_types;
150
151        let mut result = String::new();
152
153        // Add header
154        if !exporter.header.is_empty() {
155            result.push_str(&exporter.header);
156            result.push('\n');
157        }
158
159        // Add imports
160        result.push_str("import Foundation\n");
161        for protocol in &exporter.protocols {
162            result.push_str(&format!("import {}\n", protocol));
163        }
164        result.push('\n');
165
166        // Check if we need to inject Duration helper
167        if needs_duration_helper(raw_types) {
168            result.push_str(&generate_duration_helper());
169        }
170
171        // Export types
172        for ndt in raw_types.into_sorted_iter() {
173            let exported = export_type(&exporter, Some(&format), raw_types, ndt)?;
174            if !exported.is_empty() {
175                result.push_str(&exported);
176                result.push_str("\n\n");
177            }
178        }
179
180        Ok(result)
181    }
182
183    /// Export types to a file.
184    pub fn export_to(
185        &self,
186        path: impl AsRef<Path>,
187        types: &Types,
188        format: impl Format,
189    ) -> Result<(), Error> {
190        let content = self.export(types, format)?;
191        std::fs::write(path, content)?;
192        Ok(())
193    }
194}
195
196fn format_types<'a>(types: &'a Types, format: &'a dyn Format) -> Result<Cow<'a, Types>, Error> {
197    format
198        .map_types(types)
199        .map_err(|err| Error::format("type graph formatter failed", err))
200}
201
202impl NamingConvention {
203    /// Convert a string to the appropriate naming convention.
204    pub fn convert(&self, name: &str) -> String {
205        match self {
206            Self::PascalCase => self.to_pascal_case(name),
207            Self::CamelCase => self.to_camel_case(name),
208            Self::SnakeCase => self.to_snake_case(name),
209        }
210    }
211
212    /// Convert a string to camelCase (for field names).
213    pub fn convert_to_camel_case(&self, name: &str) -> String {
214        self.to_camel_case(name)
215    }
216
217    /// Convert a string to the appropriate naming convention for fields.
218    pub fn convert_field(&self, name: &str) -> String {
219        match self {
220            Self::PascalCase => self.to_camel_case(name), // Fields should be camelCase even with PascalCase
221            Self::CamelCase => self.to_camel_case(name),
222            Self::SnakeCase => self.to_snake_case(name),
223        }
224    }
225
226    /// Convert a string to the appropriate naming convention for enum cases.
227    pub fn convert_enum_case(&self, name: &str) -> String {
228        match self {
229            Self::PascalCase => self.to_camel_case(name), // Enum cases should be camelCase
230            Self::CamelCase => self.to_camel_case(name),
231            Self::SnakeCase => self.to_snake_case(name),
232        }
233    }
234
235    #[allow(clippy::wrong_self_convention)]
236    fn to_camel_case(&self, name: &str) -> String {
237        // Convert snake_case or PascalCase to camelCase
238        if name.contains('_') {
239            // Handle snake_case
240            let parts: Vec<&str> = name.split('_').collect();
241            if parts.is_empty() {
242                return name.to_string();
243            }
244
245            let mut result = String::new();
246            for (i, part) in parts.iter().enumerate() {
247                if i == 0 {
248                    result.push_str(&part.to_lowercase());
249                } else {
250                    let mut chars = part.chars();
251                    match chars.next() {
252                        None => continue,
253                        Some(first) => {
254                            result.push(first.to_uppercase().next().unwrap_or(first));
255                            for c in chars {
256                                result.extend(c.to_lowercase());
257                            }
258                        }
259                    }
260                }
261            }
262            result
263        } else {
264            if name.chars().any(|c| c.is_ascii_alphabetic())
265                && name
266                    .chars()
267                    .all(|c| !c.is_ascii_alphabetic() || c.is_ascii_uppercase())
268            {
269                return name.to_ascii_lowercase();
270            }
271
272            // Handle PascalCase - convert to camelCase
273            let mut chars = name.chars();
274            match chars.next() {
275                None => name.to_string(),
276                Some(first) => {
277                    let mut result = String::new();
278                    result.push(first.to_lowercase().next().unwrap_or(first));
279                    for c in chars {
280                        result.push(c); // Keep the rest as-is for PascalCase
281                    }
282                    result
283                }
284            }
285        }
286    }
287
288    #[allow(clippy::wrong_self_convention)]
289    fn to_pascal_case(&self, name: &str) -> String {
290        // Convert snake_case to PascalCase
291        name.split('_')
292            .map(|part| {
293                let mut chars = part.chars();
294                match chars.next() {
295                    None => String::new(),
296                    Some(first) => first.to_uppercase().chain(chars).collect(),
297                }
298            })
299            .collect()
300    }
301
302    #[allow(clippy::wrong_self_convention)]
303    fn to_snake_case(&self, name: &str) -> String {
304        // Convert camelCase/PascalCase to snake_case
305        let mut result = String::new();
306        let chars = name.chars();
307
308        for c in chars {
309            if c.is_uppercase() && !result.is_empty() {
310                result.push('_');
311            }
312            result.push(c.to_lowercase().next().unwrap_or(c));
313        }
314
315        result
316    }
317}
318
319/// Check if the type collection contains any Duration types that need the helper
320fn needs_duration_helper(types: &Types) -> bool {
321    for ndt in types.into_sorted_iter() {
322        if ndt.name == "Duration" {
323            return true;
324        }
325        // Also check if any struct fields contain Duration
326        if let Some(DataType::Struct(s)) = &ndt.ty
327            && let Fields::Named(fields) = &s.fields
328        {
329            for (_, field) in &fields.fields {
330                if let Some(ty) = field.ty.as_ref() {
331                    if let DataType::Reference(Reference::Named(r)) = ty
332                        && let Some(referenced_ndt) = types.get(r)
333                        && referenced_ndt.name == "Duration"
334                    {
335                        return true;
336                    }
337                    // Also check if the field type is a Duration struct directly
338                    if let DataType::Struct(struct_ty) = ty
339                        && is_duration_struct(struct_ty)
340                    {
341                        return true;
342                    }
343                }
344            }
345        }
346    }
347    false
348}
349
350/// Generate the Duration helper struct
351fn generate_duration_helper() -> String {
352    "// MARK: - Duration Helper\n".to_string()
353        + "/// Helper struct to decode Rust Duration format {\"secs\": u64, \"nanos\": u32}\n"
354        + "public struct RustDuration: Codable {\n"
355        + "    public let secs: UInt64\n"
356        + "    public let nanos: UInt32\n"
357        + "    \n"
358        + "    public var timeInterval: TimeInterval {\n"
359        + "        return Double(secs) + Double(nanos) / 1_000_000_000.0\n"
360        + "    }\n"
361        + "}\n\n"
362        + "// MARK: - Generated Types\n\n"
363}