1use super::types::swift_field_name;
12use facet_core::{Field, ScalarType, Shape};
13use vox_types::{
14 DEFAULT_INITIAL_CHANNEL_CREDIT, EnumInfo, ShapeKind, StructInfo, VariantKind, classify_shape,
15 classify_variant, extract_schemas, is_bytes,
16};
17
18pub struct WireType {
20 pub swift_name: String,
22 pub shape: &'static Shape,
24}
25
26pub fn generate_wire_types(types: &[WireType]) -> (String, Vec<u8>) {
30 let mut out = String::new();
31 out.push_str("// @generated by vox-codegen\n");
32 out.push_str("// DO NOT EDIT — regenerate with `cargo xtask codegen --swift-wire`\n\n");
33 out.push_str("import Foundation\n");
34 out.push_str("@preconcurrency import NIOCore\n\n");
35
36 out.push_str(&generate_preamble());
38
39 out.push_str("public typealias Metadata = [MetadataEntry]\n\n");
41
42 for wt in types {
44 out.push_str(&generate_one_type(&wt.swift_name, wt.shape, types));
45 out.push('\n');
46 }
47
48 out.push_str(&generate_factory_methods(types));
50
51 out.push_str(
53 "\n/// CBOR-encoded wire message schemas, loaded from the bundled binary resource.\n",
54 );
55 out.push_str("public let wireMessageSchemasCbor: [UInt8] = {\n");
56 out.push_str(
57 " guard let url = Bundle.module.url(forResource: \"wireMessageSchemas\", withExtension: \"bin\"),\n",
58 );
59 out.push_str(" let data = try? Data(contentsOf: url) else {\n");
60 out.push_str(
61 " preconditionFailure(\"wireMessageSchemas.bin resource not found in Bundle.module\")\n",
62 );
63 out.push_str(" }\n");
64 out.push_str(" return Array(data)\n");
65 out.push_str("}()\n");
66
67 let cbor_bytes = types
69 .iter()
70 .find(|wt| wt.swift_name == "Message")
71 .or_else(|| types.last())
72 .map(|root| {
73 let extracted =
74 extract_schemas(root.shape).expect("wire schema extraction should succeed");
75 facet_cbor::to_vec(&extracted.schemas)
76 .expect("wire schema CBOR serialization should succeed")
77 })
78 .unwrap_or_default();
79
80 (out, cbor_bytes)
81}
82
83fn generate_preamble() -> String {
85 let mut out = String::new();
86
87 out.push_str("public enum WireError: Error, Equatable {\n");
89 out.push_str(" case truncated\n");
90 out.push_str(" case unknownVariant(UInt64)\n");
91 out.push_str(" case overflow\n");
92 out.push_str(" case invalidUtf8\n");
93 out.push_str(" case trailingBytes\n");
94 out.push_str("}\n\n");
95
96 out.push_str("public struct OpaquePayload: Sendable, Equatable {\n");
98 out.push_str(" public var bytes: ByteBuffer\n\n");
99 out.push_str(" public init(_ bytes: ByteBuffer) {\n");
100 out.push_str(" self.bytes = bytes\n");
101 out.push_str(" }\n\n");
102 out.push_str(" /// Init from a [UInt8] for convenience (e.g. from legacy code)\n");
103 out.push_str(" public init(_ bytes: [UInt8]) {\n");
104 out.push_str(" var buf = ByteBufferAllocator().buffer(capacity: bytes.count)\n");
105 out.push_str(" buf.writeBytes(bytes)\n");
106 out.push_str(" self.bytes = buf\n");
107 out.push_str(" }\n\n");
108 out.push_str(" func encode(into buffer: inout ByteBuffer) {\n");
109 out.push_str(" let len = UInt32(bytes.readableBytes)\n");
110 out.push_str(" buffer.writeInteger(len, endianness: .little)\n");
111 out.push_str(" var copy = bytes\n");
112 out.push_str(" buffer.writeBuffer(©)\n");
113 out.push_str(" }\n\n");
114 out.push_str(" static func decode(from buffer: inout ByteBuffer) throws -> Self {\n");
115 out.push_str(
116 " guard let len: UInt32 = buffer.readInteger(endianness: .little) else {\n",
117 );
118 out.push_str(" throw WireError.truncated\n");
119 out.push_str(" }\n");
120 out.push_str(" guard let slice = buffer.readSlice(length: Int(len)) else {\n");
121 out.push_str(" throw WireError.truncated\n");
122 out.push_str(" }\n");
123 out.push_str(" return .init(slice)\n");
124 out.push_str(" }\n\n");
125 out.push_str(" /// Encode without a length prefix — for trailing fields only.\n");
126 out.push_str(" func encodeTrailing(into buffer: inout ByteBuffer) {\n");
127 out.push_str(" var copy = bytes\n");
128 out.push_str(" buffer.writeBuffer(©)\n");
129 out.push_str(" }\n\n");
130 out.push_str(" /// Decode by consuming all remaining bytes — for trailing fields only.\n");
131 out.push_str(" static func decodeTrailing(from buffer: inout ByteBuffer) -> Self {\n");
132 out.push_str(
133 " let slice = buffer.readSlice(length: buffer.readableBytes) ?? ByteBuffer()\n",
134 );
135 out.push_str(" return .init(slice)\n");
136 out.push_str(" }\n");
137 out.push_str("}\n\n");
138
139 out.push_str("@inline(__always)\n");
141 out.push_str(
142 "private func decodeWireVarintU32(from buffer: inout ByteBuffer) throws -> UInt32 {\n",
143 );
144 out.push_str(" let value = try decodeVarint(from: &buffer)\n");
145 out.push_str(" guard value <= UInt64(UInt32.max) else {\n");
146 out.push_str(" throw WireError.overflow\n");
147 out.push_str(" }\n");
148 out.push_str(" return UInt32(value)\n");
149 out.push_str("}\n\n");
150
151 out.push_str("@inline(__always)\n");
152 out.push_str(
153 "private func decodeWireString(from buffer: inout ByteBuffer) throws -> String {\n",
154 );
155 out.push_str(" do {\n");
156 out.push_str(" return try decodeString(from: &buffer)\n");
157 out.push_str(" } catch PostcardError.invalidUtf8 {\n");
158 out.push_str(" throw WireError.invalidUtf8\n");
159 out.push_str(" } catch PostcardError.truncated {\n");
160 out.push_str(" throw WireError.truncated\n");
161 out.push_str(" } catch {\n");
162 out.push_str(" throw error\n");
163 out.push_str(" }\n");
164 out.push_str("}\n\n");
165
166 out.push_str("@inline(__always)\n");
167 out.push_str(
168 "private func decodeWireBytes(from buffer: inout ByteBuffer) throws -> ByteBuffer {\n",
169 );
170 out.push_str(" do {\n");
171 out.push_str(" return try decodeBytes(from: &buffer)\n");
172 out.push_str(" } catch PostcardError.truncated {\n");
173 out.push_str(" throw WireError.truncated\n");
174 out.push_str(" } catch {\n");
175 out.push_str(" throw error\n");
176 out.push_str(" }\n");
177 out.push_str("}\n\n");
178
179 out
180}
181
182fn swift_wire_type(shape: &'static Shape, _field: Option<&Field>, types: &[WireType]) -> String {
184 if is_bytes(shape) {
185 return "[UInt8]".into();
186 }
187
188 match classify_shape(shape) {
189 ShapeKind::Scalar(scalar) => swift_scalar_type(scalar),
190 ShapeKind::List { element } | ShapeKind::Slice { element } => {
191 format!("[{}]", swift_wire_type(element, None, types))
192 }
193 ShapeKind::Option { inner } => {
194 format!("{}?", swift_wire_type(inner, None, types))
195 }
196 ShapeKind::Array { element, .. } => {
197 format!("[{}]", swift_wire_type(element, None, types))
198 }
199 ShapeKind::Struct(StructInfo {
200 name: Some(name), ..
201 }) => lookup_wire_name(name, types),
202 ShapeKind::Enum(EnumInfo {
203 name: Some(name), ..
204 }) => lookup_wire_name(name, types),
205 ShapeKind::Pointer { pointee } => swift_wire_type(pointee, _field, types),
206 ShapeKind::Opaque => "OpaquePayload".into(),
207 ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
208 swift_wire_type(fields[0].shape(), None, types)
209 }
210 _ => "Any /* unsupported */".into(),
211 }
212}
213
214fn swift_scalar_type(scalar: ScalarType) -> String {
215 match scalar {
216 ScalarType::Bool => "Bool".into(),
217 ScalarType::U8 => "UInt8".into(),
218 ScalarType::U16 => "UInt16".into(),
219 ScalarType::U32 => "UInt32".into(),
220 ScalarType::U64 | ScalarType::USize => "UInt64".into(),
221 ScalarType::I8 => "Int8".into(),
222 ScalarType::I16 => "Int16".into(),
223 ScalarType::I32 => "Int32".into(),
224 ScalarType::I64 | ScalarType::ISize => "Int64".into(),
225 ScalarType::F32 => "Float".into(),
226 ScalarType::F64 => "Double".into(),
227 ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
228 "String".into()
229 }
230 ScalarType::Unit => "Void".into(),
231 _ => "Any".into(),
232 }
233}
234
235fn lookup_wire_name(rust_name: &str, types: &[WireType]) -> String {
238 for wt in types {
239 if wt.shape.type_identifier.ends_with(rust_name) {
241 return wt.swift_name.clone();
242 }
243 }
244 rust_name.to_string()
245}
246
247fn generate_one_type(swift_name: &str, shape: &'static Shape, types: &[WireType]) -> String {
249 match classify_shape(shape) {
250 ShapeKind::Struct(StructInfo { fields, .. }) => {
251 generate_struct(swift_name, fields, types, swift_name == "Message")
252 }
253 ShapeKind::Enum(EnumInfo { variants, .. }) => generate_enum(swift_name, variants, types),
254 _ => format!("// Unsupported shape for {swift_name}\n"),
255 }
256}
257
258fn generate_struct(name: &str, fields: &[Field], types: &[WireType], is_top_level: bool) -> String {
260 let mut out = String::new();
261
262 out.push_str(&format!("public struct {name}: Sendable, Equatable {{\n"));
264 for f in fields {
265 let field_name = swift_field_name(f.name);
266 let field_type = swift_wire_type(f.shape(), Some(f), types);
267 out.push_str(&format!(" public var {field_name}: {field_type}\n"));
268 }
269
270 if !fields.is_empty() {
272 out.push_str("\n public init(");
273 for (i, f) in fields.iter().enumerate() {
274 if i > 0 {
275 out.push_str(", ");
276 }
277 let field_name = swift_field_name(f.name);
278 let field_type = swift_wire_type(f.shape(), Some(f), types);
279 out.push_str(&format!("{field_name}: {field_type}"));
280 if let Some(default_value) = swift_default_argument(f) {
281 out.push_str(&format!(" = {default_value}"));
282 }
283 }
284 out.push_str(") {\n");
285 for f in fields {
286 let field_name = swift_field_name(f.name);
287 out.push_str(&format!(" self.{field_name} = {field_name}\n"));
288 }
289 out.push_str(" }\n");
290 }
291
292 let vis = if is_top_level { "public " } else { "" };
294 out.push_str(&format!(
295 "\n {vis}func encode(into buffer: inout ByteBuffer) {{\n"
296 ));
297 if fields.is_empty() {
298 } else {
300 for f in fields {
301 let stmt = encode_field_stmt(f, types);
302 out.push_str(&format!(" {stmt}\n"));
303 }
304 }
305 out.push_str(" }\n");
306
307 if is_top_level {
309 out.push_str(
310 "\n /// Encode to a `[UInt8]` array (bridge for callers that need bytes).\n",
311 );
312 out.push_str(" public func encode() -> [UInt8] {\n");
313 out.push_str(" var buffer = ByteBufferAllocator().buffer(capacity: 64)\n");
314 out.push_str(" encode(into: &buffer)\n");
315 out.push_str(" return buffer.readBytes(length: buffer.readableBytes) ?? []\n");
316 out.push_str(" }\n");
317 }
318
319 out.push_str(&format!(
321 "\n {vis}static func decode(from buffer: inout ByteBuffer) throws -> Self {{\n"
322 ));
323 for f in fields {
324 for stmt in decode_field_stmts(f, types) {
325 out.push_str(&format!(" {stmt}\n"));
326 }
327 }
328 let field_names: Vec<String> = fields
329 .iter()
330 .map(|f| {
331 let n = swift_field_name(f.name);
332 format!("{n}: {n}")
333 })
334 .collect();
335 out.push_str(&format!(
336 " return .init({})\n",
337 field_names.join(", ")
338 ));
339 out.push_str(" }\n");
340
341 if is_top_level {
343 out.push_str(
344 "\n /// Decode from a `[UInt8]` array (bridge for callers that have raw bytes).\n",
345 );
346 out.push_str(" public static func decode(fromBytes data: [UInt8]) throws -> Self {\n");
347 out.push_str(" var buffer = ByteBufferAllocator().buffer(capacity: data.count)\n");
348 out.push_str(" buffer.writeBytes(data)\n");
349 out.push_str(" let result = try decode(from: &buffer)\n");
350 out.push_str(" guard buffer.readableBytes == 0 else {\n");
351 out.push_str(" throw WireError.trailingBytes\n");
352 out.push_str(" }\n");
353 out.push_str(" return result\n");
354 out.push_str(" }\n");
355 }
356
357 out.push_str("}\n");
358 out
359}
360
361fn swift_default_argument(field: &Field) -> Option<String> {
362 if field.name == "initial_channel_credit" && field.has_default() {
363 return Some(DEFAULT_INITIAL_CHANNEL_CREDIT.to_string());
364 }
365 None
366}
367
368fn generate_enum(name: &str, variants: &[facet_core::Variant], types: &[WireType]) -> String {
370 let mut out = String::new();
371
372 out.push_str(&format!("public enum {name}: Sendable, Equatable {{\n"));
374 for v in variants {
375 let variant_name = swift_field_name(v.name);
376 match classify_variant(v) {
377 VariantKind::Unit => {
378 out.push_str(&format!(" case {variant_name}\n"));
379 }
380 VariantKind::Newtype { inner } => {
381 let inner_type = swift_wire_type(inner, v.data.fields.first(), types);
382 out.push_str(&format!(" case {variant_name}({inner_type})\n"));
383 }
384 VariantKind::Tuple { fields } => {
385 let field_types: Vec<String> = fields
386 .iter()
387 .map(|f| swift_wire_type(f.shape(), Some(f), types))
388 .collect();
389 out.push_str(&format!(
390 " case {variant_name}({})\n",
391 field_types.join(", ")
392 ));
393 }
394 VariantKind::Struct { fields } => {
395 let field_decls: Vec<String> = fields
396 .iter()
397 .map(|f| {
398 format!(
399 "{}: {}",
400 swift_field_name(f.name),
401 swift_wire_type(f.shape(), Some(f), types)
402 )
403 })
404 .collect();
405 out.push_str(&format!(
406 " case {variant_name}({})\n",
407 field_decls.join(", ")
408 ));
409 }
410 }
411 }
412
413 out.push_str("\n func encode(into buffer: inout ByteBuffer) {\n");
415 out.push_str(" switch self {\n");
416 for (i, v) in variants.iter().enumerate() {
417 let variant_name = swift_field_name(v.name);
418 match classify_variant(v) {
419 VariantKind::Unit => {
420 out.push_str(&format!(
421 " case .{variant_name}:\n encodeVarint(UInt64({i}), into: &buffer)\n"
422 ));
423 }
424 VariantKind::Newtype { inner } => {
425 let stmt = encode_shape_stmt(inner, "val", v.data.fields.first(), types);
426 out.push_str(&format!(
427 " case .{variant_name}(let val):\n encodeVarint(UInt64({i}), into: &buffer)\n {stmt}\n"
428 ));
429 }
430 VariantKind::Tuple { fields } => {
431 let bindings: Vec<String> = (0..fields.len()).map(|j| format!("f{j}")).collect();
432 let binding_str = bindings
433 .iter()
434 .map(|b| format!("let {b}"))
435 .collect::<Vec<_>>()
436 .join(", ");
437 let stmts: Vec<String> = fields
438 .iter()
439 .enumerate()
440 .map(|(j, f)| encode_shape_stmt(f.shape(), &format!("f{j}"), Some(f), types))
441 .collect();
442 out.push_str(&format!(
443 " case .{variant_name}({binding_str}):\n encodeVarint(UInt64({i}), into: &buffer)\n"
444 ));
445 for stmt in &stmts {
446 out.push_str(&format!(" {stmt}\n"));
447 }
448 }
449 VariantKind::Struct { fields } => {
450 let bindings: Vec<String> =
451 fields.iter().map(|f| swift_field_name(f.name)).collect();
452 let binding_str = bindings
453 .iter()
454 .map(|b| format!("let {b}"))
455 .collect::<Vec<_>>()
456 .join(", ");
457 let stmts: Vec<String> = fields
458 .iter()
459 .map(|f| {
460 encode_shape_stmt(f.shape(), &swift_field_name(f.name), Some(f), types)
461 })
462 .collect();
463 out.push_str(&format!(
464 " case .{variant_name}({binding_str}):\n encodeVarint(UInt64({i}), into: &buffer)\n"
465 ));
466 for stmt in &stmts {
467 out.push_str(&format!(" {stmt}\n"));
468 }
469 }
470 }
471 }
472 out.push_str(" }\n");
473 out.push_str(" }\n");
474
475 out.push_str("\n static func decode(from buffer: inout ByteBuffer) throws -> Self {\n");
477 out.push_str(" let disc = try decodeVarint(from: &buffer)\n");
478 out.push_str(" switch disc {\n");
479 for (i, v) in variants.iter().enumerate() {
480 let variant_name = swift_field_name(v.name);
481 out.push_str(&format!(" case {i}:\n"));
482 match classify_variant(v) {
483 VariantKind::Unit => {
484 out.push_str(&format!(" return .{variant_name}\n"));
485 }
486 VariantKind::Newtype { inner } => {
487 for stmt in decode_stmts_for(inner, v.data.fields.first(), "_newtype_val", types) {
488 out.push_str(&format!(" {stmt}\n"));
489 }
490 out.push_str(&format!(
491 " return .{variant_name}(_newtype_val)\n"
492 ));
493 }
494 VariantKind::Tuple { fields } => {
495 for (j, f) in fields.iter().enumerate() {
496 for stmt in decode_stmts_for(f.shape(), Some(f), &format!("f{j}"), types) {
497 out.push_str(&format!(" {stmt}\n"));
498 }
499 }
500 let args: Vec<String> = (0..fields.len()).map(|j| format!("f{j}")).collect();
501 out.push_str(&format!(
502 " return .{variant_name}({})\n",
503 args.join(", ")
504 ));
505 }
506 VariantKind::Struct { fields } => {
507 for f in fields {
508 let field_name = swift_field_name(f.name);
509 for stmt in decode_stmts_for(f.shape(), Some(f), &field_name, types) {
510 out.push_str(&format!(" {stmt}\n"));
511 }
512 }
513 let args: Vec<String> = fields
514 .iter()
515 .map(|f| {
516 let n = swift_field_name(f.name);
517 format!("{n}: {n}")
518 })
519 .collect();
520 out.push_str(&format!(
521 " return .{variant_name}({})\n",
522 args.join(", ")
523 ));
524 }
525 }
526 }
527 out.push_str(" default:\n");
528 out.push_str(" throw WireError.unknownVariant(disc)\n");
529 out.push_str(" }\n");
530 out.push_str(" }\n");
531
532 out.push_str("}\n");
533 out
534}
535
536fn encode_field_stmt(field: &Field, types: &[WireType]) -> String {
538 let field_name = swift_field_name(field.name);
539 encode_shape_stmt(field.shape(), &field_name, Some(field), types)
540}
541
542fn encode_shape_stmt(
545 shape: &'static Shape,
546 value: &str,
547 field: Option<&Field>,
548 types: &[WireType],
549) -> String {
550 let is_trailing = field.is_some_and(|f| f.has_builtin_attr("trailing"));
551
552 if matches!(classify_shape(shape), ShapeKind::Opaque) {
554 return if is_trailing {
555 format!("{value}.encodeTrailing(into: &buffer)")
556 } else {
557 format!("{value}.encode(into: &buffer)")
558 };
559 }
560
561 if is_bytes(shape) {
562 return format!("encodeByteSeq({value}, into: &buffer)");
563 }
564
565 match classify_shape(shape) {
566 ShapeKind::Scalar(scalar) => encode_scalar_stmt(scalar, value),
567 ShapeKind::List { element } | ShapeKind::Slice { element } => {
568 let inner = encode_element_closure(element, types);
569 format!("encodeVec({value}, into: &buffer, encoder: {inner})")
570 }
571 ShapeKind::Option { inner } => {
572 let inner_closure = encode_element_closure(inner, types);
573 format!("encodeOption({value}, into: &buffer, encoder: {inner_closure})")
574 }
575 ShapeKind::Struct(StructInfo { .. }) | ShapeKind::Enum(EnumInfo { .. }) => {
576 format!("{value}.encode(into: &buffer)")
577 }
578 ShapeKind::Pointer { pointee } => encode_shape_stmt(pointee, value, field, types),
579 ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
580 encode_shape_stmt(fields[0].shape(), value, field, types)
581 }
582 _ => format!("/* unsupported encode for {value} */"),
583 }
584}
585
586fn encode_scalar_stmt(scalar: ScalarType, value: &str) -> String {
587 match scalar {
588 ScalarType::Bool => format!("encodeBool({value}, into: &buffer)"),
589 ScalarType::U8 => format!("encodeU8({value}, into: &buffer)"),
590 ScalarType::I8 => format!("encodeI8({value}, into: &buffer)"),
591 ScalarType::U16 => format!("encodeU16({value}, into: &buffer)"),
592 ScalarType::I16 => format!("encodeI16({value}, into: &buffer)"),
593 ScalarType::U32 => format!("encodeVarint(UInt64({value}), into: &buffer)"),
594 ScalarType::I32 => format!("encodeI32({value}, into: &buffer)"),
595 ScalarType::U64 | ScalarType::USize => format!("encodeVarint({value}, into: &buffer)"),
596 ScalarType::I64 | ScalarType::ISize => format!("encodeI64({value}, into: &buffer)"),
597 ScalarType::F32 => format!("encodeF32({value}, into: &buffer)"),
598 ScalarType::F64 => format!("encodeF64({value}, into: &buffer)"),
599 ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
600 format!("encodeString({value}, into: &buffer)")
601 }
602 _ => format!("/* unsupported scalar encode for {value} */"),
603 }
604}
605
606fn encode_element_closure(shape: &'static Shape, _types: &[WireType]) -> String {
609 if is_bytes(shape) {
610 return "{ val, buf in encodeByteSeq(val, into: &buf) }".into();
611 }
612
613 match classify_shape(shape) {
614 ShapeKind::Scalar(scalar) => {
615 let stmt = encode_scalar_stmt(scalar, "val").replace("into: &buffer", "into: &buf");
616 format!("{{ val, buf in {stmt} }}")
617 }
618 ShapeKind::Struct(StructInfo { .. }) | ShapeKind::Enum(EnumInfo { .. }) => {
619 "{ val, buf in val.encode(into: &buf) }".into()
620 }
621 ShapeKind::List { element } | ShapeKind::Slice { element } => {
622 let inner = encode_element_closure(element, _types);
623 format!("{{ val, buf in encodeVec(val, into: &buf, encoder: {inner}) }}")
624 }
625 ShapeKind::Opaque => "{ val, buf in val.encode(into: &buf) }".into(),
626 ShapeKind::Pointer { pointee } => encode_element_closure(pointee, _types),
627 _ => "{ _, _ in /* unsupported */ }".into(),
628 }
629}
630
631fn decode_field_stmts(field: &Field, types: &[WireType]) -> Vec<String> {
634 let field_name = swift_field_name(field.name);
635 decode_stmts_for(field.shape(), Some(field), &field_name, types)
636}
637
638fn decode_stmts_for(
641 shape: &'static Shape,
642 field: Option<&Field>,
643 var_name: &str,
644 types: &[WireType],
645) -> Vec<String> {
646 if let ShapeKind::Pointer { pointee } = classify_shape(shape) {
648 return decode_stmts_for(pointee, field, var_name, types);
649 }
650 if is_bytes(shape) {
651 return vec![
652 format!("var _{var_name}Buf = try decodeWireBytes(from: &buffer)"),
653 format!(
654 "let {var_name} = _{var_name}Buf.readBytes(length: _{var_name}Buf.readableBytes) ?? []"
655 ),
656 ];
657 }
658 vec![format!(
659 "let {var_name} = {}",
660 decode_shape_expr(shape, field, types)
661 )]
662}
663
664fn decode_shape_expr(shape: &'static Shape, field: Option<&Field>, types: &[WireType]) -> String {
667 let is_trailing = field.is_some_and(|f| f.has_builtin_attr("trailing"));
668
669 if matches!(classify_shape(shape), ShapeKind::Opaque) {
671 return if is_trailing {
672 "OpaquePayload.decodeTrailing(from: &buffer)".into()
673 } else {
674 "try OpaquePayload.decode(from: &buffer)".into()
675 };
676 }
677
678 match classify_shape(shape) {
679 ShapeKind::Scalar(scalar) => decode_scalar(scalar),
680 ShapeKind::List { element } | ShapeKind::Slice { element } => {
681 let inner = decode_element_closure(element, types);
682 format!("try decodeVec(from: &buffer, decoder: {inner})")
683 }
684 ShapeKind::Option { inner } => {
685 let inner_closure = decode_element_closure(inner, types);
686 format!("try decodeOption(from: &buffer, decoder: {inner_closure})")
687 }
688 ShapeKind::Struct(StructInfo {
689 name: Some(name), ..
690 }) => {
691 let swift_name = lookup_wire_name(name, types);
692 format!("try {swift_name}.decode(from: &buffer)")
693 }
694 ShapeKind::Enum(EnumInfo {
695 name: Some(name), ..
696 }) => {
697 let swift_name = lookup_wire_name(name, types);
698 format!("try {swift_name}.decode(from: &buffer)")
699 }
700 ShapeKind::Pointer { pointee } => decode_shape_expr(pointee, field, types),
701 ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
702 decode_shape_expr(fields[0].shape(), field, types)
703 }
704 _ => "nil /* unsupported decode */".into(),
705 }
706}
707
708fn decode_scalar(scalar: ScalarType) -> String {
709 match scalar {
710 ScalarType::Bool => "try decodeBool(from: &buffer)".into(),
711 ScalarType::U8 => "try decodeU8(from: &buffer)".into(),
712 ScalarType::I8 => "try decodeI8(from: &buffer)".into(),
713 ScalarType::U16 => "try decodeU16(from: &buffer)".into(),
714 ScalarType::I16 => "try decodeI16(from: &buffer)".into(),
715 ScalarType::U32 => "try decodeWireVarintU32(from: &buffer)".into(),
716 ScalarType::I32 => "try decodeI32(from: &buffer)".into(),
717 ScalarType::U64 | ScalarType::USize => "try decodeVarint(from: &buffer)".into(),
718 ScalarType::I64 | ScalarType::ISize => "try decodeI64(from: &buffer)".into(),
719 ScalarType::F32 => "try decodeF32(from: &buffer)".into(),
720 ScalarType::F64 => "try decodeF64(from: &buffer)".into(),
721 ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
722 "try decodeWireString(from: &buffer)".into()
723 }
724 _ => "nil /* unsupported scalar decode */".into(),
725 }
726}
727
728fn decode_element_closure(shape: &'static Shape, types: &[WireType]) -> String {
731 if is_bytes(shape) {
732 return "{ buf in var s = try decodeWireBytes(from: &buf); return s.readBytes(length: s.readableBytes) ?? [] }".into();
734 }
735
736 match classify_shape(shape) {
737 ShapeKind::Scalar(scalar) => {
738 let expr = decode_scalar(scalar).replace("from: &buffer", "from: &buf");
739 format!("{{ buf in {expr} }}")
740 }
741 ShapeKind::Struct(StructInfo {
742 name: Some(name), ..
743 }) => {
744 let swift_name = lookup_wire_name(name, types);
745 format!("{{ buf in try {swift_name}.decode(from: &buf) }}")
746 }
747 ShapeKind::Enum(EnumInfo {
748 name: Some(name), ..
749 }) => {
750 let swift_name = lookup_wire_name(name, types);
751 format!("{{ buf in try {swift_name}.decode(from: &buf) }}")
752 }
753 ShapeKind::List { element } | ShapeKind::Slice { element } => {
754 let inner = decode_element_closure(element, types);
755 format!("{{ buf in try decodeVec(from: &buf, decoder: {inner}) }}")
756 }
757 ShapeKind::Opaque => "{ buf in try OpaquePayload.decode(from: &buf) }".into(),
758 ShapeKind::Pointer { pointee } => decode_element_closure(pointee, types),
759 _ => "{ _ in throw WireError.truncated }".into(),
760 }
761}
762
763fn generate_factory_methods(types: &[WireType]) -> String {
765 let payload_wt = types.iter().find(|wt| wt.swift_name == "MessagePayload");
767 let payload_wt = match payload_wt {
768 Some(wt) => wt,
769 None => return String::new(),
770 };
771
772 let variants = match classify_shape(payload_wt.shape) {
773 ShapeKind::Enum(EnumInfo { variants, .. }) => variants,
774 _ => return String::new(),
775 };
776
777 let mut out = String::new();
778 out.push_str("public extension Message {\n");
779
780 for v in variants {
781 let variant_name = swift_field_name(v.name);
782 if let VariantKind::Newtype { inner } = classify_variant(v) {
783 let inner_swift = swift_wire_type(inner, None, types);
784 let is_control = matches!(
786 v.name,
787 "Hello" | "HelloYourself" | "ProtocolError" | "Ping" | "Pong"
788 );
789
790 if is_control {
791 out.push_str(&format!(
792 " static func {variant_name}(_ value: {inner_swift}) -> Message {{\n"
793 ));
794 out.push_str(&format!(
795 " Message(connectionId: 0, payload: .{variant_name}(value))\n"
796 ));
797 out.push_str(" }\n\n");
798 } else {
799 out.push_str(&format!(
800 " static func {variant_name}(connId: UInt64, _ value: {inner_swift}) -> Message {{\n"
801 ));
802 out.push_str(&format!(
803 " Message(connectionId: connId, payload: .{variant_name}(value))\n"
804 ));
805 out.push_str(" }\n\n");
806 }
807 }
808 }
809
810 out.push_str(" static func protocolError(description: String) -> Message {\n");
813 out.push_str(" Message(connectionId: 0, payload: .protocolError(.init(description: description)))\n");
814 out.push_str(" }\n\n");
815
816 out.push_str(" static func connectionOpen(\n");
817 out.push_str(
818 " connId: UInt64, settings: ConnectionSettings, metadata: [MetadataEntry]\n",
819 );
820 out.push_str(" ) -> Message {\n");
821 out.push_str(" Message(\n");
822 out.push_str(" connectionId: connId,\n");
823 out.push_str(" payload: .connectionOpen(.init(connectionSettings: settings, metadata: metadata)))\n");
824 out.push_str(" }\n\n");
825
826 out.push_str(" static func connectionAccept(\n");
827 out.push_str(
828 " connId: UInt64, settings: ConnectionSettings, metadata: [MetadataEntry]\n",
829 );
830 out.push_str(" ) -> Message {\n");
831 out.push_str(" Message(\n");
832 out.push_str(" connectionId: connId,\n");
833 out.push_str(" payload: .connectionAccept(.init(connectionSettings: settings, metadata: metadata)))\n");
834 out.push_str(" }\n\n");
835
836 out.push_str(" static func connectionReject(connId: UInt64, metadata: [MetadataEntry]) -> Message {\n");
837 out.push_str(" Message(connectionId: connId, payload: .connectionReject(.init(metadata: metadata)))\n");
838 out.push_str(" }\n\n");
839
840 out.push_str(
841 " static func connectionClose(connId: UInt64, metadata: [MetadataEntry]) -> Message {\n",
842 );
843 out.push_str(" Message(connectionId: connId, payload: .connectionClose(.init(metadata: metadata)))\n");
844 out.push_str(" }\n\n");
845
846 out.push_str(" static func request(\n");
847 out.push_str(" connId: UInt64,\n");
848 out.push_str(" requestId: UInt64,\n");
849 out.push_str(" methodId: UInt64,\n");
850 out.push_str(" metadata: [MetadataEntry],\n");
851 out.push_str(" schemas: [UInt8] = [],\n");
852 out.push_str(" payload: [UInt8]\n");
853 out.push_str(" ) -> Message {\n");
854 out.push_str(" Message(\n");
855 out.push_str(" connectionId: connId,\n");
856 out.push_str(" payload: .requestMessage(\n");
857 out.push_str(" .init(\n");
858 out.push_str(" id: requestId,\n");
859 out.push_str(" body: .call(.init(methodId: methodId, metadata: metadata, args: .init(payload), schemas: schemas))\n");
860 out.push_str(" ))\n");
861 out.push_str(" )\n");
862 out.push_str(" }\n\n");
863
864 out.push_str(" static func response(\n");
865 out.push_str(" connId: UInt64,\n");
866 out.push_str(" requestId: UInt64,\n");
867 out.push_str(" metadata: [MetadataEntry],\n");
868 out.push_str(" schemas: [UInt8] = [],\n");
869 out.push_str(" payload: [UInt8]\n");
870 out.push_str(" ) -> Message {\n");
871 out.push_str(" Message(\n");
872 out.push_str(" connectionId: connId,\n");
873 out.push_str(" payload: .requestMessage(\n");
874 out.push_str(" .init(\n");
875 out.push_str(" id: requestId,\n");
876 out.push_str(" body: .response(.init(metadata: metadata, ret: .init(payload), schemas: schemas))\n");
877 out.push_str(" ))\n");
878 out.push_str(" )\n");
879 out.push_str(" }\n\n");
880
881 out.push_str(" static func cancel(connId: UInt64, requestId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
882 out.push_str(" Message(\n");
883 out.push_str(" connectionId: connId,\n");
884 out.push_str(" payload: .requestMessage(\n");
885 out.push_str(" .init(\n");
886 out.push_str(" id: requestId,\n");
887 out.push_str(" body: .cancel(.init(metadata: metadata))\n");
888 out.push_str(" ))\n");
889 out.push_str(" )\n");
890 out.push_str(" }\n\n");
891
892 out.push_str(
893 " static func data(connId: UInt64, channelId: UInt64, payload: [UInt8]) -> Message {\n",
894 );
895 out.push_str(" Message(\n");
896 out.push_str(" connectionId: connId,\n");
897 out.push_str(" payload: .channelMessage(.init(id: channelId, body: .item(.init(item: .init(payload)))))\n");
898 out.push_str(" )\n");
899 out.push_str(" }\n\n");
900
901 out.push_str(" static func close(connId: UInt64, channelId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
902 out.push_str(" Message(\n");
903 out.push_str(" connectionId: connId,\n");
904 out.push_str(" payload: .channelMessage(.init(id: channelId, body: .close(.init(metadata: metadata))))\n");
905 out.push_str(" )\n");
906 out.push_str(" }\n\n");
907
908 out.push_str(" static func reset(connId: UInt64, channelId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
909 out.push_str(" Message(\n");
910 out.push_str(" connectionId: connId,\n");
911 out.push_str(" payload: .channelMessage(.init(id: channelId, body: .reset(.init(metadata: metadata))))\n");
912 out.push_str(" )\n");
913 out.push_str(" }\n\n");
914
915 out.push_str(
916 " static func credit(connId: UInt64, channelId: UInt64, bytes: UInt32) -> Message {\n",
917 );
918 out.push_str(" Message(\n");
919 out.push_str(" connectionId: connId,\n");
920 out.push_str(" payload: .channelMessage(.init(id: channelId, body: .grantCredit(.init(additional: bytes))))\n");
921 out.push_str(" )\n");
922 out.push_str(" }\n");
923
924 out.push_str("}\n");
925 out
926}