1use facet_core::{Facet, ScalarType, Shape};
6use heck::{ToLowerCamelCase, ToUpperCamelCase};
7use vox_types::{
8 EnumInfo, ServiceDescriptor, ShapeKind, StructInfo, TypeRef, VariantKind, VoxError,
9 classify_shape, classify_variant, extract_schemas, is_bytes,
10};
11
12use crate::code_writer::CodeWriter;
13use crate::cw_writeln;
14
15pub fn generate_schemas(service: &ServiceDescriptor) -> String {
17 let mut out = String::new();
18 out.push_str(&generate_method_schemas(service));
19 out.push_str(&generate_wire_schemas(service));
20 out.push_str(&generate_serializers(service));
21 out
22}
23
24fn generate_method_schemas(service: &ServiceDescriptor) -> String {
26 let mut out = String::new();
27 let service_name = service.service_name.to_lower_camel_case();
28
29 out.push_str(&format!(
30 "public let {service_name}_schemas: [String: MethodBindingSchema] = [\n"
31 ));
32
33 for method in service.methods {
34 let method_name = method.method_name.to_lower_camel_case();
35 out.push_str(&format!(
36 " \"{method_name}\": MethodBindingSchema(args: ["
37 ));
38
39 let schemas: Vec<String> = method
40 .args
41 .iter()
42 .map(|a| shape_to_schema(a.shape))
43 .collect();
44 out.push_str(&schemas.join(", "));
45
46 out.push_str("]),\n");
47 }
48
49 out.push_str("]\n\n");
50 out
51}
52
53fn generate_wire_schemas(service: &ServiceDescriptor) -> String {
61 use crate::render::hex_u64;
62 use std::collections::HashMap;
63 use vox_types::{Schema, SchemaHash};
64
65 let service_name = service.service_name.to_lower_camel_case();
66
67 let result_extracted =
69 extract_schemas(<Result<bool, u32> as Facet<'static>>::SHAPE).expect("Result schema");
70 let result_type_id = match &result_extracted.root {
71 TypeRef::Concrete { type_id, .. } => *type_id,
72 _ => panic!("Result root should be concrete"),
73 };
74
75 let vox_error_extracted =
76 extract_schemas(<VoxError<std::convert::Infallible> as Facet<'static>>::SHAPE)
77 .expect("VoxError schema");
78 let vox_error_type_id = match &vox_error_extracted.root {
79 TypeRef::Concrete { type_id, .. } => *type_id,
80 _ => panic!("VoxError root should be concrete"),
81 };
82
83 let mut global_schemas: HashMap<SchemaHash, Schema> = HashMap::new();
85
86 for schema in result_extracted.schemas.iter() {
88 global_schemas.insert(schema.id, schema.clone());
89 }
90 for schema in vox_error_extracted.schemas.iter() {
91 global_schemas.insert(schema.id, schema.clone());
92 }
93
94 struct MethodSchemaInfo {
96 args_schema_ids: Vec<SchemaHash>,
97 args_root: TypeRef,
98 response_schema_ids: Vec<SchemaHash>,
99 response_root: TypeRef,
100 }
101 let mut method_infos: Vec<(u64, MethodSchemaInfo)> = Vec::new();
102
103 for method in service.methods {
104 let method_id = crate::method_id(method);
105
106 let args_extracted = extract_schemas(method.args_shape).expect("args schema extraction");
108 let args_schema_ids: Vec<SchemaHash> =
109 args_extracted.schemas.iter().map(|s| s.id).collect();
110 for schema in args_extracted.schemas {
111 global_schemas.insert(schema.id, schema);
112 }
113
114 let (ok_extracted, err_extracted) = match classify_shape(method.return_shape) {
116 ShapeKind::Result { ok, err } => (
117 extract_schemas(ok).expect("ok schema"),
118 extract_schemas(err).expect("err schema"),
119 ),
120 _ => (
121 extract_schemas(method.return_shape).expect("return schema"),
122 extract_schemas(<std::convert::Infallible as Facet<'static>>::SHAPE)
123 .expect("Infallible schema"),
124 ),
125 };
126
127 let mut response_schema_ids: Vec<SchemaHash> = Vec::new();
129 for schema in result_extracted.schemas.iter() {
130 response_schema_ids.push(schema.id);
131 }
132 for schema in vox_error_extracted.schemas.iter() {
133 response_schema_ids.push(schema.id);
134 }
135 for schema in ok_extracted.schemas {
136 response_schema_ids.push(schema.id);
137 global_schemas.insert(schema.id, schema);
138 }
139 for schema in err_extracted.schemas {
140 response_schema_ids.push(schema.id);
141 global_schemas.insert(schema.id, schema);
142 }
143
144 let mut seen = std::collections::HashSet::new();
146 response_schema_ids.retain(|id| seen.insert(*id));
147
148 let vox_error_ref = TypeRef::generic(vox_error_type_id, vec![err_extracted.root]);
150 let response_root =
151 TypeRef::generic(result_type_id, vec![ok_extracted.root, vox_error_ref]);
152
153 method_infos.push((
154 method_id,
155 MethodSchemaInfo {
156 args_schema_ids,
157 args_root: args_extracted.root,
158 response_schema_ids,
159 response_root,
160 },
161 ));
162 }
163
164 let mut out = String::new();
165
166 out.push_str("/// Global schema registry containing all schemas for this service.\n");
168 out.push_str(&format!(
169 "public let {service_name}_schema_registry: [UInt64: Schema] = [\n"
170 ));
171
172 let mut sorted_schemas: Vec<_> = global_schemas.into_iter().collect();
173 sorted_schemas.sort_by_key(|(id, _)| *id);
174
175 for (schema_id, schema) in &sorted_schemas {
176 out.push_str(&format!(
177 " {}: {},\n",
178 hex_u64(schema_id.0),
179 format_swift_schema(schema)
180 ));
181 }
182 out.push_str("]\n\n");
183
184 out.push_str("/// Per-method schema information for wire protocol.\n");
186 out.push_str(&format!(
187 "public let {service_name}_method_schemas: [UInt64: MethodSchemaInfo] = [\n"
188 ));
189
190 for (method_id, info) in &method_infos {
191 out.push_str(&format!(" {}: MethodSchemaInfo(\n", hex_u64(*method_id)));
192 out.push_str(&format!(
193 " argsSchemaIds: [{}],\n",
194 info.args_schema_ids
195 .iter()
196 .map(|id| hex_u64(id.0))
197 .collect::<Vec<_>>()
198 .join(", ")
199 ));
200 out.push_str(&format!(
201 " argsRoot: {},\n",
202 format_swift_type_ref(&info.args_root)
203 ));
204 out.push_str(&format!(
205 " responseSchemaIds: [{}],\n",
206 info.response_schema_ids
207 .iter()
208 .map(|id| hex_u64(id.0))
209 .collect::<Vec<_>>()
210 .join(", ")
211 ));
212 out.push_str(&format!(
213 " responseRoot: {}\n",
214 format_swift_type_ref(&info.response_root)
215 ));
216 out.push_str(" ),\n");
217 }
218 out.push_str("]\n\n");
219
220 out
221}
222
223fn format_swift_schema(schema: &vox_types::Schema) -> String {
225 use crate::render::hex_u64;
226
227 let type_params = if schema.type_params.is_empty() {
228 "[]".to_string()
229 } else {
230 format!(
231 "[{}]",
232 schema
233 .type_params
234 .iter()
235 .map(|p| format!("\"{}\"", p.as_str()))
236 .collect::<Vec<_>>()
237 .join(", ")
238 )
239 };
240
241 format!(
242 "Schema(id: {}, typeParams: {}, kind: {})",
243 hex_u64(schema.id.0),
244 type_params,
245 format_swift_schema_kind(&schema.kind)
246 )
247}
248
249fn format_swift_schema_kind(kind: &vox_types::SchemaKind) -> String {
251 use vox_types::SchemaKind;
252
253 match kind {
254 SchemaKind::Struct { name, fields } => {
255 let fields_str = fields
256 .iter()
257 .map(|f| {
258 format!(
259 "FieldSchema(name: \"{}\", typeRef: {}, required: {})",
260 f.name,
261 format_swift_type_ref(&f.type_ref),
262 f.required
263 )
264 })
265 .collect::<Vec<_>>()
266 .join(", ");
267 format!(".struct(name: \"{}\", fields: [{}])", name, fields_str)
268 }
269 SchemaKind::Enum { name, variants } => {
270 let variants_str = variants
271 .iter()
272 .map(|v| {
273 format!(
274 "VariantSchema(name: \"{}\", index: {}, payload: {})",
275 v.name,
276 v.index,
277 format_swift_variant_payload(&v.payload)
278 )
279 })
280 .collect::<Vec<_>>()
281 .join(", ");
282 format!(".enum(name: \"{}\", variants: [{}])", name, variants_str)
283 }
284 SchemaKind::Tuple { elements } => {
285 let elems_str = elements
286 .iter()
287 .map(format_swift_type_ref)
288 .collect::<Vec<_>>()
289 .join(", ");
290 format!(".tuple(elements: [{}])", elems_str)
291 }
292 SchemaKind::List { element } => {
293 format!(".list(element: {})", format_swift_type_ref(element))
294 }
295 SchemaKind::Map { key, value } => {
296 format!(
297 ".map(key: {}, value: {})",
298 format_swift_type_ref(key),
299 format_swift_type_ref(value)
300 )
301 }
302 SchemaKind::Array { element, length } => {
303 format!(
304 ".array(element: {}, length: {})",
305 format_swift_type_ref(element),
306 length
307 )
308 }
309 SchemaKind::Option { element } => {
310 format!(".option(element: {})", format_swift_type_ref(element))
311 }
312 SchemaKind::Channel { direction, element } => {
313 let dir = match direction {
314 vox_types::ChannelDirection::Tx => ".tx",
315 vox_types::ChannelDirection::Rx => ".rx",
316 };
317 format!(
318 ".channel(direction: {}, element: {})",
319 dir,
320 format_swift_type_ref(element)
321 )
322 }
323 SchemaKind::Primitive { primitive_type } => {
324 format!(".primitive({})", format_swift_primitive(primitive_type))
325 }
326 }
327}
328
329fn format_swift_variant_payload(payload: &vox_types::VariantPayload) -> String {
331 use vox_types::VariantPayload;
332
333 match payload {
334 VariantPayload::Unit => ".unit".to_string(),
335 VariantPayload::Newtype { type_ref } => {
336 format!(".newtype(typeRef: {})", format_swift_type_ref(type_ref))
337 }
338 VariantPayload::Tuple { types } => {
339 let types_str = types
340 .iter()
341 .map(format_swift_type_ref)
342 .collect::<Vec<_>>()
343 .join(", ");
344 format!(".tuple(types: [{}])", types_str)
345 }
346 VariantPayload::Struct { fields } => {
347 let fields_str = fields
348 .iter()
349 .map(|f| {
350 format!(
351 "FieldSchema(name: \"{}\", typeRef: {}, required: {})",
352 f.name,
353 format_swift_type_ref(&f.type_ref),
354 f.required
355 )
356 })
357 .collect::<Vec<_>>()
358 .join(", ");
359 format!(".struct(fields: [{}])", fields_str)
360 }
361 }
362}
363
364fn format_swift_type_ref(type_ref: &TypeRef) -> String {
366 use crate::render::hex_u64;
367
368 match type_ref {
369 TypeRef::Concrete { type_id, args } => {
370 if args.is_empty() {
371 format!(".concrete({})", hex_u64(type_id.0))
372 } else {
373 let args_str = args
374 .iter()
375 .map(format_swift_type_ref)
376 .collect::<Vec<_>>()
377 .join(", ");
378 format!(".generic({}, args: [{}])", hex_u64(type_id.0), args_str)
379 }
380 }
381 TypeRef::Var { name } => {
382 format!(".var(name: \"{}\")", name.as_str())
383 }
384 }
385}
386
387fn format_swift_primitive(prim: &vox_types::PrimitiveType) -> String {
389 use vox_types::PrimitiveType;
390
391 match prim {
392 PrimitiveType::Bool => ".bool",
393 PrimitiveType::U8 => ".u8",
394 PrimitiveType::U16 => ".u16",
395 PrimitiveType::U32 => ".u32",
396 PrimitiveType::U64 => ".u64",
397 PrimitiveType::U128 => ".u128",
398 PrimitiveType::I8 => ".i8",
399 PrimitiveType::I16 => ".i16",
400 PrimitiveType::I32 => ".i32",
401 PrimitiveType::I64 => ".i64",
402 PrimitiveType::I128 => ".i128",
403 PrimitiveType::F32 => ".f32",
404 PrimitiveType::F64 => ".f64",
405 PrimitiveType::Char => ".char",
406 PrimitiveType::String => ".string",
407 PrimitiveType::Unit => ".unit",
408 PrimitiveType::Never => ".never",
409 PrimitiveType::Bytes => ".bytes",
410 PrimitiveType::Payload => ".payload",
411 }
412 .to_string()
413}
414
415fn shape_to_schema(shape: &'static Shape) -> String {
417 if is_bytes(shape) {
418 return ".bytes".into();
419 }
420
421 match classify_shape(shape) {
422 ShapeKind::Scalar(scalar) => match scalar {
423 ScalarType::Bool => ".bool".into(),
424 ScalarType::U8 => ".u8".into(),
425 ScalarType::U16 => ".u16".into(),
426 ScalarType::U32 => ".u32".into(),
427 ScalarType::U64 => ".u64".into(),
428 ScalarType::I8 => ".i8".into(),
429 ScalarType::I16 => ".i16".into(),
430 ScalarType::I32 => ".i32".into(),
431 ScalarType::I64 => ".i64".into(),
432 ScalarType::F32 => ".f32".into(),
433 ScalarType::F64 => ".f64".into(),
434 ScalarType::Str | ScalarType::CowStr | ScalarType::String => ".string".into(),
435 ScalarType::Unit => ".tuple(elements: [])".into(),
436 _ => ".bytes".into(), },
438 ShapeKind::List { element } | ShapeKind::Slice { element } => {
439 format!(".vec(element: {})", shape_to_schema(element))
440 }
441 ShapeKind::Option { inner } => {
442 format!(".option(inner: {})", shape_to_schema(inner))
443 }
444 ShapeKind::Map { key, value } => {
445 format!(
446 ".map(key: {}, value: {})",
447 shape_to_schema(key),
448 shape_to_schema(value)
449 )
450 }
451 ShapeKind::Tx { inner } => format!(".tx(element: {})", shape_to_schema(inner)),
452 ShapeKind::Rx { inner } => format!(".rx(element: {})", shape_to_schema(inner)),
453 ShapeKind::Tuple { elements } => {
454 let inner: Vec<String> = elements.iter().map(|p| shape_to_schema(p.shape)).collect();
455 format!(".tuple(elements: [{}])", inner.join(", "))
456 }
457 ShapeKind::Struct(StructInfo { fields, .. }) => {
458 let field_strs: Vec<String> = fields
459 .iter()
460 .map(|f| format!("(\"{}\", {})", f.name, shape_to_schema(f.shape())))
461 .collect();
462 format!(".struct(fields: [{}])", field_strs.join(", "))
463 }
464 ShapeKind::Enum(EnumInfo { variants, .. }) => {
465 let variant_strs: Vec<String> = variants
466 .iter()
467 .map(|v| {
468 let fields: Vec<String> = match classify_variant(v) {
469 VariantKind::Unit => vec![],
470 VariantKind::Newtype { inner } => vec![shape_to_schema(inner)],
471 VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
472 fields.iter().map(|f| shape_to_schema(f.shape())).collect()
473 }
474 };
475 format!("(\"{}\", [{}])", v.name, fields.join(", "))
476 })
477 .collect();
478 format!(".enum(variants: [{}])", variant_strs.join(", "))
479 }
480 _ => ".bytes".into(), }
482}
483
484fn generate_serializers(service: &ServiceDescriptor) -> String {
486 let mut out = String::new();
487 let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
488 let service_name_upper = service.service_name.to_upper_camel_case();
489
490 cw_writeln!(
491 w,
492 "public struct {service_name_upper}Serializers: BindingSerializers {{"
493 )
494 .unwrap();
495 {
496 let _indent = w.indent();
497 w.writeln("public init() {}").unwrap();
498 w.blank_line().unwrap();
499
500 w.writeln(
502 "public func txSerializer(for schema: BindingSchema) -> @Sendable (Any) -> [UInt8] {",
503 )
504 .unwrap();
505 {
506 let _indent = w.indent();
507 w.writeln("switch schema {").unwrap();
508 w.writeln("case .bool: return { encodeBool($0 as! Bool) }")
509 .unwrap();
510 w.writeln("case .u8: return { encodeU8($0 as! UInt8) }")
511 .unwrap();
512 w.writeln("case .i8: return { encodeI8($0 as! Int8) }")
513 .unwrap();
514 w.writeln("case .u16: return { encodeU16($0 as! UInt16) }")
515 .unwrap();
516 w.writeln("case .i16: return { encodeI16($0 as! Int16) }")
517 .unwrap();
518 w.writeln("case .u32: return { encodeU32($0 as! UInt32) }")
519 .unwrap();
520 w.writeln("case .i32: return { encodeI32($0 as! Int32) }")
521 .unwrap();
522 w.writeln("case .u64: return { encodeVarint($0 as! UInt64) }")
523 .unwrap();
524 w.writeln("case .i64: return { encodeI64($0 as! Int64) }")
525 .unwrap();
526 w.writeln("case .f32: return { encodeF32($0 as! Float) }")
527 .unwrap();
528 w.writeln("case .f64: return { encodeF64($0 as! Double) }")
529 .unwrap();
530 w.writeln("case .string: return { encodeString($0 as! String) }")
531 .unwrap();
532 w.writeln("case .bytes: return { [UInt8]($0 as! Data) }")
533 .unwrap();
534 w.writeln(
535 "case .tx(_, _), .rx(_, _): fatalError(\"Channel schemas are not serialized directly\")",
536 )
537 .unwrap();
538 w.writeln(
539 "default: fatalError(\"Unsupported schema for Tx serialization: \\(schema)\")",
540 )
541 .unwrap();
542 w.writeln("}").unwrap();
543 }
544 w.writeln("}").unwrap();
545 w.blank_line().unwrap();
546
547 w.writeln(
549 "public func rxDeserializer(for schema: BindingSchema) -> @Sendable ([UInt8]) throws -> Any {",
550 )
551 .unwrap();
552 {
553 let _indent = w.indent();
554 w.writeln("switch schema {").unwrap();
555 w.writeln("case .bool: return { var o = 0; return try decodeBool(from: Data($0), offset: &o) }").unwrap();
556 w.writeln(
557 "case .u8: return { var o = 0; return try decodeU8(from: Data($0), offset: &o) }",
558 )
559 .unwrap();
560 w.writeln(
561 "case .i8: return { var o = 0; return try decodeI8(from: Data($0), offset: &o) }",
562 )
563 .unwrap();
564 w.writeln(
565 "case .u16: return { var o = 0; return try decodeU16(from: Data($0), offset: &o) }",
566 )
567 .unwrap();
568 w.writeln(
569 "case .i16: return { var o = 0; return try decodeI16(from: Data($0), offset: &o) }",
570 )
571 .unwrap();
572 w.writeln(
573 "case .u32: return { var o = 0; return try decodeU32(from: Data($0), offset: &o) }",
574 )
575 .unwrap();
576 w.writeln(
577 "case .i32: return { var o = 0; return try decodeI32(from: Data($0), offset: &o) }",
578 )
579 .unwrap();
580 w.writeln("case .u64: return { var o = 0; return try decodeVarint(from: Data($0), offset: &o) }").unwrap();
581 w.writeln(
582 "case .i64: return { var o = 0; return try decodeI64(from: Data($0), offset: &o) }",
583 )
584 .unwrap();
585 w.writeln(
586 "case .f32: return { var o = 0; return try decodeF32(from: Data($0), offset: &o) }",
587 )
588 .unwrap();
589 w.writeln(
590 "case .f64: return { var o = 0; return try decodeF64(from: Data($0), offset: &o) }",
591 )
592 .unwrap();
593 w.writeln("case .string: return { var o = 0; return try decodeString(from: Data($0), offset: &o) }").unwrap();
594 w.writeln("case .bytes: return { Data($0) }").unwrap();
595 w.writeln(
596 "case .tx(_, _), .rx(_, _): fatalError(\"Channel schemas are not deserialized directly\")",
597 )
598 .unwrap();
599 w.writeln(
600 "default: fatalError(\"Unsupported schema for Rx deserialization: \\(schema)\")",
601 )
602 .unwrap();
603 w.writeln("}").unwrap();
604 }
605 w.writeln("}").unwrap();
606 }
607 w.writeln("}").unwrap();
608 w.blank_line().unwrap();
609
610 out
611}