Skip to main content

vox_codegen/targets/swift/
schema.rs

1//! Swift binding-schema generation for runtime channel binding.
2//!
3//! Generates runtime schema information for channel discovery and wire schema exchange.
4
5use facet_core::{Facet, ScalarType, Shape};
6use heck::{ToLowerCamelCase, ToUpperCamelCase};
7use vox_types::{
8    EnumInfo, ServiceDescriptor, ShapeKind, StructInfo, TypeRef, VariantKind, VoxError,
9    classify_shape, classify_variant, extract_schemas, is_bytes,
10};
11
12use crate::code_writer::CodeWriter;
13use crate::cw_writeln;
14
15/// Generate complete schema code (method schemas + serializers + wire schemas).
16pub fn generate_schemas(service: &ServiceDescriptor) -> String {
17    let mut out = String::new();
18    out.push_str(&generate_method_schemas(service));
19    out.push_str(&generate_wire_schemas(service));
20    out.push_str(&generate_serializers(service));
21    out
22}
23
24/// Generate method schemas for runtime channel binding.
25fn generate_method_schemas(service: &ServiceDescriptor) -> String {
26    let mut out = String::new();
27    let service_name = service.service_name.to_lower_camel_case();
28
29    out.push_str(&format!(
30        "public let {service_name}_schemas: [String: MethodBindingSchema] = [\n"
31    ));
32
33    for method in service.methods {
34        let method_name = method.method_name.to_lower_camel_case();
35        out.push_str(&format!(
36            "    \"{method_name}\": MethodBindingSchema(args: ["
37        ));
38
39        let schemas: Vec<String> = method
40            .args
41            .iter()
42            .map(|a| shape_to_schema(a.shape))
43            .collect();
44        out.push_str(&schemas.join(", "));
45
46        out.push_str("]),\n");
47    }
48
49    out.push_str("]\n\n");
50    out
51}
52
53/// Generate wire schema infrastructure for protocol schema exchange.
54///
55/// Generates:
56/// 1. A global schema registry containing all schemas for all methods (deduplicated)
57/// 2. Per-method schema ID lists and root TypeRefs for args and response
58///
59/// At runtime, the Swift code filters schemas through SchemaSendTracker before encoding.
60fn generate_wire_schemas(service: &ServiceDescriptor) -> String {
61    use crate::render::hex_u64;
62    use std::collections::HashMap;
63    use vox_types::{Schema, SchemaHash};
64
65    let service_name = service.service_name.to_lower_camel_case();
66
67    // Extract Result and VoxError schemas once (used for wrapping responses)
68    let result_extracted =
69        extract_schemas(<Result<bool, u32> as Facet<'static>>::SHAPE).expect("Result schema");
70    let result_type_id = match &result_extracted.root {
71        TypeRef::Concrete { type_id, .. } => *type_id,
72        _ => panic!("Result root should be concrete"),
73    };
74
75    let vox_error_extracted =
76        extract_schemas(<VoxError<std::convert::Infallible> as Facet<'static>>::SHAPE)
77            .expect("VoxError schema");
78    let vox_error_type_id = match &vox_error_extracted.root {
79        TypeRef::Concrete { type_id, .. } => *type_id,
80        _ => panic!("VoxError root should be concrete"),
81    };
82
83    // Collect all schemas across all methods into a global registry
84    let mut global_schemas: HashMap<SchemaHash, Schema> = HashMap::new();
85
86    // Add Result and VoxError schemas
87    for schema in result_extracted.schemas.iter() {
88        global_schemas.insert(schema.id, schema.clone());
89    }
90    for schema in vox_error_extracted.schemas.iter() {
91        global_schemas.insert(schema.id, schema.clone());
92    }
93
94    // Per-method info: (args_schema_ids, args_root, response_schema_ids, response_root)
95    struct MethodSchemaInfo {
96        args_schema_ids: Vec<SchemaHash>,
97        args_root: TypeRef,
98        response_schema_ids: Vec<SchemaHash>,
99        response_root: TypeRef,
100    }
101    let mut method_infos: Vec<(u64, MethodSchemaInfo)> = Vec::new();
102
103    for method in service.methods {
104        let method_id = crate::method_id(method);
105
106        // Extract args schemas
107        let args_extracted = extract_schemas(method.args_shape).expect("args schema extraction");
108        let args_schema_ids: Vec<SchemaHash> =
109            args_extracted.schemas.iter().map(|s| s.id).collect();
110        for schema in args_extracted.schemas {
111            global_schemas.insert(schema.id, schema);
112        }
113
114        // Extract response schemas - wrap in Result<T, VoxError<E>>
115        let (ok_extracted, err_extracted) = match classify_shape(method.return_shape) {
116            ShapeKind::Result { ok, err } => (
117                extract_schemas(ok).expect("ok schema"),
118                extract_schemas(err).expect("err schema"),
119            ),
120            _ => (
121                extract_schemas(method.return_shape).expect("return schema"),
122                extract_schemas(<std::convert::Infallible as Facet<'static>>::SHAPE)
123                    .expect("Infallible schema"),
124            ),
125        };
126
127        // Collect response schema IDs (including Result and VoxError)
128        let mut response_schema_ids: Vec<SchemaHash> = Vec::new();
129        for schema in result_extracted.schemas.iter() {
130            response_schema_ids.push(schema.id);
131        }
132        for schema in vox_error_extracted.schemas.iter() {
133            response_schema_ids.push(schema.id);
134        }
135        for schema in ok_extracted.schemas {
136            response_schema_ids.push(schema.id);
137            global_schemas.insert(schema.id, schema);
138        }
139        for schema in err_extracted.schemas {
140            response_schema_ids.push(schema.id);
141            global_schemas.insert(schema.id, schema);
142        }
143
144        // Deduplicate schema IDs (smaller codegen output)
145        let mut seen = std::collections::HashSet::new();
146        response_schema_ids.retain(|id| seen.insert(*id));
147
148        // Build the response root: Result<ok_root, VoxError<err_root>>
149        let vox_error_ref = TypeRef::generic(vox_error_type_id, vec![err_extracted.root]);
150        let response_root =
151            TypeRef::generic(result_type_id, vec![ok_extracted.root, vox_error_ref]);
152
153        method_infos.push((
154            method_id,
155            MethodSchemaInfo {
156                args_schema_ids,
157                args_root: args_extracted.root,
158                response_schema_ids,
159                response_root,
160            },
161        ));
162    }
163
164    let mut out = String::new();
165
166    // Generate global schema registry
167    out.push_str("/// Global schema registry containing all schemas for this service.\n");
168    out.push_str(&format!(
169        "public let {service_name}_schema_registry: [UInt64: Schema] = [\n"
170    ));
171
172    let mut sorted_schemas: Vec<_> = global_schemas.into_iter().collect();
173    sorted_schemas.sort_by_key(|(id, _)| *id);
174
175    for (schema_id, schema) in &sorted_schemas {
176        out.push_str(&format!(
177            "    {}: {},\n",
178            hex_u64(schema_id.0),
179            format_swift_schema(schema)
180        ));
181    }
182    out.push_str("]\n\n");
183
184    // Generate per-method schema info
185    out.push_str("/// Per-method schema information for wire protocol.\n");
186    out.push_str(&format!(
187        "public let {service_name}_method_schemas: [UInt64: MethodSchemaInfo] = [\n"
188    ));
189
190    for (method_id, info) in &method_infos {
191        out.push_str(&format!("    {}: MethodSchemaInfo(\n", hex_u64(*method_id)));
192        out.push_str(&format!(
193            "        argsSchemaIds: [{}],\n",
194            info.args_schema_ids
195                .iter()
196                .map(|id| hex_u64(id.0))
197                .collect::<Vec<_>>()
198                .join(", ")
199        ));
200        out.push_str(&format!(
201            "        argsRoot: {},\n",
202            format_swift_type_ref(&info.args_root)
203        ));
204        out.push_str(&format!(
205            "        responseSchemaIds: [{}],\n",
206            info.response_schema_ids
207                .iter()
208                .map(|id| hex_u64(id.0))
209                .collect::<Vec<_>>()
210                .join(", ")
211        ));
212        out.push_str(&format!(
213            "        responseRoot: {}\n",
214            format_swift_type_ref(&info.response_root)
215        ));
216        out.push_str("    ),\n");
217    }
218    out.push_str("]\n\n");
219
220    out
221}
222
223/// Format a Schema as Swift code.
224fn format_swift_schema(schema: &vox_types::Schema) -> String {
225    use crate::render::hex_u64;
226
227    let type_params = if schema.type_params.is_empty() {
228        "[]".to_string()
229    } else {
230        format!(
231            "[{}]",
232            schema
233                .type_params
234                .iter()
235                .map(|p| format!("\"{}\"", p.as_str()))
236                .collect::<Vec<_>>()
237                .join(", ")
238        )
239    };
240
241    format!(
242        "Schema(id: {}, typeParams: {}, kind: {})",
243        hex_u64(schema.id.0),
244        type_params,
245        format_swift_schema_kind(&schema.kind)
246    )
247}
248
249/// Format a SchemaKind as Swift code.
250fn format_swift_schema_kind(kind: &vox_types::SchemaKind) -> String {
251    use vox_types::SchemaKind;
252
253    match kind {
254        SchemaKind::Struct { name, fields } => {
255            let fields_str = fields
256                .iter()
257                .map(|f| {
258                    format!(
259                        "FieldSchema(name: \"{}\", typeRef: {}, required: {})",
260                        f.name,
261                        format_swift_type_ref(&f.type_ref),
262                        f.required
263                    )
264                })
265                .collect::<Vec<_>>()
266                .join(", ");
267            format!(".struct(name: \"{}\", fields: [{}])", name, fields_str)
268        }
269        SchemaKind::Enum { name, variants } => {
270            let variants_str = variants
271                .iter()
272                .map(|v| {
273                    format!(
274                        "VariantSchema(name: \"{}\", index: {}, payload: {})",
275                        v.name,
276                        v.index,
277                        format_swift_variant_payload(&v.payload)
278                    )
279                })
280                .collect::<Vec<_>>()
281                .join(", ");
282            format!(".enum(name: \"{}\", variants: [{}])", name, variants_str)
283        }
284        SchemaKind::Tuple { elements } => {
285            let elems_str = elements
286                .iter()
287                .map(format_swift_type_ref)
288                .collect::<Vec<_>>()
289                .join(", ");
290            format!(".tuple(elements: [{}])", elems_str)
291        }
292        SchemaKind::List { element } => {
293            format!(".list(element: {})", format_swift_type_ref(element))
294        }
295        SchemaKind::Map { key, value } => {
296            format!(
297                ".map(key: {}, value: {})",
298                format_swift_type_ref(key),
299                format_swift_type_ref(value)
300            )
301        }
302        SchemaKind::Array { element, length } => {
303            format!(
304                ".array(element: {}, length: {})",
305                format_swift_type_ref(element),
306                length
307            )
308        }
309        SchemaKind::Option { element } => {
310            format!(".option(element: {})", format_swift_type_ref(element))
311        }
312        SchemaKind::Channel { direction, element } => {
313            let dir = match direction {
314                vox_types::ChannelDirection::Tx => ".tx",
315                vox_types::ChannelDirection::Rx => ".rx",
316            };
317            format!(
318                ".channel(direction: {}, element: {})",
319                dir,
320                format_swift_type_ref(element)
321            )
322        }
323        SchemaKind::Primitive { primitive_type } => {
324            format!(".primitive({})", format_swift_primitive(primitive_type))
325        }
326    }
327}
328
329/// Format a VariantPayload as Swift code.
330fn format_swift_variant_payload(payload: &vox_types::VariantPayload) -> String {
331    use vox_types::VariantPayload;
332
333    match payload {
334        VariantPayload::Unit => ".unit".to_string(),
335        VariantPayload::Newtype { type_ref } => {
336            format!(".newtype(typeRef: {})", format_swift_type_ref(type_ref))
337        }
338        VariantPayload::Tuple { types } => {
339            let types_str = types
340                .iter()
341                .map(format_swift_type_ref)
342                .collect::<Vec<_>>()
343                .join(", ");
344            format!(".tuple(types: [{}])", types_str)
345        }
346        VariantPayload::Struct { fields } => {
347            let fields_str = fields
348                .iter()
349                .map(|f| {
350                    format!(
351                        "FieldSchema(name: \"{}\", typeRef: {}, required: {})",
352                        f.name,
353                        format_swift_type_ref(&f.type_ref),
354                        f.required
355                    )
356                })
357                .collect::<Vec<_>>()
358                .join(", ");
359            format!(".struct(fields: [{}])", fields_str)
360        }
361    }
362}
363
364/// Format a TypeRef as Swift code.
365fn format_swift_type_ref(type_ref: &TypeRef) -> String {
366    use crate::render::hex_u64;
367
368    match type_ref {
369        TypeRef::Concrete { type_id, args } => {
370            if args.is_empty() {
371                format!(".concrete({})", hex_u64(type_id.0))
372            } else {
373                let args_str = args
374                    .iter()
375                    .map(format_swift_type_ref)
376                    .collect::<Vec<_>>()
377                    .join(", ");
378                format!(".generic({}, args: [{}])", hex_u64(type_id.0), args_str)
379            }
380        }
381        TypeRef::Var { name } => {
382            format!(".var(name: \"{}\")", name.as_str())
383        }
384    }
385}
386
387/// Format a PrimitiveType as Swift code.
388fn format_swift_primitive(prim: &vox_types::PrimitiveType) -> String {
389    use vox_types::PrimitiveType;
390
391    match prim {
392        PrimitiveType::Bool => ".bool",
393        PrimitiveType::U8 => ".u8",
394        PrimitiveType::U16 => ".u16",
395        PrimitiveType::U32 => ".u32",
396        PrimitiveType::U64 => ".u64",
397        PrimitiveType::U128 => ".u128",
398        PrimitiveType::I8 => ".i8",
399        PrimitiveType::I16 => ".i16",
400        PrimitiveType::I32 => ".i32",
401        PrimitiveType::I64 => ".i64",
402        PrimitiveType::I128 => ".i128",
403        PrimitiveType::F32 => ".f32",
404        PrimitiveType::F64 => ".f64",
405        PrimitiveType::Char => ".char",
406        PrimitiveType::String => ".string",
407        PrimitiveType::Unit => ".unit",
408        PrimitiveType::Never => ".never",
409        PrimitiveType::Bytes => ".bytes",
410        PrimitiveType::Payload => ".payload",
411    }
412    .to_string()
413}
414
415/// Convert a Shape to its Swift binding-schema representation.
416fn shape_to_schema(shape: &'static Shape) -> String {
417    if is_bytes(shape) {
418        return ".bytes".into();
419    }
420
421    match classify_shape(shape) {
422        ShapeKind::Scalar(scalar) => match scalar {
423            ScalarType::Bool => ".bool".into(),
424            ScalarType::U8 => ".u8".into(),
425            ScalarType::U16 => ".u16".into(),
426            ScalarType::U32 => ".u32".into(),
427            ScalarType::U64 => ".u64".into(),
428            ScalarType::I8 => ".i8".into(),
429            ScalarType::I16 => ".i16".into(),
430            ScalarType::I32 => ".i32".into(),
431            ScalarType::I64 => ".i64".into(),
432            ScalarType::F32 => ".f32".into(),
433            ScalarType::F64 => ".f64".into(),
434            ScalarType::Str | ScalarType::CowStr | ScalarType::String => ".string".into(),
435            ScalarType::Unit => ".tuple(elements: [])".into(),
436            _ => ".bytes".into(), // fallback
437        },
438        ShapeKind::List { element } | ShapeKind::Slice { element } => {
439            format!(".vec(element: {})", shape_to_schema(element))
440        }
441        ShapeKind::Option { inner } => {
442            format!(".option(inner: {})", shape_to_schema(inner))
443        }
444        ShapeKind::Map { key, value } => {
445            format!(
446                ".map(key: {}, value: {})",
447                shape_to_schema(key),
448                shape_to_schema(value)
449            )
450        }
451        ShapeKind::Tx { inner } => format!(".tx(element: {})", shape_to_schema(inner)),
452        ShapeKind::Rx { inner } => format!(".rx(element: {})", shape_to_schema(inner)),
453        ShapeKind::Tuple { elements } => {
454            let inner: Vec<String> = elements.iter().map(|p| shape_to_schema(p.shape)).collect();
455            format!(".tuple(elements: [{}])", inner.join(", "))
456        }
457        ShapeKind::Struct(StructInfo { fields, .. }) => {
458            let field_strs: Vec<String> = fields
459                .iter()
460                .map(|f| format!("(\"{}\", {})", f.name, shape_to_schema(f.shape())))
461                .collect();
462            format!(".struct(fields: [{}])", field_strs.join(", "))
463        }
464        ShapeKind::Enum(EnumInfo { variants, .. }) => {
465            let variant_strs: Vec<String> = variants
466                .iter()
467                .map(|v| {
468                    let fields: Vec<String> = match classify_variant(v) {
469                        VariantKind::Unit => vec![],
470                        VariantKind::Newtype { inner } => vec![shape_to_schema(inner)],
471                        VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
472                            fields.iter().map(|f| shape_to_schema(f.shape())).collect()
473                        }
474                    };
475                    format!("(\"{}\", [{}])", v.name, fields.join(", "))
476                })
477                .collect();
478            format!(".enum(variants: [{}])", variant_strs.join(", "))
479        }
480        _ => ".bytes".into(), // fallback for unknown types
481    }
482}
483
484/// Generate serializers for runtime channel binding.
485fn generate_serializers(service: &ServiceDescriptor) -> String {
486    let mut out = String::new();
487    let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
488    let service_name_upper = service.service_name.to_upper_camel_case();
489
490    cw_writeln!(
491        w,
492        "public struct {service_name_upper}Serializers: BindingSerializers {{"
493    )
494    .unwrap();
495    {
496        let _indent = w.indent();
497        w.writeln("public init() {}").unwrap();
498        w.blank_line().unwrap();
499
500        // txSerializer
501        w.writeln(
502            "public func txSerializer(for schema: BindingSchema) -> @Sendable (Any) -> [UInt8] {",
503        )
504        .unwrap();
505        {
506            let _indent = w.indent();
507            w.writeln("switch schema {").unwrap();
508            w.writeln("case .bool: return { encodeBool($0 as! Bool) }")
509                .unwrap();
510            w.writeln("case .u8: return { encodeU8($0 as! UInt8) }")
511                .unwrap();
512            w.writeln("case .i8: return { encodeI8($0 as! Int8) }")
513                .unwrap();
514            w.writeln("case .u16: return { encodeU16($0 as! UInt16) }")
515                .unwrap();
516            w.writeln("case .i16: return { encodeI16($0 as! Int16) }")
517                .unwrap();
518            w.writeln("case .u32: return { encodeU32($0 as! UInt32) }")
519                .unwrap();
520            w.writeln("case .i32: return { encodeI32($0 as! Int32) }")
521                .unwrap();
522            w.writeln("case .u64: return { encodeVarint($0 as! UInt64) }")
523                .unwrap();
524            w.writeln("case .i64: return { encodeI64($0 as! Int64) }")
525                .unwrap();
526            w.writeln("case .f32: return { encodeF32($0 as! Float) }")
527                .unwrap();
528            w.writeln("case .f64: return { encodeF64($0 as! Double) }")
529                .unwrap();
530            w.writeln("case .string: return { encodeString($0 as! String) }")
531                .unwrap();
532            w.writeln("case .bytes: return { [UInt8]($0 as! Data) }")
533                .unwrap();
534            w.writeln(
535                "case .tx(_, _), .rx(_, _): fatalError(\"Channel schemas are not serialized directly\")",
536            )
537            .unwrap();
538            w.writeln(
539                "default: fatalError(\"Unsupported schema for Tx serialization: \\(schema)\")",
540            )
541            .unwrap();
542            w.writeln("}").unwrap();
543        }
544        w.writeln("}").unwrap();
545        w.blank_line().unwrap();
546
547        // rxDeserializer
548        w.writeln(
549            "public func rxDeserializer(for schema: BindingSchema) -> @Sendable ([UInt8]) throws -> Any {",
550        )
551        .unwrap();
552        {
553            let _indent = w.indent();
554            w.writeln("switch schema {").unwrap();
555            w.writeln("case .bool: return { var o = 0; return try decodeBool(from: Data($0), offset: &o) }").unwrap();
556            w.writeln(
557                "case .u8: return { var o = 0; return try decodeU8(from: Data($0), offset: &o) }",
558            )
559            .unwrap();
560            w.writeln(
561                "case .i8: return { var o = 0; return try decodeI8(from: Data($0), offset: &o) }",
562            )
563            .unwrap();
564            w.writeln(
565                "case .u16: return { var o = 0; return try decodeU16(from: Data($0), offset: &o) }",
566            )
567            .unwrap();
568            w.writeln(
569                "case .i16: return { var o = 0; return try decodeI16(from: Data($0), offset: &o) }",
570            )
571            .unwrap();
572            w.writeln(
573                "case .u32: return { var o = 0; return try decodeU32(from: Data($0), offset: &o) }",
574            )
575            .unwrap();
576            w.writeln(
577                "case .i32: return { var o = 0; return try decodeI32(from: Data($0), offset: &o) }",
578            )
579            .unwrap();
580            w.writeln("case .u64: return { var o = 0; return try decodeVarint(from: Data($0), offset: &o) }").unwrap();
581            w.writeln(
582                "case .i64: return { var o = 0; return try decodeI64(from: Data($0), offset: &o) }",
583            )
584            .unwrap();
585            w.writeln(
586                "case .f32: return { var o = 0; return try decodeF32(from: Data($0), offset: &o) }",
587            )
588            .unwrap();
589            w.writeln(
590                "case .f64: return { var o = 0; return try decodeF64(from: Data($0), offset: &o) }",
591            )
592            .unwrap();
593            w.writeln("case .string: return { var o = 0; return try decodeString(from: Data($0), offset: &o) }").unwrap();
594            w.writeln("case .bytes: return { Data($0) }").unwrap();
595            w.writeln(
596                "case .tx(_, _), .rx(_, _): fatalError(\"Channel schemas are not deserialized directly\")",
597            )
598            .unwrap();
599            w.writeln(
600                "default: fatalError(\"Unsupported schema for Rx deserialization: \\(schema)\")",
601            )
602            .unwrap();
603            w.writeln("}").unwrap();
604        }
605        w.writeln("}").unwrap();
606    }
607    w.writeln("}").unwrap();
608    w.blank_line().unwrap();
609
610    out
611}