1#![allow(clippy::type_complexity)]
16
17use crate::{
18 BitsOrderFormat, BitsStoreFormat, Field, FieldIter, PathIter, Primitive, ResolvedTypeVisitor,
19 UnhandledKind, Variant, VariantIter,
20};
21use smallvec::SmallVec;
22
23pub struct ConcreteFieldIter<'resolver, TypeId> {
25 fields: SmallVec<[Option<Field<'resolver, TypeId>>; 16]>,
26 idx: usize,
27}
28
29impl<'resolver, TypeId> Iterator for ConcreteFieldIter<'resolver, TypeId> {
30 type Item = Field<'resolver, TypeId>;
31 fn next(&mut self) -> Option<Self::Item> {
32 let field = self
33 .fields
34 .get_mut(self.idx)?
35 .take()
36 .expect("Expected a field but got None");
37 self.idx += 1;
38 Some(field)
39 }
40}
41
42impl<'resolver, TypeId> ExactSizeIterator for ConcreteFieldIter<'resolver, TypeId> {
43 fn len(&self) -> usize {
44 self.fields.len()
45 }
46}
47
48pub struct ConcreteResolvedTypeVisitor<
51 'resolver,
52 Context,
53 TypeId,
54 Output,
55 UnhandledFn,
56 NotFoundFn,
57 CompositeFn,
58 VariantFn,
59 SequenceFn,
60 ArrayFn,
61 TupleFn,
62 PrimitiveFn,
63 CompactFn,
64 BitSequenceFn,
65> {
66 _marker: core::marker::PhantomData<(TypeId, Output, &'resolver ())>,
67 context: Context,
68 visit_unhandled: UnhandledFn,
69 visit_not_found: NotFoundFn,
70 visit_composite: CompositeFn,
71 visit_variant: VariantFn,
72 visit_sequence: SequenceFn,
73 visit_array: ArrayFn,
74 visit_tuple: TupleFn,
75 visit_primitive: PrimitiveFn,
76 visit_compact: CompactFn,
77 visit_bit_sequence: BitSequenceFn,
78}
79
80pub fn new<'resolver, Context, TypeId, Output, NewUnhandledFn>(
130 context: Context,
131 unhandled_fn: NewUnhandledFn,
132) -> ConcreteResolvedTypeVisitor<
133 'resolver,
134 Context,
135 TypeId,
136 Output,
137 NewUnhandledFn,
138 impl FnOnce(Context) -> Output,
139 impl FnOnce(
140 Context,
141 &mut dyn PathIter<'resolver>,
142 &'_ mut dyn FieldIter<'resolver, TypeId>,
143 ) -> Output,
144 impl FnOnce(
145 Context,
146 &mut dyn PathIter<'resolver>,
147 &'_ mut dyn VariantIter<'resolver, ConcreteFieldIter<'resolver, TypeId>>,
148 ) -> Output,
149 impl FnOnce(Context, &mut dyn PathIter<'resolver>, TypeId) -> Output,
150 impl FnOnce(Context, TypeId, usize) -> Output,
151 impl FnOnce(Context, &'_ mut dyn ExactSizeIterator<Item = TypeId>) -> Output,
152 impl FnOnce(Context, Primitive) -> Output,
153 impl FnOnce(Context, TypeId) -> Output,
154 impl FnOnce(Context, BitsStoreFormat, BitsOrderFormat) -> Output,
155>
156where
157 NewUnhandledFn: FnOnce(Context, UnhandledKind) -> Output + Clone,
158{
159 let visit_unhandled = unhandled_fn.clone();
160
161 let visit_not_found = {
166 let u = unhandled_fn.clone();
167 move |ctx| u(ctx, UnhandledKind::NotFound)
168 };
169 let visit_composite = {
170 let u = unhandled_fn.clone();
171 move |ctx, _: &mut dyn PathIter<'resolver>, _: &mut dyn FieldIter<'resolver, TypeId>| {
172 u(ctx, UnhandledKind::Composite)
173 }
174 };
175 let visit_variant = {
176 let u = unhandled_fn.clone();
177 move |ctx,
178 _: &mut dyn PathIter<'resolver>,
179 _: &mut dyn VariantIter<'resolver, ConcreteFieldIter<'resolver, TypeId>>| {
180 u(ctx, UnhandledKind::Variant)
181 }
182 };
183 let visit_sequence = {
184 let u = unhandled_fn.clone();
185 move |ctx, _: &mut dyn PathIter<'resolver>, _| u(ctx, UnhandledKind::Sequence)
186 };
187 let visit_array = {
188 let u = unhandled_fn.clone();
189 move |ctx, _, _| u(ctx, UnhandledKind::Array)
190 };
191 let visit_tuple = {
192 let u = unhandled_fn.clone();
193 move |ctx, _: &mut dyn ExactSizeIterator<Item = TypeId>| u(ctx, UnhandledKind::Tuple)
194 };
195 let visit_primitive = {
196 let u = unhandled_fn.clone();
197 move |ctx, _| u(ctx, UnhandledKind::Primitive)
198 };
199 let visit_compact = {
200 let u = unhandled_fn.clone();
201 move |ctx, _| u(ctx, UnhandledKind::Compact)
202 };
203 let visit_bit_sequence = {
204 let u = unhandled_fn.clone();
205 move |ctx, _, _| u(ctx, UnhandledKind::BitSequence)
206 };
207
208 ConcreteResolvedTypeVisitor {
209 _marker: core::marker::PhantomData,
210 context,
211 visit_unhandled,
212 visit_not_found,
213 visit_composite,
214 visit_variant,
215 visit_sequence,
216 visit_array,
217 visit_tuple,
218 visit_primitive,
219 visit_compact,
220 visit_bit_sequence,
221 }
222}
223
224impl<
225 'resolver,
226 Context,
227 TypeId,
228 Output,
229 UnhandledFn,
230 NotFoundFn,
231 CompositeFn,
232 VariantFn,
233 SequenceFn,
234 ArrayFn,
235 TupleFn,
236 PrimitiveFn,
237 CompactFn,
238 BitSequenceFn,
239 >
240 ConcreteResolvedTypeVisitor<
241 'resolver,
242 Context,
243 TypeId,
244 Output,
245 UnhandledFn,
246 NotFoundFn,
247 CompositeFn,
248 VariantFn,
249 SequenceFn,
250 ArrayFn,
251 TupleFn,
252 PrimitiveFn,
253 CompactFn,
254 BitSequenceFn,
255 >
256{
257 pub fn visit_not_found<NewNotFoundFn>(
259 self,
260 new_not_found_fn: NewNotFoundFn,
261 ) -> ConcreteResolvedTypeVisitor<
262 'resolver,
263 Context,
264 TypeId,
265 Output,
266 UnhandledFn,
267 NewNotFoundFn,
268 CompositeFn,
269 VariantFn,
270 SequenceFn,
271 ArrayFn,
272 TupleFn,
273 PrimitiveFn,
274 CompactFn,
275 BitSequenceFn,
276 >
277 where
278 NewNotFoundFn: FnOnce(Context) -> Output,
279 {
280 ConcreteResolvedTypeVisitor {
281 _marker: core::marker::PhantomData,
282 context: self.context,
283 visit_unhandled: self.visit_unhandled,
284 visit_not_found: new_not_found_fn,
285 visit_composite: self.visit_composite,
286 visit_variant: self.visit_variant,
287 visit_sequence: self.visit_sequence,
288 visit_array: self.visit_array,
289 visit_tuple: self.visit_tuple,
290 visit_primitive: self.visit_primitive,
291 visit_compact: self.visit_compact,
292 visit_bit_sequence: self.visit_bit_sequence,
293 }
294 }
295
296 pub fn visit_composite<NewCompositeFn>(
298 self,
299 new_composite_fn: NewCompositeFn,
300 ) -> ConcreteResolvedTypeVisitor<
301 'resolver,
302 Context,
303 TypeId,
304 Output,
305 UnhandledFn,
306 NotFoundFn,
307 NewCompositeFn,
308 VariantFn,
309 SequenceFn,
310 ArrayFn,
311 TupleFn,
312 PrimitiveFn,
313 CompactFn,
314 BitSequenceFn,
315 >
316 where
317 NewCompositeFn: FnOnce(
318 Context,
319 &mut dyn PathIter<'resolver>,
320 &mut dyn FieldIter<'resolver, TypeId>,
321 ) -> Output,
322 {
323 ConcreteResolvedTypeVisitor {
324 _marker: core::marker::PhantomData,
325 context: self.context,
326 visit_unhandled: self.visit_unhandled,
327 visit_not_found: self.visit_not_found,
328 visit_composite: new_composite_fn,
329 visit_variant: self.visit_variant,
330 visit_sequence: self.visit_sequence,
331 visit_array: self.visit_array,
332 visit_tuple: self.visit_tuple,
333 visit_primitive: self.visit_primitive,
334 visit_compact: self.visit_compact,
335 visit_bit_sequence: self.visit_bit_sequence,
336 }
337 }
338
339 pub fn visit_variant<NewVariantFn>(
341 self,
342 new_variant_fn: NewVariantFn,
343 ) -> ConcreteResolvedTypeVisitor<
344 'resolver,
345 Context,
346 TypeId,
347 Output,
348 UnhandledFn,
349 NotFoundFn,
350 CompositeFn,
351 NewVariantFn,
352 SequenceFn,
353 ArrayFn,
354 TupleFn,
355 PrimitiveFn,
356 CompactFn,
357 BitSequenceFn,
358 >
359 where
360 NewVariantFn: FnOnce(
361 Context,
362 &mut dyn PathIter<'resolver>,
363 &mut dyn VariantIter<'resolver, ConcreteFieldIter<'resolver, TypeId>>,
364 ) -> Output,
365 {
366 ConcreteResolvedTypeVisitor {
367 _marker: core::marker::PhantomData,
368 context: self.context,
369 visit_unhandled: self.visit_unhandled,
370 visit_not_found: self.visit_not_found,
371 visit_composite: self.visit_composite,
372 visit_variant: new_variant_fn,
373 visit_sequence: self.visit_sequence,
374 visit_array: self.visit_array,
375 visit_tuple: self.visit_tuple,
376 visit_primitive: self.visit_primitive,
377 visit_compact: self.visit_compact,
378 visit_bit_sequence: self.visit_bit_sequence,
379 }
380 }
381
382 pub fn visit_sequence<NewSequenceFn>(
384 self,
385 new_sequence_fn: NewSequenceFn,
386 ) -> ConcreteResolvedTypeVisitor<
387 'resolver,
388 Context,
389 TypeId,
390 Output,
391 UnhandledFn,
392 NotFoundFn,
393 CompositeFn,
394 VariantFn,
395 NewSequenceFn,
396 ArrayFn,
397 TupleFn,
398 PrimitiveFn,
399 CompactFn,
400 BitSequenceFn,
401 >
402 where
403 NewSequenceFn: FnOnce(Context, &mut dyn PathIter<'resolver>, TypeId) -> Output,
404 TypeId: 'resolver,
405 {
406 ConcreteResolvedTypeVisitor {
407 _marker: core::marker::PhantomData,
408 context: self.context,
409 visit_unhandled: self.visit_unhandled,
410 visit_not_found: self.visit_not_found,
411 visit_composite: self.visit_composite,
412 visit_variant: self.visit_variant,
413 visit_sequence: new_sequence_fn,
414 visit_array: self.visit_array,
415 visit_tuple: self.visit_tuple,
416 visit_primitive: self.visit_primitive,
417 visit_compact: self.visit_compact,
418 visit_bit_sequence: self.visit_bit_sequence,
419 }
420 }
421
422 pub fn visit_array<NewArrayFn>(
424 self,
425 new_array_fn: NewArrayFn,
426 ) -> ConcreteResolvedTypeVisitor<
427 'resolver,
428 Context,
429 TypeId,
430 Output,
431 UnhandledFn,
432 NotFoundFn,
433 CompositeFn,
434 VariantFn,
435 SequenceFn,
436 NewArrayFn,
437 TupleFn,
438 PrimitiveFn,
439 CompactFn,
440 BitSequenceFn,
441 >
442 where
443 NewArrayFn: FnOnce(Context, TypeId, usize) -> Output,
444 TypeId: 'resolver,
445 {
446 ConcreteResolvedTypeVisitor {
447 _marker: core::marker::PhantomData,
448 context: self.context,
449 visit_unhandled: self.visit_unhandled,
450 visit_not_found: self.visit_not_found,
451 visit_composite: self.visit_composite,
452 visit_variant: self.visit_variant,
453 visit_sequence: self.visit_sequence,
454 visit_array: new_array_fn,
455 visit_tuple: self.visit_tuple,
456 visit_primitive: self.visit_primitive,
457 visit_compact: self.visit_compact,
458 visit_bit_sequence: self.visit_bit_sequence,
459 }
460 }
461
462 pub fn visit_tuple<NewTupleFn>(
464 self,
465 new_tuple_fn: NewTupleFn,
466 ) -> ConcreteResolvedTypeVisitor<
467 'resolver,
468 Context,
469 TypeId,
470 Output,
471 UnhandledFn,
472 NotFoundFn,
473 CompositeFn,
474 VariantFn,
475 SequenceFn,
476 ArrayFn,
477 NewTupleFn,
478 PrimitiveFn,
479 CompactFn,
480 BitSequenceFn,
481 >
482 where
483 NewTupleFn: FnOnce(Context, &mut dyn ExactSizeIterator<Item = TypeId>) -> Output,
484 {
485 ConcreteResolvedTypeVisitor {
486 _marker: core::marker::PhantomData,
487 context: self.context,
488 visit_unhandled: self.visit_unhandled,
489 visit_not_found: self.visit_not_found,
490 visit_composite: self.visit_composite,
491 visit_variant: self.visit_variant,
492 visit_sequence: self.visit_sequence,
493 visit_array: self.visit_array,
494 visit_tuple: new_tuple_fn,
495 visit_primitive: self.visit_primitive,
496 visit_compact: self.visit_compact,
497 visit_bit_sequence: self.visit_bit_sequence,
498 }
499 }
500
501 pub fn visit_primitive<NewPrimitiveFn>(
503 self,
504 new_primitive_fn: NewPrimitiveFn,
505 ) -> ConcreteResolvedTypeVisitor<
506 'resolver,
507 Context,
508 TypeId,
509 Output,
510 UnhandledFn,
511 NotFoundFn,
512 CompositeFn,
513 VariantFn,
514 SequenceFn,
515 ArrayFn,
516 TupleFn,
517 NewPrimitiveFn,
518 CompactFn,
519 BitSequenceFn,
520 >
521 where
522 NewPrimitiveFn: FnOnce(Context, Primitive) -> Output,
523 {
524 ConcreteResolvedTypeVisitor {
525 _marker: core::marker::PhantomData,
526 context: self.context,
527 visit_unhandled: self.visit_unhandled,
528 visit_not_found: self.visit_not_found,
529 visit_composite: self.visit_composite,
530 visit_variant: self.visit_variant,
531 visit_sequence: self.visit_sequence,
532 visit_array: self.visit_array,
533 visit_tuple: self.visit_tuple,
534 visit_primitive: new_primitive_fn,
535 visit_compact: self.visit_compact,
536 visit_bit_sequence: self.visit_bit_sequence,
537 }
538 }
539
540 pub fn visit_compact<NewCompactFn>(
542 self,
543 new_compact_fn: NewCompactFn,
544 ) -> ConcreteResolvedTypeVisitor<
545 'resolver,
546 Context,
547 TypeId,
548 Output,
549 UnhandledFn,
550 NotFoundFn,
551 CompositeFn,
552 VariantFn,
553 SequenceFn,
554 ArrayFn,
555 TupleFn,
556 PrimitiveFn,
557 NewCompactFn,
558 BitSequenceFn,
559 >
560 where
561 NewCompactFn: FnOnce(Context, TypeId) -> Output,
562 TypeId: 'resolver,
563 {
564 ConcreteResolvedTypeVisitor {
565 _marker: core::marker::PhantomData,
566 context: self.context,
567 visit_unhandled: self.visit_unhandled,
568 visit_not_found: self.visit_not_found,
569 visit_composite: self.visit_composite,
570 visit_variant: self.visit_variant,
571 visit_sequence: self.visit_sequence,
572 visit_array: self.visit_array,
573 visit_tuple: self.visit_tuple,
574 visit_primitive: self.visit_primitive,
575 visit_compact: new_compact_fn,
576 visit_bit_sequence: self.visit_bit_sequence,
577 }
578 }
579
580 pub fn visit_bit_sequence<NewBitSequenceFn>(
582 self,
583 new_bit_sequence_fn: NewBitSequenceFn,
584 ) -> ConcreteResolvedTypeVisitor<
585 'resolver,
586 Context,
587 TypeId,
588 Output,
589 UnhandledFn,
590 NotFoundFn,
591 CompositeFn,
592 VariantFn,
593 SequenceFn,
594 ArrayFn,
595 TupleFn,
596 PrimitiveFn,
597 CompactFn,
598 NewBitSequenceFn,
599 >
600 where
601 NewBitSequenceFn: FnOnce(Context, BitsStoreFormat, BitsOrderFormat) -> Output,
602 {
603 ConcreteResolvedTypeVisitor {
604 _marker: core::marker::PhantomData,
605 context: self.context,
606 visit_unhandled: self.visit_unhandled,
607 visit_not_found: self.visit_not_found,
608 visit_composite: self.visit_composite,
609 visit_variant: self.visit_variant,
610 visit_sequence: self.visit_sequence,
611 visit_array: self.visit_array,
612 visit_tuple: self.visit_tuple,
613 visit_primitive: self.visit_primitive,
614 visit_compact: self.visit_compact,
615 visit_bit_sequence: new_bit_sequence_fn,
616 }
617 }
618}
619
620impl<
622 'resolver,
623 Context,
624 TypeId,
625 Output,
626 UnhandledFn,
627 NotFoundFn,
628 CompositeFn,
629 VariantFn,
630 SequenceFn,
631 ArrayFn,
632 TupleFn,
633 PrimitiveFn,
634 CompactFn,
635 BitSequenceFn,
636 > ResolvedTypeVisitor<'resolver>
637 for ConcreteResolvedTypeVisitor<
638 'resolver,
639 Context,
640 TypeId,
641 Output,
642 UnhandledFn,
643 NotFoundFn,
644 CompositeFn,
645 VariantFn,
646 SequenceFn,
647 ArrayFn,
648 TupleFn,
649 PrimitiveFn,
650 CompactFn,
651 BitSequenceFn,
652 >
653where
654 TypeId: Clone + Default + core::fmt::Debug + 'static,
655 UnhandledFn: FnOnce(Context, UnhandledKind) -> Output,
656 NotFoundFn: FnOnce(Context) -> Output,
657 CompositeFn: FnOnce(
658 Context,
659 &mut dyn PathIter<'resolver>,
660 &mut dyn FieldIter<'resolver, TypeId>,
661 ) -> Output,
662 VariantFn: FnOnce(
663 Context,
664 &mut dyn PathIter<'resolver>,
665 &mut dyn VariantIter<'resolver, ConcreteFieldIter<'resolver, TypeId>>,
666 ) -> Output,
667 SequenceFn: FnOnce(Context, &mut dyn PathIter<'resolver>, TypeId) -> Output,
668 ArrayFn: FnOnce(Context, TypeId, usize) -> Output,
669 TupleFn: FnOnce(Context, &mut dyn ExactSizeIterator<Item = TypeId>) -> Output,
670 PrimitiveFn: FnOnce(Context, Primitive) -> Output,
671 CompactFn: FnOnce(Context, TypeId) -> Output,
672 BitSequenceFn: FnOnce(Context, BitsStoreFormat, BitsOrderFormat) -> Output,
673{
674 type TypeId = TypeId;
675 type Value = Output;
676
677 fn visit_unhandled(self, kind: UnhandledKind) -> Self::Value {
678 (self.visit_unhandled)(self.context, kind)
679 }
680
681 fn visit_not_found(self) -> Self::Value {
682 (self.visit_not_found)(self.context)
683 }
684
685 fn visit_composite<Path, Fields>(self, mut path: Path, mut fields: Fields) -> Self::Value
686 where
687 Path: PathIter<'resolver>,
688 Fields: FieldIter<'resolver, Self::TypeId>,
689 {
690 (self.visit_composite)(
691 self.context,
692 &mut path,
693 &mut fields as &mut dyn FieldIter<'resolver, Self::TypeId>,
694 )
695 }
696
697 fn visit_variant<Path, Fields, Var>(self, mut path: Path, variants: Var) -> Self::Value
698 where
699 Path: PathIter<'resolver>,
700 Fields: FieldIter<'resolver, Self::TypeId>,
701 Var: VariantIter<'resolver, Fields>,
702 {
703 let mut var_iter = variants.map(|v| Variant {
707 index: v.index,
708 name: v.name,
709 fields: ConcreteFieldIter {
710 fields: v.fields.map(Some).collect(),
711 idx: 0,
712 },
713 });
714
715 (self.visit_variant)(self.context, &mut path, &mut var_iter)
716 }
717
718 fn visit_sequence<Path>(self, mut path: Path, type_id: Self::TypeId) -> Self::Value
719 where
720 Path: PathIter<'resolver>,
721 {
722 (self.visit_sequence)(self.context, &mut path, type_id)
723 }
724
725 fn visit_array(self, type_id: Self::TypeId, len: usize) -> Self::Value {
726 (self.visit_array)(self.context, type_id, len)
727 }
728
729 fn visit_tuple<TypeIds>(self, mut type_ids: TypeIds) -> Self::Value
730 where
731 TypeIds: ExactSizeIterator<Item = Self::TypeId>,
732 {
733 (self.visit_tuple)(
734 self.context,
735 &mut type_ids as &mut dyn ExactSizeIterator<Item = Self::TypeId>,
736 )
737 }
738
739 fn visit_primitive(self, primitive: Primitive) -> Self::Value {
740 (self.visit_primitive)(self.context, primitive)
741 }
742
743 fn visit_compact(self, type_id: Self::TypeId) -> Self::Value {
744 (self.visit_compact)(self.context, type_id)
745 }
746
747 fn visit_bit_sequence(
748 self,
749 store_format: BitsStoreFormat,
750 order_format: BitsOrderFormat,
751 ) -> Self::Value {
752 (self.visit_bit_sequence)(self.context, store_format, order_format)
753 }
754}
755
756#[cfg(test)]
757mod tests {
758 use super::*;
759 use crate::TypeResolver;
760
761 #[test]
764 fn check_type_inference() {
765 let visitor = new((), |_, _| 1u64)
766 .visit_array(|_, _, _| 2)
767 .visit_composite(|_, _, _| 3)
768 .visit_bit_sequence(|_, _, _| 4)
769 .visit_compact(|_, _| 5)
770 .visit_not_found(|_| 6)
771 .visit_tuple(|_, _| 8)
772 .visit_variant(|_, _, _| 9);
773 struct Foo;
779 impl crate::TypeResolver for Foo {
780 type TypeId = u32;
781 type Error = u8;
782
783 fn resolve_type<'this, V: ResolvedTypeVisitor<'this, TypeId = Self::TypeId>>(
784 &'this self,
785 _type_id: Self::TypeId,
786 visitor: V,
787 ) -> Result<V::Value, Self::Error> {
788 Ok(visitor.visit_not_found())
789 }
790 }
791
792 assert_eq!(Foo.resolve_type(123, visitor).unwrap(), 6);
793 }
794}