1use facet_core::{Field, ScalarType, Shape};
12use heck::ToLowerCamelCase;
13use vox_types::{
14 EnumInfo, ShapeKind, StructInfo, VariantKind, classify_shape, classify_variant,
15 extract_schemas, is_bytes,
16};
17
18pub struct WireType {
20 pub swift_name: String,
22 pub shape: &'static Shape,
24}
25
26pub fn generate_wire_types(types: &[WireType]) -> (String, Vec<u8>) {
30 let mut out = String::new();
31 out.push_str("// @generated by vox-codegen\n");
32 out.push_str("// DO NOT EDIT — regenerate with `cargo xtask codegen --swift-wire`\n\n");
33 out.push_str("import Foundation\n");
34 out.push_str("@preconcurrency import NIOCore\n\n");
35
36 out.push_str(&generate_preamble());
38
39 out.push_str("public typealias Metadata = [MetadataEntry]\n\n");
41
42 for wt in types {
44 out.push_str(&generate_one_type(&wt.swift_name, wt.shape, types));
45 out.push('\n');
46 }
47
48 out.push_str(&generate_factory_methods(types));
50
51 out.push_str(
53 "\n/// CBOR-encoded wire message schemas, loaded from the bundled binary resource.\n",
54 );
55 out.push_str("public let wireMessageSchemasCbor: [UInt8] = {\n");
56 out.push_str(
57 " guard let url = Bundle.module.url(forResource: \"wireMessageSchemas\", withExtension: \"bin\"),\n",
58 );
59 out.push_str(" let data = try? Data(contentsOf: url) else {\n");
60 out.push_str(
61 " preconditionFailure(\"wireMessageSchemas.bin resource not found in Bundle.module\")\n",
62 );
63 out.push_str(" }\n");
64 out.push_str(" return Array(data)\n");
65 out.push_str("}()\n");
66
67 let cbor_bytes = types
69 .iter()
70 .find(|wt| wt.swift_name == "Message")
71 .or_else(|| types.last())
72 .map(|root| {
73 let extracted =
74 extract_schemas(root.shape).expect("wire schema extraction should succeed");
75 facet_cbor::to_vec(&extracted.schemas)
76 .expect("wire schema CBOR serialization should succeed")
77 })
78 .unwrap_or_default();
79
80 (out, cbor_bytes)
81}
82
83fn generate_preamble() -> String {
85 let mut out = String::new();
86
87 out.push_str("public enum WireError: Error, Equatable {\n");
89 out.push_str(" case truncated\n");
90 out.push_str(" case unknownVariant(UInt64)\n");
91 out.push_str(" case overflow\n");
92 out.push_str(" case invalidUtf8\n");
93 out.push_str(" case trailingBytes\n");
94 out.push_str("}\n\n");
95
96 out.push_str("public struct OpaquePayload: Sendable, Equatable {\n");
98 out.push_str(" public var bytes: ByteBuffer\n\n");
99 out.push_str(" public init(_ bytes: ByteBuffer) {\n");
100 out.push_str(" self.bytes = bytes\n");
101 out.push_str(" }\n\n");
102 out.push_str(" /// Init from a [UInt8] for convenience (e.g. from legacy code)\n");
103 out.push_str(" public init(_ bytes: [UInt8]) {\n");
104 out.push_str(" var buf = ByteBufferAllocator().buffer(capacity: bytes.count)\n");
105 out.push_str(" buf.writeBytes(bytes)\n");
106 out.push_str(" self.bytes = buf\n");
107 out.push_str(" }\n\n");
108 out.push_str(" func encode(into buffer: inout ByteBuffer) {\n");
109 out.push_str(" let len = UInt32(bytes.readableBytes)\n");
110 out.push_str(" buffer.writeInteger(len, endianness: .little)\n");
111 out.push_str(" var copy = bytes\n");
112 out.push_str(" buffer.writeBuffer(©)\n");
113 out.push_str(" }\n\n");
114 out.push_str(" static func decode(from buffer: inout ByteBuffer) throws -> Self {\n");
115 out.push_str(
116 " guard let len: UInt32 = buffer.readInteger(endianness: .little) else {\n",
117 );
118 out.push_str(" throw WireError.truncated\n");
119 out.push_str(" }\n");
120 out.push_str(" guard let slice = buffer.readSlice(length: Int(len)) else {\n");
121 out.push_str(" throw WireError.truncated\n");
122 out.push_str(" }\n");
123 out.push_str(" return .init(slice)\n");
124 out.push_str(" }\n\n");
125 out.push_str(" /// Encode without a length prefix — for trailing fields only.\n");
126 out.push_str(" func encodeTrailing(into buffer: inout ByteBuffer) {\n");
127 out.push_str(" var copy = bytes\n");
128 out.push_str(" buffer.writeBuffer(©)\n");
129 out.push_str(" }\n\n");
130 out.push_str(" /// Decode by consuming all remaining bytes — for trailing fields only.\n");
131 out.push_str(" static func decodeTrailing(from buffer: inout ByteBuffer) -> Self {\n");
132 out.push_str(
133 " let slice = buffer.readSlice(length: buffer.readableBytes) ?? ByteBuffer()\n",
134 );
135 out.push_str(" return .init(slice)\n");
136 out.push_str(" }\n");
137 out.push_str("}\n\n");
138
139 out.push_str("@inline(__always)\n");
141 out.push_str(
142 "private func decodeWireVarintU32(from buffer: inout ByteBuffer) throws -> UInt32 {\n",
143 );
144 out.push_str(" let value = try decodeVarint(from: &buffer)\n");
145 out.push_str(" guard value <= UInt64(UInt32.max) else {\n");
146 out.push_str(" throw WireError.overflow\n");
147 out.push_str(" }\n");
148 out.push_str(" return UInt32(value)\n");
149 out.push_str("}\n\n");
150
151 out.push_str("@inline(__always)\n");
152 out.push_str(
153 "private func decodeWireString(from buffer: inout ByteBuffer) throws -> String {\n",
154 );
155 out.push_str(" do {\n");
156 out.push_str(" return try decodeString(from: &buffer)\n");
157 out.push_str(" } catch PostcardError.invalidUtf8 {\n");
158 out.push_str(" throw WireError.invalidUtf8\n");
159 out.push_str(" } catch PostcardError.truncated {\n");
160 out.push_str(" throw WireError.truncated\n");
161 out.push_str(" } catch {\n");
162 out.push_str(" throw error\n");
163 out.push_str(" }\n");
164 out.push_str("}\n\n");
165
166 out.push_str("@inline(__always)\n");
167 out.push_str(
168 "private func decodeWireBytes(from buffer: inout ByteBuffer) throws -> ByteBuffer {\n",
169 );
170 out.push_str(" do {\n");
171 out.push_str(" return try decodeBytes(from: &buffer)\n");
172 out.push_str(" } catch PostcardError.truncated {\n");
173 out.push_str(" throw WireError.truncated\n");
174 out.push_str(" } catch {\n");
175 out.push_str(" throw error\n");
176 out.push_str(" }\n");
177 out.push_str("}\n\n");
178
179 out
180}
181
182fn swift_wire_type(shape: &'static Shape, _field: Option<&Field>, types: &[WireType]) -> String {
184 if is_bytes(shape) {
185 return "[UInt8]".into();
186 }
187
188 match classify_shape(shape) {
189 ShapeKind::Scalar(scalar) => swift_scalar_type(scalar),
190 ShapeKind::List { element } | ShapeKind::Slice { element } => {
191 format!("[{}]", swift_wire_type(element, None, types))
192 }
193 ShapeKind::Option { inner } => {
194 format!("{}?", swift_wire_type(inner, None, types))
195 }
196 ShapeKind::Array { element, .. } => {
197 format!("[{}]", swift_wire_type(element, None, types))
198 }
199 ShapeKind::Struct(StructInfo {
200 name: Some(name), ..
201 }) => lookup_wire_name(name, types),
202 ShapeKind::Enum(EnumInfo {
203 name: Some(name), ..
204 }) => lookup_wire_name(name, types),
205 ShapeKind::Pointer { pointee } => swift_wire_type(pointee, _field, types),
206 ShapeKind::Opaque => "OpaquePayload".into(),
207 ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
208 swift_wire_type(fields[0].shape(), None, types)
209 }
210 _ => "Any /* unsupported */".into(),
211 }
212}
213
214fn swift_scalar_type(scalar: ScalarType) -> String {
215 match scalar {
216 ScalarType::Bool => "Bool".into(),
217 ScalarType::U8 => "UInt8".into(),
218 ScalarType::U16 => "UInt16".into(),
219 ScalarType::U32 => "UInt32".into(),
220 ScalarType::U64 | ScalarType::USize => "UInt64".into(),
221 ScalarType::I8 => "Int8".into(),
222 ScalarType::I16 => "Int16".into(),
223 ScalarType::I32 => "Int32".into(),
224 ScalarType::I64 | ScalarType::ISize => "Int64".into(),
225 ScalarType::F32 => "Float".into(),
226 ScalarType::F64 => "Double".into(),
227 ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
228 "String".into()
229 }
230 ScalarType::Unit => "Void".into(),
231 _ => "Any".into(),
232 }
233}
234
235fn lookup_wire_name(rust_name: &str, types: &[WireType]) -> String {
238 for wt in types {
239 if wt.shape.type_identifier.ends_with(rust_name) {
241 return wt.swift_name.clone();
242 }
243 }
244 rust_name.to_string()
245}
246
247fn generate_one_type(swift_name: &str, shape: &'static Shape, types: &[WireType]) -> String {
249 match classify_shape(shape) {
250 ShapeKind::Struct(StructInfo { fields, .. }) => {
251 generate_struct(swift_name, fields, types, swift_name == "Message")
252 }
253 ShapeKind::Enum(EnumInfo { variants, .. }) => generate_enum(swift_name, variants, types),
254 _ => format!("// Unsupported shape for {swift_name}\n"),
255 }
256}
257
258fn generate_struct(name: &str, fields: &[Field], types: &[WireType], is_top_level: bool) -> String {
260 let mut out = String::new();
261
262 out.push_str(&format!("public struct {name}: Sendable, Equatable {{\n"));
264 for f in fields {
265 let field_name = f.name.to_lower_camel_case();
266 let field_type = swift_wire_type(f.shape(), Some(f), types);
267 out.push_str(&format!(" public var {field_name}: {field_type}\n"));
268 }
269
270 if !fields.is_empty() {
272 out.push_str("\n public init(");
273 for (i, f) in fields.iter().enumerate() {
274 if i > 0 {
275 out.push_str(", ");
276 }
277 let field_name = f.name.to_lower_camel_case();
278 let field_type = swift_wire_type(f.shape(), Some(f), types);
279 out.push_str(&format!("{field_name}: {field_type}"));
280 }
281 out.push_str(") {\n");
282 for f in fields {
283 let field_name = f.name.to_lower_camel_case();
284 out.push_str(&format!(" self.{field_name} = {field_name}\n"));
285 }
286 out.push_str(" }\n");
287 }
288
289 let vis = if is_top_level { "public " } else { "" };
291 out.push_str(&format!(
292 "\n {vis}func encode(into buffer: inout ByteBuffer) {{\n"
293 ));
294 if fields.is_empty() {
295 } else {
297 for f in fields {
298 let stmt = encode_field_stmt(f, types);
299 out.push_str(&format!(" {stmt}\n"));
300 }
301 }
302 out.push_str(" }\n");
303
304 if is_top_level {
306 out.push_str(
307 "\n /// Encode to a `[UInt8]` array (bridge for callers that need bytes).\n",
308 );
309 out.push_str(" public func encode() -> [UInt8] {\n");
310 out.push_str(" var buffer = ByteBufferAllocator().buffer(capacity: 64)\n");
311 out.push_str(" encode(into: &buffer)\n");
312 out.push_str(" return buffer.readBytes(length: buffer.readableBytes) ?? []\n");
313 out.push_str(" }\n");
314 }
315
316 out.push_str(&format!(
318 "\n {vis}static func decode(from buffer: inout ByteBuffer) throws -> Self {{\n"
319 ));
320 for f in fields {
321 for stmt in decode_field_stmts(f, types) {
322 out.push_str(&format!(" {stmt}\n"));
323 }
324 }
325 let field_names: Vec<String> = fields
326 .iter()
327 .map(|f| {
328 let n = f.name.to_lower_camel_case();
329 format!("{n}: {n}")
330 })
331 .collect();
332 out.push_str(&format!(
333 " return .init({})\n",
334 field_names.join(", ")
335 ));
336 out.push_str(" }\n");
337
338 if is_top_level {
340 out.push_str(
341 "\n /// Decode from a `[UInt8]` array (bridge for callers that have raw bytes).\n",
342 );
343 out.push_str(" public static func decode(fromBytes data: [UInt8]) throws -> Self {\n");
344 out.push_str(" var buffer = ByteBufferAllocator().buffer(capacity: data.count)\n");
345 out.push_str(" buffer.writeBytes(data)\n");
346 out.push_str(" let result = try decode(from: &buffer)\n");
347 out.push_str(" guard buffer.readableBytes == 0 else {\n");
348 out.push_str(" throw WireError.trailingBytes\n");
349 out.push_str(" }\n");
350 out.push_str(" return result\n");
351 out.push_str(" }\n");
352 }
353
354 out.push_str("}\n");
355 out
356}
357
358fn generate_enum(name: &str, variants: &[facet_core::Variant], types: &[WireType]) -> String {
360 let mut out = String::new();
361
362 out.push_str(&format!("public enum {name}: Sendable, Equatable {{\n"));
364 for v in variants {
365 let variant_name = v.name.to_lower_camel_case();
366 match classify_variant(v) {
367 VariantKind::Unit => {
368 out.push_str(&format!(" case {variant_name}\n"));
369 }
370 VariantKind::Newtype { inner } => {
371 let inner_type = swift_wire_type(inner, v.data.fields.first(), types);
372 out.push_str(&format!(" case {variant_name}({inner_type})\n"));
373 }
374 VariantKind::Tuple { fields } => {
375 let field_types: Vec<String> = fields
376 .iter()
377 .map(|f| swift_wire_type(f.shape(), Some(f), types))
378 .collect();
379 out.push_str(&format!(
380 " case {variant_name}({})\n",
381 field_types.join(", ")
382 ));
383 }
384 VariantKind::Struct { fields } => {
385 let field_decls: Vec<String> = fields
386 .iter()
387 .map(|f| {
388 format!(
389 "{}: {}",
390 f.name.to_lower_camel_case(),
391 swift_wire_type(f.shape(), Some(f), types)
392 )
393 })
394 .collect();
395 out.push_str(&format!(
396 " case {variant_name}({})\n",
397 field_decls.join(", ")
398 ));
399 }
400 }
401 }
402
403 out.push_str("\n func encode(into buffer: inout ByteBuffer) {\n");
405 out.push_str(" switch self {\n");
406 for (i, v) in variants.iter().enumerate() {
407 let variant_name = v.name.to_lower_camel_case();
408 match classify_variant(v) {
409 VariantKind::Unit => {
410 out.push_str(&format!(
411 " case .{variant_name}:\n encodeVarint(UInt64({i}), into: &buffer)\n"
412 ));
413 }
414 VariantKind::Newtype { inner } => {
415 let stmt = encode_shape_stmt(inner, "val", v.data.fields.first(), types);
416 out.push_str(&format!(
417 " case .{variant_name}(let val):\n encodeVarint(UInt64({i}), into: &buffer)\n {stmt}\n"
418 ));
419 }
420 VariantKind::Tuple { fields } => {
421 let bindings: Vec<String> = (0..fields.len()).map(|j| format!("f{j}")).collect();
422 let binding_str = bindings
423 .iter()
424 .map(|b| format!("let {b}"))
425 .collect::<Vec<_>>()
426 .join(", ");
427 let stmts: Vec<String> = fields
428 .iter()
429 .enumerate()
430 .map(|(j, f)| encode_shape_stmt(f.shape(), &format!("f{j}"), Some(f), types))
431 .collect();
432 out.push_str(&format!(
433 " case .{variant_name}({binding_str}):\n encodeVarint(UInt64({i}), into: &buffer)\n"
434 ));
435 for stmt in &stmts {
436 out.push_str(&format!(" {stmt}\n"));
437 }
438 }
439 VariantKind::Struct { fields } => {
440 let bindings: Vec<String> = fields
441 .iter()
442 .map(|f| f.name.to_lower_camel_case())
443 .collect();
444 let binding_str = bindings
445 .iter()
446 .map(|b| format!("let {b}"))
447 .collect::<Vec<_>>()
448 .join(", ");
449 let stmts: Vec<String> = fields
450 .iter()
451 .map(|f| {
452 encode_shape_stmt(f.shape(), &f.name.to_lower_camel_case(), Some(f), types)
453 })
454 .collect();
455 out.push_str(&format!(
456 " case .{variant_name}({binding_str}):\n encodeVarint(UInt64({i}), into: &buffer)\n"
457 ));
458 for stmt in &stmts {
459 out.push_str(&format!(" {stmt}\n"));
460 }
461 }
462 }
463 }
464 out.push_str(" }\n");
465 out.push_str(" }\n");
466
467 out.push_str("\n static func decode(from buffer: inout ByteBuffer) throws -> Self {\n");
469 out.push_str(" let disc = try decodeVarint(from: &buffer)\n");
470 out.push_str(" switch disc {\n");
471 for (i, v) in variants.iter().enumerate() {
472 let variant_name = v.name.to_lower_camel_case();
473 out.push_str(&format!(" case {i}:\n"));
474 match classify_variant(v) {
475 VariantKind::Unit => {
476 out.push_str(&format!(" return .{variant_name}\n"));
477 }
478 VariantKind::Newtype { inner } => {
479 for stmt in decode_stmts_for(inner, v.data.fields.first(), "_newtype_val", types) {
480 out.push_str(&format!(" {stmt}\n"));
481 }
482 out.push_str(&format!(
483 " return .{variant_name}(_newtype_val)\n"
484 ));
485 }
486 VariantKind::Tuple { fields } => {
487 for (j, f) in fields.iter().enumerate() {
488 for stmt in decode_stmts_for(f.shape(), Some(f), &format!("f{j}"), types) {
489 out.push_str(&format!(" {stmt}\n"));
490 }
491 }
492 let args: Vec<String> = (0..fields.len()).map(|j| format!("f{j}")).collect();
493 out.push_str(&format!(
494 " return .{variant_name}({})\n",
495 args.join(", ")
496 ));
497 }
498 VariantKind::Struct { fields } => {
499 for f in fields {
500 let field_name = f.name.to_lower_camel_case();
501 for stmt in decode_stmts_for(f.shape(), Some(f), &field_name, types) {
502 out.push_str(&format!(" {stmt}\n"));
503 }
504 }
505 let args: Vec<String> = fields
506 .iter()
507 .map(|f| {
508 let n = f.name.to_lower_camel_case();
509 format!("{n}: {n}")
510 })
511 .collect();
512 out.push_str(&format!(
513 " return .{variant_name}({})\n",
514 args.join(", ")
515 ));
516 }
517 }
518 }
519 out.push_str(" default:\n");
520 out.push_str(" throw WireError.unknownVariant(disc)\n");
521 out.push_str(" }\n");
522 out.push_str(" }\n");
523
524 out.push_str("}\n");
525 out
526}
527
528fn encode_field_stmt(field: &Field, types: &[WireType]) -> String {
530 let field_name = field.name.to_lower_camel_case();
531 encode_shape_stmt(field.shape(), &field_name, Some(field), types)
532}
533
534fn encode_shape_stmt(
537 shape: &'static Shape,
538 value: &str,
539 field: Option<&Field>,
540 types: &[WireType],
541) -> String {
542 let is_trailing = field.is_some_and(|f| f.has_builtin_attr("trailing"));
543
544 if matches!(classify_shape(shape), ShapeKind::Opaque) {
546 return if is_trailing {
547 format!("{value}.encodeTrailing(into: &buffer)")
548 } else {
549 format!("{value}.encode(into: &buffer)")
550 };
551 }
552
553 if is_bytes(shape) {
554 return format!("encodeByteSeq({value}, into: &buffer)");
555 }
556
557 match classify_shape(shape) {
558 ShapeKind::Scalar(scalar) => encode_scalar_stmt(scalar, value),
559 ShapeKind::List { element } | ShapeKind::Slice { element } => {
560 let inner = encode_element_closure(element, types);
561 format!("encodeVec({value}, into: &buffer, encoder: {inner})")
562 }
563 ShapeKind::Option { inner } => {
564 let inner_closure = encode_element_closure(inner, types);
565 format!("encodeOption({value}, into: &buffer, encoder: {inner_closure})")
566 }
567 ShapeKind::Struct(StructInfo { .. }) | ShapeKind::Enum(EnumInfo { .. }) => {
568 format!("{value}.encode(into: &buffer)")
569 }
570 ShapeKind::Pointer { pointee } => encode_shape_stmt(pointee, value, field, types),
571 ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
572 encode_shape_stmt(fields[0].shape(), value, field, types)
573 }
574 _ => format!("/* unsupported encode for {value} */"),
575 }
576}
577
578fn encode_scalar_stmt(scalar: ScalarType, value: &str) -> String {
579 match scalar {
580 ScalarType::Bool => format!("encodeBool({value}, into: &buffer)"),
581 ScalarType::U8 => format!("encodeU8({value}, into: &buffer)"),
582 ScalarType::I8 => format!("encodeI8({value}, into: &buffer)"),
583 ScalarType::U16 => format!("encodeU16({value}, into: &buffer)"),
584 ScalarType::I16 => format!("encodeI16({value}, into: &buffer)"),
585 ScalarType::U32 => format!("encodeVarint(UInt64({value}), into: &buffer)"),
586 ScalarType::I32 => format!("encodeI32({value}, into: &buffer)"),
587 ScalarType::U64 | ScalarType::USize => format!("encodeVarint({value}, into: &buffer)"),
588 ScalarType::I64 | ScalarType::ISize => format!("encodeI64({value}, into: &buffer)"),
589 ScalarType::F32 => format!("encodeF32({value}, into: &buffer)"),
590 ScalarType::F64 => format!("encodeF64({value}, into: &buffer)"),
591 ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
592 format!("encodeString({value}, into: &buffer)")
593 }
594 _ => format!("/* unsupported scalar encode for {value} */"),
595 }
596}
597
598fn encode_element_closure(shape: &'static Shape, _types: &[WireType]) -> String {
601 if is_bytes(shape) {
602 return "{ val, buf in encodeByteSeq(val, into: &buf) }".into();
603 }
604
605 match classify_shape(shape) {
606 ShapeKind::Scalar(scalar) => {
607 let stmt = encode_scalar_stmt(scalar, "val").replace("into: &buffer", "into: &buf");
608 format!("{{ val, buf in {stmt} }}")
609 }
610 ShapeKind::Struct(StructInfo { .. }) | ShapeKind::Enum(EnumInfo { .. }) => {
611 "{ val, buf in val.encode(into: &buf) }".into()
612 }
613 ShapeKind::List { element } | ShapeKind::Slice { element } => {
614 let inner = encode_element_closure(element, _types);
615 format!("{{ val, buf in encodeVec(val, into: &buf, encoder: {inner}) }}")
616 }
617 ShapeKind::Opaque => "{ val, buf in val.encode(into: &buf) }".into(),
618 ShapeKind::Pointer { pointee } => encode_element_closure(pointee, _types),
619 _ => "{ _, _ in /* unsupported */ }".into(),
620 }
621}
622
623fn decode_field_stmts(field: &Field, types: &[WireType]) -> Vec<String> {
626 let field_name = field.name.to_lower_camel_case();
627 decode_stmts_for(field.shape(), Some(field), &field_name, types)
628}
629
630fn decode_stmts_for(
633 shape: &'static Shape,
634 field: Option<&Field>,
635 var_name: &str,
636 types: &[WireType],
637) -> Vec<String> {
638 if let ShapeKind::Pointer { pointee } = classify_shape(shape) {
640 return decode_stmts_for(pointee, field, var_name, types);
641 }
642 if is_bytes(shape) {
643 return vec![
644 format!("var _{var_name}Buf = try decodeWireBytes(from: &buffer)"),
645 format!(
646 "let {var_name} = _{var_name}Buf.readBytes(length: _{var_name}Buf.readableBytes) ?? []"
647 ),
648 ];
649 }
650 vec![format!(
651 "let {var_name} = {}",
652 decode_shape_expr(shape, field, types)
653 )]
654}
655
656fn decode_shape_expr(shape: &'static Shape, field: Option<&Field>, types: &[WireType]) -> String {
659 let is_trailing = field.is_some_and(|f| f.has_builtin_attr("trailing"));
660
661 if matches!(classify_shape(shape), ShapeKind::Opaque) {
663 return if is_trailing {
664 "OpaquePayload.decodeTrailing(from: &buffer)".into()
665 } else {
666 "try OpaquePayload.decode(from: &buffer)".into()
667 };
668 }
669
670 match classify_shape(shape) {
671 ShapeKind::Scalar(scalar) => decode_scalar(scalar),
672 ShapeKind::List { element } | ShapeKind::Slice { element } => {
673 let inner = decode_element_closure(element, types);
674 format!("try decodeVec(from: &buffer, decoder: {inner})")
675 }
676 ShapeKind::Option { inner } => {
677 let inner_closure = decode_element_closure(inner, types);
678 format!("try decodeOption(from: &buffer, decoder: {inner_closure})")
679 }
680 ShapeKind::Struct(StructInfo {
681 name: Some(name), ..
682 }) => {
683 let swift_name = lookup_wire_name(name, types);
684 format!("try {swift_name}.decode(from: &buffer)")
685 }
686 ShapeKind::Enum(EnumInfo {
687 name: Some(name), ..
688 }) => {
689 let swift_name = lookup_wire_name(name, types);
690 format!("try {swift_name}.decode(from: &buffer)")
691 }
692 ShapeKind::Pointer { pointee } => decode_shape_expr(pointee, field, types),
693 ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
694 decode_shape_expr(fields[0].shape(), field, types)
695 }
696 _ => "nil /* unsupported decode */".into(),
697 }
698}
699
700fn decode_scalar(scalar: ScalarType) -> String {
701 match scalar {
702 ScalarType::Bool => "try decodeBool(from: &buffer)".into(),
703 ScalarType::U8 => "try decodeU8(from: &buffer)".into(),
704 ScalarType::I8 => "try decodeI8(from: &buffer)".into(),
705 ScalarType::U16 => "try decodeU16(from: &buffer)".into(),
706 ScalarType::I16 => "try decodeI16(from: &buffer)".into(),
707 ScalarType::U32 => "try decodeWireVarintU32(from: &buffer)".into(),
708 ScalarType::I32 => "try decodeI32(from: &buffer)".into(),
709 ScalarType::U64 | ScalarType::USize => "try decodeVarint(from: &buffer)".into(),
710 ScalarType::I64 | ScalarType::ISize => "try decodeI64(from: &buffer)".into(),
711 ScalarType::F32 => "try decodeF32(from: &buffer)".into(),
712 ScalarType::F64 => "try decodeF64(from: &buffer)".into(),
713 ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
714 "try decodeWireString(from: &buffer)".into()
715 }
716 _ => "nil /* unsupported scalar decode */".into(),
717 }
718}
719
720fn decode_element_closure(shape: &'static Shape, types: &[WireType]) -> String {
723 if is_bytes(shape) {
724 return "{ buf in var s = try decodeWireBytes(from: &buf); return s.readBytes(length: s.readableBytes) ?? [] }".into();
726 }
727
728 match classify_shape(shape) {
729 ShapeKind::Scalar(scalar) => {
730 let expr = decode_scalar(scalar).replace("from: &buffer", "from: &buf");
731 format!("{{ buf in {expr} }}")
732 }
733 ShapeKind::Struct(StructInfo {
734 name: Some(name), ..
735 }) => {
736 let swift_name = lookup_wire_name(name, types);
737 format!("{{ buf in try {swift_name}.decode(from: &buf) }}")
738 }
739 ShapeKind::Enum(EnumInfo {
740 name: Some(name), ..
741 }) => {
742 let swift_name = lookup_wire_name(name, types);
743 format!("{{ buf in try {swift_name}.decode(from: &buf) }}")
744 }
745 ShapeKind::List { element } | ShapeKind::Slice { element } => {
746 let inner = decode_element_closure(element, types);
747 format!("{{ buf in try decodeVec(from: &buf, decoder: {inner}) }}")
748 }
749 ShapeKind::Opaque => "{ buf in try OpaquePayload.decode(from: &buf) }".into(),
750 ShapeKind::Pointer { pointee } => decode_element_closure(pointee, types),
751 _ => "{ _ in throw WireError.truncated }".into(),
752 }
753}
754
755fn generate_factory_methods(types: &[WireType]) -> String {
757 let payload_wt = types.iter().find(|wt| wt.swift_name == "MessagePayload");
759 let payload_wt = match payload_wt {
760 Some(wt) => wt,
761 None => return String::new(),
762 };
763
764 let variants = match classify_shape(payload_wt.shape) {
765 ShapeKind::Enum(EnumInfo { variants, .. }) => variants,
766 _ => return String::new(),
767 };
768
769 let mut out = String::new();
770 out.push_str("public extension Message {\n");
771
772 for v in variants {
773 let variant_name = v.name.to_lower_camel_case();
774 if let VariantKind::Newtype { inner } = classify_variant(v) {
775 let inner_swift = swift_wire_type(inner, None, types);
776 let is_control = matches!(
778 v.name,
779 "Hello" | "HelloYourself" | "ProtocolError" | "Ping" | "Pong"
780 );
781
782 if is_control {
783 out.push_str(&format!(
784 " static func {variant_name}(_ value: {inner_swift}) -> Message {{\n"
785 ));
786 out.push_str(&format!(
787 " Message(connectionId: 0, payload: .{variant_name}(value))\n"
788 ));
789 out.push_str(" }\n\n");
790 } else {
791 out.push_str(&format!(
792 " static func {variant_name}(connId: UInt64, _ value: {inner_swift}) -> Message {{\n"
793 ));
794 out.push_str(&format!(
795 " Message(connectionId: connId, payload: .{variant_name}(value))\n"
796 ));
797 out.push_str(" }\n\n");
798 }
799 }
800 }
801
802 out.push_str(" static func protocolError(description: String) -> Message {\n");
805 out.push_str(" Message(connectionId: 0, payload: .protocolError(.init(description: description)))\n");
806 out.push_str(" }\n\n");
807
808 out.push_str(" static func connectionOpen(\n");
809 out.push_str(
810 " connId: UInt64, settings: ConnectionSettings, metadata: [MetadataEntry]\n",
811 );
812 out.push_str(" ) -> Message {\n");
813 out.push_str(" Message(\n");
814 out.push_str(" connectionId: connId,\n");
815 out.push_str(" payload: .connectionOpen(.init(connectionSettings: settings, metadata: metadata)))\n");
816 out.push_str(" }\n\n");
817
818 out.push_str(" static func connectionAccept(\n");
819 out.push_str(
820 " connId: UInt64, settings: ConnectionSettings, metadata: [MetadataEntry]\n",
821 );
822 out.push_str(" ) -> Message {\n");
823 out.push_str(" Message(\n");
824 out.push_str(" connectionId: connId,\n");
825 out.push_str(" payload: .connectionAccept(.init(connectionSettings: settings, metadata: metadata)))\n");
826 out.push_str(" }\n\n");
827
828 out.push_str(" static func connectionReject(connId: UInt64, metadata: [MetadataEntry]) -> Message {\n");
829 out.push_str(" Message(connectionId: connId, payload: .connectionReject(.init(metadata: metadata)))\n");
830 out.push_str(" }\n\n");
831
832 out.push_str(
833 " static func connectionClose(connId: UInt64, metadata: [MetadataEntry]) -> Message {\n",
834 );
835 out.push_str(" Message(connectionId: connId, payload: .connectionClose(.init(metadata: metadata)))\n");
836 out.push_str(" }\n\n");
837
838 out.push_str(" static func request(\n");
839 out.push_str(" connId: UInt64,\n");
840 out.push_str(" requestId: UInt64,\n");
841 out.push_str(" methodId: UInt64,\n");
842 out.push_str(" metadata: [MetadataEntry],\n");
843 out.push_str(" schemas: [UInt8] = [],\n");
844 out.push_str(" payload: [UInt8]\n");
845 out.push_str(" ) -> Message {\n");
846 out.push_str(" Message(\n");
847 out.push_str(" connectionId: connId,\n");
848 out.push_str(" payload: .requestMessage(\n");
849 out.push_str(" .init(\n");
850 out.push_str(" id: requestId,\n");
851 out.push_str(" body: .call(.init(methodId: methodId, metadata: metadata, args: .init(payload), schemas: schemas))\n");
852 out.push_str(" ))\n");
853 out.push_str(" )\n");
854 out.push_str(" }\n\n");
855
856 out.push_str(" static func response(\n");
857 out.push_str(" connId: UInt64,\n");
858 out.push_str(" requestId: UInt64,\n");
859 out.push_str(" metadata: [MetadataEntry],\n");
860 out.push_str(" schemas: [UInt8] = [],\n");
861 out.push_str(" payload: [UInt8]\n");
862 out.push_str(" ) -> Message {\n");
863 out.push_str(" Message(\n");
864 out.push_str(" connectionId: connId,\n");
865 out.push_str(" payload: .requestMessage(\n");
866 out.push_str(" .init(\n");
867 out.push_str(" id: requestId,\n");
868 out.push_str(" body: .response(.init(metadata: metadata, ret: .init(payload), schemas: schemas))\n");
869 out.push_str(" ))\n");
870 out.push_str(" )\n");
871 out.push_str(" }\n\n");
872
873 out.push_str(" static func cancel(connId: UInt64, requestId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
874 out.push_str(" Message(\n");
875 out.push_str(" connectionId: connId,\n");
876 out.push_str(" payload: .requestMessage(\n");
877 out.push_str(" .init(\n");
878 out.push_str(" id: requestId,\n");
879 out.push_str(" body: .cancel(.init(metadata: metadata))\n");
880 out.push_str(" ))\n");
881 out.push_str(" )\n");
882 out.push_str(" }\n\n");
883
884 out.push_str(
885 " static func data(connId: UInt64, channelId: UInt64, payload: [UInt8]) -> Message {\n",
886 );
887 out.push_str(" Message(\n");
888 out.push_str(" connectionId: connId,\n");
889 out.push_str(" payload: .channelMessage(.init(id: channelId, body: .item(.init(item: .init(payload)))))\n");
890 out.push_str(" )\n");
891 out.push_str(" }\n\n");
892
893 out.push_str(" static func close(connId: UInt64, channelId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
894 out.push_str(" Message(\n");
895 out.push_str(" connectionId: connId,\n");
896 out.push_str(" payload: .channelMessage(.init(id: channelId, body: .close(.init(metadata: metadata))))\n");
897 out.push_str(" )\n");
898 out.push_str(" }\n\n");
899
900 out.push_str(" static func reset(connId: UInt64, channelId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
901 out.push_str(" Message(\n");
902 out.push_str(" connectionId: connId,\n");
903 out.push_str(" payload: .channelMessage(.init(id: channelId, body: .reset(.init(metadata: metadata))))\n");
904 out.push_str(" )\n");
905 out.push_str(" }\n\n");
906
907 out.push_str(
908 " static func credit(connId: UInt64, channelId: UInt64, bytes: UInt32) -> Message {\n",
909 );
910 out.push_str(" Message(\n");
911 out.push_str(" connectionId: connId,\n");
912 out.push_str(" payload: .channelMessage(.init(id: channelId, body: .grantCredit(.init(additional: bytes))))\n");
913 out.push_str(" )\n");
914 out.push_str(" }\n");
915
916 out.push_str("}\n");
917 out
918}