1use facet_core::{Field, ScalarType, Shape};
12use heck::ToLowerCamelCase;
13use vox_types::{
14 EnumInfo, ShapeKind, StructInfo, VariantKind, classify_shape, classify_variant,
15 extract_schemas, is_bytes,
16};
17
18pub struct WireType {
20 pub swift_name: String,
22 pub shape: &'static Shape,
24}
25
26pub fn generate_wire_types(types: &[WireType]) -> String {
28 let mut out = String::new();
29 out.push_str("// @generated by vox-codegen\n");
30 out.push_str("// DO NOT EDIT — regenerate with `cargo xtask codegen --swift-wire`\n\n");
31 out.push_str("import Foundation\n\n");
32
33 out.push_str(&generate_preamble());
35
36 out.push_str("public typealias Metadata = [MetadataEntry]\n\n");
38
39 for wt in types {
41 out.push_str(&generate_one_type(&wt.swift_name, wt.shape, types));
42 out.push('\n');
43 }
44
45 out.push_str(&generate_factory_methods(types));
47
48 if let Some(root) = types
49 .iter()
50 .find(|wt| wt.swift_name == "Message")
51 .or_else(|| types.last())
52 {
53 let extracted = extract_schemas(root.shape).expect("wire schema extraction should succeed");
54 let cbor_bytes = facet_cbor::to_vec(&extracted.schemas)
55 .expect("wire schema CBOR serialization should succeed");
56 let body = cbor_bytes
57 .iter()
58 .map(|b| b.to_string())
59 .collect::<Vec<_>>()
60 .join(", ");
61 out.push('\n');
62 out.push_str(&format!(
63 "public let wireMessageSchemasCbor: [UInt8] = [{body}]\n"
64 ));
65 }
66
67 out
68}
69
70fn generate_preamble() -> String {
72 let mut out = String::new();
73
74 out.push_str("public enum WireError: Error, Equatable {\n");
76 out.push_str(" case truncated\n");
77 out.push_str(" case unknownVariant(UInt64)\n");
78 out.push_str(" case overflow\n");
79 out.push_str(" case invalidUtf8\n");
80 out.push_str(" case trailingBytes\n");
81 out.push_str("}\n\n");
82
83 out.push_str("public struct OpaquePayload: Sendable, Equatable {\n");
85 out.push_str(" public var bytes: [UInt8]\n\n");
86 out.push_str(" public init(_ bytes: [UInt8]) {\n");
87 out.push_str(" self.bytes = bytes\n");
88 out.push_str(" }\n\n");
89 out.push_str(" func encode() -> [UInt8] {\n");
90 out.push_str(" let len = UInt32(bytes.count)\n");
91 out.push_str(" return [\n");
92 out.push_str(" UInt8(truncatingIfNeeded: len),\n");
93 out.push_str(" UInt8(truncatingIfNeeded: len >> 8),\n");
94 out.push_str(" UInt8(truncatingIfNeeded: len >> 16),\n");
95 out.push_str(" UInt8(truncatingIfNeeded: len >> 24),\n");
96 out.push_str(" ] + bytes\n");
97 out.push_str(" }\n\n");
98 out.push_str(" static func decode(from data: Data, offset: inout Int) throws -> Self {\n");
99 out.push_str(" guard offset + 4 <= data.count else { throw WireError.truncated }\n");
100 out.push_str(" let start = data.startIndex + offset\n");
101 out.push_str(" let len = Int(UInt32(data[start]) | (UInt32(data[start + 1]) << 8) | (UInt32(data[start + 2]) << 16) | (UInt32(data[start + 3]) << 24))\n");
102 out.push_str(" offset += 4\n");
103 out.push_str(" guard offset + len <= data.count else { throw WireError.truncated }\n");
104 out.push_str(" let payloadStart = data.startIndex + offset\n");
105 out.push_str(" let payloadEnd = payloadStart + len\n");
106 out.push_str(" let payload = Array(data[payloadStart..<payloadEnd])\n");
107 out.push_str(" offset += len\n");
108 out.push_str(" return .init(payload)\n");
109 out.push_str(" }\n\n");
110 out.push_str(" /// Encode without a length prefix — for trailing fields only.\n");
111 out.push_str(" func encodeTrailing() -> [UInt8] {\n");
112 out.push_str(" bytes\n");
113 out.push_str(" }\n\n");
114 out.push_str(" /// Decode by consuming all remaining bytes — for trailing fields only.\n");
115 out.push_str(" static func decodeTrailing(from data: Data, offset: inout Int) -> Self {\n");
116 out.push_str(" let start = data.startIndex + offset\n");
117 out.push_str(" let remaining = Array(data[start...])\n");
118 out.push_str(" offset = data.count\n");
119 out.push_str(" return .init(remaining)\n");
120 out.push_str(" }\n");
121 out.push_str("}\n\n");
122
123 out.push_str("@inline(__always)\n");
125 out.push_str(
126 "private func decodeWireVarintU32(from data: Data, offset: inout Int) throws -> UInt32 {\n",
127 );
128 out.push_str(" let value = try decodeVarint(from: data, offset: &offset)\n");
129 out.push_str(" guard value <= UInt64(UInt32.max) else {\n");
130 out.push_str(" throw WireError.overflow\n");
131 out.push_str(" }\n");
132 out.push_str(" return UInt32(value)\n");
133 out.push_str("}\n\n");
134
135 out.push_str("@inline(__always)\n");
136 out.push_str(
137 "private func decodeWireString(from data: Data, offset: inout Int) throws -> String {\n",
138 );
139 out.push_str(" do {\n");
140 out.push_str(" return try decodeString(from: data, offset: &offset)\n");
141 out.push_str(" } catch PostcardError.invalidUtf8 {\n");
142 out.push_str(" throw WireError.invalidUtf8\n");
143 out.push_str(" } catch PostcardError.truncated {\n");
144 out.push_str(" throw WireError.truncated\n");
145 out.push_str(" } catch {\n");
146 out.push_str(" throw error\n");
147 out.push_str(" }\n");
148 out.push_str("}\n\n");
149
150 out.push_str("@inline(__always)\n");
151 out.push_str(
152 "private func decodeWireBytes(from data: Data, offset: inout Int) throws -> Data {\n",
153 );
154 out.push_str(" do {\n");
155 out.push_str(" return try decodeBytes(from: data, offset: &offset)\n");
156 out.push_str(" } catch PostcardError.truncated {\n");
157 out.push_str(" throw WireError.truncated\n");
158 out.push_str(" } catch {\n");
159 out.push_str(" throw error\n");
160 out.push_str(" }\n");
161 out.push_str("}\n\n");
162
163 out
164}
165
166fn swift_wire_type(shape: &'static Shape, _field: Option<&Field>, types: &[WireType]) -> String {
168 if is_bytes(shape) {
169 return "[UInt8]".into();
170 }
171
172 match classify_shape(shape) {
173 ShapeKind::Scalar(scalar) => swift_scalar_type(scalar),
174 ShapeKind::List { element } | ShapeKind::Slice { element } => {
175 format!("[{}]", swift_wire_type(element, None, types))
176 }
177 ShapeKind::Option { inner } => {
178 format!("{}?", swift_wire_type(inner, None, types))
179 }
180 ShapeKind::Array { element, .. } => {
181 format!("[{}]", swift_wire_type(element, None, types))
182 }
183 ShapeKind::Struct(StructInfo {
184 name: Some(name), ..
185 }) => lookup_wire_name(name, types),
186 ShapeKind::Enum(EnumInfo {
187 name: Some(name), ..
188 }) => lookup_wire_name(name, types),
189 ShapeKind::Pointer { pointee } => swift_wire_type(pointee, _field, types),
190 ShapeKind::Opaque => "OpaquePayload".into(),
191 ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
192 swift_wire_type(fields[0].shape(), None, types)
193 }
194 _ => "Any /* unsupported */".into(),
195 }
196}
197
198fn swift_scalar_type(scalar: ScalarType) -> String {
199 match scalar {
200 ScalarType::Bool => "Bool".into(),
201 ScalarType::U8 => "UInt8".into(),
202 ScalarType::U16 => "UInt16".into(),
203 ScalarType::U32 => "UInt32".into(),
204 ScalarType::U64 | ScalarType::USize => "UInt64".into(),
205 ScalarType::I8 => "Int8".into(),
206 ScalarType::I16 => "Int16".into(),
207 ScalarType::I32 => "Int32".into(),
208 ScalarType::I64 | ScalarType::ISize => "Int64".into(),
209 ScalarType::F32 => "Float".into(),
210 ScalarType::F64 => "Double".into(),
211 ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
212 "String".into()
213 }
214 ScalarType::Unit => "Void".into(),
215 _ => "Data".into(),
216 }
217}
218
219fn lookup_wire_name(rust_name: &str, types: &[WireType]) -> String {
222 for wt in types {
223 if wt.shape.type_identifier.ends_with(rust_name) {
225 return wt.swift_name.clone();
226 }
227 }
228 rust_name.to_string()
229}
230
231fn generate_one_type(swift_name: &str, shape: &'static Shape, types: &[WireType]) -> String {
233 match classify_shape(shape) {
234 ShapeKind::Struct(StructInfo { fields, .. }) => {
235 generate_struct(swift_name, fields, types, swift_name == "Message")
236 }
237 ShapeKind::Enum(EnumInfo { variants, .. }) => generate_enum(swift_name, variants, types),
238 _ => format!("// Unsupported shape for {swift_name}\n"),
239 }
240}
241
242fn generate_struct(name: &str, fields: &[Field], types: &[WireType], is_top_level: bool) -> String {
244 let mut out = String::new();
245
246 out.push_str(&format!("public struct {name}: Sendable, Equatable {{\n"));
248 for f in fields {
249 let field_name = f.name.to_lower_camel_case();
250 let field_type = swift_wire_type(f.shape(), Some(f), types);
251 out.push_str(&format!(" public var {field_name}: {field_type}\n"));
252 }
253
254 if !fields.is_empty() {
256 out.push_str("\n public init(");
257 for (i, f) in fields.iter().enumerate() {
258 if i > 0 {
259 out.push_str(", ");
260 }
261 let field_name = f.name.to_lower_camel_case();
262 let field_type = swift_wire_type(f.shape(), Some(f), types);
263 out.push_str(&format!("{field_name}: {field_type}"));
264 }
265 out.push_str(") {\n");
266 for f in fields {
267 let field_name = f.name.to_lower_camel_case();
268 out.push_str(&format!(" self.{field_name} = {field_name}\n"));
269 }
270 out.push_str(" }\n");
271 }
272
273 let vis = if is_top_level { "public " } else { "" };
275 out.push_str(&format!("\n {vis}func encode() -> [UInt8] {{\n"));
276 if fields.is_empty() {
277 out.push_str(" []\n");
278 } else if fields.len() == 1 {
279 let f = &fields[0];
280 let expr = encode_field_expr(f, types);
281 out.push_str(&format!(" {expr}\n"));
282 } else {
283 out.push_str(" var out: [UInt8] = []\n");
284 for f in fields {
285 let expr = encode_field_expr(f, types);
286 out.push_str(&format!(" out += {expr}\n"));
287 }
288 out.push_str(" return out\n");
289 }
290 out.push_str(" }\n");
291
292 out.push_str(&format!(
294 "\n {vis}static func decode(from data: Data, offset: inout Int) throws -> Self {{\n"
295 ));
296 for f in fields {
297 let field_name = f.name.to_lower_camel_case();
298 let decode_expr = decode_field_expr(f, types);
299 out.push_str(&format!(" let {field_name} = {decode_expr}\n"));
300 }
301 let field_names: Vec<String> = fields
302 .iter()
303 .map(|f| {
304 let n = f.name.to_lower_camel_case();
305 format!("{n}: {n}")
306 })
307 .collect();
308 out.push_str(&format!(
309 " return .init({})\n",
310 field_names.join(", ")
311 ));
312 out.push_str(" }\n");
313
314 if is_top_level {
316 out.push_str("\n public static func decode(from data: Data) throws -> Self {\n");
317 out.push_str(" var offset = 0\n");
318 out.push_str(" let result = try decode(from: data, offset: &offset)\n");
319 out.push_str(" guard offset == data.count else {\n");
320 out.push_str(" throw WireError.trailingBytes\n");
321 out.push_str(" }\n");
322 out.push_str(" return result\n");
323 out.push_str(" }\n");
324 }
325
326 out.push_str("}\n");
327 out
328}
329
330fn generate_enum(name: &str, variants: &[facet_core::Variant], types: &[WireType]) -> String {
332 let mut out = String::new();
333
334 out.push_str(&format!("public enum {name}: Sendable, Equatable {{\n"));
336 for v in variants {
337 let variant_name = v.name.to_lower_camel_case();
338 match classify_variant(v) {
339 VariantKind::Unit => {
340 out.push_str(&format!(" case {variant_name}\n"));
341 }
342 VariantKind::Newtype { inner } => {
343 let inner_type = swift_wire_type(inner, v.data.fields.first(), types);
344 out.push_str(&format!(" case {variant_name}({inner_type})\n"));
345 }
346 VariantKind::Tuple { fields } => {
347 let field_types: Vec<String> = fields
348 .iter()
349 .map(|f| swift_wire_type(f.shape(), Some(f), types))
350 .collect();
351 out.push_str(&format!(
352 " case {variant_name}({})\n",
353 field_types.join(", ")
354 ));
355 }
356 VariantKind::Struct { fields } => {
357 let field_decls: Vec<String> = fields
358 .iter()
359 .map(|f| {
360 format!(
361 "{}: {}",
362 f.name.to_lower_camel_case(),
363 swift_wire_type(f.shape(), Some(f), types)
364 )
365 })
366 .collect();
367 out.push_str(&format!(
368 " case {variant_name}({})\n",
369 field_decls.join(", ")
370 ));
371 }
372 }
373 }
374
375 out.push_str("\n func encode() -> [UInt8] {\n");
377 out.push_str(" switch self {\n");
378 for (i, v) in variants.iter().enumerate() {
379 let variant_name = v.name.to_lower_camel_case();
380 match classify_variant(v) {
381 VariantKind::Unit => {
382 out.push_str(&format!(
383 " case .{variant_name}:\n return encodeVarint(UInt64({i}))\n"
384 ));
385 }
386 VariantKind::Newtype { inner } => {
387 let encode = encode_shape_expr(inner, "val", v.data.fields.first(), types);
388 out.push_str(&format!(
389 " case .{variant_name}(let val):\n return encodeVarint(UInt64({i})) + {encode}\n"
390 ));
391 }
392 VariantKind::Tuple { fields } => {
393 let bindings: Vec<String> = (0..fields.len()).map(|j| format!("f{j}")).collect();
394 let binding_str = bindings
395 .iter()
396 .map(|b| format!("let {b}"))
397 .collect::<Vec<_>>()
398 .join(", ");
399 let encodes: Vec<String> = fields
400 .iter()
401 .enumerate()
402 .map(|(j, f)| encode_shape_expr(f.shape(), &format!("f{j}"), Some(f), types))
403 .collect();
404 out.push_str(&format!(
405 " case .{variant_name}({binding_str}):\n return encodeVarint(UInt64({i})) + {}\n",
406 encodes.join(" + ")
407 ));
408 }
409 VariantKind::Struct { fields } => {
410 let bindings: Vec<String> = fields
411 .iter()
412 .map(|f| f.name.to_lower_camel_case())
413 .collect();
414 let binding_str = bindings
415 .iter()
416 .map(|b| format!("let {b}"))
417 .collect::<Vec<_>>()
418 .join(", ");
419 let encodes: Vec<String> = fields
420 .iter()
421 .map(|f| {
422 encode_shape_expr(f.shape(), &f.name.to_lower_camel_case(), Some(f), types)
423 })
424 .collect();
425 out.push_str(&format!(
426 " case .{variant_name}({binding_str}):\n return encodeVarint(UInt64({i})) + {}\n",
427 encodes.join(" + ")
428 ));
429 }
430 }
431 }
432 out.push_str(" }\n");
433 out.push_str(" }\n");
434
435 out.push_str("\n static func decode(from data: Data, offset: inout Int) throws -> Self {\n");
437 out.push_str(" let disc = try decodeVarint(from: data, offset: &offset)\n");
438 out.push_str(" switch disc {\n");
439 for (i, v) in variants.iter().enumerate() {
440 let variant_name = v.name.to_lower_camel_case();
441 out.push_str(&format!(" case {i}:\n"));
442 match classify_variant(v) {
443 VariantKind::Unit => {
444 out.push_str(&format!(" return .{variant_name}\n"));
445 }
446 VariantKind::Newtype { inner } => {
447 let decode = decode_shape_expr(inner, v.data.fields.first(), types);
448 out.push_str(&format!(" return .{variant_name}({decode})\n"));
449 }
450 VariantKind::Tuple { fields } => {
451 for (j, f) in fields.iter().enumerate() {
452 let decode = decode_shape_expr(f.shape(), Some(f), types);
453 out.push_str(&format!(" let f{j} = {decode}\n"));
454 }
455 let args: Vec<String> = (0..fields.len()).map(|j| format!("f{j}")).collect();
456 out.push_str(&format!(
457 " return .{variant_name}({})\n",
458 args.join(", ")
459 ));
460 }
461 VariantKind::Struct { fields } => {
462 for f in fields {
463 let field_name = f.name.to_lower_camel_case();
464 let decode = decode_shape_expr(f.shape(), Some(f), types);
465 out.push_str(&format!(" let {field_name} = {decode}\n"));
466 }
467 let args: Vec<String> = fields
468 .iter()
469 .map(|f| {
470 let n = f.name.to_lower_camel_case();
471 format!("{n}: {n}")
472 })
473 .collect();
474 out.push_str(&format!(
475 " return .{variant_name}({})\n",
476 args.join(", ")
477 ));
478 }
479 }
480 }
481 out.push_str(" default:\n");
482 out.push_str(" throw WireError.unknownVariant(disc)\n");
483 out.push_str(" }\n");
484 out.push_str(" }\n");
485
486 out.push_str("}\n");
487 out
488}
489
490fn encode_field_expr(field: &Field, types: &[WireType]) -> String {
492 let field_name = field.name.to_lower_camel_case();
493 encode_shape_expr(field.shape(), &field_name, Some(field), types)
494}
495
496fn encode_shape_expr(
498 shape: &'static Shape,
499 value: &str,
500 field: Option<&Field>,
501 types: &[WireType],
502) -> String {
503 let is_trailing = field.is_some_and(|f| f.has_builtin_attr("trailing"));
504
505 if matches!(classify_shape(shape), ShapeKind::Opaque) {
507 return if is_trailing {
508 format!("{value}.encodeTrailing()")
509 } else {
510 format!("{value}.encode()")
511 };
512 }
513
514 if is_bytes(shape) {
515 return format!("encodeBytes({value})");
516 }
517
518 match classify_shape(shape) {
519 ShapeKind::Scalar(scalar) => encode_scalar(scalar, value),
520 ShapeKind::List { element } | ShapeKind::Slice { element } => {
521 let inner = encode_element_closure(element, types);
522 format!("encodeVec({value}, encoder: {inner})")
523 }
524 ShapeKind::Option { inner } => {
525 let inner_closure = encode_element_closure(inner, types);
526 format!("encodeOption({value}, encoder: {inner_closure})")
527 }
528 ShapeKind::Struct(StructInfo { .. }) | ShapeKind::Enum(EnumInfo { .. }) => {
529 format!("{value}.encode()")
530 }
531 ShapeKind::Pointer { pointee } => encode_shape_expr(pointee, value, field, types),
532 ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
533 encode_shape_expr(fields[0].shape(), value, field, types)
534 }
535 _ => format!("[] /* unsupported encode for {value} */"),
536 }
537}
538
539fn encode_scalar(scalar: ScalarType, value: &str) -> String {
540 match scalar {
541 ScalarType::Bool => format!("encodeBool({value})"),
542 ScalarType::U8 => format!("encodeU8({value})"),
543 ScalarType::I8 => format!("encodeI8({value})"),
544 ScalarType::U16 => format!("encodeU16({value})"),
545 ScalarType::I16 => format!("encodeI16({value})"),
546 ScalarType::U32 => format!("encodeVarint(UInt64({value}))"),
547 ScalarType::I32 => format!("encodeI32({value})"),
548 ScalarType::U64 | ScalarType::USize => format!("encodeVarint({value})"),
549 ScalarType::I64 | ScalarType::ISize => format!("encodeI64({value})"),
550 ScalarType::F32 => format!("encodeF32({value})"),
551 ScalarType::F64 => format!("encodeF64({value})"),
552 ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
553 format!("encodeString({value})")
554 }
555 _ => "[] /* unsupported scalar */".to_string(),
556 }
557}
558
559fn encode_element_closure(shape: &'static Shape, _types: &[WireType]) -> String {
561 if is_bytes(shape) {
562 return "{ encodeBytes($0) }".into();
563 }
564
565 match classify_shape(shape) {
566 ShapeKind::Scalar(scalar) => {
567 let expr = encode_scalar(scalar, "$0");
568 format!("{{ {expr} }}")
569 }
570 ShapeKind::Struct(StructInfo { .. }) | ShapeKind::Enum(EnumInfo { .. }) => {
571 "{ $0.encode() }".into()
572 }
573 ShapeKind::List { element } | ShapeKind::Slice { element } => {
574 let inner = encode_element_closure(element, _types);
575 format!("{{ encodeVec($0, encoder: {inner}) }}")
576 }
577 ShapeKind::Opaque => "{ $0.encode() }".into(),
578 ShapeKind::Pointer { pointee } => encode_element_closure(pointee, _types),
579 _ => "{ _ in [] }".into(),
580 }
581}
582
583fn decode_field_expr(field: &Field, types: &[WireType]) -> String {
585 decode_shape_expr(field.shape(), Some(field), types)
586}
587
588fn decode_shape_expr(shape: &'static Shape, field: Option<&Field>, types: &[WireType]) -> String {
590 let is_trailing = field.is_some_and(|f| f.has_builtin_attr("trailing"));
591
592 if matches!(classify_shape(shape), ShapeKind::Opaque) {
594 return if is_trailing {
595 "OpaquePayload.decodeTrailing(from: data, offset: &offset)".into()
596 } else {
597 "try OpaquePayload.decode(from: data, offset: &offset)".into()
598 };
599 }
600
601 if is_bytes(shape) {
602 return "Array(try decodeWireBytes(from: data, offset: &offset))".into();
603 }
604
605 match classify_shape(shape) {
606 ShapeKind::Scalar(scalar) => decode_scalar(scalar),
607 ShapeKind::List { element } | ShapeKind::Slice { element } => {
608 let inner = decode_element_closure(element, types);
609 format!("try decodeVec(from: data, offset: &offset, decoder: {inner})")
610 }
611 ShapeKind::Option { inner } => {
612 let inner_closure = decode_element_closure(inner, types);
613 format!("try decodeOption(from: data, offset: &offset, decoder: {inner_closure})")
614 }
615 ShapeKind::Struct(StructInfo {
616 name: Some(name), ..
617 }) => {
618 let swift_name = lookup_wire_name(name, types);
619 format!("try {swift_name}.decode(from: data, offset: &offset)")
620 }
621 ShapeKind::Enum(EnumInfo {
622 name: Some(name), ..
623 }) => {
624 let swift_name = lookup_wire_name(name, types);
625 format!("try {swift_name}.decode(from: data, offset: &offset)")
626 }
627 ShapeKind::Pointer { pointee } => decode_shape_expr(pointee, field, types),
628 ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
629 decode_shape_expr(fields[0].shape(), field, types)
630 }
631 _ => "nil /* unsupported decode */".into(),
632 }
633}
634
635fn decode_scalar(scalar: ScalarType) -> String {
636 match scalar {
637 ScalarType::Bool => "try decodeBool(from: data, offset: &offset)".into(),
638 ScalarType::U8 => "try decodeU8(from: data, offset: &offset)".into(),
639 ScalarType::I8 => "try decodeI8(from: data, offset: &offset)".into(),
640 ScalarType::U16 => "try decodeU16(from: data, offset: &offset)".into(),
641 ScalarType::I16 => "try decodeI16(from: data, offset: &offset)".into(),
642 ScalarType::U32 => "try decodeWireVarintU32(from: data, offset: &offset)".into(),
643 ScalarType::I32 => "try decodeI32(from: data, offset: &offset)".into(),
644 ScalarType::U64 | ScalarType::USize => {
645 "try decodeVarint(from: data, offset: &offset)".into()
646 }
647 ScalarType::I64 | ScalarType::ISize => "try decodeI64(from: data, offset: &offset)".into(),
648 ScalarType::F32 => "try decodeF32(from: data, offset: &offset)".into(),
649 ScalarType::F64 => "try decodeF64(from: data, offset: &offset)".into(),
650 ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
651 "try decodeWireString(from: data, offset: &offset)".into()
652 }
653 _ => "nil /* unsupported scalar decode */".into(),
654 }
655}
656
657fn decode_element_closure(shape: &'static Shape, types: &[WireType]) -> String {
659 if is_bytes(shape) {
660 return "{ data, off in Array(try decodeWireBytes(from: data, offset: &off)) }".into();
661 }
662
663 match classify_shape(shape) {
664 ShapeKind::Scalar(scalar) => {
665 let expr = decode_scalar_with_params(scalar, "data", "off");
666 format!("{{ data, off in {expr} }}")
667 }
668 ShapeKind::Struct(StructInfo {
669 name: Some(name), ..
670 }) => {
671 let swift_name = lookup_wire_name(name, types);
672 format!("{{ data, off in try {swift_name}.decode(from: data, offset: &off) }}")
673 }
674 ShapeKind::Enum(EnumInfo {
675 name: Some(name), ..
676 }) => {
677 let swift_name = lookup_wire_name(name, types);
678 format!("{{ data, off in try {swift_name}.decode(from: data, offset: &off) }}")
679 }
680 ShapeKind::List { element } | ShapeKind::Slice { element } => {
681 let inner = decode_element_closure(element, types);
682 format!("{{ data, off in try decodeVec(from: data, offset: &off, decoder: {inner}) }}")
683 }
684 ShapeKind::Opaque => {
685 "{ data, off in try OpaquePayload.decode(from: data, offset: &off) }".into()
686 }
687 ShapeKind::Pointer { pointee } => decode_element_closure(pointee, types),
688 _ => "{ _, _ in throw WireError.truncated }".into(),
689 }
690}
691
692fn decode_scalar_with_params(scalar: ScalarType, data: &str, offset: &str) -> String {
693 match scalar {
694 ScalarType::Bool => format!("try decodeBool(from: {data}, offset: &{offset})"),
695 ScalarType::U8 => format!("try decodeU8(from: {data}, offset: &{offset})"),
696 ScalarType::I8 => format!("try decodeI8(from: {data}, offset: &{offset})"),
697 ScalarType::U16 => format!("try decodeU16(from: {data}, offset: &{offset})"),
698 ScalarType::I16 => format!("try decodeI16(from: {data}, offset: &{offset})"),
699 ScalarType::U32 => format!("try decodeWireVarintU32(from: {data}, offset: &{offset})"),
700 ScalarType::I32 => format!("try decodeI32(from: {data}, offset: &{offset})"),
701 ScalarType::U64 | ScalarType::USize => {
702 format!("try decodeVarint(from: {data}, offset: &{offset})")
703 }
704 ScalarType::I64 | ScalarType::ISize => {
705 format!("try decodeI64(from: {data}, offset: &{offset})")
706 }
707 ScalarType::F32 => format!("try decodeF32(from: {data}, offset: &{offset})"),
708 ScalarType::F64 => format!("try decodeF64(from: {data}, offset: &{offset})"),
709 ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
710 format!("try decodeWireString(from: {data}, offset: &{offset})")
711 }
712 _ => "nil /* unsupported scalar */".to_string(),
713 }
714}
715
716fn generate_factory_methods(types: &[WireType]) -> String {
718 let payload_wt = types.iter().find(|wt| wt.swift_name == "MessagePayload");
720 let payload_wt = match payload_wt {
721 Some(wt) => wt,
722 None => return String::new(),
723 };
724
725 let variants = match classify_shape(payload_wt.shape) {
726 ShapeKind::Enum(EnumInfo { variants, .. }) => variants,
727 _ => return String::new(),
728 };
729
730 let mut out = String::new();
731 out.push_str("public extension Message {\n");
732
733 for v in variants {
734 let variant_name = v.name.to_lower_camel_case();
735 if let VariantKind::Newtype { inner } = classify_variant(v) {
736 let inner_swift = swift_wire_type(inner, None, types);
737 let is_control = matches!(
739 v.name,
740 "Hello" | "HelloYourself" | "ProtocolError" | "Ping" | "Pong"
741 );
742
743 if is_control {
744 out.push_str(&format!(
745 " static func {variant_name}(_ value: {inner_swift}) -> Message {{\n"
746 ));
747 out.push_str(&format!(
748 " Message(connectionId: 0, payload: .{variant_name}(value))\n"
749 ));
750 out.push_str(" }\n\n");
751 } else {
752 out.push_str(&format!(
753 " static func {variant_name}(connId: UInt64, _ value: {inner_swift}) -> Message {{\n"
754 ));
755 out.push_str(&format!(
756 " Message(connectionId: connId, payload: .{variant_name}(value))\n"
757 ));
758 out.push_str(" }\n\n");
759 }
760 }
761 }
762
763 out.push_str(" static func protocolError(description: String) -> Message {\n");
766 out.push_str(" Message(connectionId: 0, payload: .protocolError(.init(description: description)))\n");
767 out.push_str(" }\n\n");
768
769 out.push_str(" static func connectionOpen(connId: UInt64, settings: ConnectionSettings, metadata: [MetadataEntry]) -> Message {\n");
770 out.push_str(" Message(connectionId: connId, payload: .connectionOpen(.init(connectionSettings: settings, metadata: metadata)))\n");
771 out.push_str(" }\n\n");
772
773 out.push_str(" static func connectionAccept(connId: UInt64, settings: ConnectionSettings, metadata: [MetadataEntry]) -> Message {\n");
774 out.push_str(" Message(connectionId: connId, payload: .connectionAccept(.init(connectionSettings: settings, metadata: metadata)))\n");
775 out.push_str(" }\n\n");
776
777 out.push_str(" static func connectionReject(connId: UInt64, metadata: [MetadataEntry]) -> Message {\n");
778 out.push_str(" Message(connectionId: connId, payload: .connectionReject(.init(metadata: metadata)))\n");
779 out.push_str(" }\n\n");
780
781 out.push_str(
782 " static func connectionClose(connId: UInt64, metadata: [MetadataEntry]) -> Message {\n",
783 );
784 out.push_str(" Message(connectionId: connId, payload: .connectionClose(.init(metadata: metadata)))\n");
785 out.push_str(" }\n\n");
786
787 out.push_str(" static func request(\n");
788 out.push_str(" connId: UInt64,\n");
789 out.push_str(" requestId: UInt64,\n");
790 out.push_str(" methodId: UInt64,\n");
791 out.push_str(" metadata: [MetadataEntry],\n");
792 out.push_str(" schemas: [UInt8] = [],\n");
793 out.push_str(" payload: [UInt8]\n");
794 out.push_str(" ) -> Message {\n");
795 out.push_str(" Message(\n");
796 out.push_str(" connectionId: connId,\n");
797 out.push_str(" payload: .requestMessage(\n");
798 out.push_str(" .init(\n");
799 out.push_str(" id: requestId,\n");
800 out.push_str(" body: .call(.init(methodId: methodId, metadata: metadata, args: .init(payload), schemas: schemas))\n");
801 out.push_str(" ))\n");
802 out.push_str(" )\n");
803 out.push_str(" }\n\n");
804
805 out.push_str(" static func response(\n");
806 out.push_str(" connId: UInt64,\n");
807 out.push_str(" requestId: UInt64,\n");
808 out.push_str(" metadata: [MetadataEntry],\n");
809 out.push_str(" schemas: [UInt8] = [],\n");
810 out.push_str(" payload: [UInt8]\n");
811 out.push_str(" ) -> Message {\n");
812 out.push_str(" Message(\n");
813 out.push_str(" connectionId: connId,\n");
814 out.push_str(" payload: .requestMessage(\n");
815 out.push_str(" .init(\n");
816 out.push_str(" id: requestId,\n");
817 out.push_str(" body: .response(.init(metadata: metadata, ret: .init(payload), schemas: schemas))\n");
818 out.push_str(" ))\n");
819 out.push_str(" )\n");
820 out.push_str(" }\n\n");
821
822 out.push_str(" static func cancel(connId: UInt64, requestId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
823 out.push_str(" Message(\n");
824 out.push_str(" connectionId: connId,\n");
825 out.push_str(" payload: .requestMessage(\n");
826 out.push_str(" .init(\n");
827 out.push_str(" id: requestId,\n");
828 out.push_str(" body: .cancel(.init(metadata: metadata))\n");
829 out.push_str(" ))\n");
830 out.push_str(" )\n");
831 out.push_str(" }\n\n");
832
833 out.push_str(
834 " static func data(connId: UInt64, channelId: UInt64, payload: [UInt8]) -> Message {\n",
835 );
836 out.push_str(" Message(\n");
837 out.push_str(" connectionId: connId,\n");
838 out.push_str(" payload: .channelMessage(.init(id: channelId, body: .item(.init(item: .init(payload)))))\n");
839 out.push_str(" )\n");
840 out.push_str(" }\n\n");
841
842 out.push_str(" static func close(connId: UInt64, channelId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
843 out.push_str(" Message(\n");
844 out.push_str(" connectionId: connId,\n");
845 out.push_str(" payload: .channelMessage(.init(id: channelId, body: .close(.init(metadata: metadata))))\n");
846 out.push_str(" )\n");
847 out.push_str(" }\n\n");
848
849 out.push_str(" static func reset(connId: UInt64, channelId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
850 out.push_str(" Message(\n");
851 out.push_str(" connectionId: connId,\n");
852 out.push_str(" payload: .channelMessage(.init(id: channelId, body: .reset(.init(metadata: metadata))))\n");
853 out.push_str(" )\n");
854 out.push_str(" }\n\n");
855
856 out.push_str(
857 " static func credit(connId: UInt64, channelId: UInt64, bytes: UInt32) -> Message {\n",
858 );
859 out.push_str(" Message(\n");
860 out.push_str(" connectionId: connId,\n");
861 out.push_str(" payload: .channelMessage(.init(id: channelId, body: .grantCredit(.init(additional: bytes))))\n");
862 out.push_str(" )\n");
863 out.push_str(" }\n");
864
865 out.push_str("}\n");
866 out
867}