Skip to main content

spacetimedb_sats/
resolve_refs.rs

1use crate::{
2    typespace::TypeRefError, AlgebraicType, AlgebraicTypeRef, ArrayType, ProductType, ProductTypeElement, SumType,
3    SumTypeVariant, WithTypespace,
4};
5
6/// Resolver for [`AlgebraicTypeRef`]s within a structure.
7#[derive(Default)]
8pub struct ResolveRefState {
9    /// The stack used to handle cycle detection for [recursive types] (`μα. T`).
10    ///
11    /// [recursive types]: https://en.wikipedia.org/wiki/Recursive_data_type#Theory
12    stack: Vec<AlgebraicTypeRef>,
13}
14
15/// A trait for types that know how to resolve their [`AlgebraicTypeRef`]s
16/// provided a typing context and the resolver `state`.
17pub trait ResolveRefs {
18    /// Output type after type references have been resolved.
19    type Output;
20
21    /// Returns, if possible, an output with all [`AlgebraicTypeRef`]s
22    /// within `this` (typing context carried) resolved
23    /// using the provided resolver `state`.
24    ///
25    /// `Err` is returned if a cycle was detected, or if any `AlgebraicTypeRef` touched was invalid.
26    fn resolve_refs(this: WithTypespace<'_, Self>, state: &mut ResolveRefState) -> Result<Self::Output, TypeRefError>;
27}
28
29// -----------------------------------------------------------------------------
30// The interesting logic:
31// -----------------------------------------------------------------------------
32
33impl ResolveRefs for AlgebraicTypeRef {
34    type Output = AlgebraicType;
35    fn resolve_refs(this: WithTypespace<'_, Self>, state: &mut ResolveRefState) -> Result<Self::Output, TypeRefError> {
36        // Suppose we have `&0 = { Nil, Cons({ elem: U8, tail: &0 }) }`.
37        // This is our standard cons-list type.
38        // In this setup, when getting to `tail`,
39        // we would recurse back to expanding `tail` again, and so or...
40        // So we will never halt. This check breaks that cycle.
41        if state.stack.contains(this.ty()) {
42            return Err(TypeRefError::RecursiveTypeRef(*this.ty()));
43        }
44
45        // Push ourselves to the stack.
46        state.stack.push(*this.ty());
47
48        // Extract the `at: AlgebraicType` pointed to by `this` and then resolve `at`.
49        let ret = this
50            .typespace()
51            .get(*this.ty())
52            .ok_or(TypeRefError::InvalidTypeRef(*this.ty()))
53            .and_then(|at| this.with(at)._resolve_refs(state));
54
55        // Remove ourselves.
56        state.stack.pop();
57        ret
58    }
59}
60
61// -----------------------------------------------------------------------------
62// All the below is just plumbing:
63// -----------------------------------------------------------------------------
64
65impl ResolveRefs for AlgebraicType {
66    type Output = Self;
67    fn resolve_refs(this: WithTypespace<'_, Self>, state: &mut ResolveRefState) -> Result<Self::Output, TypeRefError> {
68        match this.ty() {
69            Self::Ref(r) => this.with(r)._resolve_refs(state),
70            Self::Sum(sum) => this.with(sum)._resolve_refs(state).map(Into::into),
71            Self::Product(prod) => this.with(prod)._resolve_refs(state).map(Into::into),
72            Self::Array(ty) => this.with(ty)._resolve_refs(state).map(Into::into),
73            // These types are plain and cannot have refs in them.
74            x => Ok(x.clone()),
75        }
76    }
77}
78
79impl ResolveRefs for ArrayType {
80    type Output = Self;
81    fn resolve_refs(this: WithTypespace<'_, Self>, state: &mut ResolveRefState) -> Result<Self::Output, TypeRefError> {
82        Ok(Self {
83            elem_ty: Box::new(this.map(|m| &*m.elem_ty)._resolve_refs(state)?),
84        })
85    }
86}
87
88impl ResolveRefs for ProductType {
89    type Output = Self;
90    fn resolve_refs(this: WithTypespace<'_, Self>, state: &mut ResolveRefState) -> Result<Self::Output, TypeRefError> {
91        let elements = this
92            .ty()
93            .elements
94            .iter()
95            .map(|el| this.with(el)._resolve_refs(state))
96            .collect::<Result<_, _>>()?;
97        Ok(ProductType { elements })
98    }
99}
100
101impl ResolveRefs for ProductTypeElement {
102    type Output = Self;
103    fn resolve_refs(this: WithTypespace<'_, Self>, state: &mut ResolveRefState) -> Result<Self::Output, TypeRefError> {
104        Ok(Self {
105            algebraic_type: this.map(|e| &e.algebraic_type)._resolve_refs(state)?,
106            name: this.ty().name.clone(),
107        })
108    }
109}
110
111impl ResolveRefs for SumType {
112    type Output = Self;
113    fn resolve_refs(this: WithTypespace<'_, Self>, state: &mut ResolveRefState) -> Result<Self::Output, TypeRefError> {
114        let variants = this
115            .ty()
116            .variants
117            .iter()
118            .map(|v| this.with(v)._resolve_refs(state))
119            .collect::<Result<_, _>>()?;
120        Ok(Self { variants })
121    }
122}
123
124impl ResolveRefs for SumTypeVariant {
125    type Output = Self;
126    fn resolve_refs(this: WithTypespace<'_, Self>, state: &mut ResolveRefState) -> Result<Self::Output, TypeRefError> {
127        Ok(Self {
128            algebraic_type: this.map(|v| &v.algebraic_type)._resolve_refs(state)?,
129            name: this.ty().name.clone(),
130        })
131    }
132}
133
134impl<T: ResolveRefs> WithTypespace<'_, T> {
135    pub fn resolve_refs(self) -> Result<T::Output, TypeRefError> {
136        T::resolve_refs(self, &mut ResolveRefState::default())
137    }
138    fn _resolve_refs(self, state: &mut ResolveRefState) -> Result<T::Output, TypeRefError> {
139        T::resolve_refs(self, state)
140    }
141}