Skip to main content

vox_codegen/targets/swift/
wire.rs

1//! Swift wire type code generation.
2//!
3//! Generates a complete Swift source file containing all wire protocol types
4//! (Message, MessagePayload, etc.) with encode/decode methods. The generated
5//! code is driven by facet `Shape`s from `vox_types::message`.
6//!
7//! The only special-cased type is `Payload` (`ShapeKind::Opaque`), which maps to
8//! `OpaquePayload` with both length-prefixed and trailing byte handling.
9//! Everything else is normal struct/enum codegen.
10
11use super::types::swift_field_name;
12use facet_core::{Field, ScalarType, Shape};
13use vox_types::{
14    DEFAULT_INITIAL_CHANNEL_CREDIT, EnumInfo, ShapeKind, StructInfo, VariantKind, classify_shape,
15    classify_variant, extract_schemas, is_bytes,
16};
17
18/// A wire type to generate Swift code for.
19pub struct WireType {
20    /// The Swift name for this type (e.g. "Message", "HelloV7")
21    pub swift_name: String,
22    /// The facet Shape for this type
23    pub shape: &'static Shape,
24}
25
26/// Generate a complete Swift wire types file, returning the Swift source and the raw
27/// CBOR schema bytes separately. The caller is responsible for writing the `.bin` file
28/// and declaring it as a `Bundle.module` resource; the generated Swift loads it at runtime.
29pub fn generate_wire_types(types: &[WireType]) -> (String, Vec<u8>) {
30    let mut out = String::new();
31    out.push_str("// @generated by vox-codegen\n");
32    out.push_str("// DO NOT EDIT — regenerate with `cargo xtask codegen --swift-wire`\n\n");
33    out.push_str("import Foundation\n");
34    out.push_str("@preconcurrency import NIOCore\n\n");
35
36    // Preamble: error type + helpers + OpaquePayload
37    out.push_str(&generate_preamble());
38
39    // Metadata typealias (Metadata is Vec<MetadataEntry> in Rust, we alias for convenience)
40    out.push_str("public typealias Metadata = [MetadataEntry]\n\n");
41
42    // Generate each type
43    for wt in types {
44        out.push_str(&generate_one_type(&wt.swift_name, wt.shape, types));
45        out.push('\n');
46    }
47
48    // Generate factory methods extension on Message
49    out.push_str(&generate_factory_methods(types));
50
51    // Emit a Bundle.module loader instead of inlining the bytes as a literal.
52    out.push_str(
53        "\n/// CBOR-encoded wire message schemas, loaded from the bundled binary resource.\n",
54    );
55    out.push_str("public let wireMessageSchemasCbor: [UInt8] = {\n");
56    out.push_str(
57        "    guard let url = Bundle.module.url(forResource: \"wireMessageSchemas\", withExtension: \"bin\"),\n",
58    );
59    out.push_str("          let data = try? Data(contentsOf: url) else {\n");
60    out.push_str(
61        "        preconditionFailure(\"wireMessageSchemas.bin resource not found in Bundle.module\")\n",
62    );
63    out.push_str("    }\n");
64    out.push_str("    return Array(data)\n");
65    out.push_str("}()\n");
66
67    // Extract the CBOR bytes to return separately.
68    let cbor_bytes = types
69        .iter()
70        .find(|wt| wt.swift_name == "Message")
71        .or_else(|| types.last())
72        .map(|root| {
73            let extracted =
74                extract_schemas(root.shape).expect("wire schema extraction should succeed");
75            facet_cbor::to_vec(&extracted.schemas)
76                .expect("wire schema CBOR serialization should succeed")
77        })
78        .unwrap_or_default();
79
80    (out, cbor_bytes)
81}
82
83/// Generate the static preamble: WireError, OpaquePayload, helpers.
84fn generate_preamble() -> String {
85    let mut out = String::new();
86
87    // Error type
88    out.push_str("public enum WireError: Error, Equatable {\n");
89    out.push_str("    case truncated\n");
90    out.push_str("    case unknownVariant(UInt64)\n");
91    out.push_str("    case overflow\n");
92    out.push_str("    case invalidUtf8\n");
93    out.push_str("    case trailingBytes\n");
94    out.push_str("}\n\n");
95
96    // OpaquePayload — backed by ByteBuffer
97    out.push_str("public struct OpaquePayload: Sendable, Equatable {\n");
98    out.push_str("    public var bytes: ByteBuffer\n\n");
99    out.push_str("    public init(_ bytes: ByteBuffer) {\n");
100    out.push_str("        self.bytes = bytes\n");
101    out.push_str("    }\n\n");
102    out.push_str("    /// Init from a [UInt8] for convenience (e.g. from legacy code)\n");
103    out.push_str("    public init(_ bytes: [UInt8]) {\n");
104    out.push_str("        var buf = ByteBufferAllocator().buffer(capacity: bytes.count)\n");
105    out.push_str("        buf.writeBytes(bytes)\n");
106    out.push_str("        self.bytes = buf\n");
107    out.push_str("    }\n\n");
108    out.push_str("    func encode(into buffer: inout ByteBuffer) {\n");
109    out.push_str("        let len = UInt32(bytes.readableBytes)\n");
110    out.push_str("        buffer.writeInteger(len, endianness: .little)\n");
111    out.push_str("        var copy = bytes\n");
112    out.push_str("        buffer.writeBuffer(&copy)\n");
113    out.push_str("    }\n\n");
114    out.push_str("    static func decode(from buffer: inout ByteBuffer) throws -> Self {\n");
115    out.push_str(
116        "        guard let len: UInt32 = buffer.readInteger(endianness: .little) else {\n",
117    );
118    out.push_str("            throw WireError.truncated\n");
119    out.push_str("        }\n");
120    out.push_str("        guard let slice = buffer.readSlice(length: Int(len)) else {\n");
121    out.push_str("            throw WireError.truncated\n");
122    out.push_str("        }\n");
123    out.push_str("        return .init(slice)\n");
124    out.push_str("    }\n\n");
125    out.push_str("    /// Encode without a length prefix — for trailing fields only.\n");
126    out.push_str("    func encodeTrailing(into buffer: inout ByteBuffer) {\n");
127    out.push_str("        var copy = bytes\n");
128    out.push_str("        buffer.writeBuffer(&copy)\n");
129    out.push_str("    }\n\n");
130    out.push_str("    /// Decode by consuming all remaining bytes — for trailing fields only.\n");
131    out.push_str("    static func decodeTrailing(from buffer: inout ByteBuffer) -> Self {\n");
132    out.push_str(
133        "        let slice = buffer.readSlice(length: buffer.readableBytes) ?? ByteBuffer()\n",
134    );
135    out.push_str("        return .init(slice)\n");
136    out.push_str("    }\n");
137    out.push_str("}\n\n");
138
139    // Helper functions
140    out.push_str("@inline(__always)\n");
141    out.push_str(
142        "private func decodeWireVarintU32(from buffer: inout ByteBuffer) throws -> UInt32 {\n",
143    );
144    out.push_str("    let value = try decodeVarint(from: &buffer)\n");
145    out.push_str("    guard value <= UInt64(UInt32.max) else {\n");
146    out.push_str("        throw WireError.overflow\n");
147    out.push_str("    }\n");
148    out.push_str("    return UInt32(value)\n");
149    out.push_str("}\n\n");
150
151    out.push_str("@inline(__always)\n");
152    out.push_str(
153        "private func decodeWireString(from buffer: inout ByteBuffer) throws -> String {\n",
154    );
155    out.push_str("    do {\n");
156    out.push_str("        return try decodeString(from: &buffer)\n");
157    out.push_str("    } catch PostcardError.invalidUtf8 {\n");
158    out.push_str("        throw WireError.invalidUtf8\n");
159    out.push_str("    } catch PostcardError.truncated {\n");
160    out.push_str("        throw WireError.truncated\n");
161    out.push_str("    } catch {\n");
162    out.push_str("        throw error\n");
163    out.push_str("    }\n");
164    out.push_str("}\n\n");
165
166    out.push_str("@inline(__always)\n");
167    out.push_str(
168        "private func decodeWireBytes(from buffer: inout ByteBuffer) throws -> ByteBuffer {\n",
169    );
170    out.push_str("    do {\n");
171    out.push_str("        return try decodeBytes(from: &buffer)\n");
172    out.push_str("    } catch PostcardError.truncated {\n");
173    out.push_str("        throw WireError.truncated\n");
174    out.push_str("    } catch {\n");
175    out.push_str("        throw error\n");
176    out.push_str("    }\n");
177    out.push_str("}\n\n");
178
179    out
180}
181
182/// Get the wire Swift type for a field's shape, taking the `WireType` list into account.
183fn swift_wire_type(shape: &'static Shape, _field: Option<&Field>, types: &[WireType]) -> String {
184    if is_bytes(shape) {
185        return "[UInt8]".into();
186    }
187
188    match classify_shape(shape) {
189        ShapeKind::Scalar(scalar) => swift_scalar_type(scalar),
190        ShapeKind::List { element } | ShapeKind::Slice { element } => {
191            format!("[{}]", swift_wire_type(element, None, types))
192        }
193        ShapeKind::Option { inner } => {
194            format!("{}?", swift_wire_type(inner, None, types))
195        }
196        ShapeKind::Array { element, .. } => {
197            format!("[{}]", swift_wire_type(element, None, types))
198        }
199        ShapeKind::Struct(StructInfo {
200            name: Some(name), ..
201        }) => lookup_wire_name(name, types),
202        ShapeKind::Enum(EnumInfo {
203            name: Some(name), ..
204        }) => lookup_wire_name(name, types),
205        ShapeKind::Pointer { pointee } => swift_wire_type(pointee, _field, types),
206        ShapeKind::Opaque => "OpaquePayload".into(),
207        ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
208            swift_wire_type(fields[0].shape(), None, types)
209        }
210        _ => "Any /* unsupported */".into(),
211    }
212}
213
214fn swift_scalar_type(scalar: ScalarType) -> String {
215    match scalar {
216        ScalarType::Bool => "Bool".into(),
217        ScalarType::U8 => "UInt8".into(),
218        ScalarType::U16 => "UInt16".into(),
219        ScalarType::U32 => "UInt32".into(),
220        ScalarType::U64 | ScalarType::USize => "UInt64".into(),
221        ScalarType::I8 => "Int8".into(),
222        ScalarType::I16 => "Int16".into(),
223        ScalarType::I32 => "Int32".into(),
224        ScalarType::I64 | ScalarType::ISize => "Int64".into(),
225        ScalarType::F32 => "Float".into(),
226        ScalarType::F64 => "Double".into(),
227        ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
228            "String".into()
229        }
230        ScalarType::Unit => "Void".into(),
231        _ => "Any".into(),
232    }
233}
234
235/// Look up the wire name for a Rust type name. If a matching WireType exists, use its
236/// swift_name; otherwise fall back to the Rust name.
237fn lookup_wire_name(rust_name: &str, types: &[WireType]) -> String {
238    for wt in types {
239        // Match by checking the reflected type identifier for the underlying Rust shape.
240        if wt.shape.type_identifier.ends_with(rust_name) {
241            return wt.swift_name.clone();
242        }
243    }
244    rust_name.to_string()
245}
246
247/// Generate a single wire type (struct or enum).
248fn generate_one_type(swift_name: &str, shape: &'static Shape, types: &[WireType]) -> String {
249    match classify_shape(shape) {
250        ShapeKind::Struct(StructInfo { fields, .. }) => {
251            generate_struct(swift_name, fields, types, swift_name == "Message")
252        }
253        ShapeKind::Enum(EnumInfo { variants, .. }) => generate_enum(swift_name, variants, types),
254        _ => format!("// Unsupported shape for {swift_name}\n"),
255    }
256}
257
258/// Generate a Swift struct with encode/decode methods.
259fn generate_struct(name: &str, fields: &[Field], types: &[WireType], is_top_level: bool) -> String {
260    let mut out = String::new();
261
262    // Struct definition
263    out.push_str(&format!("public struct {name}: Sendable, Equatable {{\n"));
264    for f in fields {
265        let field_name = swift_field_name(f.name);
266        let field_type = swift_wire_type(f.shape(), Some(f), types);
267        out.push_str(&format!("    public var {field_name}: {field_type}\n"));
268    }
269
270    // Initializer
271    if !fields.is_empty() {
272        out.push_str("\n    public init(");
273        for (i, f) in fields.iter().enumerate() {
274            if i > 0 {
275                out.push_str(", ");
276            }
277            let field_name = swift_field_name(f.name);
278            let field_type = swift_wire_type(f.shape(), Some(f), types);
279            out.push_str(&format!("{field_name}: {field_type}"));
280            if let Some(default_value) = swift_default_argument(f) {
281                out.push_str(&format!(" = {default_value}"));
282            }
283        }
284        out.push_str(") {\n");
285        for f in fields {
286            let field_name = swift_field_name(f.name);
287            out.push_str(&format!("        self.{field_name} = {field_name}\n"));
288        }
289        out.push_str("    }\n");
290    }
291
292    // Encode method — writes into a ByteBuffer, returns void
293    let vis = if is_top_level { "public " } else { "" };
294    out.push_str(&format!(
295        "\n    {vis}func encode(into buffer: inout ByteBuffer) {{\n"
296    ));
297    if fields.is_empty() {
298        // Nothing to write
299    } else {
300        for f in fields {
301            let stmt = encode_field_stmt(f, types);
302            out.push_str(&format!("        {stmt}\n"));
303        }
304    }
305    out.push_str("    }\n");
306
307    // Top-level Message gets a [UInt8] bridge shim for encode
308    if is_top_level {
309        out.push_str(
310            "\n    /// Encode to a `[UInt8]` array (bridge for callers that need bytes).\n",
311        );
312        out.push_str("    public func encode() -> [UInt8] {\n");
313        out.push_str("        var buffer = ByteBufferAllocator().buffer(capacity: 64)\n");
314        out.push_str("        encode(into: &buffer)\n");
315        out.push_str("        return buffer.readBytes(length: buffer.readableBytes) ?? []\n");
316        out.push_str("    }\n");
317    }
318
319    // Decode method — reads from inout ByteBuffer
320    out.push_str(&format!(
321        "\n    {vis}static func decode(from buffer: inout ByteBuffer) throws -> Self {{\n"
322    ));
323    for f in fields {
324        for stmt in decode_field_stmts(f, types) {
325            out.push_str(&format!("        {stmt}\n"));
326        }
327    }
328    let field_names: Vec<String> = fields
329        .iter()
330        .map(|f| {
331            let n = swift_field_name(f.name);
332            format!("{n}: {n}")
333        })
334        .collect();
335    out.push_str(&format!(
336        "        return .init({})\n",
337        field_names.join(", ")
338    ));
339    out.push_str("    }\n");
340
341    // Top-level Message gets a [UInt8] bridge shim for decode
342    if is_top_level {
343        out.push_str(
344            "\n    /// Decode from a `[UInt8]` array (bridge for callers that have raw bytes).\n",
345        );
346        out.push_str("    public static func decode(fromBytes data: [UInt8]) throws -> Self {\n");
347        out.push_str("        var buffer = ByteBufferAllocator().buffer(capacity: data.count)\n");
348        out.push_str("        buffer.writeBytes(data)\n");
349        out.push_str("        let result = try decode(from: &buffer)\n");
350        out.push_str("        guard buffer.readableBytes == 0 else {\n");
351        out.push_str("            throw WireError.trailingBytes\n");
352        out.push_str("        }\n");
353        out.push_str("        return result\n");
354        out.push_str("    }\n");
355    }
356
357    out.push_str("}\n");
358    out
359}
360
361fn swift_default_argument(field: &Field) -> Option<String> {
362    if field.name == "initial_channel_credit" && field.has_default() {
363        return Some(DEFAULT_INITIAL_CHANNEL_CREDIT.to_string());
364    }
365    None
366}
367
368/// Generate a Swift enum with varint-discriminanted encode/decode methods.
369fn generate_enum(name: &str, variants: &[facet_core::Variant], types: &[WireType]) -> String {
370    let mut out = String::new();
371
372    // Enum definition
373    out.push_str(&format!("public enum {name}: Sendable, Equatable {{\n"));
374    for v in variants {
375        let variant_name = swift_field_name(v.name);
376        match classify_variant(v) {
377            VariantKind::Unit => {
378                out.push_str(&format!("    case {variant_name}\n"));
379            }
380            VariantKind::Newtype { inner } => {
381                let inner_type = swift_wire_type(inner, v.data.fields.first(), types);
382                out.push_str(&format!("    case {variant_name}({inner_type})\n"));
383            }
384            VariantKind::Tuple { fields } => {
385                let field_types: Vec<String> = fields
386                    .iter()
387                    .map(|f| swift_wire_type(f.shape(), Some(f), types))
388                    .collect();
389                out.push_str(&format!(
390                    "    case {variant_name}({})\n",
391                    field_types.join(", ")
392                ));
393            }
394            VariantKind::Struct { fields } => {
395                let field_decls: Vec<String> = fields
396                    .iter()
397                    .map(|f| {
398                        format!(
399                            "{}: {}",
400                            swift_field_name(f.name),
401                            swift_wire_type(f.shape(), Some(f), types)
402                        )
403                    })
404                    .collect();
405                out.push_str(&format!(
406                    "    case {variant_name}({})\n",
407                    field_decls.join(", ")
408                ));
409            }
410        }
411    }
412
413    // Encode — void, writes into buffer
414    out.push_str("\n    func encode(into buffer: inout ByteBuffer) {\n");
415    out.push_str("        switch self {\n");
416    for (i, v) in variants.iter().enumerate() {
417        let variant_name = swift_field_name(v.name);
418        match classify_variant(v) {
419            VariantKind::Unit => {
420                out.push_str(&format!(
421                    "        case .{variant_name}:\n            encodeVarint(UInt64({i}), into: &buffer)\n"
422                ));
423            }
424            VariantKind::Newtype { inner } => {
425                let stmt = encode_shape_stmt(inner, "val", v.data.fields.first(), types);
426                out.push_str(&format!(
427                    "        case .{variant_name}(let val):\n            encodeVarint(UInt64({i}), into: &buffer)\n            {stmt}\n"
428                ));
429            }
430            VariantKind::Tuple { fields } => {
431                let bindings: Vec<String> = (0..fields.len()).map(|j| format!("f{j}")).collect();
432                let binding_str = bindings
433                    .iter()
434                    .map(|b| format!("let {b}"))
435                    .collect::<Vec<_>>()
436                    .join(", ");
437                let stmts: Vec<String> = fields
438                    .iter()
439                    .enumerate()
440                    .map(|(j, f)| encode_shape_stmt(f.shape(), &format!("f{j}"), Some(f), types))
441                    .collect();
442                out.push_str(&format!(
443                    "        case .{variant_name}({binding_str}):\n            encodeVarint(UInt64({i}), into: &buffer)\n"
444                ));
445                for stmt in &stmts {
446                    out.push_str(&format!("            {stmt}\n"));
447                }
448            }
449            VariantKind::Struct { fields } => {
450                let bindings: Vec<String> =
451                    fields.iter().map(|f| swift_field_name(f.name)).collect();
452                let binding_str = bindings
453                    .iter()
454                    .map(|b| format!("let {b}"))
455                    .collect::<Vec<_>>()
456                    .join(", ");
457                let stmts: Vec<String> = fields
458                    .iter()
459                    .map(|f| {
460                        encode_shape_stmt(f.shape(), &swift_field_name(f.name), Some(f), types)
461                    })
462                    .collect();
463                out.push_str(&format!(
464                    "        case .{variant_name}({binding_str}):\n            encodeVarint(UInt64({i}), into: &buffer)\n"
465                ));
466                for stmt in &stmts {
467                    out.push_str(&format!("            {stmt}\n"));
468                }
469            }
470        }
471    }
472    out.push_str("        }\n");
473    out.push_str("    }\n");
474
475    // Decode — reads from inout ByteBuffer
476    out.push_str("\n    static func decode(from buffer: inout ByteBuffer) throws -> Self {\n");
477    out.push_str("        let disc = try decodeVarint(from: &buffer)\n");
478    out.push_str("        switch disc {\n");
479    for (i, v) in variants.iter().enumerate() {
480        let variant_name = swift_field_name(v.name);
481        out.push_str(&format!("        case {i}:\n"));
482        match classify_variant(v) {
483            VariantKind::Unit => {
484                out.push_str(&format!("            return .{variant_name}\n"));
485            }
486            VariantKind::Newtype { inner } => {
487                for stmt in decode_stmts_for(inner, v.data.fields.first(), "_newtype_val", types) {
488                    out.push_str(&format!("            {stmt}\n"));
489                }
490                out.push_str(&format!(
491                    "            return .{variant_name}(_newtype_val)\n"
492                ));
493            }
494            VariantKind::Tuple { fields } => {
495                for (j, f) in fields.iter().enumerate() {
496                    for stmt in decode_stmts_for(f.shape(), Some(f), &format!("f{j}"), types) {
497                        out.push_str(&format!("            {stmt}\n"));
498                    }
499                }
500                let args: Vec<String> = (0..fields.len()).map(|j| format!("f{j}")).collect();
501                out.push_str(&format!(
502                    "            return .{variant_name}({})\n",
503                    args.join(", ")
504                ));
505            }
506            VariantKind::Struct { fields } => {
507                for f in fields {
508                    let field_name = swift_field_name(f.name);
509                    for stmt in decode_stmts_for(f.shape(), Some(f), &field_name, types) {
510                        out.push_str(&format!("            {stmt}\n"));
511                    }
512                }
513                let args: Vec<String> = fields
514                    .iter()
515                    .map(|f| {
516                        let n = swift_field_name(f.name);
517                        format!("{n}: {n}")
518                    })
519                    .collect();
520                out.push_str(&format!(
521                    "            return .{variant_name}({})\n",
522                    args.join(", ")
523                ));
524            }
525        }
526    }
527    out.push_str("        default:\n");
528    out.push_str("            throw WireError.unknownVariant(disc)\n");
529    out.push_str("        }\n");
530    out.push_str("    }\n");
531
532    out.push_str("}\n");
533    out
534}
535
536/// Generate an encode statement for a struct field (writes into `buffer`).
537fn encode_field_stmt(field: &Field, types: &[WireType]) -> String {
538    let field_name = swift_field_name(field.name);
539    encode_shape_stmt(field.shape(), &field_name, Some(field), types)
540}
541
542/// Generate an encode statement for a shape with a given value expression.
543/// The statement writes into the implicit `buffer: inout ByteBuffer` in scope.
544fn encode_shape_stmt(
545    shape: &'static Shape,
546    value: &str,
547    field: Option<&Field>,
548    types: &[WireType],
549) -> String {
550    let is_trailing = field.is_some_and(|f| f.has_builtin_attr("trailing"));
551
552    // Opaque type → OpaquePayload
553    if matches!(classify_shape(shape), ShapeKind::Opaque) {
554        return if is_trailing {
555            format!("{value}.encodeTrailing(into: &buffer)")
556        } else {
557            format!("{value}.encode(into: &buffer)")
558        };
559    }
560
561    if is_bytes(shape) {
562        return format!("encodeByteSeq({value}, into: &buffer)");
563    }
564
565    match classify_shape(shape) {
566        ShapeKind::Scalar(scalar) => encode_scalar_stmt(scalar, value),
567        ShapeKind::List { element } | ShapeKind::Slice { element } => {
568            let inner = encode_element_closure(element, types);
569            format!("encodeVec({value}, into: &buffer, encoder: {inner})")
570        }
571        ShapeKind::Option { inner } => {
572            let inner_closure = encode_element_closure(inner, types);
573            format!("encodeOption({value}, into: &buffer, encoder: {inner_closure})")
574        }
575        ShapeKind::Struct(StructInfo { .. }) | ShapeKind::Enum(EnumInfo { .. }) => {
576            format!("{value}.encode(into: &buffer)")
577        }
578        ShapeKind::Pointer { pointee } => encode_shape_stmt(pointee, value, field, types),
579        ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
580            encode_shape_stmt(fields[0].shape(), value, field, types)
581        }
582        _ => format!("/* unsupported encode for {value} */"),
583    }
584}
585
586fn encode_scalar_stmt(scalar: ScalarType, value: &str) -> String {
587    match scalar {
588        ScalarType::Bool => format!("encodeBool({value}, into: &buffer)"),
589        ScalarType::U8 => format!("encodeU8({value}, into: &buffer)"),
590        ScalarType::I8 => format!("encodeI8({value}, into: &buffer)"),
591        ScalarType::U16 => format!("encodeU16({value}, into: &buffer)"),
592        ScalarType::I16 => format!("encodeI16({value}, into: &buffer)"),
593        ScalarType::U32 => format!("encodeVarint(UInt64({value}), into: &buffer)"),
594        ScalarType::I32 => format!("encodeI32({value}, into: &buffer)"),
595        ScalarType::U64 | ScalarType::USize => format!("encodeVarint({value}, into: &buffer)"),
596        ScalarType::I64 | ScalarType::ISize => format!("encodeI64({value}, into: &buffer)"),
597        ScalarType::F32 => format!("encodeF32({value}, into: &buffer)"),
598        ScalarType::F64 => format!("encodeF64({value}, into: &buffer)"),
599        ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
600            format!("encodeString({value}, into: &buffer)")
601        }
602        _ => format!("/* unsupported scalar encode for {value} */"),
603    }
604}
605
606/// Generate an encode closure for use with encodeVec / encodeOption.
607/// Closures take `(T, inout ByteBuffer)` — parameters named `val` and `buf`.
608fn encode_element_closure(shape: &'static Shape, _types: &[WireType]) -> String {
609    if is_bytes(shape) {
610        return "{ val, buf in encodeByteSeq(val, into: &buf) }".into();
611    }
612
613    match classify_shape(shape) {
614        ShapeKind::Scalar(scalar) => {
615            let stmt = encode_scalar_stmt(scalar, "val").replace("into: &buffer", "into: &buf");
616            format!("{{ val, buf in {stmt} }}")
617        }
618        ShapeKind::Struct(StructInfo { .. }) | ShapeKind::Enum(EnumInfo { .. }) => {
619            "{ val, buf in val.encode(into: &buf) }".into()
620        }
621        ShapeKind::List { element } | ShapeKind::Slice { element } => {
622            let inner = encode_element_closure(element, _types);
623            format!("{{ val, buf in encodeVec(val, into: &buf, encoder: {inner}) }}")
624        }
625        ShapeKind::Opaque => "{ val, buf in val.encode(into: &buf) }".into(),
626        ShapeKind::Pointer { pointee } => encode_element_closure(pointee, _types),
627        _ => "{ _, _ in /* unsupported */ }".into(),
628    }
629}
630
631/// Generate decode statements for a struct field into a named variable.
632/// Returns one or two Swift statements (as lines to be written individually).
633fn decode_field_stmts(field: &Field, types: &[WireType]) -> Vec<String> {
634    let field_name = swift_field_name(field.name);
635    decode_stmts_for(field.shape(), Some(field), &field_name, types)
636}
637
638/// Produce one or two Swift statements that decode `shape` into a local named `var_name`.
639/// When the shape (or its pointer-unwrapped inner) is bytes, two statements are needed.
640fn decode_stmts_for(
641    shape: &'static Shape,
642    field: Option<&Field>,
643    var_name: &str,
644    types: &[WireType],
645) -> Vec<String> {
646    // Unwrap pointers transparently
647    if let ShapeKind::Pointer { pointee } = classify_shape(shape) {
648        return decode_stmts_for(pointee, field, var_name, types);
649    }
650    if is_bytes(shape) {
651        return vec![
652            format!("var _{var_name}Buf = try decodeWireBytes(from: &buffer)"),
653            format!(
654                "let {var_name} = _{var_name}Buf.readBytes(length: _{var_name}Buf.readableBytes) ?? []"
655            ),
656        ];
657    }
658    vec![format!(
659        "let {var_name} = {}",
660        decode_shape_expr(shape, field, types)
661    )]
662}
663
664/// Generate a decode expression for a shape.
665/// All decode calls read from the implicit `buffer: inout ByteBuffer` in scope.
666fn decode_shape_expr(shape: &'static Shape, field: Option<&Field>, types: &[WireType]) -> String {
667    let is_trailing = field.is_some_and(|f| f.has_builtin_attr("trailing"));
668
669    // Opaque type → OpaquePayload
670    if matches!(classify_shape(shape), ShapeKind::Opaque) {
671        return if is_trailing {
672            "OpaquePayload.decodeTrailing(from: &buffer)".into()
673        } else {
674            "try OpaquePayload.decode(from: &buffer)".into()
675        };
676    }
677
678    match classify_shape(shape) {
679        ShapeKind::Scalar(scalar) => decode_scalar(scalar),
680        ShapeKind::List { element } | ShapeKind::Slice { element } => {
681            let inner = decode_element_closure(element, types);
682            format!("try decodeVec(from: &buffer, decoder: {inner})")
683        }
684        ShapeKind::Option { inner } => {
685            let inner_closure = decode_element_closure(inner, types);
686            format!("try decodeOption(from: &buffer, decoder: {inner_closure})")
687        }
688        ShapeKind::Struct(StructInfo {
689            name: Some(name), ..
690        }) => {
691            let swift_name = lookup_wire_name(name, types);
692            format!("try {swift_name}.decode(from: &buffer)")
693        }
694        ShapeKind::Enum(EnumInfo {
695            name: Some(name), ..
696        }) => {
697            let swift_name = lookup_wire_name(name, types);
698            format!("try {swift_name}.decode(from: &buffer)")
699        }
700        ShapeKind::Pointer { pointee } => decode_shape_expr(pointee, field, types),
701        ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
702            decode_shape_expr(fields[0].shape(), field, types)
703        }
704        _ => "nil /* unsupported decode */".into(),
705    }
706}
707
708fn decode_scalar(scalar: ScalarType) -> String {
709    match scalar {
710        ScalarType::Bool => "try decodeBool(from: &buffer)".into(),
711        ScalarType::U8 => "try decodeU8(from: &buffer)".into(),
712        ScalarType::I8 => "try decodeI8(from: &buffer)".into(),
713        ScalarType::U16 => "try decodeU16(from: &buffer)".into(),
714        ScalarType::I16 => "try decodeI16(from: &buffer)".into(),
715        ScalarType::U32 => "try decodeWireVarintU32(from: &buffer)".into(),
716        ScalarType::I32 => "try decodeI32(from: &buffer)".into(),
717        ScalarType::U64 | ScalarType::USize => "try decodeVarint(from: &buffer)".into(),
718        ScalarType::I64 | ScalarType::ISize => "try decodeI64(from: &buffer)".into(),
719        ScalarType::F32 => "try decodeF32(from: &buffer)".into(),
720        ScalarType::F64 => "try decodeF64(from: &buffer)".into(),
721        ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
722            "try decodeWireString(from: &buffer)".into()
723        }
724        _ => "nil /* unsupported scalar decode */".into(),
725    }
726}
727
728/// Generate a decode closure for use with decodeVec / decodeOption.
729/// Closures take `(inout ByteBuffer)` — parameter named `buf`.
730fn decode_element_closure(shape: &'static Shape, types: &[WireType]) -> String {
731    if is_bytes(shape) {
732        // The closure is itself throwing so `try` inside is valid here.
733        return "{ buf in var s = try decodeWireBytes(from: &buf); return s.readBytes(length: s.readableBytes) ?? [] }".into();
734    }
735
736    match classify_shape(shape) {
737        ShapeKind::Scalar(scalar) => {
738            let expr = decode_scalar(scalar).replace("from: &buffer", "from: &buf");
739            format!("{{ buf in {expr} }}")
740        }
741        ShapeKind::Struct(StructInfo {
742            name: Some(name), ..
743        }) => {
744            let swift_name = lookup_wire_name(name, types);
745            format!("{{ buf in try {swift_name}.decode(from: &buf) }}")
746        }
747        ShapeKind::Enum(EnumInfo {
748            name: Some(name), ..
749        }) => {
750            let swift_name = lookup_wire_name(name, types);
751            format!("{{ buf in try {swift_name}.decode(from: &buf) }}")
752        }
753        ShapeKind::List { element } | ShapeKind::Slice { element } => {
754            let inner = decode_element_closure(element, types);
755            format!("{{ buf in try decodeVec(from: &buf, decoder: {inner}) }}")
756        }
757        ShapeKind::Opaque => "{ buf in try OpaquePayload.decode(from: &buf) }".into(),
758        ShapeKind::Pointer { pointee } => decode_element_closure(pointee, types),
759        _ => "{ _ in throw WireError.truncated }".into(),
760    }
761}
762
763/// Generate convenience factory methods on Message.
764fn generate_factory_methods(types: &[WireType]) -> String {
765    // Find the MessagePayload shape to inspect variants
766    let payload_wt = types.iter().find(|wt| wt.swift_name == "MessagePayload");
767    let payload_wt = match payload_wt {
768        Some(wt) => wt,
769        None => return String::new(),
770    };
771
772    let variants = match classify_shape(payload_wt.shape) {
773        ShapeKind::Enum(EnumInfo { variants, .. }) => variants,
774        _ => return String::new(),
775    };
776
777    let mut out = String::new();
778    out.push_str("public extension Message {\n");
779
780    for v in variants {
781        let variant_name = swift_field_name(v.name);
782        if let VariantKind::Newtype { inner } = classify_variant(v) {
783            let inner_swift = swift_wire_type(inner, None, types);
784            // Control messages use connId=0.
785            let is_control = matches!(
786                v.name,
787                "Hello" | "HelloYourself" | "ProtocolError" | "Ping" | "Pong"
788            );
789
790            if is_control {
791                out.push_str(&format!(
792                    "    static func {variant_name}(_ value: {inner_swift}) -> Message {{\n"
793                ));
794                out.push_str(&format!(
795                    "        Message(connectionId: 0, payload: .{variant_name}(value))\n"
796                ));
797                out.push_str("    }\n\n");
798            } else {
799                out.push_str(&format!(
800                    "    static func {variant_name}(connId: UInt64, _ value: {inner_swift}) -> Message {{\n"
801                ));
802                out.push_str(&format!(
803                    "        Message(connectionId: connId, payload: .{variant_name}(value))\n"
804                ));
805                out.push_str("    }\n\n");
806            }
807        }
808    }
809
810    // Additional ergonomic factory methods that flatten nested structs.
811    // These match the existing hand-coded API.
812    out.push_str("    static func protocolError(description: String) -> Message {\n");
813    out.push_str("        Message(connectionId: 0, payload: .protocolError(.init(description: description)))\n");
814    out.push_str("    }\n\n");
815
816    out.push_str("    static func connectionOpen(\n");
817    out.push_str(
818        "        connId: UInt64, settings: ConnectionSettings, metadata: [MetadataEntry]\n",
819    );
820    out.push_str("    ) -> Message {\n");
821    out.push_str("        Message(\n");
822    out.push_str("            connectionId: connId,\n");
823    out.push_str("            payload: .connectionOpen(.init(connectionSettings: settings, metadata: metadata)))\n");
824    out.push_str("    }\n\n");
825
826    out.push_str("    static func connectionAccept(\n");
827    out.push_str(
828        "        connId: UInt64, settings: ConnectionSettings, metadata: [MetadataEntry]\n",
829    );
830    out.push_str("    ) -> Message {\n");
831    out.push_str("        Message(\n");
832    out.push_str("            connectionId: connId,\n");
833    out.push_str("            payload: .connectionAccept(.init(connectionSettings: settings, metadata: metadata)))\n");
834    out.push_str("    }\n\n");
835
836    out.push_str("    static func connectionReject(connId: UInt64, metadata: [MetadataEntry]) -> Message {\n");
837    out.push_str("        Message(connectionId: connId, payload: .connectionReject(.init(metadata: metadata)))\n");
838    out.push_str("    }\n\n");
839
840    out.push_str(
841        "    static func connectionClose(connId: UInt64, metadata: [MetadataEntry]) -> Message {\n",
842    );
843    out.push_str("        Message(connectionId: connId, payload: .connectionClose(.init(metadata: metadata)))\n");
844    out.push_str("    }\n\n");
845
846    out.push_str("    static func request(\n");
847    out.push_str("        connId: UInt64,\n");
848    out.push_str("        requestId: UInt64,\n");
849    out.push_str("        methodId: UInt64,\n");
850    out.push_str("        metadata: [MetadataEntry],\n");
851    out.push_str("        schemas: [UInt8] = [],\n");
852    out.push_str("        payload: [UInt8]\n");
853    out.push_str("    ) -> Message {\n");
854    out.push_str("        Message(\n");
855    out.push_str("            connectionId: connId,\n");
856    out.push_str("            payload: .requestMessage(\n");
857    out.push_str("                .init(\n");
858    out.push_str("                    id: requestId,\n");
859    out.push_str("                    body: .call(.init(methodId: methodId, metadata: metadata, args: .init(payload), schemas: schemas))\n");
860    out.push_str("                ))\n");
861    out.push_str("        )\n");
862    out.push_str("    }\n\n");
863
864    out.push_str("    static func response(\n");
865    out.push_str("        connId: UInt64,\n");
866    out.push_str("        requestId: UInt64,\n");
867    out.push_str("        metadata: [MetadataEntry],\n");
868    out.push_str("        schemas: [UInt8] = [],\n");
869    out.push_str("        payload: [UInt8]\n");
870    out.push_str("    ) -> Message {\n");
871    out.push_str("        Message(\n");
872    out.push_str("            connectionId: connId,\n");
873    out.push_str("            payload: .requestMessage(\n");
874    out.push_str("                .init(\n");
875    out.push_str("                    id: requestId,\n");
876    out.push_str("                    body: .response(.init(metadata: metadata, ret: .init(payload), schemas: schemas))\n");
877    out.push_str("                ))\n");
878    out.push_str("        )\n");
879    out.push_str("    }\n\n");
880
881    out.push_str("    static func cancel(connId: UInt64, requestId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
882    out.push_str("        Message(\n");
883    out.push_str("            connectionId: connId,\n");
884    out.push_str("            payload: .requestMessage(\n");
885    out.push_str("                .init(\n");
886    out.push_str("                    id: requestId,\n");
887    out.push_str("                    body: .cancel(.init(metadata: metadata))\n");
888    out.push_str("                ))\n");
889    out.push_str("        )\n");
890    out.push_str("    }\n\n");
891
892    out.push_str(
893        "    static func data(connId: UInt64, channelId: UInt64, payload: [UInt8]) -> Message {\n",
894    );
895    out.push_str("        Message(\n");
896    out.push_str("            connectionId: connId,\n");
897    out.push_str("            payload: .channelMessage(.init(id: channelId, body: .item(.init(item: .init(payload)))))\n");
898    out.push_str("        )\n");
899    out.push_str("    }\n\n");
900
901    out.push_str("    static func close(connId: UInt64, channelId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
902    out.push_str("        Message(\n");
903    out.push_str("            connectionId: connId,\n");
904    out.push_str("            payload: .channelMessage(.init(id: channelId, body: .close(.init(metadata: metadata))))\n");
905    out.push_str("        )\n");
906    out.push_str("    }\n\n");
907
908    out.push_str("    static func reset(connId: UInt64, channelId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
909    out.push_str("        Message(\n");
910    out.push_str("            connectionId: connId,\n");
911    out.push_str("            payload: .channelMessage(.init(id: channelId, body: .reset(.init(metadata: metadata))))\n");
912    out.push_str("        )\n");
913    out.push_str("    }\n\n");
914
915    out.push_str(
916        "    static func credit(connId: UInt64, channelId: UInt64, bytes: UInt32) -> Message {\n",
917    );
918    out.push_str("        Message(\n");
919    out.push_str("            connectionId: connId,\n");
920    out.push_str("            payload: .channelMessage(.init(id: channelId, body: .grantCredit(.init(additional: bytes))))\n");
921    out.push_str("        )\n");
922    out.push_str("    }\n");
923
924    out.push_str("}\n");
925    out
926}