Skip to main content

vox_codegen/targets/swift/
wire.rs

1//! Swift wire type code generation.
2//!
3//! Generates a complete Swift source file containing all wire protocol types
4//! (Message, MessagePayload, etc.) with encode/decode methods. The generated
5//! code is driven by facet `Shape`s from `vox_types::message`.
6//!
7//! The only special-cased type is `Payload` (`ShapeKind::Opaque`), which maps to
8//! `OpaquePayload` with both length-prefixed and trailing byte handling.
9//! Everything else is normal struct/enum codegen.
10
11use facet_core::{Field, ScalarType, Shape};
12use heck::ToLowerCamelCase;
13use vox_types::{
14    EnumInfo, ShapeKind, StructInfo, VariantKind, classify_shape, classify_variant,
15    extract_schemas, is_bytes,
16};
17
18/// A wire type to generate Swift code for.
19pub struct WireType {
20    /// The Swift name for this type (e.g. "Message", "HelloV7")
21    pub swift_name: String,
22    /// The facet Shape for this type
23    pub shape: &'static Shape,
24}
25
26/// Generate a complete Swift wire types file, returning the Swift source and the raw
27/// CBOR schema bytes separately. The caller is responsible for writing the `.bin` file
28/// and declaring it as a `Bundle.module` resource; the generated Swift loads it at runtime.
29pub fn generate_wire_types(types: &[WireType]) -> (String, Vec<u8>) {
30    let mut out = String::new();
31    out.push_str("// @generated by vox-codegen\n");
32    out.push_str("// DO NOT EDIT — regenerate with `cargo xtask codegen --swift-wire`\n\n");
33    out.push_str("import Foundation\n");
34    out.push_str("@preconcurrency import NIOCore\n\n");
35
36    // Preamble: error type + helpers + OpaquePayload
37    out.push_str(&generate_preamble());
38
39    // Metadata typealias (Metadata is Vec<MetadataEntry> in Rust, we alias for convenience)
40    out.push_str("public typealias Metadata = [MetadataEntry]\n\n");
41
42    // Generate each type
43    for wt in types {
44        out.push_str(&generate_one_type(&wt.swift_name, wt.shape, types));
45        out.push('\n');
46    }
47
48    // Generate factory methods extension on Message
49    out.push_str(&generate_factory_methods(types));
50
51    // Emit a Bundle.module loader instead of inlining the bytes as a literal.
52    out.push_str(
53        "\n/// CBOR-encoded wire message schemas, loaded from the bundled binary resource.\n",
54    );
55    out.push_str("public let wireMessageSchemasCbor: [UInt8] = {\n");
56    out.push_str(
57        "    guard let url = Bundle.module.url(forResource: \"wireMessageSchemas\", withExtension: \"bin\"),\n",
58    );
59    out.push_str("          let data = try? Data(contentsOf: url) else {\n");
60    out.push_str(
61        "        preconditionFailure(\"wireMessageSchemas.bin resource not found in Bundle.module\")\n",
62    );
63    out.push_str("    }\n");
64    out.push_str("    return Array(data)\n");
65    out.push_str("}()\n");
66
67    // Extract the CBOR bytes to return separately.
68    let cbor_bytes = types
69        .iter()
70        .find(|wt| wt.swift_name == "Message")
71        .or_else(|| types.last())
72        .map(|root| {
73            let extracted =
74                extract_schemas(root.shape).expect("wire schema extraction should succeed");
75            facet_cbor::to_vec(&extracted.schemas)
76                .expect("wire schema CBOR serialization should succeed")
77        })
78        .unwrap_or_default();
79
80    (out, cbor_bytes)
81}
82
83/// Generate the static preamble: WireError, OpaquePayload, helpers.
84fn generate_preamble() -> String {
85    let mut out = String::new();
86
87    // Error type
88    out.push_str("public enum WireError: Error, Equatable {\n");
89    out.push_str("    case truncated\n");
90    out.push_str("    case unknownVariant(UInt64)\n");
91    out.push_str("    case overflow\n");
92    out.push_str("    case invalidUtf8\n");
93    out.push_str("    case trailingBytes\n");
94    out.push_str("}\n\n");
95
96    // OpaquePayload — backed by ByteBuffer
97    out.push_str("public struct OpaquePayload: Sendable, Equatable {\n");
98    out.push_str("    public var bytes: ByteBuffer\n\n");
99    out.push_str("    public init(_ bytes: ByteBuffer) {\n");
100    out.push_str("        self.bytes = bytes\n");
101    out.push_str("    }\n\n");
102    out.push_str("    /// Init from a [UInt8] for convenience (e.g. from legacy code)\n");
103    out.push_str("    public init(_ bytes: [UInt8]) {\n");
104    out.push_str("        var buf = ByteBufferAllocator().buffer(capacity: bytes.count)\n");
105    out.push_str("        buf.writeBytes(bytes)\n");
106    out.push_str("        self.bytes = buf\n");
107    out.push_str("    }\n\n");
108    out.push_str("    func encode(into buffer: inout ByteBuffer) {\n");
109    out.push_str("        let len = UInt32(bytes.readableBytes)\n");
110    out.push_str("        buffer.writeInteger(len, endianness: .little)\n");
111    out.push_str("        var copy = bytes\n");
112    out.push_str("        buffer.writeBuffer(&copy)\n");
113    out.push_str("    }\n\n");
114    out.push_str("    static func decode(from buffer: inout ByteBuffer) throws -> Self {\n");
115    out.push_str(
116        "        guard let len: UInt32 = buffer.readInteger(endianness: .little) else {\n",
117    );
118    out.push_str("            throw WireError.truncated\n");
119    out.push_str("        }\n");
120    out.push_str("        guard let slice = buffer.readSlice(length: Int(len)) else {\n");
121    out.push_str("            throw WireError.truncated\n");
122    out.push_str("        }\n");
123    out.push_str("        return .init(slice)\n");
124    out.push_str("    }\n\n");
125    out.push_str("    /// Encode without a length prefix — for trailing fields only.\n");
126    out.push_str("    func encodeTrailing(into buffer: inout ByteBuffer) {\n");
127    out.push_str("        var copy = bytes\n");
128    out.push_str("        buffer.writeBuffer(&copy)\n");
129    out.push_str("    }\n\n");
130    out.push_str("    /// Decode by consuming all remaining bytes — for trailing fields only.\n");
131    out.push_str("    static func decodeTrailing(from buffer: inout ByteBuffer) -> Self {\n");
132    out.push_str(
133        "        let slice = buffer.readSlice(length: buffer.readableBytes) ?? ByteBuffer()\n",
134    );
135    out.push_str("        return .init(slice)\n");
136    out.push_str("    }\n");
137    out.push_str("}\n\n");
138
139    // Helper functions
140    out.push_str("@inline(__always)\n");
141    out.push_str(
142        "private func decodeWireVarintU32(from buffer: inout ByteBuffer) throws -> UInt32 {\n",
143    );
144    out.push_str("    let value = try decodeVarint(from: &buffer)\n");
145    out.push_str("    guard value <= UInt64(UInt32.max) else {\n");
146    out.push_str("        throw WireError.overflow\n");
147    out.push_str("    }\n");
148    out.push_str("    return UInt32(value)\n");
149    out.push_str("}\n\n");
150
151    out.push_str("@inline(__always)\n");
152    out.push_str(
153        "private func decodeWireString(from buffer: inout ByteBuffer) throws -> String {\n",
154    );
155    out.push_str("    do {\n");
156    out.push_str("        return try decodeString(from: &buffer)\n");
157    out.push_str("    } catch PostcardError.invalidUtf8 {\n");
158    out.push_str("        throw WireError.invalidUtf8\n");
159    out.push_str("    } catch PostcardError.truncated {\n");
160    out.push_str("        throw WireError.truncated\n");
161    out.push_str("    } catch {\n");
162    out.push_str("        throw error\n");
163    out.push_str("    }\n");
164    out.push_str("}\n\n");
165
166    out.push_str("@inline(__always)\n");
167    out.push_str(
168        "private func decodeWireBytes(from buffer: inout ByteBuffer) throws -> ByteBuffer {\n",
169    );
170    out.push_str("    do {\n");
171    out.push_str("        return try decodeBytes(from: &buffer)\n");
172    out.push_str("    } catch PostcardError.truncated {\n");
173    out.push_str("        throw WireError.truncated\n");
174    out.push_str("    } catch {\n");
175    out.push_str("        throw error\n");
176    out.push_str("    }\n");
177    out.push_str("}\n\n");
178
179    out
180}
181
182/// Get the wire Swift type for a field's shape, taking the `WireType` list into account.
183fn swift_wire_type(shape: &'static Shape, _field: Option<&Field>, types: &[WireType]) -> String {
184    if is_bytes(shape) {
185        return "[UInt8]".into();
186    }
187
188    match classify_shape(shape) {
189        ShapeKind::Scalar(scalar) => swift_scalar_type(scalar),
190        ShapeKind::List { element } | ShapeKind::Slice { element } => {
191            format!("[{}]", swift_wire_type(element, None, types))
192        }
193        ShapeKind::Option { inner } => {
194            format!("{}?", swift_wire_type(inner, None, types))
195        }
196        ShapeKind::Array { element, .. } => {
197            format!("[{}]", swift_wire_type(element, None, types))
198        }
199        ShapeKind::Struct(StructInfo {
200            name: Some(name), ..
201        }) => lookup_wire_name(name, types),
202        ShapeKind::Enum(EnumInfo {
203            name: Some(name), ..
204        }) => lookup_wire_name(name, types),
205        ShapeKind::Pointer { pointee } => swift_wire_type(pointee, _field, types),
206        ShapeKind::Opaque => "OpaquePayload".into(),
207        ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
208            swift_wire_type(fields[0].shape(), None, types)
209        }
210        _ => "Any /* unsupported */".into(),
211    }
212}
213
214fn swift_scalar_type(scalar: ScalarType) -> String {
215    match scalar {
216        ScalarType::Bool => "Bool".into(),
217        ScalarType::U8 => "UInt8".into(),
218        ScalarType::U16 => "UInt16".into(),
219        ScalarType::U32 => "UInt32".into(),
220        ScalarType::U64 | ScalarType::USize => "UInt64".into(),
221        ScalarType::I8 => "Int8".into(),
222        ScalarType::I16 => "Int16".into(),
223        ScalarType::I32 => "Int32".into(),
224        ScalarType::I64 | ScalarType::ISize => "Int64".into(),
225        ScalarType::F32 => "Float".into(),
226        ScalarType::F64 => "Double".into(),
227        ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
228            "String".into()
229        }
230        ScalarType::Unit => "Void".into(),
231        _ => "Any".into(),
232    }
233}
234
235/// Look up the wire name for a Rust type name. If a matching WireType exists, use its
236/// swift_name; otherwise fall back to the Rust name.
237fn lookup_wire_name(rust_name: &str, types: &[WireType]) -> String {
238    for wt in types {
239        // Match by checking the reflected type identifier for the underlying Rust shape.
240        if wt.shape.type_identifier.ends_with(rust_name) {
241            return wt.swift_name.clone();
242        }
243    }
244    rust_name.to_string()
245}
246
247/// Generate a single wire type (struct or enum).
248fn generate_one_type(swift_name: &str, shape: &'static Shape, types: &[WireType]) -> String {
249    match classify_shape(shape) {
250        ShapeKind::Struct(StructInfo { fields, .. }) => {
251            generate_struct(swift_name, fields, types, swift_name == "Message")
252        }
253        ShapeKind::Enum(EnumInfo { variants, .. }) => generate_enum(swift_name, variants, types),
254        _ => format!("// Unsupported shape for {swift_name}\n"),
255    }
256}
257
258/// Generate a Swift struct with encode/decode methods.
259fn generate_struct(name: &str, fields: &[Field], types: &[WireType], is_top_level: bool) -> String {
260    let mut out = String::new();
261
262    // Struct definition
263    out.push_str(&format!("public struct {name}: Sendable, Equatable {{\n"));
264    for f in fields {
265        let field_name = f.name.to_lower_camel_case();
266        let field_type = swift_wire_type(f.shape(), Some(f), types);
267        out.push_str(&format!("    public var {field_name}: {field_type}\n"));
268    }
269
270    // Initializer
271    if !fields.is_empty() {
272        out.push_str("\n    public init(");
273        for (i, f) in fields.iter().enumerate() {
274            if i > 0 {
275                out.push_str(", ");
276            }
277            let field_name = f.name.to_lower_camel_case();
278            let field_type = swift_wire_type(f.shape(), Some(f), types);
279            out.push_str(&format!("{field_name}: {field_type}"));
280        }
281        out.push_str(") {\n");
282        for f in fields {
283            let field_name = f.name.to_lower_camel_case();
284            out.push_str(&format!("        self.{field_name} = {field_name}\n"));
285        }
286        out.push_str("    }\n");
287    }
288
289    // Encode method — writes into a ByteBuffer, returns void
290    let vis = if is_top_level { "public " } else { "" };
291    out.push_str(&format!(
292        "\n    {vis}func encode(into buffer: inout ByteBuffer) {{\n"
293    ));
294    if fields.is_empty() {
295        // Nothing to write
296    } else {
297        for f in fields {
298            let stmt = encode_field_stmt(f, types);
299            out.push_str(&format!("        {stmt}\n"));
300        }
301    }
302    out.push_str("    }\n");
303
304    // Top-level Message gets a [UInt8] bridge shim for encode
305    if is_top_level {
306        out.push_str(
307            "\n    /// Encode to a `[UInt8]` array (bridge for callers that need bytes).\n",
308        );
309        out.push_str("    public func encode() -> [UInt8] {\n");
310        out.push_str("        var buffer = ByteBufferAllocator().buffer(capacity: 64)\n");
311        out.push_str("        encode(into: &buffer)\n");
312        out.push_str("        return buffer.readBytes(length: buffer.readableBytes) ?? []\n");
313        out.push_str("    }\n");
314    }
315
316    // Decode method — reads from inout ByteBuffer
317    out.push_str(&format!(
318        "\n    {vis}static func decode(from buffer: inout ByteBuffer) throws -> Self {{\n"
319    ));
320    for f in fields {
321        for stmt in decode_field_stmts(f, types) {
322            out.push_str(&format!("        {stmt}\n"));
323        }
324    }
325    let field_names: Vec<String> = fields
326        .iter()
327        .map(|f| {
328            let n = f.name.to_lower_camel_case();
329            format!("{n}: {n}")
330        })
331        .collect();
332    out.push_str(&format!(
333        "        return .init({})\n",
334        field_names.join(", ")
335    ));
336    out.push_str("    }\n");
337
338    // Top-level Message gets a [UInt8] bridge shim for decode
339    if is_top_level {
340        out.push_str(
341            "\n    /// Decode from a `[UInt8]` array (bridge for callers that have raw bytes).\n",
342        );
343        out.push_str("    public static func decode(fromBytes data: [UInt8]) throws -> Self {\n");
344        out.push_str("        var buffer = ByteBufferAllocator().buffer(capacity: data.count)\n");
345        out.push_str("        buffer.writeBytes(data)\n");
346        out.push_str("        let result = try decode(from: &buffer)\n");
347        out.push_str("        guard buffer.readableBytes == 0 else {\n");
348        out.push_str("            throw WireError.trailingBytes\n");
349        out.push_str("        }\n");
350        out.push_str("        return result\n");
351        out.push_str("    }\n");
352    }
353
354    out.push_str("}\n");
355    out
356}
357
358/// Generate a Swift enum with varint-discriminanted encode/decode methods.
359fn generate_enum(name: &str, variants: &[facet_core::Variant], types: &[WireType]) -> String {
360    let mut out = String::new();
361
362    // Enum definition
363    out.push_str(&format!("public enum {name}: Sendable, Equatable {{\n"));
364    for v in variants {
365        let variant_name = v.name.to_lower_camel_case();
366        match classify_variant(v) {
367            VariantKind::Unit => {
368                out.push_str(&format!("    case {variant_name}\n"));
369            }
370            VariantKind::Newtype { inner } => {
371                let inner_type = swift_wire_type(inner, v.data.fields.first(), types);
372                out.push_str(&format!("    case {variant_name}({inner_type})\n"));
373            }
374            VariantKind::Tuple { fields } => {
375                let field_types: Vec<String> = fields
376                    .iter()
377                    .map(|f| swift_wire_type(f.shape(), Some(f), types))
378                    .collect();
379                out.push_str(&format!(
380                    "    case {variant_name}({})\n",
381                    field_types.join(", ")
382                ));
383            }
384            VariantKind::Struct { fields } => {
385                let field_decls: Vec<String> = fields
386                    .iter()
387                    .map(|f| {
388                        format!(
389                            "{}: {}",
390                            f.name.to_lower_camel_case(),
391                            swift_wire_type(f.shape(), Some(f), types)
392                        )
393                    })
394                    .collect();
395                out.push_str(&format!(
396                    "    case {variant_name}({})\n",
397                    field_decls.join(", ")
398                ));
399            }
400        }
401    }
402
403    // Encode — void, writes into buffer
404    out.push_str("\n    func encode(into buffer: inout ByteBuffer) {\n");
405    out.push_str("        switch self {\n");
406    for (i, v) in variants.iter().enumerate() {
407        let variant_name = v.name.to_lower_camel_case();
408        match classify_variant(v) {
409            VariantKind::Unit => {
410                out.push_str(&format!(
411                    "        case .{variant_name}:\n            encodeVarint(UInt64({i}), into: &buffer)\n"
412                ));
413            }
414            VariantKind::Newtype { inner } => {
415                let stmt = encode_shape_stmt(inner, "val", v.data.fields.first(), types);
416                out.push_str(&format!(
417                    "        case .{variant_name}(let val):\n            encodeVarint(UInt64({i}), into: &buffer)\n            {stmt}\n"
418                ));
419            }
420            VariantKind::Tuple { fields } => {
421                let bindings: Vec<String> = (0..fields.len()).map(|j| format!("f{j}")).collect();
422                let binding_str = bindings
423                    .iter()
424                    .map(|b| format!("let {b}"))
425                    .collect::<Vec<_>>()
426                    .join(", ");
427                let stmts: Vec<String> = fields
428                    .iter()
429                    .enumerate()
430                    .map(|(j, f)| encode_shape_stmt(f.shape(), &format!("f{j}"), Some(f), types))
431                    .collect();
432                out.push_str(&format!(
433                    "        case .{variant_name}({binding_str}):\n            encodeVarint(UInt64({i}), into: &buffer)\n"
434                ));
435                for stmt in &stmts {
436                    out.push_str(&format!("            {stmt}\n"));
437                }
438            }
439            VariantKind::Struct { fields } => {
440                let bindings: Vec<String> = fields
441                    .iter()
442                    .map(|f| f.name.to_lower_camel_case())
443                    .collect();
444                let binding_str = bindings
445                    .iter()
446                    .map(|b| format!("let {b}"))
447                    .collect::<Vec<_>>()
448                    .join(", ");
449                let stmts: Vec<String> = fields
450                    .iter()
451                    .map(|f| {
452                        encode_shape_stmt(f.shape(), &f.name.to_lower_camel_case(), Some(f), types)
453                    })
454                    .collect();
455                out.push_str(&format!(
456                    "        case .{variant_name}({binding_str}):\n            encodeVarint(UInt64({i}), into: &buffer)\n"
457                ));
458                for stmt in &stmts {
459                    out.push_str(&format!("            {stmt}\n"));
460                }
461            }
462        }
463    }
464    out.push_str("        }\n");
465    out.push_str("    }\n");
466
467    // Decode — reads from inout ByteBuffer
468    out.push_str("\n    static func decode(from buffer: inout ByteBuffer) throws -> Self {\n");
469    out.push_str("        let disc = try decodeVarint(from: &buffer)\n");
470    out.push_str("        switch disc {\n");
471    for (i, v) in variants.iter().enumerate() {
472        let variant_name = v.name.to_lower_camel_case();
473        out.push_str(&format!("        case {i}:\n"));
474        match classify_variant(v) {
475            VariantKind::Unit => {
476                out.push_str(&format!("            return .{variant_name}\n"));
477            }
478            VariantKind::Newtype { inner } => {
479                for stmt in decode_stmts_for(inner, v.data.fields.first(), "_newtype_val", types) {
480                    out.push_str(&format!("            {stmt}\n"));
481                }
482                out.push_str(&format!(
483                    "            return .{variant_name}(_newtype_val)\n"
484                ));
485            }
486            VariantKind::Tuple { fields } => {
487                for (j, f) in fields.iter().enumerate() {
488                    for stmt in decode_stmts_for(f.shape(), Some(f), &format!("f{j}"), types) {
489                        out.push_str(&format!("            {stmt}\n"));
490                    }
491                }
492                let args: Vec<String> = (0..fields.len()).map(|j| format!("f{j}")).collect();
493                out.push_str(&format!(
494                    "            return .{variant_name}({})\n",
495                    args.join(", ")
496                ));
497            }
498            VariantKind::Struct { fields } => {
499                for f in fields {
500                    let field_name = f.name.to_lower_camel_case();
501                    for stmt in decode_stmts_for(f.shape(), Some(f), &field_name, types) {
502                        out.push_str(&format!("            {stmt}\n"));
503                    }
504                }
505                let args: Vec<String> = fields
506                    .iter()
507                    .map(|f| {
508                        let n = f.name.to_lower_camel_case();
509                        format!("{n}: {n}")
510                    })
511                    .collect();
512                out.push_str(&format!(
513                    "            return .{variant_name}({})\n",
514                    args.join(", ")
515                ));
516            }
517        }
518    }
519    out.push_str("        default:\n");
520    out.push_str("            throw WireError.unknownVariant(disc)\n");
521    out.push_str("        }\n");
522    out.push_str("    }\n");
523
524    out.push_str("}\n");
525    out
526}
527
528/// Generate an encode statement for a struct field (writes into `buffer`).
529fn encode_field_stmt(field: &Field, types: &[WireType]) -> String {
530    let field_name = field.name.to_lower_camel_case();
531    encode_shape_stmt(field.shape(), &field_name, Some(field), types)
532}
533
534/// Generate an encode statement for a shape with a given value expression.
535/// The statement writes into the implicit `buffer: inout ByteBuffer` in scope.
536fn encode_shape_stmt(
537    shape: &'static Shape,
538    value: &str,
539    field: Option<&Field>,
540    types: &[WireType],
541) -> String {
542    let is_trailing = field.is_some_and(|f| f.has_builtin_attr("trailing"));
543
544    // Opaque type → OpaquePayload
545    if matches!(classify_shape(shape), ShapeKind::Opaque) {
546        return if is_trailing {
547            format!("{value}.encodeTrailing(into: &buffer)")
548        } else {
549            format!("{value}.encode(into: &buffer)")
550        };
551    }
552
553    if is_bytes(shape) {
554        return format!("encodeByteSeq({value}, into: &buffer)");
555    }
556
557    match classify_shape(shape) {
558        ShapeKind::Scalar(scalar) => encode_scalar_stmt(scalar, value),
559        ShapeKind::List { element } | ShapeKind::Slice { element } => {
560            let inner = encode_element_closure(element, types);
561            format!("encodeVec({value}, into: &buffer, encoder: {inner})")
562        }
563        ShapeKind::Option { inner } => {
564            let inner_closure = encode_element_closure(inner, types);
565            format!("encodeOption({value}, into: &buffer, encoder: {inner_closure})")
566        }
567        ShapeKind::Struct(StructInfo { .. }) | ShapeKind::Enum(EnumInfo { .. }) => {
568            format!("{value}.encode(into: &buffer)")
569        }
570        ShapeKind::Pointer { pointee } => encode_shape_stmt(pointee, value, field, types),
571        ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
572            encode_shape_stmt(fields[0].shape(), value, field, types)
573        }
574        _ => format!("/* unsupported encode for {value} */"),
575    }
576}
577
578fn encode_scalar_stmt(scalar: ScalarType, value: &str) -> String {
579    match scalar {
580        ScalarType::Bool => format!("encodeBool({value}, into: &buffer)"),
581        ScalarType::U8 => format!("encodeU8({value}, into: &buffer)"),
582        ScalarType::I8 => format!("encodeI8({value}, into: &buffer)"),
583        ScalarType::U16 => format!("encodeU16({value}, into: &buffer)"),
584        ScalarType::I16 => format!("encodeI16({value}, into: &buffer)"),
585        ScalarType::U32 => format!("encodeVarint(UInt64({value}), into: &buffer)"),
586        ScalarType::I32 => format!("encodeI32({value}, into: &buffer)"),
587        ScalarType::U64 | ScalarType::USize => format!("encodeVarint({value}, into: &buffer)"),
588        ScalarType::I64 | ScalarType::ISize => format!("encodeI64({value}, into: &buffer)"),
589        ScalarType::F32 => format!("encodeF32({value}, into: &buffer)"),
590        ScalarType::F64 => format!("encodeF64({value}, into: &buffer)"),
591        ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
592            format!("encodeString({value}, into: &buffer)")
593        }
594        _ => format!("/* unsupported scalar encode for {value} */"),
595    }
596}
597
598/// Generate an encode closure for use with encodeVec / encodeOption.
599/// Closures take `(T, inout ByteBuffer)` — parameters named `val` and `buf`.
600fn encode_element_closure(shape: &'static Shape, _types: &[WireType]) -> String {
601    if is_bytes(shape) {
602        return "{ val, buf in encodeByteSeq(val, into: &buf) }".into();
603    }
604
605    match classify_shape(shape) {
606        ShapeKind::Scalar(scalar) => {
607            let stmt = encode_scalar_stmt(scalar, "val").replace("into: &buffer", "into: &buf");
608            format!("{{ val, buf in {stmt} }}")
609        }
610        ShapeKind::Struct(StructInfo { .. }) | ShapeKind::Enum(EnumInfo { .. }) => {
611            "{ val, buf in val.encode(into: &buf) }".into()
612        }
613        ShapeKind::List { element } | ShapeKind::Slice { element } => {
614            let inner = encode_element_closure(element, _types);
615            format!("{{ val, buf in encodeVec(val, into: &buf, encoder: {inner}) }}")
616        }
617        ShapeKind::Opaque => "{ val, buf in val.encode(into: &buf) }".into(),
618        ShapeKind::Pointer { pointee } => encode_element_closure(pointee, _types),
619        _ => "{ _, _ in /* unsupported */ }".into(),
620    }
621}
622
623/// Generate decode statements for a struct field into a named variable.
624/// Returns one or two Swift statements (as lines to be written individually).
625fn decode_field_stmts(field: &Field, types: &[WireType]) -> Vec<String> {
626    let field_name = field.name.to_lower_camel_case();
627    decode_stmts_for(field.shape(), Some(field), &field_name, types)
628}
629
630/// Produce one or two Swift statements that decode `shape` into a local named `var_name`.
631/// When the shape (or its pointer-unwrapped inner) is bytes, two statements are needed.
632fn decode_stmts_for(
633    shape: &'static Shape,
634    field: Option<&Field>,
635    var_name: &str,
636    types: &[WireType],
637) -> Vec<String> {
638    // Unwrap pointers transparently
639    if let ShapeKind::Pointer { pointee } = classify_shape(shape) {
640        return decode_stmts_for(pointee, field, var_name, types);
641    }
642    if is_bytes(shape) {
643        return vec![
644            format!("var _{var_name}Buf = try decodeWireBytes(from: &buffer)"),
645            format!(
646                "let {var_name} = _{var_name}Buf.readBytes(length: _{var_name}Buf.readableBytes) ?? []"
647            ),
648        ];
649    }
650    vec![format!(
651        "let {var_name} = {}",
652        decode_shape_expr(shape, field, types)
653    )]
654}
655
656/// Generate a decode expression for a shape.
657/// All decode calls read from the implicit `buffer: inout ByteBuffer` in scope.
658fn decode_shape_expr(shape: &'static Shape, field: Option<&Field>, types: &[WireType]) -> String {
659    let is_trailing = field.is_some_and(|f| f.has_builtin_attr("trailing"));
660
661    // Opaque type → OpaquePayload
662    if matches!(classify_shape(shape), ShapeKind::Opaque) {
663        return if is_trailing {
664            "OpaquePayload.decodeTrailing(from: &buffer)".into()
665        } else {
666            "try OpaquePayload.decode(from: &buffer)".into()
667        };
668    }
669
670    match classify_shape(shape) {
671        ShapeKind::Scalar(scalar) => decode_scalar(scalar),
672        ShapeKind::List { element } | ShapeKind::Slice { element } => {
673            let inner = decode_element_closure(element, types);
674            format!("try decodeVec(from: &buffer, decoder: {inner})")
675        }
676        ShapeKind::Option { inner } => {
677            let inner_closure = decode_element_closure(inner, types);
678            format!("try decodeOption(from: &buffer, decoder: {inner_closure})")
679        }
680        ShapeKind::Struct(StructInfo {
681            name: Some(name), ..
682        }) => {
683            let swift_name = lookup_wire_name(name, types);
684            format!("try {swift_name}.decode(from: &buffer)")
685        }
686        ShapeKind::Enum(EnumInfo {
687            name: Some(name), ..
688        }) => {
689            let swift_name = lookup_wire_name(name, types);
690            format!("try {swift_name}.decode(from: &buffer)")
691        }
692        ShapeKind::Pointer { pointee } => decode_shape_expr(pointee, field, types),
693        ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
694            decode_shape_expr(fields[0].shape(), field, types)
695        }
696        _ => "nil /* unsupported decode */".into(),
697    }
698}
699
700fn decode_scalar(scalar: ScalarType) -> String {
701    match scalar {
702        ScalarType::Bool => "try decodeBool(from: &buffer)".into(),
703        ScalarType::U8 => "try decodeU8(from: &buffer)".into(),
704        ScalarType::I8 => "try decodeI8(from: &buffer)".into(),
705        ScalarType::U16 => "try decodeU16(from: &buffer)".into(),
706        ScalarType::I16 => "try decodeI16(from: &buffer)".into(),
707        ScalarType::U32 => "try decodeWireVarintU32(from: &buffer)".into(),
708        ScalarType::I32 => "try decodeI32(from: &buffer)".into(),
709        ScalarType::U64 | ScalarType::USize => "try decodeVarint(from: &buffer)".into(),
710        ScalarType::I64 | ScalarType::ISize => "try decodeI64(from: &buffer)".into(),
711        ScalarType::F32 => "try decodeF32(from: &buffer)".into(),
712        ScalarType::F64 => "try decodeF64(from: &buffer)".into(),
713        ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
714            "try decodeWireString(from: &buffer)".into()
715        }
716        _ => "nil /* unsupported scalar decode */".into(),
717    }
718}
719
720/// Generate a decode closure for use with decodeVec / decodeOption.
721/// Closures take `(inout ByteBuffer)` — parameter named `buf`.
722fn decode_element_closure(shape: &'static Shape, types: &[WireType]) -> String {
723    if is_bytes(shape) {
724        // The closure is itself throwing so `try` inside is valid here.
725        return "{ buf in var s = try decodeWireBytes(from: &buf); return s.readBytes(length: s.readableBytes) ?? [] }".into();
726    }
727
728    match classify_shape(shape) {
729        ShapeKind::Scalar(scalar) => {
730            let expr = decode_scalar(scalar).replace("from: &buffer", "from: &buf");
731            format!("{{ buf in {expr} }}")
732        }
733        ShapeKind::Struct(StructInfo {
734            name: Some(name), ..
735        }) => {
736            let swift_name = lookup_wire_name(name, types);
737            format!("{{ buf in try {swift_name}.decode(from: &buf) }}")
738        }
739        ShapeKind::Enum(EnumInfo {
740            name: Some(name), ..
741        }) => {
742            let swift_name = lookup_wire_name(name, types);
743            format!("{{ buf in try {swift_name}.decode(from: &buf) }}")
744        }
745        ShapeKind::List { element } | ShapeKind::Slice { element } => {
746            let inner = decode_element_closure(element, types);
747            format!("{{ buf in try decodeVec(from: &buf, decoder: {inner}) }}")
748        }
749        ShapeKind::Opaque => "{ buf in try OpaquePayload.decode(from: &buf) }".into(),
750        ShapeKind::Pointer { pointee } => decode_element_closure(pointee, types),
751        _ => "{ _ in throw WireError.truncated }".into(),
752    }
753}
754
755/// Generate convenience factory methods on Message.
756fn generate_factory_methods(types: &[WireType]) -> String {
757    // Find the MessagePayload shape to inspect variants
758    let payload_wt = types.iter().find(|wt| wt.swift_name == "MessagePayload");
759    let payload_wt = match payload_wt {
760        Some(wt) => wt,
761        None => return String::new(),
762    };
763
764    let variants = match classify_shape(payload_wt.shape) {
765        ShapeKind::Enum(EnumInfo { variants, .. }) => variants,
766        _ => return String::new(),
767    };
768
769    let mut out = String::new();
770    out.push_str("public extension Message {\n");
771
772    for v in variants {
773        let variant_name = v.name.to_lower_camel_case();
774        if let VariantKind::Newtype { inner } = classify_variant(v) {
775            let inner_swift = swift_wire_type(inner, None, types);
776            // Control messages use connId=0.
777            let is_control = matches!(
778                v.name,
779                "Hello" | "HelloYourself" | "ProtocolError" | "Ping" | "Pong"
780            );
781
782            if is_control {
783                out.push_str(&format!(
784                    "    static func {variant_name}(_ value: {inner_swift}) -> Message {{\n"
785                ));
786                out.push_str(&format!(
787                    "        Message(connectionId: 0, payload: .{variant_name}(value))\n"
788                ));
789                out.push_str("    }\n\n");
790            } else {
791                out.push_str(&format!(
792                    "    static func {variant_name}(connId: UInt64, _ value: {inner_swift}) -> Message {{\n"
793                ));
794                out.push_str(&format!(
795                    "        Message(connectionId: connId, payload: .{variant_name}(value))\n"
796                ));
797                out.push_str("    }\n\n");
798            }
799        }
800    }
801
802    // Additional ergonomic factory methods that flatten nested structs.
803    // These match the existing hand-coded API.
804    out.push_str("    static func protocolError(description: String) -> Message {\n");
805    out.push_str("        Message(connectionId: 0, payload: .protocolError(.init(description: description)))\n");
806    out.push_str("    }\n\n");
807
808    out.push_str("    static func connectionOpen(\n");
809    out.push_str(
810        "        connId: UInt64, settings: ConnectionSettings, metadata: [MetadataEntry]\n",
811    );
812    out.push_str("    ) -> Message {\n");
813    out.push_str("        Message(\n");
814    out.push_str("            connectionId: connId,\n");
815    out.push_str("            payload: .connectionOpen(.init(connectionSettings: settings, metadata: metadata)))\n");
816    out.push_str("    }\n\n");
817
818    out.push_str("    static func connectionAccept(\n");
819    out.push_str(
820        "        connId: UInt64, settings: ConnectionSettings, metadata: [MetadataEntry]\n",
821    );
822    out.push_str("    ) -> Message {\n");
823    out.push_str("        Message(\n");
824    out.push_str("            connectionId: connId,\n");
825    out.push_str("            payload: .connectionAccept(.init(connectionSettings: settings, metadata: metadata)))\n");
826    out.push_str("    }\n\n");
827
828    out.push_str("    static func connectionReject(connId: UInt64, metadata: [MetadataEntry]) -> Message {\n");
829    out.push_str("        Message(connectionId: connId, payload: .connectionReject(.init(metadata: metadata)))\n");
830    out.push_str("    }\n\n");
831
832    out.push_str(
833        "    static func connectionClose(connId: UInt64, metadata: [MetadataEntry]) -> Message {\n",
834    );
835    out.push_str("        Message(connectionId: connId, payload: .connectionClose(.init(metadata: metadata)))\n");
836    out.push_str("    }\n\n");
837
838    out.push_str("    static func request(\n");
839    out.push_str("        connId: UInt64,\n");
840    out.push_str("        requestId: UInt64,\n");
841    out.push_str("        methodId: UInt64,\n");
842    out.push_str("        metadata: [MetadataEntry],\n");
843    out.push_str("        schemas: [UInt8] = [],\n");
844    out.push_str("        payload: [UInt8]\n");
845    out.push_str("    ) -> Message {\n");
846    out.push_str("        Message(\n");
847    out.push_str("            connectionId: connId,\n");
848    out.push_str("            payload: .requestMessage(\n");
849    out.push_str("                .init(\n");
850    out.push_str("                    id: requestId,\n");
851    out.push_str("                    body: .call(.init(methodId: methodId, metadata: metadata, args: .init(payload), schemas: schemas))\n");
852    out.push_str("                ))\n");
853    out.push_str("        )\n");
854    out.push_str("    }\n\n");
855
856    out.push_str("    static func response(\n");
857    out.push_str("        connId: UInt64,\n");
858    out.push_str("        requestId: UInt64,\n");
859    out.push_str("        metadata: [MetadataEntry],\n");
860    out.push_str("        schemas: [UInt8] = [],\n");
861    out.push_str("        payload: [UInt8]\n");
862    out.push_str("    ) -> Message {\n");
863    out.push_str("        Message(\n");
864    out.push_str("            connectionId: connId,\n");
865    out.push_str("            payload: .requestMessage(\n");
866    out.push_str("                .init(\n");
867    out.push_str("                    id: requestId,\n");
868    out.push_str("                    body: .response(.init(metadata: metadata, ret: .init(payload), schemas: schemas))\n");
869    out.push_str("                ))\n");
870    out.push_str("        )\n");
871    out.push_str("    }\n\n");
872
873    out.push_str("    static func cancel(connId: UInt64, requestId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
874    out.push_str("        Message(\n");
875    out.push_str("            connectionId: connId,\n");
876    out.push_str("            payload: .requestMessage(\n");
877    out.push_str("                .init(\n");
878    out.push_str("                    id: requestId,\n");
879    out.push_str("                    body: .cancel(.init(metadata: metadata))\n");
880    out.push_str("                ))\n");
881    out.push_str("        )\n");
882    out.push_str("    }\n\n");
883
884    out.push_str(
885        "    static func data(connId: UInt64, channelId: UInt64, payload: [UInt8]) -> Message {\n",
886    );
887    out.push_str("        Message(\n");
888    out.push_str("            connectionId: connId,\n");
889    out.push_str("            payload: .channelMessage(.init(id: channelId, body: .item(.init(item: .init(payload)))))\n");
890    out.push_str("        )\n");
891    out.push_str("    }\n\n");
892
893    out.push_str("    static func close(connId: UInt64, channelId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
894    out.push_str("        Message(\n");
895    out.push_str("            connectionId: connId,\n");
896    out.push_str("            payload: .channelMessage(.init(id: channelId, body: .close(.init(metadata: metadata))))\n");
897    out.push_str("        )\n");
898    out.push_str("    }\n\n");
899
900    out.push_str("    static func reset(connId: UInt64, channelId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
901    out.push_str("        Message(\n");
902    out.push_str("            connectionId: connId,\n");
903    out.push_str("            payload: .channelMessage(.init(id: channelId, body: .reset(.init(metadata: metadata))))\n");
904    out.push_str("        )\n");
905    out.push_str("    }\n\n");
906
907    out.push_str(
908        "    static func credit(connId: UInt64, channelId: UInt64, bytes: UInt32) -> Message {\n",
909    );
910    out.push_str("        Message(\n");
911    out.push_str("            connectionId: connId,\n");
912    out.push_str("            payload: .channelMessage(.init(id: channelId, body: .grantCredit(.init(additional: bytes))))\n");
913    out.push_str("        )\n");
914    out.push_str("    }\n");
915
916    out.push_str("}\n");
917    out
918}