1use facet_core::{Facet, ScalarType, Shape};
6use heck::{ToLowerCamelCase, ToUpperCamelCase};
7use vox_types::{
8 EnumInfo, ServiceDescriptor, ShapeKind, StructInfo, TypeRef, VariantKind, VoxError,
9 classify_shape, classify_variant, extract_schemas, is_bytes,
10};
11
12use crate::code_writer::CodeWriter;
13use crate::cw_writeln;
14
15pub fn generate_schemas(service: &ServiceDescriptor) -> String {
17 let mut out = String::new();
18 out.push_str(&generate_method_schemas(service));
19 out.push_str(&generate_wire_schemas(service));
20 out.push_str(&generate_serializers(service));
21 out
22}
23
24fn generate_method_schemas(service: &ServiceDescriptor) -> String {
26 let mut out = String::new();
27 let service_name = service.service_name.to_lower_camel_case();
28
29 out.push_str(&format!(
30 "public let {service_name}_schemas: [String: MethodBindingSchema] = [\n"
31 ));
32
33 for method in service.methods {
34 let method_name = method.method_name.to_lower_camel_case();
35 out.push_str(&format!(
36 " \"{method_name}\": MethodBindingSchema(args: ["
37 ));
38
39 let schemas: Vec<String> = method
40 .args
41 .iter()
42 .map(|a| shape_to_schema(a.shape))
43 .collect();
44 out.push_str(&schemas.join(", "));
45
46 out.push_str("]),\n");
47 }
48
49 out.push_str("]\n\n");
50 out
51}
52
53fn generate_wire_schemas(service: &ServiceDescriptor) -> String {
61 use crate::render::hex_u64;
62 use std::collections::HashMap;
63 use vox_types::{Schema, SchemaHash};
64
65 let service_name = service.service_name.to_lower_camel_case();
66
67 let result_extracted =
69 extract_schemas(<Result<bool, u32> as Facet<'static>>::SHAPE).expect("Result schema");
70 let result_type_id = match &result_extracted.root {
71 TypeRef::Concrete { type_id, .. } => *type_id,
72 _ => panic!("Result root should be concrete"),
73 };
74
75 let vox_error_extracted =
76 extract_schemas(<VoxError<std::convert::Infallible> as Facet<'static>>::SHAPE)
77 .expect("VoxError schema");
78 let vox_error_type_id = match &vox_error_extracted.root {
79 TypeRef::Concrete { type_id, .. } => *type_id,
80 _ => panic!("VoxError root should be concrete"),
81 };
82
83 let mut global_schemas: HashMap<SchemaHash, Schema> = HashMap::new();
85
86 for schema in result_extracted.schemas.iter() {
88 global_schemas.insert(schema.id, schema.clone());
89 }
90 for schema in vox_error_extracted.schemas.iter() {
91 global_schemas.insert(schema.id, schema.clone());
92 }
93
94 struct MethodSchemaInfo {
96 args_schema_ids: Vec<SchemaHash>,
97 args_root: TypeRef,
98 response_schema_ids: Vec<SchemaHash>,
99 response_root: TypeRef,
100 }
101 let mut method_infos: Vec<(u64, MethodSchemaInfo)> = Vec::new();
102
103 for method in service.methods {
104 let method_id = crate::method_id(method);
105
106 let args_extracted = extract_schemas(method.args_shape).expect("args schema extraction");
108 let args_schema_ids: Vec<SchemaHash> =
109 args_extracted.schemas.iter().map(|s| s.id).collect();
110 for schema in args_extracted.schemas.iter().cloned() {
111 global_schemas.insert(schema.id, schema);
112 }
113
114 let (ok_extracted, err_extracted) = match classify_shape(method.return_shape) {
116 ShapeKind::Result { ok, err } => (
117 extract_schemas(ok).expect("ok schema"),
118 extract_schemas(err).expect("err schema"),
119 ),
120 _ => (
121 extract_schemas(method.return_shape).expect("return schema"),
122 extract_schemas(<std::convert::Infallible as Facet<'static>>::SHAPE)
123 .expect("Infallible schema"),
124 ),
125 };
126
127 let mut response_schema_ids: Vec<SchemaHash> = Vec::new();
129 for schema in result_extracted.schemas.iter() {
130 response_schema_ids.push(schema.id);
131 }
132 for schema in vox_error_extracted.schemas.iter() {
133 response_schema_ids.push(schema.id);
134 }
135 for schema in ok_extracted.schemas.iter().cloned() {
136 response_schema_ids.push(schema.id);
137 global_schemas.insert(schema.id, schema);
138 }
139 for schema in err_extracted.schemas.iter().cloned() {
140 response_schema_ids.push(schema.id);
141 global_schemas.insert(schema.id, schema);
142 }
143
144 let mut seen = std::collections::HashSet::new();
146 response_schema_ids.retain(|id| seen.insert(*id));
147
148 let vox_error_ref = TypeRef::generic(vox_error_type_id, vec![err_extracted.root.clone()]);
150 let response_root = TypeRef::generic(
151 result_type_id,
152 vec![ok_extracted.root.clone(), vox_error_ref],
153 );
154
155 method_infos.push((
156 method_id,
157 MethodSchemaInfo {
158 args_schema_ids,
159 args_root: args_extracted.root.clone(),
160 response_schema_ids,
161 response_root,
162 },
163 ));
164 }
165
166 let mut out = String::new();
167
168 out.push_str("/// Global schema registry containing all schemas for this service.\n");
170 out.push_str(&format!(
171 "nonisolated(unsafe) public let {service_name}_schema_registry: [UInt64: Schema] = [\n"
172 ));
173
174 let mut sorted_schemas: Vec<_> = global_schemas.into_iter().collect();
175 sorted_schemas.sort_by_key(|(id, _)| *id);
176
177 for (schema_id, schema) in &sorted_schemas {
178 out.push_str(&format!(
179 " {}: {},\n",
180 hex_u64(schema_id.0),
181 format_swift_schema(schema)
182 ));
183 }
184 out.push_str("]\n\n");
185
186 out.push_str("/// Per-method schema information for wire protocol.\n");
188 out.push_str(&format!(
189 "nonisolated(unsafe) public let {service_name}_method_schemas: [UInt64: MethodSchemaInfo] = [\n"
190 ));
191
192 for (method_id, info) in &method_infos {
193 out.push_str(&format!(" {}: MethodSchemaInfo(\n", hex_u64(*method_id)));
194 out.push_str(&format!(
195 " argsSchemaIds: [{}],\n",
196 info.args_schema_ids
197 .iter()
198 .map(|id| hex_u64(id.0))
199 .collect::<Vec<_>>()
200 .join(", ")
201 ));
202 out.push_str(&format!(
203 " argsRoot: {},\n",
204 format_swift_type_ref(&info.args_root)
205 ));
206 out.push_str(&format!(
207 " responseSchemaIds: [{}],\n",
208 info.response_schema_ids
209 .iter()
210 .map(|id| hex_u64(id.0))
211 .collect::<Vec<_>>()
212 .join(", ")
213 ));
214 out.push_str(&format!(
215 " responseRoot: {}\n",
216 format_swift_type_ref(&info.response_root)
217 ));
218 out.push_str(" ),\n");
219 }
220 out.push_str("]\n\n");
221
222 out
223}
224
225fn format_swift_schema(schema: &vox_types::Schema) -> String {
227 use crate::render::hex_u64;
228
229 let type_params = if schema.type_params.is_empty() {
230 "[]".to_string()
231 } else {
232 format!(
233 "[{}]",
234 schema
235 .type_params
236 .iter()
237 .map(|p| format!("\"{}\"", p.as_str()))
238 .collect::<Vec<_>>()
239 .join(", ")
240 )
241 };
242
243 format!(
244 "Schema(id: {}, typeParams: {}, kind: {})",
245 hex_u64(schema.id.0),
246 type_params,
247 format_swift_schema_kind(&schema.kind)
248 )
249}
250
251fn format_swift_schema_kind(kind: &vox_types::SchemaKind) -> String {
253 use vox_types::SchemaKind;
254
255 match kind {
256 SchemaKind::Struct { name, fields } => {
257 let fields_str = fields
258 .iter()
259 .map(|f| {
260 format!(
261 "FieldSchema(name: \"{}\", typeRef: {}, required: {})",
262 f.name,
263 format_swift_type_ref(&f.type_ref),
264 f.required
265 )
266 })
267 .collect::<Vec<_>>()
268 .join(", ");
269 format!(".struct(name: \"{}\", fields: [{}])", name, fields_str)
270 }
271 SchemaKind::Enum { name, variants } => {
272 let variants_str = variants
273 .iter()
274 .map(|v| {
275 format!(
276 "VariantSchema(name: \"{}\", index: {}, payload: {})",
277 v.name,
278 v.index,
279 format_swift_variant_payload(&v.payload)
280 )
281 })
282 .collect::<Vec<_>>()
283 .join(", ");
284 format!(".enum(name: \"{}\", variants: [{}])", name, variants_str)
285 }
286 SchemaKind::Tuple { elements } => {
287 let elems_str = elements
288 .iter()
289 .map(format_swift_type_ref)
290 .collect::<Vec<_>>()
291 .join(", ");
292 format!(".tuple(elements: [{}])", elems_str)
293 }
294 SchemaKind::List { element } => {
295 format!(".list(element: {})", format_swift_type_ref(element))
296 }
297 SchemaKind::Map { key, value } => {
298 format!(
299 ".map(key: {}, value: {})",
300 format_swift_type_ref(key),
301 format_swift_type_ref(value)
302 )
303 }
304 SchemaKind::Array { element, length } => {
305 format!(
306 ".array(element: {}, length: {})",
307 format_swift_type_ref(element),
308 length
309 )
310 }
311 SchemaKind::Option { element } => {
312 format!(".option(element: {})", format_swift_type_ref(element))
313 }
314 SchemaKind::Channel { direction, element } => {
315 let dir = match direction {
316 vox_types::ChannelDirection::Tx => ".tx",
317 vox_types::ChannelDirection::Rx => ".rx",
318 };
319 format!(
320 ".channel(direction: {}, element: {})",
321 dir,
322 format_swift_type_ref(element)
323 )
324 }
325 SchemaKind::Primitive { primitive_type } => {
326 format!(".primitive({})", format_swift_primitive(primitive_type))
327 }
328 }
329}
330
331fn format_swift_variant_payload(payload: &vox_types::VariantPayload) -> String {
333 use vox_types::VariantPayload;
334
335 match payload {
336 VariantPayload::Unit => ".unit".to_string(),
337 VariantPayload::Newtype { type_ref } => {
338 format!(".newtype(typeRef: {})", format_swift_type_ref(type_ref))
339 }
340 VariantPayload::Tuple { types } => {
341 let types_str = types
342 .iter()
343 .map(format_swift_type_ref)
344 .collect::<Vec<_>>()
345 .join(", ");
346 format!(".tuple(types: [{}])", types_str)
347 }
348 VariantPayload::Struct { fields } => {
349 let fields_str = fields
350 .iter()
351 .map(|f| {
352 format!(
353 "FieldSchema(name: \"{}\", typeRef: {}, required: {})",
354 f.name,
355 format_swift_type_ref(&f.type_ref),
356 f.required
357 )
358 })
359 .collect::<Vec<_>>()
360 .join(", ");
361 format!(".struct(fields: [{}])", fields_str)
362 }
363 }
364}
365
366fn format_swift_type_ref(type_ref: &TypeRef) -> String {
368 use crate::render::hex_u64;
369
370 match type_ref {
371 TypeRef::Concrete { type_id, args } => {
372 if args.is_empty() {
373 format!(".concrete({})", hex_u64(type_id.0))
374 } else {
375 let args_str = args
376 .iter()
377 .map(format_swift_type_ref)
378 .collect::<Vec<_>>()
379 .join(", ");
380 format!(".generic({}, args: [{}])", hex_u64(type_id.0), args_str)
381 }
382 }
383 TypeRef::Var { name } => {
384 format!(".var(name: \"{}\")", name.as_str())
385 }
386 }
387}
388
389fn format_swift_primitive(prim: &vox_types::PrimitiveType) -> String {
391 use vox_types::PrimitiveType;
392
393 match prim {
394 PrimitiveType::Bool => ".bool",
395 PrimitiveType::U8 => ".u8",
396 PrimitiveType::U16 => ".u16",
397 PrimitiveType::U32 => ".u32",
398 PrimitiveType::U64 => ".u64",
399 PrimitiveType::U128 => ".u128",
400 PrimitiveType::I8 => ".i8",
401 PrimitiveType::I16 => ".i16",
402 PrimitiveType::I32 => ".i32",
403 PrimitiveType::I64 => ".i64",
404 PrimitiveType::I128 => ".i128",
405 PrimitiveType::F32 => ".f32",
406 PrimitiveType::F64 => ".f64",
407 PrimitiveType::Char => ".char",
408 PrimitiveType::String => ".string",
409 PrimitiveType::Unit => ".unit",
410 PrimitiveType::Never => ".never",
411 PrimitiveType::Bytes => ".bytes",
412 PrimitiveType::Payload => ".payload",
413 }
414 .to_string()
415}
416
417fn shape_to_schema(shape: &'static Shape) -> String {
419 if is_bytes(shape) {
420 return ".bytes".into();
421 }
422
423 match classify_shape(shape) {
424 ShapeKind::Scalar(scalar) => match scalar {
425 ScalarType::Bool => ".bool".into(),
426 ScalarType::U8 => ".u8".into(),
427 ScalarType::U16 => ".u16".into(),
428 ScalarType::U32 => ".u32".into(),
429 ScalarType::U64 => ".u64".into(),
430 ScalarType::I8 => ".i8".into(),
431 ScalarType::I16 => ".i16".into(),
432 ScalarType::I32 => ".i32".into(),
433 ScalarType::I64 => ".i64".into(),
434 ScalarType::F32 => ".f32".into(),
435 ScalarType::F64 => ".f64".into(),
436 ScalarType::Str | ScalarType::CowStr | ScalarType::String => ".string".into(),
437 ScalarType::Unit => ".tuple(elements: [])".into(),
438 _ => ".bytes".into(), },
440 ShapeKind::List { element } | ShapeKind::Slice { element } => {
441 format!(".vec(element: {})", shape_to_schema(element))
442 }
443 ShapeKind::Option { inner } => {
444 format!(".option(inner: {})", shape_to_schema(inner))
445 }
446 ShapeKind::Map { key, value } => {
447 format!(
448 ".map(key: {}, value: {})",
449 shape_to_schema(key),
450 shape_to_schema(value)
451 )
452 }
453 ShapeKind::Tx { inner } => format!(".tx(element: {})", shape_to_schema(inner)),
454 ShapeKind::Rx { inner } => format!(".rx(element: {})", shape_to_schema(inner)),
455 ShapeKind::Tuple { elements } => {
456 let inner: Vec<String> = elements.iter().map(|p| shape_to_schema(p.shape)).collect();
457 format!(".tuple(elements: [{}])", inner.join(", "))
458 }
459 ShapeKind::Struct(StructInfo { fields, .. }) => {
460 let field_strs: Vec<String> = fields
461 .iter()
462 .map(|f| format!("(\"{}\", {})", f.name, shape_to_schema(f.shape())))
463 .collect();
464 format!(".struct(fields: [{}])", field_strs.join(", "))
465 }
466 ShapeKind::Enum(EnumInfo { variants, .. }) => {
467 let variant_strs: Vec<String> = variants
468 .iter()
469 .map(|v| {
470 let fields: Vec<String> = match classify_variant(v) {
471 VariantKind::Unit => vec![],
472 VariantKind::Newtype { inner } => vec![shape_to_schema(inner)],
473 VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
474 fields.iter().map(|f| shape_to_schema(f.shape())).collect()
475 }
476 };
477 format!("(\"{}\", [{}])", v.name, fields.join(", "))
478 })
479 .collect();
480 format!(".enum(variants: [{}])", variant_strs.join(", "))
481 }
482 _ => ".bytes".into(), }
484}
485
486fn generate_serializers(service: &ServiceDescriptor) -> String {
488 let mut out = String::new();
489 let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
490 let service_name_upper = service.service_name.to_upper_camel_case();
491
492 cw_writeln!(
493 w,
494 "public struct {service_name_upper}Serializers: BindingSerializers {{"
495 )
496 .unwrap();
497 {
498 let _indent = w.indent();
499 w.writeln("public init() {}").unwrap();
500 w.blank_line().unwrap();
501
502 w.writeln(
504 "public func txSerializer(for schema: BindingSchema) -> @Sendable (Any) -> [UInt8] {",
505 )
506 .unwrap();
507 {
508 let _indent = w.indent();
509 w.writeln("switch schema {").unwrap();
510 w.writeln("case .bool: return { var b = ByteBufferAllocator().buffer(capacity: 1); encodeBool($0 as! Bool, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
511 w.writeln("case .u8: return { var b = ByteBufferAllocator().buffer(capacity: 1); encodeU8($0 as! UInt8, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
512 w.writeln("case .i8: return { var b = ByteBufferAllocator().buffer(capacity: 1); encodeI8($0 as! Int8, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
513 w.writeln("case .u16: return { var b = ByteBufferAllocator().buffer(capacity: 2); encodeU16($0 as! UInt16, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
514 w.writeln("case .i16: return { var b = ByteBufferAllocator().buffer(capacity: 2); encodeI16($0 as! Int16, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
515 w.writeln("case .u32: return { var b = ByteBufferAllocator().buffer(capacity: 4); encodeU32($0 as! UInt32, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
516 w.writeln("case .i32: return { var b = ByteBufferAllocator().buffer(capacity: 4); encodeI32($0 as! Int32, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
517 w.writeln("case .u64: return { var b = ByteBufferAllocator().buffer(capacity: 9); encodeVarint($0 as! UInt64, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
518 w.writeln("case .i64: return { var b = ByteBufferAllocator().buffer(capacity: 9); encodeI64($0 as! Int64, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
519 w.writeln("case .f32: return { var b = ByteBufferAllocator().buffer(capacity: 4); encodeF32($0 as! Float, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
520 w.writeln("case .f64: return { var b = ByteBufferAllocator().buffer(capacity: 8); encodeF64($0 as! Double, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
521 w.writeln("case .string: return { var b = ByteBufferAllocator().buffer(capacity: 64); encodeString($0 as! String, into: &b); return b.readBytes(length: b.readableBytes) ?? [] }").unwrap();
522 w.writeln("case .bytes: return { [UInt8]($0 as! Data) }")
523 .unwrap();
524 w.writeln(
525 "case .tx(_, _), .rx(_, _): fatalError(\"Channel schemas are not serialized directly\")",
526 )
527 .unwrap();
528 w.writeln(
529 "default: fatalError(\"Unsupported schema for Tx serialization: \\(schema)\")",
530 )
531 .unwrap();
532 w.writeln("}").unwrap();
533 }
534 w.writeln("}").unwrap();
535 w.blank_line().unwrap();
536
537 w.writeln(
539 "public func rxDeserializer(for schema: BindingSchema) -> @Sendable ([UInt8]) throws -> Any {",
540 )
541 .unwrap();
542 {
543 let _indent = w.indent();
544 w.writeln("switch schema {").unwrap();
545 w.writeln("case .bool: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeBool(from: &b) }").unwrap();
546 w.writeln("case .u8: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeU8(from: &b) }").unwrap();
547 w.writeln("case .i8: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeI8(from: &b) }").unwrap();
548 w.writeln("case .u16: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeU16(from: &b) }").unwrap();
549 w.writeln("case .i16: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeI16(from: &b) }").unwrap();
550 w.writeln("case .u32: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeU32(from: &b) }").unwrap();
551 w.writeln("case .i32: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeI32(from: &b) }").unwrap();
552 w.writeln("case .u64: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeVarint(from: &b) }").unwrap();
553 w.writeln("case .i64: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeI64(from: &b) }").unwrap();
554 w.writeln("case .f32: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeF32(from: &b) }").unwrap();
555 w.writeln("case .f64: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeF64(from: &b) }").unwrap();
556 w.writeln("case .string: return { var b = ByteBufferAllocator().buffer(bytes: $0); return try decodeString(from: &b) }").unwrap();
557 w.writeln("case .bytes: return { Data($0) }").unwrap();
558 w.writeln(
559 "case .tx(_, _), .rx(_, _): fatalError(\"Channel schemas are not deserialized directly\")",
560 )
561 .unwrap();
562 w.writeln(
563 "default: fatalError(\"Unsupported schema for Rx deserialization: \\(schema)\")",
564 )
565 .unwrap();
566 w.writeln("}").unwrap();
567 }
568 w.writeln("}").unwrap();
569 }
570 w.writeln("}").unwrap();
571 w.blank_line().unwrap();
572
573 out
574}