Skip to main content

vox_codegen/targets/swift/
types.rs

1//! Swift type generation and collection.
2//!
3//! This module handles:
4//! - Collecting named types (structs and enums) from service definitions
5//! - Generating Swift type definitions (structs, enums)
6//! - Converting Rust types to Swift type strings
7
8use std::collections::HashSet;
9
10use facet_core::{ScalarType, Shape};
11use heck::ToLowerCamelCase;
12use vox_types::{
13    EnumInfo, ServiceDescriptor, ShapeKind, StructInfo, VariantKind, classify_shape,
14    classify_variant, is_bytes, is_rx, is_tx,
15};
16
17/// Collect all named types (structs and enums with a name) from a service.
18/// Returns a vector of (name, Shape) pairs in dependency order.
19pub fn collect_named_types(service: &ServiceDescriptor) -> Vec<(String, &'static Shape)> {
20    let mut seen: HashSet<String> = HashSet::new();
21    let mut types = Vec::new();
22
23    fn visit(
24        shape: &'static Shape,
25        seen: &mut HashSet<String>,
26        types: &mut Vec<(String, &'static Shape)>,
27    ) {
28        match classify_shape(shape) {
29            ShapeKind::Struct(StructInfo {
30                name: Some(name),
31                fields,
32                ..
33            }) if seen.insert(name.to_string()) => {
34                // Visit nested types first (dependencies before dependents)
35                for field in fields {
36                    visit(field.shape(), seen, types);
37                }
38                types.push((name.to_string(), shape));
39            }
40            ShapeKind::Enum(EnumInfo {
41                name: Some(name),
42                variants,
43            }) if seen.insert(name.to_string()) => {
44                // Visit nested types in variants
45                for variant in variants {
46                    match classify_variant(variant) {
47                        VariantKind::Newtype { inner } => visit(inner, seen, types),
48                        VariantKind::Struct { fields } | VariantKind::Tuple { fields } => {
49                            for field in fields {
50                                visit(field.shape(), seen, types);
51                            }
52                        }
53                        VariantKind::Unit => {}
54                    }
55                }
56                types.push((name.to_string(), shape));
57            }
58            ShapeKind::List { element }
59            | ShapeKind::Slice { element }
60            | ShapeKind::Option { inner: element }
61            | ShapeKind::Array { element, .. }
62            | ShapeKind::Set { element } => visit(element, seen, types),
63            ShapeKind::Map { key, value } => {
64                visit(key, seen, types);
65                visit(value, seen, types);
66            }
67            ShapeKind::Tuple { elements } => {
68                for param in elements {
69                    visit(param.shape, seen, types);
70                }
71            }
72            ShapeKind::Tx { inner } | ShapeKind::Rx { inner } => visit(inner, seen, types),
73            ShapeKind::Pointer { pointee } => visit(pointee, seen, types),
74            ShapeKind::Result { ok, err } => {
75                visit(ok, seen, types);
76                visit(err, seen, types);
77            }
78            _ => {}
79        }
80    }
81
82    for method in service.methods {
83        for arg in method.args {
84            visit(arg.shape, &mut seen, &mut types);
85        }
86        visit(method.return_shape, &mut seen, &mut types);
87    }
88
89    types
90}
91
92/// Generate Swift type definitions for all named types.
93pub fn generate_named_types(named_types: &[(String, &'static Shape)]) -> String {
94    let mut out = String::new();
95
96    for (name, shape) in named_types {
97        match classify_shape(shape) {
98            ShapeKind::Struct(StructInfo { fields, .. }) => {
99                out.push_str(&format!("public struct {name}: Codable, Sendable {{\n"));
100                for field in fields {
101                    let field_name = swift_field_name(field.name);
102                    let field_type = swift_type_base(field.shape());
103                    out.push_str(&format!("    public var {field_name}: {field_type}\n"));
104                }
105                out.push('\n');
106                // Generate initializer
107                out.push_str("    nonisolated public init(");
108                for (i, field) in fields.iter().enumerate() {
109                    if i > 0 {
110                        out.push_str(", ");
111                    }
112                    let field_name = swift_field_name(field.name);
113                    let field_type = swift_type_base(field.shape());
114                    out.push_str(&format!("{field_name}: {field_type}"));
115                }
116                out.push_str(") {\n");
117                for field in fields {
118                    let field_name = swift_field_name(field.name);
119                    out.push_str(&format!("        self.{field_name} = {field_name}\n"));
120                }
121                out.push_str("    }\n");
122                out.push_str("}\n\n");
123            }
124            ShapeKind::Enum(EnumInfo { variants, .. }) => {
125                // Add Error conformance if the enum name ends with "Error"
126                let protocols = if name.ends_with("Error") {
127                    "Codable, Sendable, Error"
128                } else {
129                    "Codable, Sendable"
130                };
131                out.push_str(&format!("public enum {name}: {protocols} {{\n"));
132                for variant in variants {
133                    let variant_name = swift_field_name(variant.name);
134                    match classify_variant(variant) {
135                        VariantKind::Unit => {
136                            out.push_str(&format!("    case {variant_name}\n"));
137                        }
138                        VariantKind::Newtype { inner } => {
139                            let inner_type = swift_type_base(inner);
140                            out.push_str(&format!("    case {variant_name}({inner_type})\n"));
141                        }
142                        VariantKind::Tuple { fields } => {
143                            let field_types: Vec<_> =
144                                fields.iter().map(|f| swift_type_base(f.shape())).collect();
145                            out.push_str(&format!(
146                                "    case {variant_name}({})\n",
147                                field_types.join(", ")
148                            ));
149                        }
150                        VariantKind::Struct { fields } => {
151                            let field_decls: Vec<_> = fields
152                                .iter()
153                                .map(|f| {
154                                    format!(
155                                        "{}: {}",
156                                        swift_field_name(f.name),
157                                        swift_type_base(f.shape())
158                                    )
159                                })
160                                .collect();
161                            out.push_str(&format!(
162                                "    case {variant_name}({})\n",
163                                field_decls.join(", ")
164                            ));
165                        }
166                    }
167                }
168                out.push_str("}\n\n");
169            }
170            _ => {}
171        }
172    }
173
174    out
175}
176
177/// Map a Rust field or variant name to a valid Swift identifier.
178///
179/// Two transforms layered on top of `to_lower_camel_case`:
180///   - Rust tuple-struct positional fields are exposed by facet with
181///     numeric names ("0", "1", ...); prefix them with an underscore.
182///   - Swift reserved keywords (e.g. `internal`, `class`) need
183///     backticks when used as identifiers.
184pub fn swift_field_name(name: &str) -> String {
185    if name.chars().next().is_some_and(|c| c.is_ascii_digit()) {
186        return format!("_{name}");
187    }
188    let lower = name.to_lower_camel_case();
189    if SWIFT_RESERVED.binary_search(&lower.as_str()).is_ok() {
190        format!("`{lower}`")
191    } else {
192        lower
193    }
194}
195
196/// Sorted list of Swift reserved words and contextual keywords that
197/// can't be used as bare identifiers. Kept sorted so we can
198/// `binary_search` it.
199const SWIFT_RESERVED: &[&str] = &[
200    "Any",
201    "Self",
202    "as",
203    "associatedtype",
204    "break",
205    "case",
206    "catch",
207    "class",
208    "continue",
209    "default",
210    "defer",
211    "deinit",
212    "do",
213    "else",
214    "enum",
215    "extension",
216    "fallthrough",
217    "false",
218    "fileprivate",
219    "for",
220    "func",
221    "guard",
222    "if",
223    "import",
224    "in",
225    "init",
226    "inout",
227    "internal",
228    "is",
229    "let",
230    "nil",
231    "open",
232    "operator",
233    "precedencegroup",
234    "private",
235    "protocol",
236    "public",
237    "repeat",
238    "rethrows",
239    "return",
240    "self",
241    "static",
242    "struct",
243    "subscript",
244    "super",
245    "switch",
246    "throw",
247    "throws",
248    "true",
249    "try",
250    "typealias",
251    "var",
252    "where",
253    "while",
254];
255
256/// Convert ScalarType to Swift type string.
257pub fn swift_scalar_type(scalar: ScalarType) -> String {
258    match scalar {
259        ScalarType::Bool => "Bool".into(),
260        ScalarType::U8 => "UInt8".into(),
261        ScalarType::U16 => "UInt16".into(),
262        ScalarType::U32 => "UInt32".into(),
263        ScalarType::U64 => "UInt64".into(),
264        ScalarType::U128 => "UInt128".into(),
265        ScalarType::USize => "UInt".into(),
266        ScalarType::I8 => "Int8".into(),
267        ScalarType::I16 => "Int16".into(),
268        ScalarType::I32 => "Int32".into(),
269        ScalarType::I64 => "Int64".into(),
270        ScalarType::I128 => "Int128".into(),
271        ScalarType::ISize => "Int".into(),
272        ScalarType::F32 => "Float".into(),
273        ScalarType::F64 => "Double".into(),
274        ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
275            "String".into()
276        }
277        ScalarType::Unit => "Void".into(),
278        _ => "Data".into(),
279    }
280}
281
282/// Convert Shape to Swift type string.
283pub fn swift_type_base(shape: &'static Shape) -> String {
284    // Check for bytes first
285    if is_bytes(shape) {
286        return "Data".into();
287    }
288
289    match classify_shape(shape) {
290        ShapeKind::Scalar(scalar) => swift_scalar_type(scalar),
291        ShapeKind::List { element } => format!("[{}]", swift_type_base(element)),
292        ShapeKind::Slice { element } => format!("[{}]", swift_type_base(element)),
293        ShapeKind::Option { inner } => format!("{}?", swift_type_base(inner)),
294        ShapeKind::Array { element, .. } => format!("[{}]", swift_type_base(element)),
295        ShapeKind::Map { key, value } => {
296            format!("[{}: {}]", swift_type_base(key), swift_type_base(value))
297        }
298        ShapeKind::Set { element } => format!("Set<{}>", swift_type_base(element)),
299        ShapeKind::Tuple { elements } => {
300            if elements.is_empty() {
301                "Void".into()
302            } else {
303                let types: Vec<_> = elements.iter().map(|p| swift_type_base(p.shape)).collect();
304                format!("({})", types.join(", "))
305            }
306        }
307        ShapeKind::Tx { inner } => format!("Tx<{}>", swift_type_base(inner)),
308        ShapeKind::Rx { inner } => format!("Rx<{}>", swift_type_base(inner)),
309        ShapeKind::Struct(StructInfo {
310            name: Some(name), ..
311        }) => name.to_string(),
312        ShapeKind::Enum(EnumInfo {
313            name: Some(name), ..
314        }) => name.to_string(),
315        ShapeKind::Struct(StructInfo {
316            name: None, fields, ..
317        }) => {
318            // Anonymous struct - use tuple-like representation
319            let types: Vec<_> = fields.iter().map(|f| swift_type_base(f.shape())).collect();
320            format!("({})", types.join(", "))
321        }
322        ShapeKind::Enum(EnumInfo {
323            name: None,
324            variants,
325        }) => {
326            // Anonymous enum - not well supported in Swift, use Any
327            let _ = variants; // suppress warning
328            "Any".into()
329        }
330        ShapeKind::Pointer { pointee } => swift_type_base(pointee),
331        ShapeKind::Result { ok, err } => {
332            format!("Result<{}, {}>", swift_type_base(ok), swift_type_base(err))
333        }
334        ShapeKind::TupleStruct { fields } => {
335            let types: Vec<_> = fields.iter().map(|f| swift_type_base(f.shape())).collect();
336            format!("({})", types.join(", "))
337        }
338        ShapeKind::Opaque => "Data".into(),
339    }
340}
341
342/// Convert Shape to Swift type string for client arguments.
343pub fn swift_type_client_arg(shape: &'static Shape) -> String {
344    match classify_shape(shape) {
345        ShapeKind::Tx { inner } => format!("UnboundTx<{}>", swift_type_base(inner)),
346        ShapeKind::Rx { inner } => format!("UnboundRx<{}>", swift_type_base(inner)),
347        _ => swift_type_base(shape),
348    }
349}
350
351/// Convert Shape to Swift type string for client returns.
352pub fn swift_type_client_return(shape: &'static Shape) -> String {
353    assert_no_channels_in_return_shape(shape);
354    match classify_shape(shape) {
355        ShapeKind::Scalar(ScalarType::Unit) => "Void".into(),
356        ShapeKind::Tuple { elements: [] } => "Void".into(),
357        _ => swift_type_base(shape),
358    }
359}
360
361/// Convert Shape to Swift type string for server/handler arguments.
362pub fn swift_type_server_arg(shape: &'static Shape) -> String {
363    match classify_shape(shape) {
364        ShapeKind::Tx { inner } => format!("Tx<{}>", swift_type_base(inner)),
365        ShapeKind::Rx { inner } => format!("Rx<{}>", swift_type_base(inner)),
366        _ => swift_type_base(shape),
367    }
368}
369
370/// Convert Shape to Swift type string for server returns.
371pub fn swift_type_server_return(shape: &'static Shape) -> String {
372    assert_no_channels_in_return_shape(shape);
373    match classify_shape(shape) {
374        ShapeKind::Scalar(ScalarType::Unit) => "Void".into(),
375        ShapeKind::Tuple { elements: [] } => "Void".into(),
376        _ => swift_type_base(shape),
377    }
378}
379
380/// Check if a shape represents a channel type (Tx or Rx).
381pub fn is_channel(shape: &'static Shape) -> bool {
382    is_tx(shape) || is_rx(shape)
383}
384
385/// Format documentation comments for Swift.
386pub fn format_doc(doc: &str, indent: &str) -> String {
387    doc.lines()
388        .map(|line| format!("{indent}/// {line}\n"))
389        .collect()
390}
391
392pub fn assert_no_channels_in_return_shape(shape: &'static Shape) {
393    fn has_channel(shape: &'static Shape) -> bool {
394        matches!(
395            classify_shape(shape),
396            ShapeKind::Tx { .. } | ShapeKind::Rx { .. }
397        )
398    }
399    assert!(
400        !has_channel(shape),
401        "channels are not allowed in return types"
402    );
403}