1pub use vox_schema::*;
8
9use std::sync::Arc;
10
11use facet::Facet;
12use facet_core::{DeclId, Def, ScalarType, Shape, StructKind, Type, UserType};
13use indexmap::IndexMap;
14use std::collections::{HashMap, HashSet};
15use std::sync::Mutex;
16
17use crate::{MethodId, RequestCall, RequestResponse, is_rx, is_tx};
18
19#[derive(Debug)]
25pub enum SchemaExtractError {
26 UnhandledType { type_desc: String },
28
29 PointerWithoutTypeParams { shape_desc: String },
31
32 UnresolvedTempId { temp_id: CycleSchemaIndex },
34
35 MissingAssignment { context: String },
37}
38
39impl std::fmt::Display for SchemaExtractError {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 Self::UnhandledType { type_desc } => {
43 write!(f, "schema extraction: unhandled type: {type_desc}")
44 }
45 Self::PointerWithoutTypeParams { shape_desc } => {
46 write!(
47 f,
48 "schema extraction: Pointer type without type_params: {shape_desc}"
49 )
50 }
51 Self::UnresolvedTempId { temp_id } => {
52 write!(
53 f,
54 "schema extraction: unresolved temp ID {temp_id:?} during finalization"
55 )
56 }
57 Self::MissingAssignment { context } => {
58 write!(f, "schema extraction: missing DeclId assignment: {context}")
59 }
60 }
61 }
62}
63
64pub trait Schematic {
66 fn direction(&self) -> BindingDirection;
67 fn attach_schemas(&mut self, schemas: CborPayload);
68}
69
70impl<'payload> Schematic for RequestCall<'payload> {
71 fn direction(&self) -> BindingDirection {
72 BindingDirection::Args
73 }
74
75 fn attach_schemas(&mut self, schemas: CborPayload) {
76 self.schemas = schemas;
77 }
78}
79
80impl<'payload> Schematic for RequestResponse<'payload> {
81 fn direction(&self) -> BindingDirection {
82 BindingDirection::Response
83 }
84
85 fn attach_schemas(&mut self, schemas: CborPayload) {
86 self.schemas = schemas;
87 }
88}
89
90impl std::error::Error for SchemaExtractError {}
91
92pub struct SchemaSendTracker {
103 sent_bindings: HashSet<(MethodId, BindingDirection)>,
106
107 sent_schemas: HashSet<SchemaHash>,
109}
110
111#[derive(Debug, Clone)]
113pub struct PreparedSchemaPlan {
114 pub schemas: Vec<Schema>,
115 pub root: TypeRef,
116}
117
118impl PreparedSchemaPlan {
119 pub fn to_cbor(&self) -> CborPayload {
120 SchemaPayload {
121 schemas: self.schemas.clone(),
122 root: self.root.clone(),
123 }
124 .to_cbor()
125 }
126}
127
128impl SchemaSendTracker {
129 pub fn new() -> Self {
130 SchemaSendTracker {
131 sent_bindings: HashSet::new(),
132 sent_schemas: HashSet::new(),
133 }
134 }
135
136 pub fn reset(&mut self) {
138 self.sent_bindings.clear();
139 self.sent_schemas.clear();
140 }
141
142 pub fn has_sent_binding(&self, method_id: MethodId, direction: BindingDirection) -> bool {
144 self.sent_bindings.contains(&(method_id, direction))
145 }
146
147 pub fn plan_for_shape(shape: &'static Shape) -> Result<PreparedSchemaPlan, SchemaExtractError> {
150 let extracted = extract_schemas(shape)?;
151 Ok(PreparedSchemaPlan {
152 schemas: extracted.schemas.to_vec(),
153 root: extracted.root.clone(),
154 })
155 }
156
157 pub fn plan_from_source(root_type: &TypeRef, source: &dyn SchemaSource) -> PreparedSchemaPlan {
160 let mut all_schemas = Vec::new();
161 let mut visited = HashSet::new();
162 let mut queue = Vec::new();
163 root_type.collect_ids(&mut queue);
164
165 while let Some(id) = queue.pop() {
166 if !visited.insert(id) {
167 continue;
168 }
169 if let Some(schema) = source.get_schema(id) {
170 for child_id in schema_child_ids(&schema.kind) {
171 queue.push(child_id);
172 }
173 all_schemas.push(schema);
174 }
175 }
176
177 PreparedSchemaPlan {
178 schemas: all_schemas,
179 root: root_type.clone(),
180 }
181 }
182
183 fn unsent_schemas_for_prepared_plan(&self, prepared: &PreparedSchemaPlan) -> Vec<Schema> {
184 prepared
185 .schemas
186 .iter()
187 .filter(|schema| !self.sent_schemas.contains(&schema.id))
188 .cloned()
189 .collect()
190 }
191
192 pub fn preview_prepared_plan(
195 &mut self,
196 method_id: MethodId,
197 direction: BindingDirection,
198 prepared: &PreparedSchemaPlan,
199 ) -> CborPayload {
200 let key = (method_id, direction);
201 if self.sent_bindings.contains(&key) {
202 return CborPayload::default();
203 }
204 let schema_payload = SchemaPayload {
205 schemas: self.unsent_schemas_for_prepared_plan(prepared),
206 root: prepared.root.clone(),
207 };
208 schema_payload.to_cbor()
209 }
210
211 pub fn mark_prepared_plan_sent(
213 &mut self,
214 method_id: MethodId,
215 direction: BindingDirection,
216 prepared: &PreparedSchemaPlan,
217 ) {
218 let key = (method_id, direction);
219 if self.sent_bindings.contains(&key) {
220 return;
221 }
222 for schema in &prepared.schemas {
223 self.sent_schemas.insert(schema.id);
224 }
225 self.sent_bindings.insert(key);
226 }
227
228 pub fn commit_prepared_plan(
232 &mut self,
233 method_id: MethodId,
234 direction: BindingDirection,
235 prepared: PreparedSchemaPlan,
236 ) -> CborPayload {
237 let schema_payload = SchemaPayload {
238 schemas: self.unsent_schemas_for_prepared_plan(&prepared),
239 root: prepared.root.clone(),
240 };
241 dlog!(
242 "[schema] commit binding: method={:?} direction={:?} root={:?} schema_count={}",
243 method_id,
244 direction,
245 schema_payload.root,
246 schema_payload.schemas.len()
247 );
248 let cbor = schema_payload.to_cbor();
249 self.mark_prepared_plan_sent(method_id, direction, &prepared);
250 cbor
251 }
252
253 pub fn attach_schemas_for_shape_if_needed(
263 &mut self,
264 method_id: MethodId,
265 shape: &'static Shape,
266 schematic: &mut impl Schematic,
267 ) -> Result<CborPayload, SchemaExtractError> {
268 let key = (method_id, schematic.direction());
269
270 if self.sent_bindings.contains(&key) {
272 let empty = CborPayload::default();
273 schematic.attach_schemas(empty.clone());
274 return Ok(empty);
275 }
276
277 let prepared = Self::plan_for_shape(shape)?;
278 let cbor = self.commit_prepared_plan(method_id, schematic.direction(), prepared);
279 schematic.attach_schemas(cbor.clone());
280 Ok(cbor)
281 }
282
283 pub fn prepare_send(
288 &mut self,
289 method_id: MethodId,
290 direction: BindingDirection,
291 root_type: &TypeRef,
292 source: &dyn SchemaSource,
293 ) -> CborPayload {
294 let prepared = Self::plan_from_source(root_type, source);
295 self.commit_prepared_plan(method_id, direction, prepared)
296 }
297
298 pub fn commit_prepared_send(
299 &mut self,
300 method_id: MethodId,
301 direction: BindingDirection,
302 prepared: &CborPayload,
303 ) -> CborPayload {
304 let prepared_payload = SchemaPayload::from_cbor(&prepared.0)
305 .expect("prepared schema payloads must be valid CBOR");
306 self.commit_prepared_plan(
307 method_id,
308 direction,
309 PreparedSchemaPlan {
310 schemas: prepared_payload.schemas,
311 root: prepared_payload.root,
312 },
313 )
314 }
315
316 pub fn extract_schemas(
319 &mut self,
320 shape: &'static Shape,
321 ) -> Result<Arc<ExtractedSchemas>, SchemaExtractError> {
322 self::extract_schemas(shape)
323 }
324}
325
326impl Default for SchemaSendTracker {
327 fn default() -> Self {
328 Self::new()
329 }
330}
331
332impl std::fmt::Debug for SchemaSendTracker {
333 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334 f.debug_struct("SchemaSendTracker").finish_non_exhaustive()
335 }
336}
337
338pub struct SchemaRecvTracker {
350 received: Mutex<HashMap<SchemaHash, Schema>>,
352 received_args_bindings: Mutex<HashMap<MethodId, TypeRef>>,
354 received_response_bindings: Mutex<HashMap<MethodId, TypeRef>>,
356 plan_cache: Mutex<HashMap<PlanCacheKey, Box<dyn std::any::Any + Send + Sync>>>,
359}
360
361#[derive(Clone, Copy, PartialEq, Eq, Hash)]
363pub struct PlanCacheKey {
364 pub method_id: MethodId,
365 pub direction: BindingDirection,
366 pub local_shape: &'static Shape,
367}
368
369#[derive(Debug)]
371pub struct DuplicateSchemaError {
372 pub type_id: SchemaHash,
373}
374
375impl std::fmt::Display for DuplicateSchemaError {
376 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
377 write!(
378 f,
379 "duplicate TypeSchemaId {:?} received on same connection — protocol error",
380 self.type_id
381 )
382 }
383}
384
385impl std::error::Error for DuplicateSchemaError {}
386
387impl SchemaRecvTracker {
388 pub fn new() -> Self {
389 SchemaRecvTracker {
390 received: Mutex::new(HashMap::new()),
391 received_args_bindings: Mutex::new(HashMap::new()),
392 received_response_bindings: Mutex::new(HashMap::new()),
393 plan_cache: Mutex::new(HashMap::new()),
394 }
395 }
396
397 pub fn record_received(
402 &self,
403 method_id: MethodId,
404 direction: BindingDirection,
405 payload: SchemaPayload,
406 ) -> Result<(), DuplicateSchemaError> {
407 {
408 let mut received = self.received.lock().unwrap();
409 for schema in &payload.schemas {
410 dlog!("[schema] record_received: id={:?}", schema.id);
411 }
412 for schema in payload.schemas {
413 if let Some(existing) = received.get(&schema.id) {
414 dlog!(
415 "[schema] DUPLICATE: id={:?} existing={:?} new={:?}",
416 schema.id,
417 existing,
418 schema
419 );
420 return Err(DuplicateSchemaError { type_id: schema.id });
421 }
422 received.insert(schema.id, schema);
423 }
424 }
425 let map = match direction {
426 BindingDirection::Args => &self.received_args_bindings,
427 BindingDirection::Response => &self.received_response_bindings,
428 };
429 dlog!(
430 "[schema] record binding: method={:?} direction={:?} root={:?}",
431 method_id,
432 direction,
433 payload.root
434 );
435 map.lock().unwrap().insert(method_id, payload.root);
436 Ok(())
437 }
438
439 pub fn get_remote_args_root(&self, method_id: MethodId) -> Option<TypeRef> {
441 self.received_args_bindings
442 .lock()
443 .unwrap()
444 .get(&method_id)
445 .cloned()
446 }
447
448 pub fn get_remote_response_root(&self, method_id: MethodId) -> Option<TypeRef> {
450 self.received_response_bindings
451 .lock()
452 .unwrap()
453 .get(&method_id)
454 .cloned()
455 }
456
457 pub fn get_received(&self, type_id: &SchemaHash) -> Option<Schema> {
459 self.received.lock().unwrap().get(type_id).cloned()
460 }
461
462 pub fn received_registry(&self) -> SchemaRegistry {
464 self.received.lock().unwrap().clone()
465 }
466
467 pub fn get_cached_plan<T: Send + Sync + 'static>(
469 &self,
470 key: &PlanCacheKey,
471 ) -> Option<std::sync::Arc<T>> {
472 let cache = self.plan_cache.lock().unwrap();
473 cache.get(key)?.downcast_ref::<std::sync::Arc<T>>().cloned()
474 }
475
476 pub fn insert_cached_plan<T: Send + Sync + 'static>(
478 &self,
479 key: PlanCacheKey,
480 plan: std::sync::Arc<T>,
481 ) {
482 self.plan_cache.lock().unwrap().insert(key, Box::new(plan));
483 }
484}
485
486impl Default for SchemaRecvTracker {
487 fn default() -> Self {
488 Self::new()
489 }
490}
491
492impl SchemaSource for SchemaRecvTracker {
493 fn get_schema(&self, id: SchemaHash) -> Option<Schema> {
494 self.get_received(&id)
495 }
496}
497
498impl std::fmt::Debug for SchemaRecvTracker {
499 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
500 f.debug_struct("SchemaRecvTracker").finish_non_exhaustive()
501 }
502}
503
504#[derive(Clone)]
506pub struct ExtractedSchemas {
507 pub schemas: Vec<Schema>,
509
510 pub root: TypeRef,
512}
513
514pub fn extract_schemas(shape: &'static Shape) -> Result<Arc<ExtractedSchemas>, SchemaExtractError> {
517 use std::sync::OnceLock;
518
519 static CACHE: OnceLock<Mutex<HashMap<&'static Shape, Arc<ExtractedSchemas>>>> = OnceLock::new();
520 let cache = CACHE.get_or_init(|| Mutex::new(HashMap::new()));
521
522 if let Some(cached) = cache.lock().unwrap().get(shape) {
523 return Ok(Arc::clone(cached));
524 }
525
526 let result = Arc::new(extract_schemas_uncached(shape)?);
527 cache.lock().unwrap().insert(shape, Arc::clone(&result));
528 Ok(result)
529}
530
531fn extract_schemas_uncached(shape: &'static Shape) -> Result<ExtractedSchemas, SchemaExtractError> {
532 let mut ctx = ExtractCtx {
533 next_id: CycleSchemaIndex::first(),
534 schemas: IndexMap::new(),
535 assigned: HashMap::new(),
536 seen: HashSet::new(),
537 };
538 let root_mixed_ref = ctx.extract(shape)?;
539 let schemas: Vec<MixedSchema> = ctx.schemas.into_values().collect();
540 let (finalized, temp_to_final) = finalize_content_hashes(schemas)?;
541
542 let resolve = |mid: MixedId| -> SchemaHash {
543 match mid {
544 MixedId::Final(tid) => tid,
545 MixedId::Temp(t) => temp_to_final.get(&t).copied().unwrap_or(SchemaHash(0)),
546 }
547 };
548 let root_type_ref = root_mixed_ref.map(resolve);
549
550 Ok(ExtractedSchemas {
551 schemas: finalized,
552 root: root_type_ref,
553 })
554}
555
556fn resolve_mixed(id: MixedId, temp_to_final: &HashMap<CycleSchemaIndex, SchemaHash>) -> SchemaHash {
565 match id {
566 MixedId::Final(tid) => tid,
567 MixedId::Temp(t) => temp_to_final.get(&t).copied().unwrap_or(SchemaHash(0)),
568 }
569}
570
571#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
572enum ExtractKey {
573 Decl(DeclId),
574 AnonymousTupleArity(usize),
575}
576
577fn finalize_content_hashes(
588 schemas: Vec<MixedSchema>,
589) -> Result<(Vec<Schema>, HashMap<CycleSchemaIndex, SchemaHash>), SchemaExtractError> {
590 let temp_to_idx: HashMap<CycleSchemaIndex, usize> = schemas
592 .iter()
593 .enumerate()
594 .filter_map(|(i, s)| match s.id {
595 MixedId::Temp(t) => Some((t, i)),
596 MixedId::Final(_) => None,
597 })
598 .collect();
599
600 fn collect_refs(kind: &MixedSchemaKind) -> Vec<MixedId> {
601 let mut refs = Vec::new();
602 kind.for_each_type_ref(&mut |tr: &TypeRef<MixedId>| tr.collect_ids(&mut refs));
603 refs
604 }
605
606 let n = schemas.len();
608 let mut in_recursive_group: Vec<bool> = vec![false; n];
609
610 for (i, schema) in schemas.iter().enumerate() {
611 if matches!(schema.id, MixedId::Final(_)) {
612 continue; }
614 for r in collect_refs(&schema.kind) {
615 if let MixedId::Temp(t) = r
616 && let Some(&ref_idx) = temp_to_idx.get(&t)
617 && ref_idx >= i
618 {
619 in_recursive_group[i] = true;
620 in_recursive_group[ref_idx] = true;
621 }
622 }
623 }
624
625 let mut temp_to_final: HashMap<CycleSchemaIndex, SchemaHash> = HashMap::new();
627
628 for (i, schema) in schemas.iter().enumerate() {
630 if in_recursive_group[i] {
631 continue;
632 }
633 if let MixedId::Temp(temp) = schema.id {
634 let final_id = compute_content_hash(&schema.kind, &schema.type_params, &|mid| {
635 resolve_mixed(mid, &temp_to_final)
636 });
637 temp_to_final.insert(temp, final_id);
638 }
639 }
640
641 let mut i = 0;
643 while i < n {
644 if !in_recursive_group[i] {
645 i += 1;
646 continue;
647 }
648
649 let group_start = i;
650 while i < n && in_recursive_group[i] {
651 i += 1;
652 }
653 let group_end = i;
654
655 let group_temp_ids: HashSet<CycleSchemaIndex> = schemas[group_start..group_end]
657 .iter()
658 .filter_map(|s| match s.id {
659 MixedId::Temp(t) => Some(t),
660 _ => None,
661 })
662 .collect();
663
664 let mut prelim_hashes: Vec<SchemaHash> = Vec::new();
666 for schema in &schemas[group_start..group_end] {
667 let prelim =
668 compute_content_hash(&schema.kind, &schema.type_params, &|mid| match mid {
669 MixedId::Final(tid) => tid,
670 MixedId::Temp(t) => {
671 if group_temp_ids.contains(&t) {
672 SchemaHash(0) } else {
674 temp_to_final.get(&t).copied().unwrap_or(SchemaHash(0))
675 }
676 }
677 });
678 prelim_hashes.push(prelim);
679 }
680
681 let mut order: Vec<usize> = (0..prelim_hashes.len()).collect();
683 order.sort_by_key(|&i| prelim_hashes[i].0);
684
685 let mut group_hasher = blake3::Hasher::new();
687 for &idx in &order {
688 group_hasher.update(&prelim_hashes[idx].0.to_le_bytes());
689 }
690 let gh = group_hasher.finalize();
691 let group_hash = u64::from_le_bytes(gh.as_bytes()[0..8].try_into().unwrap());
692
693 for (position, &idx) in order.iter().enumerate() {
694 let mut fh = blake3::Hasher::new();
695 fh.update(&group_hash.to_le_bytes());
696 fh.update(&(position as u64).to_le_bytes());
697 let fo = fh.finalize();
698 let final_hash =
699 SchemaHash(u64::from_le_bytes(fo.as_bytes()[0..8].try_into().unwrap()));
700
701 if let MixedId::Temp(t) = schemas[group_start + idx].id {
702 temp_to_final.insert(t, final_hash);
703 }
704 }
705 }
706
707 let resolve = |mid: MixedId| -> Result<SchemaHash, SchemaExtractError> {
709 match mid {
710 MixedId::Final(tid) => Ok(tid),
711 MixedId::Temp(t) => temp_to_final
712 .get(&t)
713 .copied()
714 .ok_or(SchemaExtractError::UnresolvedTempId { temp_id: t }),
715 }
716 };
717
718 let mut resolve_type_ref =
719 |type_ref: TypeRef<MixedId>| -> Result<TypeRef<SchemaHash>, SchemaExtractError> {
720 type_ref.try_map(&resolve)
721 };
722
723 let mut seen_ids = HashSet::new();
724 let finalized: Vec<Schema> = schemas
725 .into_iter()
726 .map(|s| {
727 let type_id = resolve(s.id)?;
728 Ok(Schema {
729 id: type_id,
730 type_params: s.type_params,
731 kind: s.kind.try_map_type_refs(&mut resolve_type_ref)?,
732 })
733 })
734 .collect::<Result<Vec<_>, _>>()?
735 .into_iter()
736 .filter(|s| seen_ids.insert(s.id))
737 .collect();
738
739 Ok((finalized, temp_to_final))
740}
741
742struct ExtractCtx {
743 next_id: CycleSchemaIndex,
745 schemas: IndexMap<ExtractKey, MixedSchema>,
748 assigned: HashMap<ExtractKey, MixedId>,
751 seen: HashSet<&'static Shape>,
754}
755
756impl ExtractCtx {
757 fn id_for_key(&mut self, key: ExtractKey) -> MixedId {
759 if let Some(&id) = self.assigned.get(&key) {
760 return id;
761 }
762 let id = MixedId::Temp(self.next_id.next_index());
763 self.assigned.insert(key, id);
764 id
765 }
766
767 fn emit_schema(&mut self, key: ExtractKey, schema: MixedSchema) {
769 self.schemas.entry(key).or_insert(schema);
770 }
771
772 fn key_for_shape(&self, shape: &'static Shape) -> ExtractKey {
773 match anonymous_tuple_arity(shape) {
774 Some(arity) => ExtractKey::AnonymousTupleArity(arity),
775 None => ExtractKey::Decl(shape.decl_id),
776 }
777 }
778
779 fn type_ref_for_shape(
782 &mut self,
783 shape: &'static Shape,
784 param_map: &[(&'static Shape, TypeParamName)],
785 ) -> Result<TypeRef<MixedId>, SchemaExtractError> {
786 if let Some((_, name)) = param_map
787 .iter()
788 .find(|(param_shape, _)| shape.is_shape(param_shape))
789 {
790 self.extract(shape)?;
793 Ok(TypeRef::Var { name: name.clone() })
794 } else {
795 self.extract(shape)
796 }
797 }
798
799 fn extract(&mut self, shape: &'static Shape) -> Result<TypeRef<MixedId>, SchemaExtractError> {
802 if is_tx(shape) || is_rx(shape) {
804 let direction = if is_tx(shape) {
805 ChannelDirection::Tx
806 } else {
807 ChannelDirection::Rx
808 };
809 if let Some(inner) = shape.type_params.first() {
810 let elem_ref = self.extract(inner.shape)?;
811 let key = self.key_for_shape(shape);
812 let id = self.id_for_key(key);
813 let type_params = vec![TypeParamName("T".to_string())];
816 self.emit_schema(
817 key,
818 MixedSchema {
819 id,
820 type_params,
821 kind: SchemaKind::Channel {
822 direction,
823 element: TypeRef::Var {
824 name: TypeParamName("T".to_string()),
825 },
826 },
827 },
828 );
829 self.seen.insert(shape);
830 return Ok(TypeRef::Concrete {
831 type_id: id,
832 args: vec![elem_ref],
833 });
834 }
835 }
836
837 if shape.is_transparent()
839 && let Some(inner) = shape.inner
840 {
841 return self.extract(inner);
842 }
843
844 if let Def::Pointer(ptr_def) = shape.def
847 && let Some(pointee) = ptr_def.pointee
848 {
849 return self.extract(pointee);
850 }
851
852 let key = self.key_for_shape(shape);
853 let id = self.id_for_key(key);
854
855 if !self.seen.insert(shape) {
859 let args = self.extract_instantiation_args(shape)?;
862 return Ok(if args.is_empty() {
863 TypeRef::concrete(id)
864 } else {
865 TypeRef::generic(id, args)
866 });
867 }
868
869 let already_emitted = self.schemas.contains_key(&key);
872 if already_emitted {
873 let args = self.extract_instantiation_args(shape)?;
874 return Ok(if args.is_empty() {
875 TypeRef::concrete(id)
876 } else {
877 TypeRef::generic(id, args)
878 });
879 }
880
881 let param_map: Vec<(&'static Shape, TypeParamName)> = shape
884 .type_params
885 .iter()
886 .map(|tp| (tp.shape, TypeParamName(tp.name.to_string())))
887 .collect();
888 let type_param_names: Vec<TypeParamName> = shape
889 .type_params
890 .iter()
891 .map(|tp| TypeParamName(tp.name.to_string()))
892 .collect();
893
894 if let Some(scalar) = shape.scalar_type() {
897 self.emit_schema(
898 key,
899 MixedSchema {
900 id,
901 type_params: vec![],
902 kind: SchemaKind::Primitive {
903 primitive_type: scalar_to_primitive(scalar),
904 },
905 },
906 );
907 return Ok(TypeRef::concrete(id));
908 }
909
910 match shape.def {
913 Def::List(list_def) => {
914 if let Some(ScalarType::U8) = list_def.t().scalar_type() {
915 self.emit_schema(
916 key,
917 MixedSchema {
918 id,
919 type_params: vec![],
920 kind: SchemaKind::Primitive {
921 primitive_type: PrimitiveType::Bytes,
922 },
923 },
924 );
925 return Ok(TypeRef::concrete(id));
926 }
927 let elem_ref = self.type_ref_for_shape(list_def.t(), ¶m_map)?;
928 let args = self.extract_type_args(shape)?;
929 self.emit_schema(
930 key,
931 MixedSchema {
932 id,
933 type_params: type_param_names,
934 kind: SchemaKind::List { element: elem_ref },
935 },
936 );
937 return Ok(if args.is_empty() {
938 TypeRef::concrete(id)
939 } else {
940 TypeRef::generic(id, args)
941 });
942 }
943 Def::Array(array_def) => {
944 let elem_ref = self.type_ref_for_shape(array_def.t(), ¶m_map)?;
945 let args = self.extract_type_args(shape)?;
946 self.emit_schema(
947 key,
948 MixedSchema {
949 id,
950 type_params: type_param_names,
951 kind: SchemaKind::Array {
952 element: elem_ref,
953 length: array_def.n as u64,
954 },
955 },
956 );
957 return Ok(if args.is_empty() {
958 TypeRef::concrete(id)
959 } else {
960 TypeRef::generic(id, args)
961 });
962 }
963 Def::Slice(slice_def) => {
964 if let Some(ScalarType::U8) = slice_def.t().scalar_type() {
965 self.emit_schema(
966 key,
967 MixedSchema {
968 id,
969 type_params: vec![],
970 kind: SchemaKind::Primitive {
971 primitive_type: PrimitiveType::Bytes,
972 },
973 },
974 );
975 return Ok(TypeRef::concrete(id));
976 }
977 let elem_ref = self.type_ref_for_shape(slice_def.t(), ¶m_map)?;
978 let args = self.extract_type_args(shape)?;
979 self.emit_schema(
980 key,
981 MixedSchema {
982 id,
983 type_params: type_param_names,
984 kind: SchemaKind::List { element: elem_ref },
985 },
986 );
987 return Ok(if args.is_empty() {
988 TypeRef::concrete(id)
989 } else {
990 TypeRef::generic(id, args)
991 });
992 }
993 Def::Map(map_def) => {
994 let key_ref = self.type_ref_for_shape(map_def.k(), ¶m_map)?;
995 let val_ref = self.type_ref_for_shape(map_def.v(), ¶m_map)?;
996 let args = self.extract_type_args(shape)?;
997 self.emit_schema(
998 key,
999 MixedSchema {
1000 id,
1001 type_params: type_param_names,
1002 kind: SchemaKind::Map {
1003 key: key_ref,
1004 value: val_ref,
1005 },
1006 },
1007 );
1008 return Ok(if args.is_empty() {
1009 TypeRef::concrete(id)
1010 } else {
1011 TypeRef::generic(id, args)
1012 });
1013 }
1014 Def::Set(set_def) => {
1015 let elem_ref = self.type_ref_for_shape(set_def.t(), ¶m_map)?;
1016 let args = self.extract_type_args(shape)?;
1017 self.emit_schema(
1018 key,
1019 MixedSchema {
1020 id,
1021 type_params: type_param_names,
1022 kind: SchemaKind::List { element: elem_ref },
1023 },
1024 );
1025 return Ok(if args.is_empty() {
1026 TypeRef::concrete(id)
1027 } else {
1028 TypeRef::generic(id, args)
1029 });
1030 }
1031 Def::Option(opt_def) => {
1032 let elem_ref = self.type_ref_for_shape(opt_def.t(), ¶m_map)?;
1033 let args = self.extract_type_args(shape)?;
1034 self.emit_schema(
1035 key,
1036 MixedSchema {
1037 id,
1038 type_params: type_param_names,
1039 kind: SchemaKind::Option { element: elem_ref },
1040 },
1041 );
1042 return Ok(if args.is_empty() {
1043 TypeRef::concrete(id)
1044 } else {
1045 TypeRef::generic(id, args)
1046 });
1047 }
1048 Def::Result(result_def) => {
1049 let ok_ref = self.type_ref_for_shape(result_def.t(), ¶m_map)?;
1050 let err_ref = self.type_ref_for_shape(result_def.e(), ¶m_map)?;
1051 let args = self.extract_type_args(shape)?;
1052 self.emit_schema(
1053 key,
1054 MixedSchema {
1055 id,
1056 type_params: type_param_names,
1057 kind: SchemaKind::Enum {
1058 name: shape.type_identifier.to_string(),
1059 variants: vec![
1060 VariantSchema {
1061 name: "Ok".to_string(),
1062 index: 0,
1063 payload: VariantPayload::Newtype { type_ref: ok_ref },
1064 },
1065 VariantSchema {
1066 name: "Err".to_string(),
1067 index: 1,
1068 payload: VariantPayload::Newtype { type_ref: err_ref },
1069 },
1070 ],
1071 },
1072 },
1073 );
1074 return Ok(if args.is_empty() {
1075 TypeRef::concrete(id)
1076 } else {
1077 TypeRef::generic(id, args)
1078 });
1079 }
1080 _ => {}
1081 }
1082
1083 let kind = match shape.ty {
1085 Type::User(UserType::Struct(struct_type)) => match struct_type.kind {
1088 StructKind::Unit => {
1089 let primitive_type = if is_infallible_shape(shape) {
1090 PrimitiveType::Never
1091 } else {
1092 PrimitiveType::Unit
1093 };
1094 SchemaKind::Primitive { primitive_type }
1095 }
1096 StructKind::TupleStruct | StructKind::Tuple => {
1097 if let Some(arity) = anonymous_tuple_arity(shape) {
1098 let args = self.extract_instantiation_args(shape)?;
1099 let type_params = tuple_type_params(arity);
1100 let elements = type_params
1101 .iter()
1102 .cloned()
1103 .map(|name| TypeRef::Var { name })
1104 .collect();
1105 self.emit_schema(
1106 key,
1107 MixedSchema {
1108 id,
1109 type_params,
1110 kind: SchemaKind::Tuple { elements },
1111 },
1112 );
1113 return Ok(TypeRef::generic(id, args));
1114 }
1115 let mut elements = Vec::with_capacity(struct_type.fields.len());
1116 for f in struct_type.fields {
1117 elements.push(self.type_ref_for_shape(f.shape(), ¶m_map)?);
1118 }
1119 SchemaKind::Tuple { elements }
1120 }
1121 StructKind::Struct => {
1122 let mut fields = Vec::with_capacity(struct_type.fields.len());
1123 for f in struct_type.fields {
1124 fields.push(FieldSchema {
1125 name: f.name.to_string(),
1126 type_ref: self.type_ref_for_shape(f.shape(), ¶m_map)?,
1127 required: f.default.is_none(),
1128 });
1129 }
1130 SchemaKind::Struct {
1131 name: shape.type_identifier.to_string(),
1132 fields,
1133 }
1134 }
1135 },
1136 Type::User(UserType::Enum(enum_type)) => {
1138 let mut variants = Vec::with_capacity(enum_type.variants.len());
1139 for (i, v) in enum_type.variants.iter().enumerate() {
1140 let payload = match v.data.kind {
1141 StructKind::Unit => VariantPayload::Unit,
1142 StructKind::TupleStruct | StructKind::Tuple => {
1143 if v.data.fields.len() == 1 {
1144 VariantPayload::Newtype {
1145 type_ref: self
1146 .type_ref_for_shape(v.data.fields[0].shape(), ¶m_map)?,
1147 }
1148 } else {
1149 let mut types = Vec::with_capacity(v.data.fields.len());
1150 for f in v.data.fields {
1151 types.push(self.type_ref_for_shape(f.shape(), ¶m_map)?);
1152 }
1153 VariantPayload::Tuple { types }
1154 }
1155 }
1156 StructKind::Struct => {
1157 let mut fields = Vec::with_capacity(v.data.fields.len());
1158 for f in v.data.fields {
1159 fields.push(FieldSchema {
1160 name: f.name.to_string(),
1161 type_ref: self.type_ref_for_shape(f.shape(), ¶m_map)?,
1162 required: true,
1163 });
1164 }
1165 VariantPayload::Struct { fields }
1166 }
1167 };
1168 variants.push(VariantSchema {
1169 name: v.name.to_string(),
1170 index: i as u32,
1171 payload,
1172 });
1173 }
1174 SchemaKind::Enum {
1175 name: shape.type_identifier.to_string(),
1176 variants,
1177 }
1178 }
1179 Type::User(UserType::Opaque) => SchemaKind::Primitive {
1180 primitive_type: PrimitiveType::Payload,
1181 },
1182 other => {
1183 return Err(SchemaExtractError::UnhandledType {
1184 type_desc: format!("{other:?} for shape {shape} (def={:?})", shape.def),
1185 });
1186 }
1187 };
1188
1189 let args = self.extract_type_args(shape)?;
1190 self.emit_schema(
1191 key,
1192 MixedSchema {
1193 id,
1194 type_params: type_param_names,
1195 kind,
1196 },
1197 );
1198
1199 Ok(if args.is_empty() {
1200 TypeRef::concrete(id)
1201 } else {
1202 TypeRef::generic(id, args)
1203 })
1204 }
1205
1206 fn extract_type_args(
1210 &mut self,
1211 shape: &'static Shape,
1212 ) -> Result<Vec<TypeRef<MixedId>>, SchemaExtractError> {
1213 if shape.type_params.is_empty() {
1214 return Ok(vec![]);
1215 }
1216 let mut args = Vec::with_capacity(shape.type_params.len());
1217 for tp in shape.type_params {
1218 args.push(self.extract(tp.shape)?);
1219 }
1220 Ok(args)
1221 }
1222
1223 fn extract_instantiation_args(
1229 &mut self,
1230 shape: &'static Shape,
1231 ) -> Result<Vec<TypeRef<MixedId>>, SchemaExtractError> {
1232 if anonymous_tuple_arity(shape).is_some()
1233 && let Type::User(UserType::Struct(struct_type)) = shape.ty
1234 {
1235 let mut args = Vec::with_capacity(struct_type.fields.len());
1236 for field in struct_type.fields {
1237 args.push(self.extract(field.shape())?);
1238 }
1239 return Ok(args);
1240 }
1241 self.extract_type_args(shape)
1242 }
1243}
1244
1245fn anonymous_tuple_arity(shape: &'static Shape) -> Option<usize> {
1246 match shape.ty {
1247 Type::User(UserType::Struct(struct_type))
1248 if struct_type.kind == StructKind::Tuple && shape.type_identifier.starts_with('(') =>
1249 {
1250 Some(struct_type.fields.len())
1251 }
1252 _ => None,
1253 }
1254}
1255
1256fn tuple_type_params(arity: usize) -> Vec<TypeParamName> {
1257 (0..arity)
1258 .map(|index| TypeParamName(format!("T{index}")))
1259 .collect()
1260}
1261
1262fn is_infallible_shape(shape: &'static Shape) -> bool {
1263 shape.is_shape(<std::convert::Infallible as Facet<'static>>::SHAPE)
1264}
1265
1266fn scalar_to_primitive(scalar: ScalarType) -> PrimitiveType {
1267 match scalar {
1268 ScalarType::Unit => PrimitiveType::Unit,
1269 ScalarType::Bool => PrimitiveType::Bool,
1270 ScalarType::Char => PrimitiveType::Char,
1271 ScalarType::Str | ScalarType::String | ScalarType::CowStr => PrimitiveType::String,
1272 ScalarType::F32 => PrimitiveType::F32,
1273 ScalarType::F64 => PrimitiveType::F64,
1274 ScalarType::U8 => PrimitiveType::U8,
1275 ScalarType::U16 => PrimitiveType::U16,
1276 ScalarType::U32 => PrimitiveType::U32,
1277 ScalarType::U64 => PrimitiveType::U64,
1278 ScalarType::U128 => PrimitiveType::U128,
1279 ScalarType::USize => PrimitiveType::U64,
1280 ScalarType::I8 => PrimitiveType::I8,
1281 ScalarType::I16 => PrimitiveType::I16,
1282 ScalarType::I32 => PrimitiveType::I32,
1283 ScalarType::I64 => PrimitiveType::I64,
1284 ScalarType::I128 => PrimitiveType::I128,
1285 ScalarType::ISize => PrimitiveType::I64,
1286 ScalarType::ConstTypeId => PrimitiveType::U64,
1287 _ => PrimitiveType::Unit,
1288 }
1289}
1290
1291#[cfg(test)]
1292mod tests {
1293 use super::*;
1294 use facet::Facet;
1295
1296 struct TestSchematic {
1297 direction: BindingDirection,
1298 shape: &'static Shape,
1299 attached: CborPayload,
1300 }
1301
1302 impl TestSchematic {
1303 fn new(direction: BindingDirection, shape: &'static Shape) -> Self {
1304 Self {
1305 direction,
1306 shape,
1307 attached: CborPayload::default(),
1308 }
1309 }
1310 }
1311
1312 impl Schematic for TestSchematic {
1313 fn direction(&self) -> BindingDirection {
1314 self.direction
1315 }
1316
1317 fn attach_schemas(&mut self, schemas: CborPayload) {
1318 self.attached = schemas;
1319 }
1320 }
1321
1322 #[test]
1324 fn type_ids_are_u64_content_hashes() {
1325 let id = SchemaHash(42);
1326 assert_eq!(id.0, 42);
1327 assert_eq!(id, SchemaHash(42));
1328 assert_ne!(id, SchemaHash(43));
1329 }
1330
1331 #[test]
1334 fn cbor_round_trip() {
1335 let schema = Schema {
1336 id: SchemaHash(1),
1337 type_params: vec![],
1338 kind: SchemaKind::Primitive {
1339 primitive_type: PrimitiveType::U32,
1340 },
1341 };
1342 let bytes = SchemaPayload {
1343 schemas: vec![schema.clone()],
1344 root: TypeRef::concrete(schema.id),
1345 }
1346 .to_cbor();
1347 let payload = SchemaPayload::from_cbor(&bytes.0).expect("should parse CBOR");
1348 assert_eq!(payload.schemas.len(), 1);
1349 assert_eq!(payload.schemas[0].id, schema.id);
1350 assert_eq!(payload.root, TypeRef::concrete(schema.id));
1351 }
1352
1353 #[test]
1355 fn primitive_u32() {
1356 let schemas = extract_schemas(<u32 as Facet>::SHAPE)
1357 .unwrap()
1358 .schemas
1359 .clone();
1360 assert_eq!(schemas.len(), 1);
1361 assert!(matches!(
1362 schemas[0].kind,
1363 SchemaKind::Primitive {
1364 primitive_type: PrimitiveType::U32
1365 }
1366 ));
1367 }
1368
1369 #[test]
1370 fn primitive_string() {
1371 let schemas = extract_schemas(<String as Facet>::SHAPE)
1372 .unwrap()
1373 .schemas
1374 .clone();
1375 assert_eq!(schemas.len(), 1);
1376 assert!(matches!(
1377 schemas[0].kind,
1378 SchemaKind::Primitive {
1379 primitive_type: PrimitiveType::String
1380 }
1381 ));
1382 }
1383
1384 #[test]
1385 fn primitive_bool() {
1386 let schemas = extract_schemas(<bool as Facet>::SHAPE)
1387 .unwrap()
1388 .schemas
1389 .clone();
1390 assert_eq!(schemas.len(), 1);
1391 assert!(matches!(
1392 schemas[0].kind,
1393 SchemaKind::Primitive {
1394 primitive_type: PrimitiveType::Bool
1395 }
1396 ));
1397 }
1398
1399 #[test]
1401 fn simple_struct() {
1402 #[derive(Facet)]
1403 struct Point {
1404 x: f64,
1405 y: f64,
1406 }
1407
1408 let schemas = extract_schemas(Point::SHAPE).unwrap().schemas.clone();
1409 assert!(schemas.len() >= 2);
1410
1411 let point_schema = schemas.last().unwrap();
1412 match &point_schema.kind {
1413 SchemaKind::Struct { name, fields } => {
1414 assert!(
1415 name.contains("Point"),
1416 "expected name to contain Point, got {name}"
1417 );
1418 assert_eq!(fields.len(), 2);
1419 assert_eq!(fields[0].name, "x");
1420 assert_eq!(fields[1].name, "y");
1421 assert!(fields[0].required);
1422 assert_eq!(fields[0].type_ref, fields[1].type_ref);
1423 }
1424 other => panic!("expected Struct, got {other:?}"),
1425 }
1426 }
1427
1428 #[test]
1430 fn simple_enum() {
1431 #[derive(Facet)]
1432 #[repr(u8)]
1433 enum Color {
1434 Red,
1435 Green,
1436 Blue,
1437 }
1438
1439 let schemas = extract_schemas(Color::SHAPE).unwrap().schemas.clone();
1440 let color_schema = schemas.last().unwrap();
1441 match &color_schema.kind {
1442 SchemaKind::Enum { variants, .. } => {
1443 assert_eq!(variants.len(), 3);
1444 assert_eq!(variants[0].name, "Red");
1445 assert_eq!(variants[1].name, "Green");
1446 assert_eq!(variants[2].name, "Blue");
1447 assert!(matches!(variants[0].payload, VariantPayload::Unit));
1448 }
1449 other => panic!("expected Enum, got {other:?}"),
1450 }
1451 }
1452
1453 #[test]
1455 fn enum_with_payloads() {
1456 #[derive(Facet)]
1457 #[repr(u8)]
1458 #[allow(dead_code)]
1459 enum Shape {
1460 Circle(f64),
1461 Rect { w: f64, h: f64 },
1462 Empty,
1463 }
1464
1465 let schemas = extract_schemas(Shape::SHAPE).unwrap().schemas.clone();
1466 let shape_schema = schemas.last().unwrap();
1467 match &shape_schema.kind {
1468 SchemaKind::Enum { variants, .. } => {
1469 assert_eq!(variants.len(), 3);
1470 assert!(matches!(
1471 variants[0].payload,
1472 VariantPayload::Newtype { .. }
1473 ));
1474 match &variants[1].payload {
1475 VariantPayload::Struct { fields } => {
1476 assert_eq!(fields.len(), 2);
1477 assert_eq!(fields[0].name, "w");
1478 assert_eq!(fields[1].name, "h");
1479 }
1480 other => panic!("expected Struct variant, got {other:?}"),
1481 }
1482 assert!(matches!(variants[2].payload, VariantPayload::Unit));
1483 }
1484 other => panic!("expected Enum, got {other:?}"),
1485 }
1486 }
1487
1488 #[test]
1490 fn container_vec() {
1491 let schemas = extract_schemas(<Vec<u32> as Facet>::SHAPE)
1492 .unwrap()
1493 .schemas
1494 .clone();
1495 assert_eq!(schemas.len(), 2);
1496 assert!(matches!(
1497 schemas[0].kind,
1498 SchemaKind::Primitive {
1499 primitive_type: PrimitiveType::U32
1500 }
1501 ));
1502 assert!(matches!(schemas[1].kind, SchemaKind::List { .. }));
1503 }
1504
1505 #[test]
1507 fn container_option() {
1508 let schemas = extract_schemas(<Option<String> as Facet>::SHAPE)
1509 .unwrap()
1510 .schemas
1511 .clone();
1512 assert_eq!(schemas.len(), 2);
1513 assert!(matches!(
1514 schemas[0].kind,
1515 SchemaKind::Primitive {
1516 primitive_type: PrimitiveType::String
1517 }
1518 ));
1519 assert!(matches!(schemas[1].kind, SchemaKind::Option { .. }));
1520 }
1521
1522 #[test]
1524 fn recursive_type_terminates() {
1525 #[derive(Facet)]
1526 struct Node {
1527 value: u32,
1528 next: Option<Box<Node>>,
1529 }
1530
1531 let schemas = extract_schemas(Node::SHAPE).unwrap().schemas.clone();
1532 assert!(schemas.len() >= 2);
1533
1534 let node_schema = schemas.last().unwrap();
1535 assert!(matches!(node_schema.kind, SchemaKind::Struct { .. }));
1536 }
1537
1538 #[test]
1540 fn vec_u8_is_bytes() {
1541 let schemas = extract_schemas(<Vec<u8> as Facet>::SHAPE)
1542 .unwrap()
1543 .schemas
1544 .clone();
1545 assert_eq!(schemas.len(), 1);
1546 assert!(matches!(
1547 schemas[0].kind,
1548 SchemaKind::Primitive {
1549 primitive_type: PrimitiveType::Bytes
1550 }
1551 ));
1552 }
1553
1554 #[test]
1555 fn slice_u8_is_bytes() {
1556 let schemas = extract_schemas(<&[u8] as Facet>::SHAPE)
1557 .unwrap()
1558 .schemas
1559 .clone();
1560 assert_eq!(schemas.len(), 1);
1561 assert!(matches!(
1562 schemas[0].kind,
1563 SchemaKind::Primitive {
1564 primitive_type: PrimitiveType::Bytes
1565 }
1566 ));
1567 }
1568
1569 #[test]
1570 fn cbor_payload_is_bytes() {
1571 let schemas = extract_schemas(CborPayload::SHAPE).unwrap().schemas.clone();
1572 assert_eq!(schemas.len(), 1);
1573 assert!(matches!(
1574 schemas[0].kind,
1575 SchemaKind::Primitive {
1576 primitive_type: PrimitiveType::Bytes
1577 }
1578 ));
1579 }
1580
1581 #[test]
1583 fn opaque_payload_is_payload_primitive() {
1584 let schemas = extract_schemas(crate::Payload::<'static>::SHAPE)
1585 .unwrap()
1586 .schemas
1587 .clone();
1588 assert_eq!(schemas.len(), 1);
1589 assert!(matches!(
1590 schemas[0].kind,
1591 SchemaKind::Primitive {
1592 primitive_type: PrimitiveType::Payload
1593 }
1594 ));
1595 }
1596
1597 #[test]
1598 fn infallible_is_never_primitive() {
1599 let schemas = extract_schemas(<std::convert::Infallible as Facet>::SHAPE)
1600 .unwrap()
1601 .schemas
1602 .clone();
1603 assert_eq!(schemas.len(), 1);
1604 assert!(matches!(
1605 schemas[0].kind,
1606 SchemaKind::Primitive {
1607 primitive_type: PrimitiveType::Never
1608 }
1609 ));
1610 }
1611
1612 #[test]
1614 fn deduplication_two_u32_fields() {
1615 #[derive(Facet)]
1616 struct TwoU32 {
1617 a: u32,
1618 b: u32,
1619 }
1620
1621 let schemas = extract_schemas(TwoU32::SHAPE).unwrap().schemas.clone();
1622 let u32_count = schemas
1623 .iter()
1624 .filter(|s| {
1625 matches!(
1626 s.kind,
1627 SchemaKind::Primitive {
1628 primitive_type: PrimitiveType::U32
1629 }
1630 )
1631 })
1632 .count();
1633 assert_eq!(u32_count, 1, "u32 schema should appear exactly once");
1634 assert_eq!(schemas.len(), 2);
1635 }
1636
1637 #[test]
1639 fn container_map() {
1640 let schemas = extract_schemas(<std::collections::HashMap<String, u32> as Facet>::SHAPE)
1641 .unwrap()
1642 .schemas
1643 .clone();
1644 let map_schema = schemas.last().unwrap();
1645 assert!(matches!(map_schema.kind, SchemaKind::Map { .. }));
1646 }
1647
1648 #[test]
1650 fn container_array() {
1651 let schemas = extract_schemas(<[u32; 4] as Facet>::SHAPE)
1652 .unwrap()
1653 .schemas
1654 .clone();
1655 let arr_schema = schemas.last().unwrap();
1656 match &arr_schema.kind {
1657 SchemaKind::Array { length, .. } => assert_eq!(*length, 4),
1658 other => panic!("expected Array, got {other:?}"),
1659 }
1660 }
1661
1662 #[test]
1664 fn tuple_type() {
1665 let schemas = extract_schemas(<(u32, String) as Facet>::SHAPE)
1666 .unwrap()
1667 .schemas
1668 .clone();
1669 let tuple_schema = schemas.last().unwrap();
1670 match &tuple_schema.kind {
1671 SchemaKind::Tuple { elements } => {
1672 assert_eq!(elements.len(), 2);
1673 assert_ne!(elements[0], elements[1]);
1674 }
1675 other => panic!("expected Tuple, got {other:?}"),
1676 }
1677 }
1678
1679 #[test]
1681 fn extract_schemas_returns_all_kinds() {
1682 #[derive(Facet)]
1683 struct Mixed {
1684 count: u32,
1685 tags: Vec<String>,
1686 pair: (u8, u8),
1687 }
1688
1689 let schemas = extract_schemas(Mixed::SHAPE).unwrap().schemas.clone();
1690 assert!(schemas.len() >= 4);
1691 }
1692
1693 #[test]
1696 fn tracker_prepare_send_returns_payload_then_empty() {
1697 let mut tracker = SchemaSendTracker::new();
1698 let method = MethodId(1);
1699 let mut schematic = TestSchematic::new(BindingDirection::Args, <u32 as Facet>::SHAPE);
1700 let first = tracker
1701 .attach_schemas_for_shape_if_needed(method, schematic.shape, &mut schematic)
1702 .unwrap();
1703 assert!(
1704 !first.is_empty(),
1705 "first prepare_send should return payload"
1706 );
1707 assert_eq!(schematic.attached.0, first.0);
1708 let second = tracker
1709 .attach_schemas_for_shape_if_needed(method, schematic.shape, &mut schematic)
1710 .unwrap();
1711 assert!(
1712 second.is_empty(),
1713 "second prepare_send for same method should return empty"
1714 );
1715 assert!(schematic.attached.is_empty());
1716 }
1717
1718 #[test]
1721 fn tracker_prepare_send_includes_transitive_deps() {
1722 #[derive(Facet)]
1723 struct Outer {
1724 inner: u32,
1725 name: String,
1726 }
1727
1728 let mut tracker = SchemaSendTracker::new();
1729 let method = MethodId(1);
1730 let mut schematic = TestSchematic::new(BindingDirection::Args, Outer::SHAPE);
1731 let first = tracker
1732 .attach_schemas_for_shape_if_needed(method, schematic.shape, &mut schematic)
1733 .unwrap();
1734 assert!(!first.is_empty(), "should return schemas");
1735 let parsed = SchemaPayload::from_cbor(&first.0).expect("should parse CBOR");
1736 assert!(
1737 parsed.schemas.len() >= 3,
1738 "should include transitive deps, got {}",
1739 parsed.schemas.len()
1740 );
1741
1742 schematic.shape = <u32 as Facet>::SHAPE;
1744 let again = tracker
1745 .attach_schemas_for_shape_if_needed(method, schematic.shape, &mut schematic)
1746 .unwrap();
1747 assert!(
1748 again.is_empty(),
1749 "u32 was already sent as transitive dep, method already bound"
1750 );
1751 }
1752
1753 #[test]
1755 fn tracker_record_and_get_received() {
1756 let tracker = SchemaRecvTracker::new();
1757 let schemas = extract_schemas(<u32 as Facet>::SHAPE)
1758 .unwrap()
1759 .schemas
1760 .clone();
1761 let id = schemas[0].id;
1762 assert!(tracker.get_received(&id).is_none());
1763 tracker
1764 .record_received(
1765 MethodId(7),
1766 BindingDirection::Args,
1767 SchemaPayload {
1768 schemas,
1769 root: TypeRef::concrete(id),
1770 },
1771 )
1772 .expect("first record should succeed");
1773 assert!(tracker.get_received(&id).is_some());
1774 assert_eq!(
1775 tracker.get_remote_args_root(MethodId(7)),
1776 Some(TypeRef::concrete(id))
1777 );
1778 }
1779
1780 #[test]
1783 fn type_ids_are_content_hashes() {
1784 let mut tracker = SchemaSendTracker::new();
1785 let extracted = tracker
1786 .extract_schemas(<(u32, String) as Facet>::SHAPE)
1787 .unwrap();
1788 let schemas = extracted.schemas.clone();
1789 assert!(schemas.len() >= 3);
1790
1791 let mut tracker2 = SchemaSendTracker::new();
1793 let schemas2 = tracker2
1794 .extract_schemas(<(u32, String) as Facet>::SHAPE)
1795 .unwrap()
1796 .schemas
1797 .clone();
1798 assert_eq!(schemas.len(), schemas2.len());
1799 for (a, b) in schemas.iter().zip(schemas2.iter()) {
1800 assert_eq!(a.id, b.id, "content hash should be deterministic");
1801 }
1802
1803 let mut tracker3 = SchemaSendTracker::new();
1805 let extracted3 = tracker3
1806 .extract_schemas(<(u64, String) as Facet>::SHAPE)
1807 .unwrap();
1808 assert_ne!(
1809 extracted.root, extracted3.root,
1810 "different types should produce different root refs"
1811 );
1812 }
1813
1814 #[test]
1816 fn primitive_content_hashes_are_stable() {
1817 let primitives = [
1820 PrimitiveType::Bool,
1821 PrimitiveType::U8,
1822 PrimitiveType::U16,
1823 PrimitiveType::U32,
1824 PrimitiveType::U64,
1825 PrimitiveType::U128,
1826 PrimitiveType::I8,
1827 PrimitiveType::I16,
1828 PrimitiveType::I32,
1829 PrimitiveType::I64,
1830 PrimitiveType::I128,
1831 PrimitiveType::F32,
1832 PrimitiveType::F64,
1833 PrimitiveType::Char,
1834 PrimitiveType::String,
1835 PrimitiveType::Unit,
1836 PrimitiveType::Never,
1837 PrimitiveType::Bytes,
1838 PrimitiveType::Payload,
1839 ];
1840
1841 let hashes: Vec<SchemaHash> = primitives
1843 .iter()
1844 .map(|p| {
1845 compute_content_hash(&SchemaKind::Primitive { primitive_type: *p }, &[], &|id| id)
1846 })
1847 .collect();
1848 let unique: HashSet<SchemaHash> = hashes.iter().copied().collect();
1849 assert_eq!(
1850 unique.len(),
1851 hashes.len(),
1852 "all primitive hashes must be unique"
1853 );
1854
1855 for (i, p) in primitives.iter().enumerate() {
1857 let hash2 =
1858 compute_content_hash(&SchemaKind::Primitive { primitive_type: *p }, &[], &|id| id);
1859 assert_eq!(hashes[i], hash2, "hash for {:?} must be deterministic", p);
1860 }
1861 }
1862
1863 #[test]
1865 fn struct_hash_is_deterministic() {
1866 #[derive(Facet)]
1867 struct Point {
1868 x: f64,
1869 y: f64,
1870 }
1871
1872 let schemas1 = extract_schemas(Point::SHAPE).unwrap().schemas.clone();
1873 let schemas2 = extract_schemas(Point::SHAPE).unwrap().schemas.clone();
1874 assert_eq!(
1875 schemas1.last().unwrap().id,
1876 schemas2.last().unwrap().id,
1877 "same struct must produce the same content hash"
1878 );
1879 }
1880
1881 #[test]
1883 fn recursive_type_hash_is_deterministic() {
1884 #[derive(Facet)]
1885 struct TreeNode {
1886 label: String,
1887 children: Vec<TreeNode>,
1888 }
1889
1890 let schemas1 = extract_schemas(TreeNode::SHAPE).unwrap().schemas.clone();
1891 let schemas2 = extract_schemas(TreeNode::SHAPE).unwrap().schemas.clone();
1892
1893 assert!(schemas1.len() >= 2);
1895
1896 let root1 = schemas1.last().unwrap().id;
1898 let root2 = schemas2.last().unwrap().id;
1899 assert_eq!(root1, root2, "recursive type hash must be deterministic");
1900
1901 for s in &schemas1 {
1903 assert_ne!(s.id.0, 0, "content hash must not be zero");
1904 }
1905 }
1906
1907 #[test]
1908 fn bidirectional_bindings_are_independent() {
1909 let mut tracker = SchemaSendTracker::new();
1910 let method = MethodId(1);
1911
1912 let mut args_schematic = TestSchematic::new(BindingDirection::Args, <u32 as Facet>::SHAPE);
1914 let args = tracker
1915 .attach_schemas_for_shape_if_needed(method, args_schematic.shape, &mut args_schematic)
1916 .unwrap();
1917 assert!(!args.is_empty(), "should send args");
1918 let args_parsed = SchemaPayload::from_cbor(&args.0).expect("parse args CBOR");
1919
1920 let mut response_schematic =
1922 TestSchematic::new(BindingDirection::Response, <String as Facet>::SHAPE);
1923 let response = tracker
1924 .attach_schemas_for_shape_if_needed(
1925 method,
1926 response_schematic.shape,
1927 &mut response_schematic,
1928 )
1929 .unwrap();
1930 assert!(!response.is_empty(), "should send response");
1931 let response_parsed = SchemaPayload::from_cbor(&response.0).expect("parse response CBOR");
1932 assert_ne!(args_parsed.root, response_parsed.root);
1933
1934 let recv_tracker = SchemaRecvTracker::new();
1936 recv_tracker
1937 .record_received(
1938 MethodId(42),
1939 BindingDirection::Args,
1940 SchemaPayload {
1941 schemas: extract_schemas(<u64 as Facet>::SHAPE)
1942 .unwrap()
1943 .schemas
1944 .clone(),
1945 root: TypeRef::concrete(SchemaHash(100)),
1946 },
1947 )
1948 .expect("record should succeed");
1949 recv_tracker
1950 .record_received(
1951 MethodId(42),
1952 BindingDirection::Response,
1953 SchemaPayload {
1954 schemas: vec![],
1955 root: TypeRef::concrete(SchemaHash(200)),
1956 },
1957 )
1958 .expect("record should succeed");
1959
1960 assert_eq!(
1961 recv_tracker.get_remote_args_root(MethodId(42)),
1962 Some(TypeRef::concrete(SchemaHash(100)))
1963 );
1964 assert_eq!(
1965 recv_tracker.get_remote_response_root(MethodId(42)),
1966 Some(TypeRef::concrete(SchemaHash(200)))
1967 );
1968 }
1969
1970 #[test]
1971 fn duplicate_schema_is_protocol_error() {
1972 let tracker = SchemaRecvTracker::new();
1973 let schemas = extract_schemas(<u32 as Facet>::SHAPE)
1974 .unwrap()
1975 .schemas
1976 .clone();
1977 tracker
1978 .record_received(
1979 MethodId(9),
1980 BindingDirection::Args,
1981 SchemaPayload {
1982 schemas: schemas.clone(),
1983 root: TypeRef::concrete(schemas[0].id),
1984 },
1985 )
1986 .expect("first record should succeed");
1987 let err = tracker
1988 .record_received(
1989 MethodId(9),
1990 BindingDirection::Args,
1991 SchemaPayload {
1992 schemas: schemas.clone(),
1993 root: TypeRef::concrete(schemas[0].id),
1994 },
1995 )
1996 .expect_err("duplicate should fail");
1997 assert_eq!(err.type_id, schemas[0].id);
1998 }
1999
2000 #[test]
2001 fn send_tracker_reset_clears_all_state() {
2002 let mut tracker = SchemaSendTracker::new();
2003 let method = MethodId(1);
2004 let mut schematic = TestSchematic::new(BindingDirection::Args, <u32 as Facet>::SHAPE);
2005 let first = tracker
2006 .attach_schemas_for_shape_if_needed(method, schematic.shape, &mut schematic)
2007 .unwrap();
2008 assert!(!first.is_empty(), "first should return payload");
2009
2010 tracker.reset();
2011
2012 let after_reset = tracker
2013 .attach_schemas_for_shape_if_needed(method, schematic.shape, &mut schematic)
2014 .unwrap();
2015 assert!(
2016 !after_reset.is_empty(),
2017 "after reset, prepare_send should return payload again"
2018 );
2019 }
2020
2021 #[test]
2026 fn generic_vec_uses_var_in_body() {
2027 let schemas = extract_schemas(<Vec<u32> as Facet>::SHAPE)
2028 .unwrap()
2029 .schemas
2030 .clone();
2031 let list_schema = schemas
2032 .iter()
2033 .find(|s| matches!(s.kind, SchemaKind::List { .. }))
2034 .unwrap();
2035 assert_eq!(
2036 list_schema.type_params.len(),
2037 1,
2038 "Vec should have 1 type param"
2039 );
2040 match &list_schema.kind {
2041 SchemaKind::List { element } => {
2042 assert!(
2043 matches!(element, TypeRef::Var { .. }),
2044 "element should be Var, got {element:?}"
2045 );
2046 }
2047 other => panic!("expected List, got {other:?}"),
2048 }
2049 }
2050
2051 #[test]
2052 fn generic_option_uses_var_in_body() {
2053 let schemas = extract_schemas(<Option<String> as Facet>::SHAPE)
2054 .unwrap()
2055 .schemas
2056 .clone();
2057 let opt_schema = schemas
2058 .iter()
2059 .find(|s| matches!(s.kind, SchemaKind::Option { .. }))
2060 .unwrap();
2061 assert_eq!(
2062 opt_schema.type_params.len(),
2063 1,
2064 "Option should have 1 type param"
2065 );
2066 match &opt_schema.kind {
2067 SchemaKind::Option { element } => {
2068 assert!(
2069 matches!(element, TypeRef::Var { .. }),
2070 "element should be Var, got {element:?}"
2071 );
2072 }
2073 other => panic!("expected Option, got {other:?}"),
2074 }
2075 }
2076
2077 #[test]
2078 fn generic_tuple_uses_vars_in_body() {
2079 let schemas = extract_schemas(<(u32, String) as Facet>::SHAPE)
2080 .unwrap()
2081 .schemas
2082 .clone();
2083 let tuple_schema = schemas
2084 .iter()
2085 .find(|s| matches!(s.kind, SchemaKind::Tuple { .. }))
2086 .unwrap();
2087 assert_eq!(
2088 tuple_schema.type_params.len(),
2089 2,
2090 "tuple arity 2 should have 2 type params"
2091 );
2092 match &tuple_schema.kind {
2093 SchemaKind::Tuple { elements } => {
2094 assert_eq!(elements.len(), 2);
2095 assert!(matches!(elements[0], TypeRef::Var { .. }));
2096 assert!(matches!(elements[1], TypeRef::Var { .. }));
2097 }
2098 other => panic!("expected Tuple, got {other:?}"),
2099 }
2100 }
2101
2102 #[test]
2103 fn generic_vox_error_uses_var_in_user_payload() {
2104 use crate::VoxError;
2105
2106 let schemas = extract_schemas(<VoxError<::core::convert::Infallible> as Facet>::SHAPE)
2107 .unwrap()
2108 .schemas
2109 .clone();
2110 let vox_error_schema = schemas
2111 .iter()
2112 .find(|s| matches!(&s.kind, SchemaKind::Enum { name, .. } if name == "VoxError"))
2113 .expect("VoxError schema should be present");
2114 match &vox_error_schema.kind {
2115 SchemaKind::Enum { variants, .. } => {
2116 let user = variants
2117 .iter()
2118 .find(|variant| variant.name == "User")
2119 .expect("VoxError should have User variant");
2120 let VariantPayload::Newtype { type_ref } = &user.payload else {
2121 panic!("User variant should be newtype");
2122 };
2123 assert!(
2124 matches!(type_ref, TypeRef::Var { .. }),
2125 "User payload should be a type variable, got {type_ref:?}"
2126 );
2127 }
2128 other => panic!("expected enum, got {other:?}"),
2129 }
2130 }
2131
2132 #[test]
2133 fn vec_of_option_of_u32_deduplicates() {
2134 let schemas = extract_schemas(<Vec<Option<u32>> as Facet>::SHAPE)
2137 .unwrap()
2138 .schemas
2139 .clone();
2140
2141 let list_count = schemas
2142 .iter()
2143 .filter(|s| matches!(s.kind, SchemaKind::List { .. }))
2144 .count();
2145 let option_count = schemas
2146 .iter()
2147 .filter(|s| matches!(s.kind, SchemaKind::Option { .. }))
2148 .count();
2149 assert_eq!(list_count, 1, "should have exactly 1 List schema");
2150 assert_eq!(option_count, 1, "should have exactly 1 Option schema");
2151 }
2152
2153 #[test]
2154 fn vec_u32_and_vec_string_share_one_list_schema() {
2155 #[derive(Facet)]
2156 struct Both {
2157 a: Vec<u32>,
2158 b: Vec<String>,
2159 }
2160
2161 let schemas = extract_schemas(Both::SHAPE).unwrap().schemas.clone();
2162 let list_count = schemas
2163 .iter()
2164 .filter(|s| matches!(s.kind, SchemaKind::List { .. }))
2165 .count();
2166 assert_eq!(
2167 list_count, 1,
2168 "Vec<u32> and Vec<String> should share one List schema"
2169 );
2170 }
2171
2172 #[test]
2173 fn resolve_kind_substitutes_vars() {
2174 let schemas = extract_schemas(<Vec<u32> as Facet>::SHAPE)
2175 .unwrap()
2176 .schemas
2177 .clone();
2178 let registry = build_registry(&schemas);
2179
2180 let root = schemas.last().unwrap();
2182 assert!(matches!(root.kind, SchemaKind::List { .. }));
2183
2184 let u32_schema = schemas
2186 .iter()
2187 .find(|s| {
2188 matches!(
2189 s.kind,
2190 SchemaKind::Primitive {
2191 primitive_type: PrimitiveType::U32
2192 }
2193 )
2194 })
2195 .unwrap();
2196 let type_ref = TypeRef::generic(root.id, vec![TypeRef::concrete(u32_schema.id)]);
2197
2198 let resolved = type_ref.resolve_kind(®istry).expect("should resolve");
2200 match &resolved {
2201 SchemaKind::List { element } => match element {
2202 TypeRef::Concrete { type_id, args } => {
2203 assert_eq!(*type_id, u32_schema.id);
2204 assert!(args.is_empty());
2205 }
2206 other => panic!("expected concrete after resolution, got {other:?}"),
2207 },
2208 other => panic!("expected List, got {other:?}"),
2209 }
2210 }
2211
2212 #[test]
2213 fn extract_result_tuple_root_preserves_ok_tuple() {
2214 use crate::VoxError;
2215
2216 let extracted = extract_schemas(
2217 <Result<(String, i32), VoxError<::core::convert::Infallible>> as Facet>::SHAPE,
2218 )
2219 .unwrap();
2220 let registry = build_registry(&extracted.schemas);
2221 let root = extracted
2222 .root
2223 .resolve_kind(®istry)
2224 .expect("result root should resolve");
2225
2226 let SchemaKind::Enum { variants, .. } = root else {
2227 panic!("expected Result enum root");
2228 };
2229 let ok_variant = variants
2230 .iter()
2231 .find(|variant| variant.name == "Ok")
2232 .expect("Result should have Ok variant");
2233 let VariantPayload::Newtype { type_ref } = &ok_variant.payload else {
2234 panic!("Ok variant should be newtype");
2235 };
2236 let ok_kind = type_ref
2237 .resolve_kind(®istry)
2238 .expect("Ok payload should resolve");
2239 match ok_kind {
2240 SchemaKind::Tuple { elements } => {
2241 assert_eq!(elements.len(), 2, "Ok tuple should have two elements");
2242 }
2243 other => panic!("expected Ok payload to be tuple, got {other:?}"),
2244 }
2245 }
2246
2247 #[test]
2248 fn result_ok_tuple_uses_generic_tuple_schema() {
2249 use crate::VoxError;
2250
2251 let result_shape =
2252 <Result<(String, i32), VoxError<::core::convert::Infallible>> as Facet>::SHAPE;
2253 let ok_shape = result_shape.type_params[0].shape;
2254 let extracted = extract_schemas(
2255 <Result<(String, i32), VoxError<::core::convert::Infallible>> as Facet>::SHAPE,
2256 )
2257 .unwrap();
2258 let TypeRef::Concrete { args, .. } = &extracted.root else {
2259 panic!("Result root should be concrete");
2260 };
2261 assert_eq!(
2262 args.len(),
2263 2,
2264 "Result root should have Ok and Err type args"
2265 );
2266 let TypeRef::Concrete { args: ok_args, .. } = &args[0] else {
2267 panic!("Ok type arg should be concrete tuple ref");
2268 };
2269 assert_eq!(
2270 ok_args.len(),
2271 2,
2272 "Ok tuple ref should carry concrete tuple element args; root={:?}; ok_shape={}; ok_shape_ty={:?}",
2273 extracted.root,
2274 ok_shape.type_identifier,
2275 ok_shape.ty
2276 );
2277 }
2278
2279 #[test]
2280 fn unary_tuple_root_preserves_nested_tuple() {
2281 let extracted = extract_schemas(<((i32, String),) as Facet>::SHAPE).unwrap();
2282 let registry = build_registry(&extracted.schemas);
2283
2284 let root = extracted
2285 .root
2286 .resolve_kind(®istry)
2287 .expect("root should resolve");
2288 let SchemaKind::Tuple { elements } = root else {
2289 panic!("expected unary tuple root");
2290 };
2291 assert_eq!(elements.len(), 1, "outer tuple should remain unary");
2292
2293 let inner = elements[0]
2294 .resolve_kind(®istry)
2295 .expect("inner tuple should resolve");
2296 match inner {
2297 SchemaKind::Tuple { elements } => {
2298 assert_eq!(elements.len(), 2, "inner tuple should remain binary");
2299 }
2300 other => panic!("expected inner tuple, got {other:?}"),
2301 }
2302
2303 let tuple_count = extracted
2304 .schemas
2305 .iter()
2306 .filter(|schema| matches!(schema.kind, SchemaKind::Tuple { .. }))
2307 .count();
2308 assert_eq!(tuple_count, 2, "should emit one tuple schema per arity");
2309 }
2310
2311 #[test]
2312 fn nested_generic_vec_of_vec_of_u32() {
2313 let schemas = extract_schemas(<Vec<Vec<u32>> as Facet>::SHAPE)
2315 .unwrap()
2316 .schemas
2317 .clone();
2318 let list_count = schemas
2319 .iter()
2320 .filter(|s| matches!(s.kind, SchemaKind::List { .. }))
2321 .count();
2322 assert_eq!(
2323 list_count, 1,
2324 "Vec<Vec<u32>> should have exactly 1 List schema (Vec<T>)"
2325 );
2326 }
2327
2328 #[test]
2329 fn recursive_type_with_option_box() {
2330 #[derive(Facet)]
2331 struct Node {
2332 value: u32,
2333 next: Option<Box<Node>>,
2334 }
2335
2336 let schemas = extract_schemas(Node::SHAPE).unwrap().schemas.clone();
2337 let option_count = schemas
2339 .iter()
2340 .filter(|s| matches!(s.kind, SchemaKind::Option { .. }))
2341 .count();
2342 assert_eq!(option_count, 1, "should have exactly 1 Option schema");
2343
2344 let opt_schema = schemas
2346 .iter()
2347 .find(|s| matches!(s.kind, SchemaKind::Option { .. }))
2348 .unwrap();
2349 match &opt_schema.kind {
2350 SchemaKind::Option { element } => {
2351 assert!(
2352 matches!(element, TypeRef::Var { .. }),
2353 "element should be Var"
2354 );
2355 }
2356 _ => unreachable!(),
2357 }
2358
2359 for s in &schemas {
2361 assert_ne!(s.id.0, 0, "content hash must not be zero: {:?}", s.kind);
2362 }
2363 }
2364
2365 #[test]
2366 fn map_schema_is_generic() {
2367 let schemas = extract_schemas(<std::collections::HashMap<String, u32> as Facet>::SHAPE)
2368 .unwrap()
2369 .schemas
2370 .clone();
2371 let map_schema = schemas
2372 .iter()
2373 .find(|s| matches!(s.kind, SchemaKind::Map { .. }))
2374 .unwrap();
2375 assert_eq!(
2376 map_schema.type_params.len(),
2377 2,
2378 "HashMap should have 2 type params"
2379 );
2380 match &map_schema.kind {
2381 SchemaKind::Map { key, value } => {
2382 assert!(matches!(key, TypeRef::Var { .. }), "key should be Var");
2383 assert!(matches!(value, TypeRef::Var { .. }), "value should be Var");
2384 }
2385 _ => unreachable!(),
2386 }
2387 }
2388
2389 #[test]
2390 fn schema_payload_cbor_round_trip() {
2391 let payload = SchemaPayload {
2392 schemas: vec![],
2393 root: TypeRef::Concrete {
2394 type_id: SchemaHash(123),
2395 args: vec![TypeRef::concrete(SchemaHash(456))],
2396 },
2397 };
2398 let bytes = payload.to_cbor();
2399 let parsed = SchemaPayload::from_cbor(&bytes.0).expect("should parse CBOR");
2400 match &parsed.root {
2401 TypeRef::Concrete { type_id, args } => {
2402 assert_eq!(*type_id, SchemaHash(123));
2403 assert_eq!(args.len(), 1);
2404 match &args[0] {
2405 TypeRef::Concrete { type_id, args } => {
2406 assert_eq!(*type_id, SchemaHash(456));
2407 assert!(args.is_empty());
2408 }
2409 other => panic!("expected concrete arg, got {other:?}"),
2410 }
2411 }
2412 other => panic!("expected concrete root, got {other:?}"),
2413 }
2414 }
2415}