simplicity/types/
mod.rs

1// SPDX-License-Identifier: CC0-1.0
2
3//! Types and Type Inference
4//!
5//! Every Simplicity expression has two types associated with it: a source and
6//! a target. We refer to this pair of types as an "arrow". The types are
7//! inferred from the structure of the program.
8//!
9//! Simplicity types are one of three things
10//!   * A unit type, which has one value
11//!   * A sum of two other types
12//!   * A product of two other types
13//!
14//! During type inference, types are initially "free", meaning that there are
15//! no constraints on what they will eventually be. The program structure then
16//! imposes additional constraints (for example, the `comp` combinator requires
17//! that its left child's target type be the same as its right child's source
18//! type), and by unifying all these constraints, all types can be inferred.
19//!
20//! Type inference is done progressively during construction of Simplicity
21//! expressions. It is completed by the [`Type::finalize`] method, which
22//! recursively completes types by setting any remaining free variables to unit.
23//! If any type constraints are incompatible with each other (e.g. a type is
24//! bound to be both a product and a sum type) then inference fails at that point
25//! by returning an error.
26//!
27//! In addition to completing types [`Type::finalize`], does one additional
28//! check, the "occurs check", to ensures that there are no infinitely-sized
29//! types. Such types occur when a type has itself as a child, are illegal in
30//! Simplicity, and could not be represented by our data structures.
31//!
32
33// In this module, during inference types are characterized by their [`Bound`],
34// which describes the constraints on the type. The bound of a type can be
35// obtained by the [`Type::bound`] method, and is an enum with four variants:
36//
37// * [`Bound::Free`] means that the type has no constraints; it is a free
38//   variable. The type has a name which can be used to identify it in error
39//   messages.
40// * [`Bound::Sum`] and [`Bound::Product`] means that the the type is a sum
41//   (resp. product) of two other types, which are characterized by their
42//   own bounds.
43// * [`Bound::Complete`] means that the type has no free variables at all,
44//   and has an already-computed [`Final`] structure suitable for use in
45//   contexts that require complete types. (Unit types are always complete,
46//   and therefore use this variant rather than getting their own.)
47//
48// During inference, it is possible for a type to be complete, in the sense
49// of having no free variables, without its bound being [`Bound::Complete`].
50// This occurs, for example, if a type is a sum of two incomplete types, then
51// the child types are completed during type inference on an unrelated part
52// of the type hierarchy. The type would then have a [`Bound::Sum`] with two
53// children, both of which are complete.
54//
55// The inference engine makes an effort to notice when this happens and set
56// the bound of complete types to [`Bound::Complete`], but since type inference
57// is inherently non-local this cannot always be done.
58//
59// When the distinction matters, we say a type is "finalized" only if its bound
60// is `Complete` and "complete" if it has no free variables. But the distinction
61// usually does not matter, so we prefer to use the word "complete".
62//
63// There are three main types in this module:
64//   * [`Type`] is the main type representing a Simplicity type, whether it is
65//     complete or not. Its main methods are [`Type::bound`] which returns the
66//     current state of the type and [`Type::bind`] which adds a new constraint
67//     to the type.
68//   * `Final` is a mutex-free structure that can be obtained from a complete
69//     type. It includes the TMR and the complete bound describing the type.
70//   * `Bound` defines the structure of a type: whether it is free, complete,
71//     or a sum or product of other types.
72//
73
74use self::union_bound::{PointerLike, UbElement, WithGhostToken};
75use crate::dag::{DagLike, NoSharing};
76use crate::Tmr;
77
78use std::fmt;
79use std::sync::Arc;
80
81pub mod arrow;
82mod context;
83mod final_data;
84mod incomplete;
85mod precomputed;
86mod union_bound;
87mod variable;
88
89pub use context::{BoundRef, Context};
90pub use final_data::{CompleteBound, Final};
91pub use incomplete::Incomplete;
92
93/// Error type for simplicity
94#[non_exhaustive]
95#[derive(Clone, Debug)]
96pub enum Error {
97    /// An attempt to bind a type conflicted with an existing bound on the type
98    Bind {
99        existing_bound: Arc<Incomplete>,
100        new_bound: Arc<Incomplete>,
101        hint: &'static str,
102    },
103    /// Two unequal complete types were attempted to be unified
104    CompleteTypeMismatch {
105        type1: Arc<Final>,
106        type2: Arc<Final>,
107        hint: &'static str,
108    },
109    /// A type is recursive (i.e., occurs within itself), violating the "occurs check"
110    OccursCheck { infinite_bound: Arc<Incomplete> },
111    /// Attempted to combine two nodes which had different type inference
112    /// contexts. This is probably a programming error.
113    InferenceContextMismatch,
114}
115
116impl fmt::Display for Error {
117    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
118        match self {
119            Error::Bind {
120                ref existing_bound,
121                ref new_bound,
122                hint,
123            } => {
124                write!(
125                    f,
126                    "failed to apply bound `{}` to existing bound `{}`: {}",
127                    new_bound, existing_bound, hint,
128                )
129            }
130            Error::CompleteTypeMismatch {
131                ref type1,
132                ref type2,
133                hint,
134            } => {
135                write!(
136                    f,
137                    "attempted to unify unequal types `{}` and `{}`: {}",
138                    type1, type2, hint,
139                )
140            }
141            Error::OccursCheck { infinite_bound } => {
142                write!(f, "infinitely-sized type {}", infinite_bound,)
143            }
144            Error::InferenceContextMismatch => {
145                f.write_str("attempted to combine two nodes with different type inference contexts")
146            }
147        }
148    }
149}
150
151impl std::error::Error for Error {}
152
153/// The state of a [`Type`] based on all constraints currently imposed on it.
154#[derive(Clone)]
155enum Bound<'brand> {
156    /// Fully-unconstrained type
157    Free(String),
158    /// Fully-constrained (i.e. complete) type, which has no free variables.
159    Complete(Arc<Final>),
160    /// A sum of two other types
161    Sum(TypeInner<'brand>, TypeInner<'brand>),
162    /// A product of two other types
163    Product(TypeInner<'brand>, TypeInner<'brand>),
164}
165
166impl Bound<'_> {
167    /// Clones the `Bound`.
168    ///
169    /// This is the same as just calling `.clone()` but has a different name to
170    /// emphasize that what's being cloned is (at most) a pair of ref-counted
171    /// pointers.
172    pub fn shallow_clone(&self) -> Self {
173        self.clone()
174    }
175}
176
177/// Source or target type of a Simplicity expression.
178///
179/// Internally this type is essentially just a refcounted pointer; it is
180/// therefore quite cheap to clone, but be aware that cloning will not
181/// actually create a new independent type, just a second pointer to the
182/// first one.
183#[derive(Clone)]
184pub struct Type<'brand> {
185    /// Handle to the type context.
186    ctx: Context<'brand>,
187    /// The actual contents of the type.
188    inner: TypeInner<'brand>,
189}
190
191#[derive(Clone)]
192struct TypeInner<'brand> {
193    /// A set of constraints, which maintained by the union-bound algorithm and
194    /// is progressively tightened as type inference proceeds.
195    bound: UbElement<'brand, BoundRef<'brand>>,
196}
197
198impl TypeInner<'_> {
199    fn shallow_clone(&self) -> Self {
200        self.clone()
201    }
202}
203
204impl<'brand> Type<'brand> {
205    /// Return an unbound type with the given name
206    pub fn free(ctx: &Context<'brand>, name: String) -> Self {
207        Self::wrap_bound(ctx, ctx.alloc_free(name))
208    }
209
210    /// Create the unit type.
211    pub fn unit(ctx: &Context<'brand>) -> Self {
212        Self::wrap_bound(ctx, ctx.alloc_unit())
213    }
214
215    /// Create the type `2^(2^n)` for the given `n`.
216    ///
217    /// The type is precomputed and fast to access.
218    pub fn two_two_n(ctx: &Context<'brand>, n: usize) -> Self {
219        Self::complete(ctx, precomputed::nth_power_of_2(n))
220    }
221
222    /// Create the sum of the given `left` and `right` types.
223    pub fn sum(ctx: &Context<'brand>, left: Self, right: Self) -> Self {
224        Self::wrap_bound(ctx, ctx.alloc_sum(left, right))
225    }
226
227    /// Create the product of the given `left` and `right` types.
228    pub fn product(ctx: &Context<'brand>, left: Self, right: Self) -> Self {
229        Self::wrap_bound(ctx, ctx.alloc_product(left, right))
230    }
231
232    /// Create a complete type.
233    pub fn complete(ctx: &Context<'brand>, final_data: Arc<Final>) -> Self {
234        Self::wrap_bound(ctx, ctx.alloc_complete(final_data))
235    }
236
237    fn wrap_bound(ctx: &Context<'brand>, bound: BoundRef<'brand>) -> Self {
238        Type {
239            ctx: ctx.shallow_clone(),
240            inner: TypeInner {
241                bound: UbElement::new(bound),
242            },
243        }
244    }
245
246    /// Clones the `Type`.
247    ///
248    /// This is the same as just calling `.clone()` but has a different name to
249    /// emphasize that what's being cloned is merely a ref-counted pointer.
250    pub fn shallow_clone(&self) -> Self {
251        self.clone()
252    }
253
254    /// Accessor for the TMR of this type, if it is final
255    pub fn tmr(&self) -> Option<Tmr> {
256        self.final_data().map(|data| data.tmr())
257    }
258
259    /// Accessor for the data of this type, if it is complete
260    pub fn final_data(&self) -> Option<Arc<Final>> {
261        let root = self.ctx.get_root_ref(&self.inner.bound);
262        let bound = self.ctx.get(&root);
263        if let Bound::Complete(ref data) = bound {
264            Some(Arc::clone(data))
265        } else {
266            None
267        }
268    }
269
270    /// Whether this type is known to be final
271    ///
272    /// During type inference this may be false even though the type is, in fact,
273    /// complete, since its children may have been unified to a complete type. To
274    /// ensure a type is complete, call [`Type::finalize`].
275    pub fn is_final(&self) -> bool {
276        self.final_data().is_some()
277    }
278
279    /// Converts a type as-is to an `Incomplete` type for use in an error.
280    pub fn to_incomplete(&self) -> Arc<Incomplete> {
281        let root = self.ctx.get_root_ref(&self.inner.bound);
282        Incomplete::from_bound_ref(&self.ctx, root)
283    }
284
285    /// Attempts to finalize the type. Returns its TMR on success.
286    pub fn finalize(&self) -> Result<Arc<Final>, Error> {
287        let root = self.ctx.get_root_ref(&self.inner.bound);
288        let bound = self.ctx.get(&root);
289        if let Bound::Complete(ref data) = bound {
290            return Ok(Arc::clone(data));
291        }
292
293        // First, do occurs-check to ensure that we have no infinitely sized types.
294        if let Some(infinite_bound) = Incomplete::occurs_check(&self.ctx, root.shallow_clone()) {
295            return Err(Error::OccursCheck { infinite_bound });
296        }
297
298        // Now that we know our types have finite size, we can safely use a
299        // post-order iterator to finalize them.
300        let mut finalized = vec![];
301        for data in (&self.ctx, root).post_order_iter::<NoSharing>() {
302            let bound_get = data.node.0.get(&data.node.1);
303            let final_data = match bound_get {
304                Bound::Free(_) => Final::unit(),
305                Bound::Complete(ref arc) => Arc::clone(arc),
306                Bound::Sum(..) => Final::sum(
307                    Arc::clone(&finalized[data.left_index.unwrap()]),
308                    Arc::clone(&finalized[data.right_index.unwrap()]),
309                ),
310                Bound::Product(..) => Final::product(
311                    Arc::clone(&finalized[data.left_index.unwrap()]),
312                    Arc::clone(&finalized[data.right_index.unwrap()]),
313                ),
314            };
315
316            if !matches!(bound_get, Bound::Complete(..)) {
317                self.ctx
318                    .reassign_non_complete(data.node.1, Bound::Complete(Arc::clone(&final_data)));
319            }
320            finalized.push(final_data);
321        }
322        Ok(finalized.pop().unwrap())
323    }
324
325    /// Return a vector containing the types 2^(2^i) for i from 0 to n-1.
326    pub fn powers_of_two(ctx: &Context<'brand>, n: usize) -> Vec<Self> {
327        let mut ret = Vec::with_capacity(n);
328
329        let unit = Type::unit(ctx);
330        let mut two = Type::sum(ctx, unit.shallow_clone(), unit);
331        for _ in 0..n {
332            ret.push(two.shallow_clone());
333            two = Type::product(ctx, two.shallow_clone(), two);
334        }
335        ret
336    }
337}
338
339const MAX_DISPLAY_DEPTH: usize = 64;
340const MAX_DISPLAY_LENGTH: usize = 10000;
341
342impl fmt::Debug for Type<'_> {
343    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
344        let root = self.ctx.get_root_ref(&self.inner.bound);
345        for data in (&self.ctx, root).verbose_pre_order_iter::<NoSharing>(Some(MAX_DISPLAY_DEPTH)) {
346            if data.index > MAX_DISPLAY_LENGTH {
347                write!(f, "... [truncated type after {} nodes]", MAX_DISPLAY_LENGTH)?;
348                return Ok(());
349            }
350            if data.depth == MAX_DISPLAY_DEPTH {
351                if data.n_children_yielded == 0 {
352                    f.write_str("...")?;
353                }
354                continue;
355            }
356            let bound = data.node.0.get(&data.node.1);
357            match (bound, data.n_children_yielded) {
358                (Bound::Free(ref s), _) => f.write_str(s)?,
359                (Bound::Complete(ref comp), _) => fmt::Debug::fmt(comp, f)?,
360                (Bound::Sum(..), 0) | (Bound::Product(..), 0) => {
361                    if data.index > 0 {
362                        f.write_str("(")?;
363                    }
364                }
365                (Bound::Sum(..), 2) | (Bound::Product(..), 2) => {
366                    if data.index > 0 {
367                        f.write_str(")")?
368                    }
369                }
370                (Bound::Sum(..), _) => f.write_str(" + ")?,
371                (Bound::Product(..), _) => f.write_str(" × ")?,
372            }
373        }
374        Ok(())
375    }
376}
377
378impl fmt::Display for Type<'_> {
379    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
380        let root = self.ctx.get_root_ref(&self.inner.bound);
381        for data in (&self.ctx, root).verbose_pre_order_iter::<NoSharing>(Some(MAX_DISPLAY_DEPTH)) {
382            if data.index > MAX_DISPLAY_LENGTH {
383                write!(f, "... [truncated type after {} nodes]", MAX_DISPLAY_LENGTH)?;
384                return Ok(());
385            }
386            if data.depth == MAX_DISPLAY_DEPTH {
387                if data.n_children_yielded == 0 {
388                    f.write_str("...")?;
389                }
390                continue;
391            }
392            let bound = data.node.0.get(&data.node.1);
393            match (bound, data.n_children_yielded) {
394                (Bound::Free(ref s), _) => f.write_str(s)?,
395                (Bound::Complete(ref comp), _) => fmt::Display::fmt(comp, f)?,
396                (Bound::Sum(..), 0) | (Bound::Product(..), 0) => {
397                    if data.index > 0 {
398                        f.write_str("(")?;
399                    }
400                }
401                (Bound::Sum(..), 2) | (Bound::Product(..), 2) => {
402                    if data.index > 0 {
403                        f.write_str(")")?
404                    }
405                }
406                (Bound::Sum(..), _) => f.write_str(" + ")?,
407                (Bound::Product(..), _) => f.write_str(" × ")?,
408            }
409        }
410        Ok(())
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417
418    use crate::jet::Core;
419    use crate::node::{ConstructNode, CoreConstructible};
420
421    #[test]
422    fn inference_failure() {
423        Context::with_context(|ctx| {
424            // unit: A -> 1
425            let unit = Arc::<ConstructNode<Core>>::unit(&ctx); // 1 -> 1
426
427            // Force unit to be 1->1
428            Arc::<ConstructNode<Core>>::comp(&unit, &unit).unwrap();
429
430            // take unit: 1 * B -> 1
431            let take_unit = Arc::<ConstructNode<Core>>::take(&unit); // 1*1 -> 1
432
433            // Pair will try to unify 1 and 1*B
434            Arc::<ConstructNode<Core>>::pair(&unit, &take_unit).unwrap_err();
435            // Trying to do it again should not work.
436            Arc::<ConstructNode<Core>>::pair(&unit, &take_unit).unwrap_err();
437        });
438    }
439
440    #[test]
441    fn memory_leak() {
442        Context::with_context(|ctx| {
443            let iden = Arc::<ConstructNode<Core>>::iden(&ctx);
444            let drop = Arc::<ConstructNode<Core>>::drop_(&iden);
445            let case = Arc::<ConstructNode<Core>>::case(&iden, &drop).unwrap();
446
447            let _ = format!("{:?}", case.arrow().source);
448        });
449    }
450}