Skip to main content

vox_codegen/targets/swift/
schema.rs

1//! Swift schema generation for wire schema exchange.
2//!
3//! Generates runtime schema information for protocol wire schema exchange
4//! using the canonical `Schema`/`SchemaKind`/`TypeRef` representation.
5//! Type-level channel binding is driven by the same schema registry.
6
7use facet_core::Facet;
8use heck::ToLowerCamelCase;
9use vox_types::{ServiceDescriptor, ShapeKind, TypeRef, VoxError, classify_shape, extract_schemas};
10
11/// Generate schema code (wire schemas only — binding uses the same registry).
12pub fn generate_schemas(service: &ServiceDescriptor) -> String {
13    generate_wire_schemas(service)
14}
15
16/// Generate wire schema infrastructure for protocol schema exchange.
17///
18/// Generates:
19/// 1. A global schema registry containing all schemas for all methods (deduplicated)
20/// 2. Per-method schema ID lists and root TypeRefs for args and response
21///
22/// At runtime, the Swift code filters schemas through SchemaSendTracker before encoding.
23fn generate_wire_schemas(service: &ServiceDescriptor) -> String {
24    use crate::render::hex_u64;
25    use std::collections::HashMap;
26    use vox_types::{Schema, SchemaHash};
27
28    let service_name = service.service_name.to_lower_camel_case();
29
30    // Extract Result and VoxError schemas once (used for wrapping responses)
31    let result_extracted =
32        extract_schemas(<Result<bool, u32> as Facet<'static>>::SHAPE).expect("Result schema");
33    let result_type_id = match &result_extracted.root {
34        TypeRef::Concrete { type_id, .. } => *type_id,
35        _ => panic!("Result root should be concrete"),
36    };
37
38    let vox_error_extracted =
39        extract_schemas(<VoxError<std::convert::Infallible> as Facet<'static>>::SHAPE)
40            .expect("VoxError schema");
41    let vox_error_type_id = match &vox_error_extracted.root {
42        TypeRef::Concrete { type_id, .. } => *type_id,
43        _ => panic!("VoxError root should be concrete"),
44    };
45
46    // Collect all schemas across all methods into a global registry
47    let mut global_schemas: HashMap<SchemaHash, Schema> = HashMap::new();
48
49    // Add Result and VoxError schemas
50    for schema in result_extracted.schemas.iter() {
51        global_schemas.insert(schema.id, schema.clone());
52    }
53    for schema in vox_error_extracted.schemas.iter() {
54        global_schemas.insert(schema.id, schema.clone());
55    }
56
57    // Per-method info: (args_schema_ids, args_root, response_schema_ids, response_root)
58    struct MethodSchemaInfo {
59        args_schema_ids: Vec<SchemaHash>,
60        args_root: TypeRef,
61        response_schema_ids: Vec<SchemaHash>,
62        response_root: TypeRef,
63    }
64    let mut method_infos: Vec<(u64, MethodSchemaInfo)> = Vec::new();
65
66    for method in service.methods {
67        let method_id = crate::method_id(method);
68
69        // Extract args schemas
70        let args_extracted = extract_schemas(method.args_shape).expect("args schema extraction");
71        let args_schema_ids: Vec<SchemaHash> =
72            args_extracted.schemas.iter().map(|s| s.id).collect();
73        for schema in args_extracted.schemas.iter().cloned() {
74            global_schemas.insert(schema.id, schema);
75        }
76
77        // Extract response schemas - wrap in Result<T, VoxError<E>>
78        let (ok_extracted, err_extracted) = match classify_shape(method.return_shape) {
79            ShapeKind::Result { ok, err } => (
80                extract_schemas(ok).expect("ok schema"),
81                extract_schemas(err).expect("err schema"),
82            ),
83            _ => (
84                extract_schemas(method.return_shape).expect("return schema"),
85                extract_schemas(<std::convert::Infallible as Facet<'static>>::SHAPE)
86                    .expect("Infallible schema"),
87            ),
88        };
89
90        // Collect response schema IDs (including Result and VoxError)
91        let mut response_schema_ids: Vec<SchemaHash> = Vec::new();
92        for schema in result_extracted.schemas.iter() {
93            response_schema_ids.push(schema.id);
94        }
95        for schema in vox_error_extracted.schemas.iter() {
96            response_schema_ids.push(schema.id);
97        }
98        for schema in ok_extracted.schemas.iter().cloned() {
99            response_schema_ids.push(schema.id);
100            global_schemas.insert(schema.id, schema);
101        }
102        for schema in err_extracted.schemas.iter().cloned() {
103            response_schema_ids.push(schema.id);
104            global_schemas.insert(schema.id, schema);
105        }
106
107        // Deduplicate schema IDs (smaller codegen output)
108        let mut seen = std::collections::HashSet::new();
109        response_schema_ids.retain(|id| seen.insert(*id));
110
111        // Build the response root: Result<ok_root, VoxError<err_root>>
112        let vox_error_ref = TypeRef::generic(vox_error_type_id, vec![err_extracted.root.clone()]);
113        let response_root = TypeRef::generic(
114            result_type_id,
115            vec![ok_extracted.root.clone(), vox_error_ref],
116        );
117
118        method_infos.push((
119            method_id,
120            MethodSchemaInfo {
121                args_schema_ids,
122                args_root: args_extracted.root.clone(),
123                response_schema_ids,
124                response_root,
125            },
126        ));
127    }
128
129    let mut out = String::new();
130
131    // Generate global schema registry
132    out.push_str("/// Global schema registry containing all schemas for this service.\n");
133    out.push_str(&format!(
134        "nonisolated(unsafe) public let {service_name}_schema_registry: [UInt64: Schema] = [\n"
135    ));
136
137    let mut sorted_schemas: Vec<_> = global_schemas.into_iter().collect();
138    sorted_schemas.sort_by_key(|(id, _)| *id);
139
140    for (schema_id, schema) in &sorted_schemas {
141        out.push_str(&format!(
142            "    {}: {},\n",
143            hex_u64(schema_id.0),
144            format_swift_schema(schema)
145        ));
146    }
147    out.push_str("]\n\n");
148
149    // Generate per-method schema info
150    out.push_str("/// Per-method schema information for wire protocol.\n");
151    out.push_str(&format!(
152        "nonisolated(unsafe) public let {service_name}_method_schemas: [UInt64: MethodSchemaInfo] = [\n"
153    ));
154
155    for (method_id, info) in &method_infos {
156        out.push_str(&format!("    {}: MethodSchemaInfo(\n", hex_u64(*method_id)));
157        out.push_str(&format!(
158            "        argsSchemaIds: [{}],\n",
159            info.args_schema_ids
160                .iter()
161                .map(|id| hex_u64(id.0))
162                .collect::<Vec<_>>()
163                .join(", ")
164        ));
165        out.push_str(&format!(
166            "        argsRoot: {},\n",
167            format_swift_type_ref(&info.args_root)
168        ));
169        out.push_str(&format!(
170            "        responseSchemaIds: [{}],\n",
171            info.response_schema_ids
172                .iter()
173                .map(|id| hex_u64(id.0))
174                .collect::<Vec<_>>()
175                .join(", ")
176        ));
177        out.push_str(&format!(
178            "        responseRoot: {}\n",
179            format_swift_type_ref(&info.response_root)
180        ));
181        out.push_str("    ),\n");
182    }
183    out.push_str("]\n\n");
184
185    out
186}
187
188/// Format a Schema as Swift code.
189fn format_swift_schema(schema: &vox_types::Schema) -> String {
190    use crate::render::hex_u64;
191
192    let type_params = if schema.type_params.is_empty() {
193        "[]".to_string()
194    } else {
195        format!(
196            "[{}]",
197            schema
198                .type_params
199                .iter()
200                .map(|p| format!("\"{}\"", p.as_str()))
201                .collect::<Vec<_>>()
202                .join(", ")
203        )
204    };
205
206    format!(
207        "Schema(id: {}, typeParams: {}, kind: {})",
208        hex_u64(schema.id.0),
209        type_params,
210        format_swift_schema_kind(&schema.kind)
211    )
212}
213
214/// Format a SchemaKind as Swift code.
215fn format_swift_schema_kind(kind: &vox_types::SchemaKind) -> String {
216    use vox_types::SchemaKind;
217
218    match kind {
219        SchemaKind::Struct { name, fields } => {
220            let fields_str = fields
221                .iter()
222                .map(|f| {
223                    format!(
224                        "FieldSchema(name: \"{}\", typeRef: {}, required: {})",
225                        f.name,
226                        format_swift_type_ref(&f.type_ref),
227                        f.required
228                    )
229                })
230                .collect::<Vec<_>>()
231                .join(", ");
232            format!(".struct(name: \"{}\", fields: [{}])", name, fields_str)
233        }
234        SchemaKind::Enum { name, variants } => {
235            let variants_str = variants
236                .iter()
237                .map(|v| {
238                    format!(
239                        "VariantSchema(name: \"{}\", index: {}, payload: {})",
240                        v.name,
241                        v.index,
242                        format_swift_variant_payload(&v.payload)
243                    )
244                })
245                .collect::<Vec<_>>()
246                .join(", ");
247            format!(".enum(name: \"{}\", variants: [{}])", name, variants_str)
248        }
249        SchemaKind::Tuple { elements } => {
250            let elems_str = elements
251                .iter()
252                .map(format_swift_type_ref)
253                .collect::<Vec<_>>()
254                .join(", ");
255            format!(".tuple(elements: [{}])", elems_str)
256        }
257        SchemaKind::List { element } => {
258            format!(".list(element: {})", format_swift_type_ref(element))
259        }
260        SchemaKind::Map { key, value } => {
261            format!(
262                ".map(key: {}, value: {})",
263                format_swift_type_ref(key),
264                format_swift_type_ref(value)
265            )
266        }
267        SchemaKind::Array { element, length } => {
268            format!(
269                ".array(element: {}, length: {})",
270                format_swift_type_ref(element),
271                length
272            )
273        }
274        SchemaKind::Option { element } => {
275            format!(".option(element: {})", format_swift_type_ref(element))
276        }
277        SchemaKind::Channel { direction, element } => {
278            let dir = match direction {
279                vox_types::ChannelDirection::Tx => ".tx",
280                vox_types::ChannelDirection::Rx => ".rx",
281            };
282            format!(
283                ".channel(direction: {}, element: {})",
284                dir,
285                format_swift_type_ref(element)
286            )
287        }
288        SchemaKind::Primitive { primitive_type } => {
289            format!(".primitive({})", format_swift_primitive(primitive_type))
290        }
291    }
292}
293
294/// Format a VariantPayload as Swift code.
295fn format_swift_variant_payload(payload: &vox_types::VariantPayload) -> String {
296    use vox_types::VariantPayload;
297
298    match payload {
299        VariantPayload::Unit => ".unit".to_string(),
300        VariantPayload::Newtype { type_ref } => {
301            format!(".newtype(typeRef: {})", format_swift_type_ref(type_ref))
302        }
303        VariantPayload::Tuple { types } => {
304            let types_str = types
305                .iter()
306                .map(format_swift_type_ref)
307                .collect::<Vec<_>>()
308                .join(", ");
309            format!(".tuple(types: [{}])", types_str)
310        }
311        VariantPayload::Struct { fields } => {
312            let fields_str = fields
313                .iter()
314                .map(|f| {
315                    format!(
316                        "FieldSchema(name: \"{}\", typeRef: {}, required: {})",
317                        f.name,
318                        format_swift_type_ref(&f.type_ref),
319                        f.required
320                    )
321                })
322                .collect::<Vec<_>>()
323                .join(", ");
324            format!(".struct(fields: [{}])", fields_str)
325        }
326    }
327}
328
329/// Format a TypeRef as Swift code.
330fn format_swift_type_ref(type_ref: &TypeRef) -> String {
331    use crate::render::hex_u64;
332
333    match type_ref {
334        TypeRef::Concrete { type_id, args } => {
335            if args.is_empty() {
336                format!(".concrete({})", hex_u64(type_id.0))
337            } else {
338                let args_str = args
339                    .iter()
340                    .map(format_swift_type_ref)
341                    .collect::<Vec<_>>()
342                    .join(", ");
343                format!(".generic({}, args: [{}])", hex_u64(type_id.0), args_str)
344            }
345        }
346        TypeRef::Var { name } => {
347            format!(".var(name: \"{}\")", name.as_str())
348        }
349    }
350}
351
352/// Format a PrimitiveType as Swift code.
353fn format_swift_primitive(prim: &vox_types::PrimitiveType) -> String {
354    use vox_types::PrimitiveType;
355
356    match prim {
357        PrimitiveType::Bool => ".bool",
358        PrimitiveType::U8 => ".u8",
359        PrimitiveType::U16 => ".u16",
360        PrimitiveType::U32 => ".u32",
361        PrimitiveType::U64 => ".u64",
362        PrimitiveType::U128 => ".u128",
363        PrimitiveType::I8 => ".i8",
364        PrimitiveType::I16 => ".i16",
365        PrimitiveType::I32 => ".i32",
366        PrimitiveType::I64 => ".i64",
367        PrimitiveType::I128 => ".i128",
368        PrimitiveType::F32 => ".f32",
369        PrimitiveType::F64 => ".f64",
370        PrimitiveType::Char => ".char",
371        PrimitiveType::String => ".string",
372        PrimitiveType::Unit => ".unit",
373        PrimitiveType::Never => ".never",
374        PrimitiveType::Bytes => ".bytes",
375        PrimitiveType::Payload => ".payload",
376    }
377    .to_string()
378}