Skip to main content

vox_codegen/targets/swift/
wire.rs

1//! Swift wire type code generation.
2//!
3//! Generates a complete Swift source file containing all wire protocol types
4//! (Message, MessagePayload, etc.) with encode/decode methods. The generated
5//! code is driven by facet `Shape`s from `vox_types::message`.
6//!
7//! The only special-cased type is `Payload` (`ShapeKind::Opaque`), which maps to
8//! `OpaquePayload` with both length-prefixed and trailing byte handling.
9//! Everything else is normal struct/enum codegen.
10
11use facet_core::{Field, ScalarType, Shape};
12use heck::ToLowerCamelCase;
13use vox_types::{
14    EnumInfo, ShapeKind, StructInfo, VariantKind, classify_shape, classify_variant,
15    extract_schemas, is_bytes,
16};
17
18/// A wire type to generate Swift code for.
19pub struct WireType {
20    /// The Swift name for this type (e.g. "Message", "HelloV7")
21    pub swift_name: String,
22    /// The facet Shape for this type
23    pub shape: &'static Shape,
24}
25
26/// Generate a complete Swift wire types file.
27pub fn generate_wire_types(types: &[WireType]) -> String {
28    let mut out = String::new();
29    out.push_str("// @generated by vox-codegen\n");
30    out.push_str("// DO NOT EDIT — regenerate with `cargo xtask codegen --swift-wire`\n\n");
31    out.push_str("import Foundation\n\n");
32
33    // Preamble: error type + helpers + OpaquePayload
34    out.push_str(&generate_preamble());
35
36    // Metadata typealias (Metadata is Vec<MetadataEntry> in Rust, we alias for convenience)
37    out.push_str("public typealias Metadata = [MetadataEntry]\n\n");
38
39    // Generate each type
40    for wt in types {
41        out.push_str(&generate_one_type(&wt.swift_name, wt.shape, types));
42        out.push('\n');
43    }
44
45    // Generate factory methods extension on Message
46    out.push_str(&generate_factory_methods(types));
47
48    if let Some(root) = types
49        .iter()
50        .find(|wt| wt.swift_name == "Message")
51        .or_else(|| types.last())
52    {
53        let extracted = extract_schemas(root.shape).expect("wire schema extraction should succeed");
54        let cbor_bytes = facet_cbor::to_vec(&extracted.schemas)
55            .expect("wire schema CBOR serialization should succeed");
56        let body = cbor_bytes
57            .iter()
58            .map(|b| b.to_string())
59            .collect::<Vec<_>>()
60            .join(", ");
61        out.push('\n');
62        out.push_str(&format!(
63            "public let wireMessageSchemasCbor: [UInt8] = [{body}]\n"
64        ));
65    }
66
67    out
68}
69
70/// Generate the static preamble: WireError, helpers, OpaquePayload.
71fn generate_preamble() -> String {
72    let mut out = String::new();
73
74    // Error type
75    out.push_str("public enum WireError: Error, Equatable {\n");
76    out.push_str("    case truncated\n");
77    out.push_str("    case unknownVariant(UInt64)\n");
78    out.push_str("    case overflow\n");
79    out.push_str("    case invalidUtf8\n");
80    out.push_str("    case trailingBytes\n");
81    out.push_str("}\n\n");
82
83    // OpaquePayload
84    out.push_str("public struct OpaquePayload: Sendable, Equatable {\n");
85    out.push_str("    public var bytes: [UInt8]\n\n");
86    out.push_str("    public init(_ bytes: [UInt8]) {\n");
87    out.push_str("        self.bytes = bytes\n");
88    out.push_str("    }\n\n");
89    out.push_str("    func encode() -> [UInt8] {\n");
90    out.push_str("        let len = UInt32(bytes.count)\n");
91    out.push_str("        return [\n");
92    out.push_str("            UInt8(truncatingIfNeeded: len),\n");
93    out.push_str("            UInt8(truncatingIfNeeded: len >> 8),\n");
94    out.push_str("            UInt8(truncatingIfNeeded: len >> 16),\n");
95    out.push_str("            UInt8(truncatingIfNeeded: len >> 24),\n");
96    out.push_str("        ] + bytes\n");
97    out.push_str("    }\n\n");
98    out.push_str("    static func decode(from data: Data, offset: inout Int) throws -> Self {\n");
99    out.push_str("        guard offset + 4 <= data.count else { throw WireError.truncated }\n");
100    out.push_str("        let start = data.startIndex + offset\n");
101    out.push_str("        let len = Int(UInt32(data[start]) | (UInt32(data[start + 1]) << 8) | (UInt32(data[start + 2]) << 16) | (UInt32(data[start + 3]) << 24))\n");
102    out.push_str("        offset += 4\n");
103    out.push_str("        guard offset + len <= data.count else { throw WireError.truncated }\n");
104    out.push_str("        let payloadStart = data.startIndex + offset\n");
105    out.push_str("        let payloadEnd = payloadStart + len\n");
106    out.push_str("        let payload = Array(data[payloadStart..<payloadEnd])\n");
107    out.push_str("        offset += len\n");
108    out.push_str("        return .init(payload)\n");
109    out.push_str("    }\n\n");
110    out.push_str("    /// Encode without a length prefix — for trailing fields only.\n");
111    out.push_str("    func encodeTrailing() -> [UInt8] {\n");
112    out.push_str("        bytes\n");
113    out.push_str("    }\n\n");
114    out.push_str("    /// Decode by consuming all remaining bytes — for trailing fields only.\n");
115    out.push_str("    static func decodeTrailing(from data: Data, offset: inout Int) -> Self {\n");
116    out.push_str("        let start = data.startIndex + offset\n");
117    out.push_str("        let remaining = Array(data[start...])\n");
118    out.push_str("        offset = data.count\n");
119    out.push_str("        return .init(remaining)\n");
120    out.push_str("    }\n");
121    out.push_str("}\n\n");
122
123    // Helper functions
124    out.push_str("@inline(__always)\n");
125    out.push_str(
126        "private func decodeWireVarintU32(from data: Data, offset: inout Int) throws -> UInt32 {\n",
127    );
128    out.push_str("    let value = try decodeVarint(from: data, offset: &offset)\n");
129    out.push_str("    guard value <= UInt64(UInt32.max) else {\n");
130    out.push_str("        throw WireError.overflow\n");
131    out.push_str("    }\n");
132    out.push_str("    return UInt32(value)\n");
133    out.push_str("}\n\n");
134
135    out.push_str("@inline(__always)\n");
136    out.push_str(
137        "private func decodeWireString(from data: Data, offset: inout Int) throws -> String {\n",
138    );
139    out.push_str("    do {\n");
140    out.push_str("        return try decodeString(from: data, offset: &offset)\n");
141    out.push_str("    } catch PostcardError.invalidUtf8 {\n");
142    out.push_str("        throw WireError.invalidUtf8\n");
143    out.push_str("    } catch PostcardError.truncated {\n");
144    out.push_str("        throw WireError.truncated\n");
145    out.push_str("    } catch {\n");
146    out.push_str("        throw error\n");
147    out.push_str("    }\n");
148    out.push_str("}\n\n");
149
150    out.push_str("@inline(__always)\n");
151    out.push_str(
152        "private func decodeWireBytes(from data: Data, offset: inout Int) throws -> Data {\n",
153    );
154    out.push_str("    do {\n");
155    out.push_str("        return try decodeBytes(from: data, offset: &offset)\n");
156    out.push_str("    } catch PostcardError.truncated {\n");
157    out.push_str("        throw WireError.truncated\n");
158    out.push_str("    } catch {\n");
159    out.push_str("        throw error\n");
160    out.push_str("    }\n");
161    out.push_str("}\n\n");
162
163    out
164}
165
166/// Get the wire Swift type for a field's shape, taking the `WireType` list into account.
167fn swift_wire_type(shape: &'static Shape, _field: Option<&Field>, types: &[WireType]) -> String {
168    if is_bytes(shape) {
169        return "[UInt8]".into();
170    }
171
172    match classify_shape(shape) {
173        ShapeKind::Scalar(scalar) => swift_scalar_type(scalar),
174        ShapeKind::List { element } | ShapeKind::Slice { element } => {
175            format!("[{}]", swift_wire_type(element, None, types))
176        }
177        ShapeKind::Option { inner } => {
178            format!("{}?", swift_wire_type(inner, None, types))
179        }
180        ShapeKind::Array { element, .. } => {
181            format!("[{}]", swift_wire_type(element, None, types))
182        }
183        ShapeKind::Struct(StructInfo {
184            name: Some(name), ..
185        }) => lookup_wire_name(name, types),
186        ShapeKind::Enum(EnumInfo {
187            name: Some(name), ..
188        }) => lookup_wire_name(name, types),
189        ShapeKind::Pointer { pointee } => swift_wire_type(pointee, _field, types),
190        ShapeKind::Opaque => "OpaquePayload".into(),
191        ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
192            swift_wire_type(fields[0].shape(), None, types)
193        }
194        _ => "Any /* unsupported */".into(),
195    }
196}
197
198fn swift_scalar_type(scalar: ScalarType) -> String {
199    match scalar {
200        ScalarType::Bool => "Bool".into(),
201        ScalarType::U8 => "UInt8".into(),
202        ScalarType::U16 => "UInt16".into(),
203        ScalarType::U32 => "UInt32".into(),
204        ScalarType::U64 | ScalarType::USize => "UInt64".into(),
205        ScalarType::I8 => "Int8".into(),
206        ScalarType::I16 => "Int16".into(),
207        ScalarType::I32 => "Int32".into(),
208        ScalarType::I64 | ScalarType::ISize => "Int64".into(),
209        ScalarType::F32 => "Float".into(),
210        ScalarType::F64 => "Double".into(),
211        ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
212            "String".into()
213        }
214        ScalarType::Unit => "Void".into(),
215        _ => "Data".into(),
216    }
217}
218
219/// Look up the wire name for a Rust type name. If a matching WireType exists, use its
220/// swift_name; otherwise fall back to the Rust name.
221fn lookup_wire_name(rust_name: &str, types: &[WireType]) -> String {
222    for wt in types {
223        // Match by checking the reflected type identifier for the underlying Rust shape.
224        if wt.shape.type_identifier.ends_with(rust_name) {
225            return wt.swift_name.clone();
226        }
227    }
228    rust_name.to_string()
229}
230
231/// Generate a single wire type (struct or enum).
232fn generate_one_type(swift_name: &str, shape: &'static Shape, types: &[WireType]) -> String {
233    match classify_shape(shape) {
234        ShapeKind::Struct(StructInfo { fields, .. }) => {
235            generate_struct(swift_name, fields, types, swift_name == "Message")
236        }
237        ShapeKind::Enum(EnumInfo { variants, .. }) => generate_enum(swift_name, variants, types),
238        _ => format!("// Unsupported shape for {swift_name}\n"),
239    }
240}
241
242/// Generate a Swift struct with encode/decode methods.
243fn generate_struct(name: &str, fields: &[Field], types: &[WireType], is_top_level: bool) -> String {
244    let mut out = String::new();
245
246    // Struct definition
247    out.push_str(&format!("public struct {name}: Sendable, Equatable {{\n"));
248    for f in fields {
249        let field_name = f.name.to_lower_camel_case();
250        let field_type = swift_wire_type(f.shape(), Some(f), types);
251        out.push_str(&format!("    public var {field_name}: {field_type}\n"));
252    }
253
254    // Initializer
255    if !fields.is_empty() {
256        out.push_str("\n    public init(");
257        for (i, f) in fields.iter().enumerate() {
258            if i > 0 {
259                out.push_str(", ");
260            }
261            let field_name = f.name.to_lower_camel_case();
262            let field_type = swift_wire_type(f.shape(), Some(f), types);
263            out.push_str(&format!("{field_name}: {field_type}"));
264        }
265        out.push_str(") {\n");
266        for f in fields {
267            let field_name = f.name.to_lower_camel_case();
268            out.push_str(&format!("        self.{field_name} = {field_name}\n"));
269        }
270        out.push_str("    }\n");
271    }
272
273    // Encode method
274    let vis = if is_top_level { "public " } else { "" };
275    out.push_str(&format!("\n    {vis}func encode() -> [UInt8] {{\n"));
276    if fields.is_empty() {
277        out.push_str("        []\n");
278    } else if fields.len() == 1 {
279        let f = &fields[0];
280        let expr = encode_field_expr(f, types);
281        out.push_str(&format!("        {expr}\n"));
282    } else {
283        out.push_str("        var out: [UInt8] = []\n");
284        for f in fields {
285            let expr = encode_field_expr(f, types);
286            out.push_str(&format!("        out += {expr}\n"));
287        }
288        out.push_str("        return out\n");
289    }
290    out.push_str("    }\n");
291
292    // Decode method (with offset)
293    out.push_str(&format!(
294        "\n    {vis}static func decode(from data: Data, offset: inout Int) throws -> Self {{\n"
295    ));
296    for f in fields {
297        let field_name = f.name.to_lower_camel_case();
298        let decode_expr = decode_field_expr(f, types);
299        out.push_str(&format!("        let {field_name} = {decode_expr}\n"));
300    }
301    let field_names: Vec<String> = fields
302        .iter()
303        .map(|f| {
304            let n = f.name.to_lower_camel_case();
305            format!("{n}: {n}")
306        })
307        .collect();
308    out.push_str(&format!(
309        "        return .init({})\n",
310        field_names.join(", ")
311    ));
312    out.push_str("    }\n");
313
314    // Top-level Message gets an extra decode(from:) without offset that checks trailing bytes
315    if is_top_level {
316        out.push_str("\n    public static func decode(from data: Data) throws -> Self {\n");
317        out.push_str("        var offset = 0\n");
318        out.push_str("        let result = try decode(from: data, offset: &offset)\n");
319        out.push_str("        guard offset == data.count else {\n");
320        out.push_str("            throw WireError.trailingBytes\n");
321        out.push_str("        }\n");
322        out.push_str("        return result\n");
323        out.push_str("    }\n");
324    }
325
326    out.push_str("}\n");
327    out
328}
329
330/// Generate a Swift enum with varint-discriminanted encode/decode methods.
331fn generate_enum(name: &str, variants: &[facet_core::Variant], types: &[WireType]) -> String {
332    let mut out = String::new();
333
334    // Enum definition
335    out.push_str(&format!("public enum {name}: Sendable, Equatable {{\n"));
336    for v in variants {
337        let variant_name = v.name.to_lower_camel_case();
338        match classify_variant(v) {
339            VariantKind::Unit => {
340                out.push_str(&format!("    case {variant_name}\n"));
341            }
342            VariantKind::Newtype { inner } => {
343                let inner_type = swift_wire_type(inner, v.data.fields.first(), types);
344                out.push_str(&format!("    case {variant_name}({inner_type})\n"));
345            }
346            VariantKind::Tuple { fields } => {
347                let field_types: Vec<String> = fields
348                    .iter()
349                    .map(|f| swift_wire_type(f.shape(), Some(f), types))
350                    .collect();
351                out.push_str(&format!(
352                    "    case {variant_name}({})\n",
353                    field_types.join(", ")
354                ));
355            }
356            VariantKind::Struct { fields } => {
357                let field_decls: Vec<String> = fields
358                    .iter()
359                    .map(|f| {
360                        format!(
361                            "{}: {}",
362                            f.name.to_lower_camel_case(),
363                            swift_wire_type(f.shape(), Some(f), types)
364                        )
365                    })
366                    .collect();
367                out.push_str(&format!(
368                    "    case {variant_name}({})\n",
369                    field_decls.join(", ")
370                ));
371            }
372        }
373    }
374
375    // Encode
376    out.push_str("\n    func encode() -> [UInt8] {\n");
377    out.push_str("        switch self {\n");
378    for (i, v) in variants.iter().enumerate() {
379        let variant_name = v.name.to_lower_camel_case();
380        match classify_variant(v) {
381            VariantKind::Unit => {
382                out.push_str(&format!(
383                    "        case .{variant_name}:\n            return encodeVarint(UInt64({i}))\n"
384                ));
385            }
386            VariantKind::Newtype { inner } => {
387                let encode = encode_shape_expr(inner, "val", v.data.fields.first(), types);
388                out.push_str(&format!(
389                    "        case .{variant_name}(let val):\n            return encodeVarint(UInt64({i})) + {encode}\n"
390                ));
391            }
392            VariantKind::Tuple { fields } => {
393                let bindings: Vec<String> = (0..fields.len()).map(|j| format!("f{j}")).collect();
394                let binding_str = bindings
395                    .iter()
396                    .map(|b| format!("let {b}"))
397                    .collect::<Vec<_>>()
398                    .join(", ");
399                let encodes: Vec<String> = fields
400                    .iter()
401                    .enumerate()
402                    .map(|(j, f)| encode_shape_expr(f.shape(), &format!("f{j}"), Some(f), types))
403                    .collect();
404                out.push_str(&format!(
405                    "        case .{variant_name}({binding_str}):\n            return encodeVarint(UInt64({i})) + {}\n",
406                    encodes.join(" + ")
407                ));
408            }
409            VariantKind::Struct { fields } => {
410                let bindings: Vec<String> = fields
411                    .iter()
412                    .map(|f| f.name.to_lower_camel_case())
413                    .collect();
414                let binding_str = bindings
415                    .iter()
416                    .map(|b| format!("let {b}"))
417                    .collect::<Vec<_>>()
418                    .join(", ");
419                let encodes: Vec<String> = fields
420                    .iter()
421                    .map(|f| {
422                        encode_shape_expr(f.shape(), &f.name.to_lower_camel_case(), Some(f), types)
423                    })
424                    .collect();
425                out.push_str(&format!(
426                    "        case .{variant_name}({binding_str}):\n            return encodeVarint(UInt64({i})) + {}\n",
427                    encodes.join(" + ")
428                ));
429            }
430        }
431    }
432    out.push_str("        }\n");
433    out.push_str("    }\n");
434
435    // Decode
436    out.push_str("\n    static func decode(from data: Data, offset: inout Int) throws -> Self {\n");
437    out.push_str("        let disc = try decodeVarint(from: data, offset: &offset)\n");
438    out.push_str("        switch disc {\n");
439    for (i, v) in variants.iter().enumerate() {
440        let variant_name = v.name.to_lower_camel_case();
441        out.push_str(&format!("        case {i}:\n"));
442        match classify_variant(v) {
443            VariantKind::Unit => {
444                out.push_str(&format!("            return .{variant_name}\n"));
445            }
446            VariantKind::Newtype { inner } => {
447                let decode = decode_shape_expr(inner, v.data.fields.first(), types);
448                out.push_str(&format!("            return .{variant_name}({decode})\n"));
449            }
450            VariantKind::Tuple { fields } => {
451                for (j, f) in fields.iter().enumerate() {
452                    let decode = decode_shape_expr(f.shape(), Some(f), types);
453                    out.push_str(&format!("            let f{j} = {decode}\n"));
454                }
455                let args: Vec<String> = (0..fields.len()).map(|j| format!("f{j}")).collect();
456                out.push_str(&format!(
457                    "            return .{variant_name}({})\n",
458                    args.join(", ")
459                ));
460            }
461            VariantKind::Struct { fields } => {
462                for f in fields {
463                    let field_name = f.name.to_lower_camel_case();
464                    let decode = decode_shape_expr(f.shape(), Some(f), types);
465                    out.push_str(&format!("            let {field_name} = {decode}\n"));
466                }
467                let args: Vec<String> = fields
468                    .iter()
469                    .map(|f| {
470                        let n = f.name.to_lower_camel_case();
471                        format!("{n}: {n}")
472                    })
473                    .collect();
474                out.push_str(&format!(
475                    "            return .{variant_name}({})\n",
476                    args.join(", ")
477                ));
478            }
479        }
480    }
481    out.push_str("        default:\n");
482    out.push_str("            throw WireError.unknownVariant(disc)\n");
483    out.push_str("        }\n");
484    out.push_str("    }\n");
485
486    out.push_str("}\n");
487    out
488}
489
490/// Generate an encode expression for a struct field.
491fn encode_field_expr(field: &Field, types: &[WireType]) -> String {
492    let field_name = field.name.to_lower_camel_case();
493    encode_shape_expr(field.shape(), &field_name, Some(field), types)
494}
495
496/// Generate an encode expression for a shape with a given value expression.
497fn encode_shape_expr(
498    shape: &'static Shape,
499    value: &str,
500    field: Option<&Field>,
501    types: &[WireType],
502) -> String {
503    let is_trailing = field.is_some_and(|f| f.has_builtin_attr("trailing"));
504
505    // Opaque type → OpaquePayload
506    if matches!(classify_shape(shape), ShapeKind::Opaque) {
507        return if is_trailing {
508            format!("{value}.encodeTrailing()")
509        } else {
510            format!("{value}.encode()")
511        };
512    }
513
514    if is_bytes(shape) {
515        return format!("encodeBytes({value})");
516    }
517
518    match classify_shape(shape) {
519        ShapeKind::Scalar(scalar) => encode_scalar(scalar, value),
520        ShapeKind::List { element } | ShapeKind::Slice { element } => {
521            let inner = encode_element_closure(element, types);
522            format!("encodeVec({value}, encoder: {inner})")
523        }
524        ShapeKind::Option { inner } => {
525            let inner_closure = encode_element_closure(inner, types);
526            format!("encodeOption({value}, encoder: {inner_closure})")
527        }
528        ShapeKind::Struct(StructInfo { .. }) | ShapeKind::Enum(EnumInfo { .. }) => {
529            format!("{value}.encode()")
530        }
531        ShapeKind::Pointer { pointee } => encode_shape_expr(pointee, value, field, types),
532        ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
533            encode_shape_expr(fields[0].shape(), value, field, types)
534        }
535        _ => format!("[] /* unsupported encode for {value} */"),
536    }
537}
538
539fn encode_scalar(scalar: ScalarType, value: &str) -> String {
540    match scalar {
541        ScalarType::Bool => format!("encodeBool({value})"),
542        ScalarType::U8 => format!("encodeU8({value})"),
543        ScalarType::I8 => format!("encodeI8({value})"),
544        ScalarType::U16 => format!("encodeU16({value})"),
545        ScalarType::I16 => format!("encodeI16({value})"),
546        ScalarType::U32 => format!("encodeVarint(UInt64({value}))"),
547        ScalarType::I32 => format!("encodeI32({value})"),
548        ScalarType::U64 | ScalarType::USize => format!("encodeVarint({value})"),
549        ScalarType::I64 | ScalarType::ISize => format!("encodeI64({value})"),
550        ScalarType::F32 => format!("encodeF32({value})"),
551        ScalarType::F64 => format!("encodeF64({value})"),
552        ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
553            format!("encodeString({value})")
554        }
555        _ => "[] /* unsupported scalar */".to_string(),
556    }
557}
558
559/// Generate an encode closure for use with encodeVec.
560fn encode_element_closure(shape: &'static Shape, _types: &[WireType]) -> String {
561    if is_bytes(shape) {
562        return "{ encodeBytes($0) }".into();
563    }
564
565    match classify_shape(shape) {
566        ShapeKind::Scalar(scalar) => {
567            let expr = encode_scalar(scalar, "$0");
568            format!("{{ {expr} }}")
569        }
570        ShapeKind::Struct(StructInfo { .. }) | ShapeKind::Enum(EnumInfo { .. }) => {
571            "{ $0.encode() }".into()
572        }
573        ShapeKind::List { element } | ShapeKind::Slice { element } => {
574            let inner = encode_element_closure(element, _types);
575            format!("{{ encodeVec($0, encoder: {inner}) }}")
576        }
577        ShapeKind::Opaque => "{ $0.encode() }".into(),
578        ShapeKind::Pointer { pointee } => encode_element_closure(pointee, _types),
579        _ => "{ _ in [] }".into(),
580    }
581}
582
583/// Generate a decode expression for a struct field.
584fn decode_field_expr(field: &Field, types: &[WireType]) -> String {
585    decode_shape_expr(field.shape(), Some(field), types)
586}
587
588/// Generate a decode expression for a shape.
589fn decode_shape_expr(shape: &'static Shape, field: Option<&Field>, types: &[WireType]) -> String {
590    let is_trailing = field.is_some_and(|f| f.has_builtin_attr("trailing"));
591
592    // Opaque type → OpaquePayload
593    if matches!(classify_shape(shape), ShapeKind::Opaque) {
594        return if is_trailing {
595            "OpaquePayload.decodeTrailing(from: data, offset: &offset)".into()
596        } else {
597            "try OpaquePayload.decode(from: data, offset: &offset)".into()
598        };
599    }
600
601    if is_bytes(shape) {
602        return "Array(try decodeWireBytes(from: data, offset: &offset))".into();
603    }
604
605    match classify_shape(shape) {
606        ShapeKind::Scalar(scalar) => decode_scalar(scalar),
607        ShapeKind::List { element } | ShapeKind::Slice { element } => {
608            let inner = decode_element_closure(element, types);
609            format!("try decodeVec(from: data, offset: &offset, decoder: {inner})")
610        }
611        ShapeKind::Option { inner } => {
612            let inner_closure = decode_element_closure(inner, types);
613            format!("try decodeOption(from: data, offset: &offset, decoder: {inner_closure})")
614        }
615        ShapeKind::Struct(StructInfo {
616            name: Some(name), ..
617        }) => {
618            let swift_name = lookup_wire_name(name, types);
619            format!("try {swift_name}.decode(from: data, offset: &offset)")
620        }
621        ShapeKind::Enum(EnumInfo {
622            name: Some(name), ..
623        }) => {
624            let swift_name = lookup_wire_name(name, types);
625            format!("try {swift_name}.decode(from: data, offset: &offset)")
626        }
627        ShapeKind::Pointer { pointee } => decode_shape_expr(pointee, field, types),
628        ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
629            decode_shape_expr(fields[0].shape(), field, types)
630        }
631        _ => "nil /* unsupported decode */".into(),
632    }
633}
634
635fn decode_scalar(scalar: ScalarType) -> String {
636    match scalar {
637        ScalarType::Bool => "try decodeBool(from: data, offset: &offset)".into(),
638        ScalarType::U8 => "try decodeU8(from: data, offset: &offset)".into(),
639        ScalarType::I8 => "try decodeI8(from: data, offset: &offset)".into(),
640        ScalarType::U16 => "try decodeU16(from: data, offset: &offset)".into(),
641        ScalarType::I16 => "try decodeI16(from: data, offset: &offset)".into(),
642        ScalarType::U32 => "try decodeWireVarintU32(from: data, offset: &offset)".into(),
643        ScalarType::I32 => "try decodeI32(from: data, offset: &offset)".into(),
644        ScalarType::U64 | ScalarType::USize => {
645            "try decodeVarint(from: data, offset: &offset)".into()
646        }
647        ScalarType::I64 | ScalarType::ISize => "try decodeI64(from: data, offset: &offset)".into(),
648        ScalarType::F32 => "try decodeF32(from: data, offset: &offset)".into(),
649        ScalarType::F64 => "try decodeF64(from: data, offset: &offset)".into(),
650        ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
651            "try decodeWireString(from: data, offset: &offset)".into()
652        }
653        _ => "nil /* unsupported scalar decode */".into(),
654    }
655}
656
657/// Generate a decode closure for use with decodeVec.
658fn decode_element_closure(shape: &'static Shape, types: &[WireType]) -> String {
659    if is_bytes(shape) {
660        return "{ data, off in Array(try decodeWireBytes(from: data, offset: &off)) }".into();
661    }
662
663    match classify_shape(shape) {
664        ShapeKind::Scalar(scalar) => {
665            let expr = decode_scalar_with_params(scalar, "data", "off");
666            format!("{{ data, off in {expr} }}")
667        }
668        ShapeKind::Struct(StructInfo {
669            name: Some(name), ..
670        }) => {
671            let swift_name = lookup_wire_name(name, types);
672            format!("{{ data, off in try {swift_name}.decode(from: data, offset: &off) }}")
673        }
674        ShapeKind::Enum(EnumInfo {
675            name: Some(name), ..
676        }) => {
677            let swift_name = lookup_wire_name(name, types);
678            format!("{{ data, off in try {swift_name}.decode(from: data, offset: &off) }}")
679        }
680        ShapeKind::List { element } | ShapeKind::Slice { element } => {
681            let inner = decode_element_closure(element, types);
682            format!("{{ data, off in try decodeVec(from: data, offset: &off, decoder: {inner}) }}")
683        }
684        ShapeKind::Opaque => {
685            "{ data, off in try OpaquePayload.decode(from: data, offset: &off) }".into()
686        }
687        ShapeKind::Pointer { pointee } => decode_element_closure(pointee, types),
688        _ => "{ _, _ in throw WireError.truncated }".into(),
689    }
690}
691
692fn decode_scalar_with_params(scalar: ScalarType, data: &str, offset: &str) -> String {
693    match scalar {
694        ScalarType::Bool => format!("try decodeBool(from: {data}, offset: &{offset})"),
695        ScalarType::U8 => format!("try decodeU8(from: {data}, offset: &{offset})"),
696        ScalarType::I8 => format!("try decodeI8(from: {data}, offset: &{offset})"),
697        ScalarType::U16 => format!("try decodeU16(from: {data}, offset: &{offset})"),
698        ScalarType::I16 => format!("try decodeI16(from: {data}, offset: &{offset})"),
699        ScalarType::U32 => format!("try decodeWireVarintU32(from: {data}, offset: &{offset})"),
700        ScalarType::I32 => format!("try decodeI32(from: {data}, offset: &{offset})"),
701        ScalarType::U64 | ScalarType::USize => {
702            format!("try decodeVarint(from: {data}, offset: &{offset})")
703        }
704        ScalarType::I64 | ScalarType::ISize => {
705            format!("try decodeI64(from: {data}, offset: &{offset})")
706        }
707        ScalarType::F32 => format!("try decodeF32(from: {data}, offset: &{offset})"),
708        ScalarType::F64 => format!("try decodeF64(from: {data}, offset: &{offset})"),
709        ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
710            format!("try decodeWireString(from: {data}, offset: &{offset})")
711        }
712        _ => "nil /* unsupported scalar */".to_string(),
713    }
714}
715
716/// Generate convenience factory methods on Message.
717fn generate_factory_methods(types: &[WireType]) -> String {
718    // Find the MessagePayload shape to inspect variants
719    let payload_wt = types.iter().find(|wt| wt.swift_name == "MessagePayload");
720    let payload_wt = match payload_wt {
721        Some(wt) => wt,
722        None => return String::new(),
723    };
724
725    let variants = match classify_shape(payload_wt.shape) {
726        ShapeKind::Enum(EnumInfo { variants, .. }) => variants,
727        _ => return String::new(),
728    };
729
730    let mut out = String::new();
731    out.push_str("public extension Message {\n");
732
733    for v in variants {
734        let variant_name = v.name.to_lower_camel_case();
735        if let VariantKind::Newtype { inner } = classify_variant(v) {
736            let inner_swift = swift_wire_type(inner, None, types);
737            // Control messages use connId=0.
738            let is_control = matches!(
739                v.name,
740                "Hello" | "HelloYourself" | "ProtocolError" | "Ping" | "Pong"
741            );
742
743            if is_control {
744                out.push_str(&format!(
745                    "    static func {variant_name}(_ value: {inner_swift}) -> Message {{\n"
746                ));
747                out.push_str(&format!(
748                    "        Message(connectionId: 0, payload: .{variant_name}(value))\n"
749                ));
750                out.push_str("    }\n\n");
751            } else {
752                out.push_str(&format!(
753                    "    static func {variant_name}(connId: UInt64, _ value: {inner_swift}) -> Message {{\n"
754                ));
755                out.push_str(&format!(
756                    "        Message(connectionId: connId, payload: .{variant_name}(value))\n"
757                ));
758                out.push_str("    }\n\n");
759            }
760        }
761    }
762
763    // Additional ergonomic factory methods that flatten nested structs.
764    // These match the existing hand-coded API.
765    out.push_str("    static func protocolError(description: String) -> Message {\n");
766    out.push_str("        Message(connectionId: 0, payload: .protocolError(.init(description: description)))\n");
767    out.push_str("    }\n\n");
768
769    out.push_str("    static func connectionOpen(connId: UInt64, settings: ConnectionSettings, metadata: [MetadataEntry]) -> Message {\n");
770    out.push_str("        Message(connectionId: connId, payload: .connectionOpen(.init(connectionSettings: settings, metadata: metadata)))\n");
771    out.push_str("    }\n\n");
772
773    out.push_str("    static func connectionAccept(connId: UInt64, settings: ConnectionSettings, metadata: [MetadataEntry]) -> Message {\n");
774    out.push_str("        Message(connectionId: connId, payload: .connectionAccept(.init(connectionSettings: settings, metadata: metadata)))\n");
775    out.push_str("    }\n\n");
776
777    out.push_str("    static func connectionReject(connId: UInt64, metadata: [MetadataEntry]) -> Message {\n");
778    out.push_str("        Message(connectionId: connId, payload: .connectionReject(.init(metadata: metadata)))\n");
779    out.push_str("    }\n\n");
780
781    out.push_str(
782        "    static func connectionClose(connId: UInt64, metadata: [MetadataEntry]) -> Message {\n",
783    );
784    out.push_str("        Message(connectionId: connId, payload: .connectionClose(.init(metadata: metadata)))\n");
785    out.push_str("    }\n\n");
786
787    out.push_str("    static func request(\n");
788    out.push_str("        connId: UInt64,\n");
789    out.push_str("        requestId: UInt64,\n");
790    out.push_str("        methodId: UInt64,\n");
791    out.push_str("        metadata: [MetadataEntry],\n");
792    out.push_str("        schemas: [UInt8] = [],\n");
793    out.push_str("        payload: [UInt8]\n");
794    out.push_str("    ) -> Message {\n");
795    out.push_str("        Message(\n");
796    out.push_str("            connectionId: connId,\n");
797    out.push_str("            payload: .requestMessage(\n");
798    out.push_str("                .init(\n");
799    out.push_str("                    id: requestId,\n");
800    out.push_str("                    body: .call(.init(methodId: methodId, metadata: metadata, args: .init(payload), schemas: schemas))\n");
801    out.push_str("                ))\n");
802    out.push_str("        )\n");
803    out.push_str("    }\n\n");
804
805    out.push_str("    static func response(\n");
806    out.push_str("        connId: UInt64,\n");
807    out.push_str("        requestId: UInt64,\n");
808    out.push_str("        metadata: [MetadataEntry],\n");
809    out.push_str("        schemas: [UInt8] = [],\n");
810    out.push_str("        payload: [UInt8]\n");
811    out.push_str("    ) -> Message {\n");
812    out.push_str("        Message(\n");
813    out.push_str("            connectionId: connId,\n");
814    out.push_str("            payload: .requestMessage(\n");
815    out.push_str("                .init(\n");
816    out.push_str("                    id: requestId,\n");
817    out.push_str("                    body: .response(.init(metadata: metadata, ret: .init(payload), schemas: schemas))\n");
818    out.push_str("                ))\n");
819    out.push_str("        )\n");
820    out.push_str("    }\n\n");
821
822    out.push_str("    static func cancel(connId: UInt64, requestId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
823    out.push_str("        Message(\n");
824    out.push_str("            connectionId: connId,\n");
825    out.push_str("            payload: .requestMessage(\n");
826    out.push_str("                .init(\n");
827    out.push_str("                    id: requestId,\n");
828    out.push_str("                    body: .cancel(.init(metadata: metadata))\n");
829    out.push_str("                ))\n");
830    out.push_str("        )\n");
831    out.push_str("    }\n\n");
832
833    out.push_str(
834        "    static func data(connId: UInt64, channelId: UInt64, payload: [UInt8]) -> Message {\n",
835    );
836    out.push_str("        Message(\n");
837    out.push_str("            connectionId: connId,\n");
838    out.push_str("            payload: .channelMessage(.init(id: channelId, body: .item(.init(item: .init(payload)))))\n");
839    out.push_str("        )\n");
840    out.push_str("    }\n\n");
841
842    out.push_str("    static func close(connId: UInt64, channelId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
843    out.push_str("        Message(\n");
844    out.push_str("            connectionId: connId,\n");
845    out.push_str("            payload: .channelMessage(.init(id: channelId, body: .close(.init(metadata: metadata))))\n");
846    out.push_str("        )\n");
847    out.push_str("    }\n\n");
848
849    out.push_str("    static func reset(connId: UInt64, channelId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
850    out.push_str("        Message(\n");
851    out.push_str("            connectionId: connId,\n");
852    out.push_str("            payload: .channelMessage(.init(id: channelId, body: .reset(.init(metadata: metadata))))\n");
853    out.push_str("        )\n");
854    out.push_str("    }\n\n");
855
856    out.push_str(
857        "    static func credit(connId: UInt64, channelId: UInt64, bytes: UInt32) -> Message {\n",
858    );
859    out.push_str("        Message(\n");
860    out.push_str("            connectionId: connId,\n");
861    out.push_str("            payload: .channelMessage(.init(id: channelId, body: .grantCredit(.init(additional: bytes))))\n");
862    out.push_str("        )\n");
863    out.push_str("    }\n");
864
865    out.push_str("}\n");
866    out
867}