Skip to main content

vox_codegen/targets/swift/
schema.rs

1//! Swift binding-schema generation for runtime channel binding.
2//!
3//! Generates runtime schema information for channel discovery and wire schema exchange.
4
5use facet_core::{Facet, ScalarType, Shape};
6use heck::{ToLowerCamelCase, ToUpperCamelCase};
7use vox_types::{
8    EnumInfo, ServiceDescriptor, ShapeKind, StructInfo, TypeRef, VariantKind, VoxError,
9    classify_shape, classify_variant, extract_schemas, is_bytes,
10};
11
12use crate::code_writer::CodeWriter;
13use crate::cw_writeln;
14
15/// Generate complete schema code (method schemas + serializers + wire schemas).
16pub fn generate_schemas(service: &ServiceDescriptor) -> String {
17    let mut out = String::new();
18    out.push_str(&generate_method_schemas(service));
19    out.push_str(&generate_wire_schemas(service));
20    out.push_str(&generate_serializers(service));
21    out
22}
23
24/// Generate method schemas for runtime channel binding.
25fn generate_method_schemas(service: &ServiceDescriptor) -> String {
26    let mut out = String::new();
27    let service_name = service.service_name.to_lower_camel_case();
28
29    out.push_str(&format!(
30        "public let {service_name}_schemas: [String: MethodBindingSchema] = [\n"
31    ));
32
33    for method in service.methods {
34        let method_name = method.method_name.to_lower_camel_case();
35        out.push_str(&format!(
36            "    \"{method_name}\": MethodBindingSchema(args: ["
37        ));
38
39        let schemas: Vec<String> = method
40            .args
41            .iter()
42            .map(|a| shape_to_schema(a.shape))
43            .collect();
44        out.push_str(&schemas.join(", "));
45
46        out.push_str("]),\n");
47    }
48
49    out.push_str("]\n\n");
50    out
51}
52
53/// Generate wire schema infrastructure for protocol schema exchange.
54///
55/// Generates:
56/// 1. A global schema registry containing all schemas for all methods (deduplicated)
57/// 2. Per-method schema ID lists and root TypeRefs for args and response
58///
59/// At runtime, the Swift code filters schemas through SchemaSendTracker before encoding.
60fn generate_wire_schemas(service: &ServiceDescriptor) -> String {
61    use crate::render::hex_u64;
62    use std::collections::HashMap;
63    use vox_types::{Schema, SchemaHash};
64
65    let service_name = service.service_name.to_lower_camel_case();
66
67    // Extract Result and VoxError schemas once (used for wrapping responses)
68    let result_extracted =
69        extract_schemas(<Result<bool, u32> as Facet<'static>>::SHAPE).expect("Result schema");
70    let result_type_id = match &result_extracted.root {
71        TypeRef::Concrete { type_id, .. } => *type_id,
72        _ => panic!("Result root should be concrete"),
73    };
74
75    let vox_error_extracted =
76        extract_schemas(<VoxError<std::convert::Infallible> as Facet<'static>>::SHAPE)
77            .expect("VoxError schema");
78    let vox_error_type_id = match &vox_error_extracted.root {
79        TypeRef::Concrete { type_id, .. } => *type_id,
80        _ => panic!("VoxError root should be concrete"),
81    };
82
83    // Collect all schemas across all methods into a global registry
84    let mut global_schemas: HashMap<SchemaHash, Schema> = HashMap::new();
85
86    // Add Result and VoxError schemas
87    for schema in result_extracted.schemas.iter() {
88        global_schemas.insert(schema.id, schema.clone());
89    }
90    for schema in vox_error_extracted.schemas.iter() {
91        global_schemas.insert(schema.id, schema.clone());
92    }
93
94    // Per-method info: (args_schema_ids, args_root, response_schema_ids, response_root)
95    struct MethodSchemaInfo {
96        args_schema_ids: Vec<SchemaHash>,
97        args_root: TypeRef,
98        response_schema_ids: Vec<SchemaHash>,
99        response_root: TypeRef,
100    }
101    let mut method_infos: Vec<(u64, MethodSchemaInfo)> = Vec::new();
102
103    for method in service.methods {
104        let method_id = crate::method_id(method);
105
106        // Extract args schemas
107        let args_extracted = extract_schemas(method.args_shape).expect("args schema extraction");
108        let args_schema_ids: Vec<SchemaHash> =
109            args_extracted.schemas.iter().map(|s| s.id).collect();
110        for schema in args_extracted.schemas.iter().cloned() {
111            global_schemas.insert(schema.id, schema);
112        }
113
114        // Extract response schemas - wrap in Result<T, VoxError<E>>
115        let (ok_extracted, err_extracted) = match classify_shape(method.return_shape) {
116            ShapeKind::Result { ok, err } => (
117                extract_schemas(ok).expect("ok schema"),
118                extract_schemas(err).expect("err schema"),
119            ),
120            _ => (
121                extract_schemas(method.return_shape).expect("return schema"),
122                extract_schemas(<std::convert::Infallible as Facet<'static>>::SHAPE)
123                    .expect("Infallible schema"),
124            ),
125        };
126
127        // Collect response schema IDs (including Result and VoxError)
128        let mut response_schema_ids: Vec<SchemaHash> = Vec::new();
129        for schema in result_extracted.schemas.iter() {
130            response_schema_ids.push(schema.id);
131        }
132        for schema in vox_error_extracted.schemas.iter() {
133            response_schema_ids.push(schema.id);
134        }
135        for schema in ok_extracted.schemas.iter().cloned() {
136            response_schema_ids.push(schema.id);
137            global_schemas.insert(schema.id, schema);
138        }
139        for schema in err_extracted.schemas.iter().cloned() {
140            response_schema_ids.push(schema.id);
141            global_schemas.insert(schema.id, schema);
142        }
143
144        // Deduplicate schema IDs (smaller codegen output)
145        let mut seen = std::collections::HashSet::new();
146        response_schema_ids.retain(|id| seen.insert(*id));
147
148        // Build the response root: Result<ok_root, VoxError<err_root>>
149        let vox_error_ref = TypeRef::generic(vox_error_type_id, vec![err_extracted.root.clone()]);
150        let response_root = TypeRef::generic(
151            result_type_id,
152            vec![ok_extracted.root.clone(), vox_error_ref],
153        );
154
155        method_infos.push((
156            method_id,
157            MethodSchemaInfo {
158                args_schema_ids,
159                args_root: args_extracted.root.clone(),
160                response_schema_ids,
161                response_root,
162            },
163        ));
164    }
165
166    let mut out = String::new();
167
168    // Generate global schema registry
169    out.push_str("/// Global schema registry containing all schemas for this service.\n");
170    out.push_str(&format!(
171        "nonisolated(unsafe) public let {service_name}_schema_registry: [UInt64: Schema] = [\n"
172    ));
173
174    let mut sorted_schemas: Vec<_> = global_schemas.into_iter().collect();
175    sorted_schemas.sort_by_key(|(id, _)| *id);
176
177    for (schema_id, schema) in &sorted_schemas {
178        out.push_str(&format!(
179            "    {}: {},\n",
180            hex_u64(schema_id.0),
181            format_swift_schema(schema)
182        ));
183    }
184    out.push_str("]\n\n");
185
186    // Generate per-method schema info
187    out.push_str("/// Per-method schema information for wire protocol.\n");
188    out.push_str(&format!(
189        "nonisolated(unsafe) public let {service_name}_method_schemas: [UInt64: MethodSchemaInfo] = [\n"
190    ));
191
192    for (method_id, info) in &method_infos {
193        out.push_str(&format!("    {}: MethodSchemaInfo(\n", hex_u64(*method_id)));
194        out.push_str(&format!(
195            "        argsSchemaIds: [{}],\n",
196            info.args_schema_ids
197                .iter()
198                .map(|id| hex_u64(id.0))
199                .collect::<Vec<_>>()
200                .join(", ")
201        ));
202        out.push_str(&format!(
203            "        argsRoot: {},\n",
204            format_swift_type_ref(&info.args_root)
205        ));
206        out.push_str(&format!(
207            "        responseSchemaIds: [{}],\n",
208            info.response_schema_ids
209                .iter()
210                .map(|id| hex_u64(id.0))
211                .collect::<Vec<_>>()
212                .join(", ")
213        ));
214        out.push_str(&format!(
215            "        responseRoot: {}\n",
216            format_swift_type_ref(&info.response_root)
217        ));
218        out.push_str("    ),\n");
219    }
220    out.push_str("]\n\n");
221
222    out
223}
224
225/// Format a Schema as Swift code.
226fn format_swift_schema(schema: &vox_types::Schema) -> String {
227    use crate::render::hex_u64;
228
229    let type_params = if schema.type_params.is_empty() {
230        "[]".to_string()
231    } else {
232        format!(
233            "[{}]",
234            schema
235                .type_params
236                .iter()
237                .map(|p| format!("\"{}\"", p.as_str()))
238                .collect::<Vec<_>>()
239                .join(", ")
240        )
241    };
242
243    format!(
244        "Schema(id: {}, typeParams: {}, kind: {})",
245        hex_u64(schema.id.0),
246        type_params,
247        format_swift_schema_kind(&schema.kind)
248    )
249}
250
251/// Format a SchemaKind as Swift code.
252fn format_swift_schema_kind(kind: &vox_types::SchemaKind) -> String {
253    use vox_types::SchemaKind;
254
255    match kind {
256        SchemaKind::Struct { name, fields } => {
257            let fields_str = fields
258                .iter()
259                .map(|f| {
260                    format!(
261                        "FieldSchema(name: \"{}\", typeRef: {}, required: {})",
262                        f.name,
263                        format_swift_type_ref(&f.type_ref),
264                        f.required
265                    )
266                })
267                .collect::<Vec<_>>()
268                .join(", ");
269            format!(".struct(name: \"{}\", fields: [{}])", name, fields_str)
270        }
271        SchemaKind::Enum { name, variants } => {
272            let variants_str = variants
273                .iter()
274                .map(|v| {
275                    format!(
276                        "VariantSchema(name: \"{}\", index: {}, payload: {})",
277                        v.name,
278                        v.index,
279                        format_swift_variant_payload(&v.payload)
280                    )
281                })
282                .collect::<Vec<_>>()
283                .join(", ");
284            format!(".enum(name: \"{}\", variants: [{}])", name, variants_str)
285        }
286        SchemaKind::Tuple { elements } => {
287            let elems_str = elements
288                .iter()
289                .map(format_swift_type_ref)
290                .collect::<Vec<_>>()
291                .join(", ");
292            format!(".tuple(elements: [{}])", elems_str)
293        }
294        SchemaKind::List { element } => {
295            format!(".list(element: {})", format_swift_type_ref(element))
296        }
297        SchemaKind::Map { key, value } => {
298            format!(
299                ".map(key: {}, value: {})",
300                format_swift_type_ref(key),
301                format_swift_type_ref(value)
302            )
303        }
304        SchemaKind::Array { element, length } => {
305            format!(
306                ".array(element: {}, length: {})",
307                format_swift_type_ref(element),
308                length
309            )
310        }
311        SchemaKind::Option { element } => {
312            format!(".option(element: {})", format_swift_type_ref(element))
313        }
314        SchemaKind::Channel { direction, element } => {
315            let dir = match direction {
316                vox_types::ChannelDirection::Tx => ".tx",
317                vox_types::ChannelDirection::Rx => ".rx",
318            };
319            format!(
320                ".channel(direction: {}, element: {})",
321                dir,
322                format_swift_type_ref(element)
323            )
324        }
325        SchemaKind::Primitive { primitive_type } => {
326            format!(".primitive({})", format_swift_primitive(primitive_type))
327        }
328    }
329}
330
331/// Format a VariantPayload as Swift code.
332fn format_swift_variant_payload(payload: &vox_types::VariantPayload) -> String {
333    use vox_types::VariantPayload;
334
335    match payload {
336        VariantPayload::Unit => ".unit".to_string(),
337        VariantPayload::Newtype { type_ref } => {
338            format!(".newtype(typeRef: {})", format_swift_type_ref(type_ref))
339        }
340        VariantPayload::Tuple { types } => {
341            let types_str = types
342                .iter()
343                .map(format_swift_type_ref)
344                .collect::<Vec<_>>()
345                .join(", ");
346            format!(".tuple(types: [{}])", types_str)
347        }
348        VariantPayload::Struct { fields } => {
349            let fields_str = fields
350                .iter()
351                .map(|f| {
352                    format!(
353                        "FieldSchema(name: \"{}\", typeRef: {}, required: {})",
354                        f.name,
355                        format_swift_type_ref(&f.type_ref),
356                        f.required
357                    )
358                })
359                .collect::<Vec<_>>()
360                .join(", ");
361            format!(".struct(fields: [{}])", fields_str)
362        }
363    }
364}
365
366/// Format a TypeRef as Swift code.
367fn format_swift_type_ref(type_ref: &TypeRef) -> String {
368    use crate::render::hex_u64;
369
370    match type_ref {
371        TypeRef::Concrete { type_id, args } => {
372            if args.is_empty() {
373                format!(".concrete({})", hex_u64(type_id.0))
374            } else {
375                let args_str = args
376                    .iter()
377                    .map(format_swift_type_ref)
378                    .collect::<Vec<_>>()
379                    .join(", ");
380                format!(".generic({}, args: [{}])", hex_u64(type_id.0), args_str)
381            }
382        }
383        TypeRef::Var { name } => {
384            format!(".var(name: \"{}\")", name.as_str())
385        }
386    }
387}
388
389/// Format a PrimitiveType as Swift code.
390fn format_swift_primitive(prim: &vox_types::PrimitiveType) -> String {
391    use vox_types::PrimitiveType;
392
393    match prim {
394        PrimitiveType::Bool => ".bool",
395        PrimitiveType::U8 => ".u8",
396        PrimitiveType::U16 => ".u16",
397        PrimitiveType::U32 => ".u32",
398        PrimitiveType::U64 => ".u64",
399        PrimitiveType::U128 => ".u128",
400        PrimitiveType::I8 => ".i8",
401        PrimitiveType::I16 => ".i16",
402        PrimitiveType::I32 => ".i32",
403        PrimitiveType::I64 => ".i64",
404        PrimitiveType::I128 => ".i128",
405        PrimitiveType::F32 => ".f32",
406        PrimitiveType::F64 => ".f64",
407        PrimitiveType::Char => ".char",
408        PrimitiveType::String => ".string",
409        PrimitiveType::Unit => ".unit",
410        PrimitiveType::Never => ".never",
411        PrimitiveType::Bytes => ".bytes",
412        PrimitiveType::Payload => ".payload",
413    }
414    .to_string()
415}
416
417/// Convert a Shape to its Swift binding-schema representation.
418fn shape_to_schema(shape: &'static Shape) -> String {
419    if is_bytes(shape) {
420        return ".bytes".into();
421    }
422
423    match classify_shape(shape) {
424        ShapeKind::Scalar(scalar) => match scalar {
425            ScalarType::Bool => ".bool".into(),
426            ScalarType::U8 => ".u8".into(),
427            ScalarType::U16 => ".u16".into(),
428            ScalarType::U32 => ".u32".into(),
429            ScalarType::U64 => ".u64".into(),
430            ScalarType::I8 => ".i8".into(),
431            ScalarType::I16 => ".i16".into(),
432            ScalarType::I32 => ".i32".into(),
433            ScalarType::I64 => ".i64".into(),
434            ScalarType::F32 => ".f32".into(),
435            ScalarType::F64 => ".f64".into(),
436            ScalarType::Str | ScalarType::CowStr | ScalarType::String => ".string".into(),
437            ScalarType::Unit => ".tuple(elements: [])".into(),
438            _ => ".bytes".into(), // fallback
439        },
440        ShapeKind::List { element } | ShapeKind::Slice { element } => {
441            format!(".vec(element: {})", shape_to_schema(element))
442        }
443        ShapeKind::Option { inner } => {
444            format!(".option(inner: {})", shape_to_schema(inner))
445        }
446        ShapeKind::Map { key, value } => {
447            format!(
448                ".map(key: {}, value: {})",
449                shape_to_schema(key),
450                shape_to_schema(value)
451            )
452        }
453        ShapeKind::Tx { inner } => format!(".tx(element: {})", shape_to_schema(inner)),
454        ShapeKind::Rx { inner } => format!(".rx(element: {})", shape_to_schema(inner)),
455        ShapeKind::Tuple { elements } => {
456            let inner: Vec<String> = elements.iter().map(|p| shape_to_schema(p.shape)).collect();
457            format!(".tuple(elements: [{}])", inner.join(", "))
458        }
459        ShapeKind::Struct(StructInfo { fields, .. }) => {
460            let field_strs: Vec<String> = fields
461                .iter()
462                .map(|f| format!("(\"{}\", {})", f.name, shape_to_schema(f.shape())))
463                .collect();
464            format!(".struct(fields: [{}])", field_strs.join(", "))
465        }
466        ShapeKind::Enum(EnumInfo { variants, .. }) => {
467            let variant_strs: Vec<String> = variants
468                .iter()
469                .map(|v| {
470                    let fields: Vec<String> = match classify_variant(v) {
471                        VariantKind::Unit => vec![],
472                        VariantKind::Newtype { inner } => vec![shape_to_schema(inner)],
473                        VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
474                            fields.iter().map(|f| shape_to_schema(f.shape())).collect()
475                        }
476                    };
477                    format!("(\"{}\", [{}])", v.name, fields.join(", "))
478                })
479                .collect();
480            format!(".enum(variants: [{}])", variant_strs.join(", "))
481        }
482        _ => ".bytes".into(), // fallback for unknown types
483    }
484}
485
486/// Generate serializers for runtime channel binding.
487fn generate_serializers(service: &ServiceDescriptor) -> String {
488    let mut out = String::new();
489    let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
490    let service_name_upper = service.service_name.to_upper_camel_case();
491
492    cw_writeln!(
493        w,
494        "public struct {service_name_upper}Serializers: BindingSerializers {{"
495    )
496    .unwrap();
497    {
498        let _indent = w.indent();
499        w.writeln("public init() {}").unwrap();
500        w.blank_line().unwrap();
501
502        // txSerializer — returns (Any) -> [UInt8] by wrapping the ByteBuffer-based encoders
503        w.writeln(
504            "public func txSerializer(for schema: BindingSchema) -> @Sendable (Any) -> [UInt8] {",
505        )
506        .unwrap();
507        {
508            let _indent = w.indent();
509            w.writeln("switch schema {").unwrap();
510            w.writeln("case .bool: return { var b = ByteBufferAllocator().buffer(capacity: 1); encodeBool($0 as! Bool, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
511            w.writeln("case .u8: return { var b = ByteBufferAllocator().buffer(capacity: 1); encodeU8($0 as! UInt8, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
512            w.writeln("case .i8: return { var b = ByteBufferAllocator().buffer(capacity: 1); encodeI8($0 as! Int8, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
513            w.writeln("case .u16: return { var b = ByteBufferAllocator().buffer(capacity: 2); encodeU16($0 as! UInt16, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
514            w.writeln("case .i16: return { var b = ByteBufferAllocator().buffer(capacity: 2); encodeI16($0 as! Int16, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
515            w.writeln("case .u32: return { var b = ByteBufferAllocator().buffer(capacity: 4); encodeU32($0 as! UInt32, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
516            w.writeln("case .i32: return { var b = ByteBufferAllocator().buffer(capacity: 4); encodeI32($0 as! Int32, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
517            w.writeln("case .u64: return { var b = ByteBufferAllocator().buffer(capacity: 9); encodeVarint($0 as! UInt64, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
518            w.writeln("case .i64: return { var b = ByteBufferAllocator().buffer(capacity: 9); encodeI64($0 as! Int64, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
519            w.writeln("case .f32: return { var b = ByteBufferAllocator().buffer(capacity: 4); encodeF32($0 as! Float, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
520            w.writeln("case .f64: return { var b = ByteBufferAllocator().buffer(capacity: 8); encodeF64($0 as! Double, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
521            w.writeln("case .string: return { var b = ByteBufferAllocator().buffer(capacity: 64); encodeString($0 as! String, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
522            w.writeln("case .bytes: return { [UInt8]($0 as! Data) }")
523                .unwrap();
524            w.writeln(
525                "case .tx(_, _), .rx(_, _): fatalError(\"Channel schemas are not serialized directly\")",
526            )
527            .unwrap();
528            w.writeln(
529                "default: fatalError(\"Unsupported schema for Tx serialization: \\(schema)\")",
530            )
531            .unwrap();
532            w.writeln("}").unwrap();
533        }
534        w.writeln("}").unwrap();
535        w.blank_line().unwrap();
536
537        // rxDeserializer — takes [UInt8], wraps in ByteBuffer for decoding
538        w.writeln(
539            "public func rxDeserializer(for schema: BindingSchema) -> @Sendable ([UInt8]) throws -> Any {",
540        )
541        .unwrap();
542        {
543            let _indent = w.indent();
544            w.writeln("switch schema {").unwrap();
545            w.writeln("case .bool: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeBool(from: &b) }").unwrap();
546            w.writeln("case .u8: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeU8(from: &b) }").unwrap();
547            w.writeln("case .i8: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeI8(from: &b) }").unwrap();
548            w.writeln("case .u16: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeU16(from: &b) }").unwrap();
549            w.writeln("case .i16: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeI16(from: &b) }").unwrap();
550            w.writeln("case .u32: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeU32(from: &b) }").unwrap();
551            w.writeln("case .i32: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeI32(from: &b) }").unwrap();
552            w.writeln("case .u64: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeVarint(from: &b) }").unwrap();
553            w.writeln("case .i64: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeI64(from: &b) }").unwrap();
554            w.writeln("case .f32: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeF32(from: &b) }").unwrap();
555            w.writeln("case .f64: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeF64(from: &b) }").unwrap();
556            w.writeln("case .string: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeString(from: &b) }").unwrap();
557            w.writeln("case .bytes: return { Data($0) }").unwrap();
558            w.writeln(
559                "case .tx(_, _), .rx(_, _): fatalError(\"Channel schemas are not deserialized directly\")",
560            )
561            .unwrap();
562            w.writeln(
563                "default: fatalError(\"Unsupported schema for Rx deserialization: \\(schema)\")",
564            )
565            .unwrap();
566            w.writeln("}").unwrap();
567        }
568        w.writeln("}").unwrap();
569    }
570    w.writeln("}").unwrap();
571    w.blank_line().unwrap();
572
573    out
574}