simplicity/types/
mod.rs

1// SPDX-License-Identifier: CC0-1.0
2
3//! Types and Type Inference
4//!
5//! Every Simplicity expression has two types associated with it: a source and
6//! a target. We refer to this pair of types as an "arrow". The types are
7//! inferred from the structure of the program.
8//!
9//! Simplicity types are one of three things
10//!   * A unit type, which has one value
11//!   * A sum of two other types
12//!   * A product of two other types
13//!
14//! During type inference, types are initially "free", meaning that there are
15//! no constraints on what they will eventually be. The program structure then
16//! imposes additional constraints (for example, the `comp` combinator requires
17//! that its left child's target type be the same as its right child's source
18//! type), and by unifying all these constraints, all types can be inferred.
19//!
20//! Type inference is done progressively during construction of Simplicity
21//! expressions. It is completed by the [`Type::finalize`] method, which
22//! recursively completes types by setting any remaining free variables to unit.
23//! If any type constraints are incompatible with each other (e.g. a type is
24//! bound to be both a product and a sum type) then inference fails at that point
25//! by returning an error.
26//!
27//! In addition to completing types [`Type::finalize`], does one additional
28//! check, the "occurs check", to ensures that there are no infinitely-sized
29//! types. Such types occur when a type has itself as a child, are illegal in
30//! Simplicity, and could not be represented by our data structures.
31//!
32
33// In this module, during inference types are characterized by their [`Bound`],
34// which describes the constraints on the type. The bound of a type can be
35// obtained by the [`Type::bound`] method, and is an enum with four variants:
36//
37// * [`Bound::Free`] means that the type has no constraints; it is a free
38//   variable. The type has a name which can be used to identify it in error
39//   messages.
40// * [`Bound::Sum`] and [`Bound::Product`] means that the the type is a sum
41//   (resp. product) of two other types, which are characterized by their
42//   own bounds.
43// * [`Bound::Complete`] means that the type has no free variables at all,
44//   and has an already-computed [`Final`] structure suitable for use in
45//   contexts that require complete types. (Unit types are always complete,
46//   and therefore use this variant rather than getting their own.)
47//
48// During inference, it is possible for a type to be complete, in the sense
49// of having no free variables, without its bound being [`Bound::Complete`].
50// This occurs, for example, if a type is a sum of two incomplete types, then
51// the child types are completed during type inference on an unrelated part
52// of the type hierarchy. The type would then have a [`Bound::Sum`] with two
53// children, both of which are complete.
54//
55// The inference engine makes an effort to notice when this happens and set
56// the bound of complete types to [`Bound::Complete`], but since type inference
57// is inherently non-local this cannot always be done.
58//
59// When the distinction matters, we say a type is "finalized" only if its bound
60// is `Complete` and "complete" if it has no free variables. But the distinction
61// usually does not matter, so we prefer to use the word "complete".
62//
63// There are three main types in this module:
64//   * [`Type`] is the main type representing a Simplicity type, whether it is
65//     complete or not. Its main methods are [`Type::bound`] which returns the
66//     current state of the type and [`Type::bind`] which adds a new constraint
67//     to the type.
68//   * `Final` is a mutex-free structure that can be obtained from a complete
69//     type. It includes the TMR and the complete bound describing the type.
70//   * `Bound` defines the structure of a type: whether it is free, complete,
71//     or a sum or product of other types.
72//
73
74use self::union_bound::{PointerLike, UbElement};
75use crate::dag::{DagLike, NoSharing};
76use crate::Tmr;
77
78use std::collections::HashSet;
79use std::fmt;
80use std::sync::Arc;
81
82pub mod arrow;
83mod context;
84mod final_data;
85mod precomputed;
86mod union_bound;
87mod variable;
88
89pub use context::{BoundRef, Context};
90pub use final_data::{CompleteBound, Final};
91
92/// Error type for simplicity
93#[non_exhaustive]
94#[derive(Clone, Debug)]
95pub enum Error {
96    /// An attempt to bind a type conflicted with an existing bound on the type
97    Bind {
98        existing_bound: Type,
99        new_bound: Type,
100        hint: &'static str,
101    },
102    /// Two unequal complete types were attempted to be unified
103    CompleteTypeMismatch {
104        type1: Arc<Final>,
105        type2: Arc<Final>,
106        hint: &'static str,
107    },
108    /// A type is recursive (i.e., occurs within itself), violating the "occurs check"
109    OccursCheck { infinite_bound: Type },
110    /// Attempted to combine two nodes which had different type inference
111    /// contexts. This is probably a programming error.
112    InferenceContextMismatch,
113}
114
115impl fmt::Display for Error {
116    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
117        match self {
118            Error::Bind {
119                ref existing_bound,
120                ref new_bound,
121                hint,
122            } => {
123                write!(
124                    f,
125                    "failed to apply bound `{}` to existing bound `{}`: {}",
126                    new_bound, existing_bound, hint,
127                )
128            }
129            Error::CompleteTypeMismatch {
130                ref type1,
131                ref type2,
132                hint,
133            } => {
134                write!(
135                    f,
136                    "attempted to unify unequal types `{}` and `{}`: {}",
137                    type1, type2, hint,
138                )
139            }
140            Error::OccursCheck { infinite_bound } => {
141                write!(f, "infinitely-sized type {}", infinite_bound,)
142            }
143            Error::InferenceContextMismatch => {
144                f.write_str("attempted to combine two nodes with different type inference contexts")
145            }
146        }
147    }
148}
149
150impl std::error::Error for Error {}
151
152/// The state of a [`Type`] based on all constraints currently imposed on it.
153#[derive(Clone)]
154enum Bound {
155    /// Fully-unconstrained type
156    Free(String),
157    /// Fully-constrained (i.e. complete) type, which has no free variables.
158    Complete(Arc<Final>),
159    /// A sum of two other types
160    Sum(TypeInner, TypeInner),
161    /// A product of two other types
162    Product(TypeInner, TypeInner),
163}
164
165impl Bound {
166    /// Clones the `Bound`.
167    ///
168    /// This is the same as just calling `.clone()` but has a different name to
169    /// emphasize that what's being cloned is (at most) a pair of ref-counted
170    /// pointers.
171    pub fn shallow_clone(&self) -> Bound {
172        self.clone()
173    }
174}
175
176/// Source or target type of a Simplicity expression.
177///
178/// Internally this type is essentially just a refcounted pointer; it is
179/// therefore quite cheap to clone, but be aware that cloning will not
180/// actually create a new independent type, just a second pointer to the
181/// first one.
182#[derive(Clone)]
183pub struct Type {
184    /// Handle to the type context.
185    ctx: Context,
186    /// The actual contents of the type.
187    inner: TypeInner,
188}
189
190#[derive(Clone)]
191struct TypeInner {
192    /// A set of constraints, which maintained by the union-bound algorithm and
193    /// is progressively tightened as type inference proceeds.
194    bound: UbElement<BoundRef>,
195}
196
197impl TypeInner {
198    fn shallow_clone(&self) -> Self {
199        self.clone()
200    }
201}
202
203impl Type {
204    /// Return an unbound type with the given name
205    pub fn free(ctx: &Context, name: String) -> Self {
206        Self::wrap_bound(ctx, ctx.alloc_free(name))
207    }
208
209    /// Create the unit type.
210    pub fn unit(ctx: &Context) -> Self {
211        Self::wrap_bound(ctx, ctx.alloc_unit())
212    }
213
214    /// Create the type `2^(2^n)` for the given `n`.
215    ///
216    /// The type is precomputed and fast to access.
217    pub fn two_two_n(ctx: &Context, n: usize) -> Self {
218        Self::complete(ctx, precomputed::nth_power_of_2(n))
219    }
220
221    /// Create the sum of the given `left` and `right` types.
222    pub fn sum(ctx: &Context, left: Self, right: Self) -> Self {
223        Self::wrap_bound(ctx, ctx.alloc_sum(left, right))
224    }
225
226    /// Create the product of the given `left` and `right` types.
227    pub fn product(ctx: &Context, left: Self, right: Self) -> Self {
228        Self::wrap_bound(ctx, ctx.alloc_product(left, right))
229    }
230
231    /// Create a complete type.
232    pub fn complete(ctx: &Context, final_data: Arc<Final>) -> Self {
233        Self::wrap_bound(ctx, ctx.alloc_complete(final_data))
234    }
235
236    fn wrap_bound(ctx: &Context, bound: BoundRef) -> Self {
237        bound.assert_matches_context(ctx);
238        Type {
239            ctx: ctx.shallow_clone(),
240            inner: TypeInner {
241                bound: UbElement::new(bound),
242            },
243        }
244    }
245
246    /// Clones the `Type`.
247    ///
248    /// This is the same as just calling `.clone()` but has a different name to
249    /// emphasize that what's being cloned is merely a ref-counted pointer.
250    pub fn shallow_clone(&self) -> Type {
251        self.clone()
252    }
253
254    /// Accessor for the TMR of this type, if it is final
255    pub fn tmr(&self) -> Option<Tmr> {
256        self.final_data().map(|data| data.tmr())
257    }
258
259    /// Accessor for the data of this type, if it is complete
260    pub fn final_data(&self) -> Option<Arc<Final>> {
261        if let Bound::Complete(ref data) = self.ctx.get(&self.inner.bound.root()) {
262            Some(Arc::clone(data))
263        } else {
264            None
265        }
266    }
267
268    /// Whether this type is known to be final
269    ///
270    /// During type inference this may be false even though the type is, in fact,
271    /// complete, since its children may have been unified to a complete type. To
272    /// ensure a type is complete, call [`Type::finalize`].
273    pub fn is_final(&self) -> bool {
274        self.final_data().is_some()
275    }
276
277    /// Attempts to finalize the type. Returns its TMR on success.
278    pub fn finalize(&self) -> Result<Arc<Final>, Error> {
279        use context::OccursCheckId;
280
281        /// Helper type for the occurs-check.
282        enum OccursCheckStack {
283            Iterate(BoundRef),
284            Complete(OccursCheckId),
285        }
286
287        // Done with sharing tracker. Actual algorithm follows.
288        let root = self.inner.bound.root();
289        let bound = self.ctx.get(&root);
290        if let Bound::Complete(ref data) = bound {
291            return Ok(Arc::clone(data));
292        }
293
294        // First, do occurs-check to ensure that we have no infinitely sized types.
295        let mut stack = vec![OccursCheckStack::Iterate(root)];
296        let mut in_progress = HashSet::new();
297        let mut completed = HashSet::new();
298        while let Some(top) = stack.pop() {
299            let bound = match top {
300                OccursCheckStack::Complete(id) => {
301                    in_progress.remove(&id);
302                    completed.insert(id);
303                    continue;
304                }
305                OccursCheckStack::Iterate(b) => b,
306            };
307
308            let id = bound.occurs_check_id();
309            if completed.contains(&id) {
310                // Once we have iterated through a type, we don't need to check it again.
311                // Without this shortcut the occurs-check would take exponential time.
312                continue;
313            }
314            if !in_progress.insert(id) {
315                return Err(Error::OccursCheck {
316                    infinite_bound: Type::wrap_bound(&self.ctx, bound),
317                });
318            }
319
320            stack.push(OccursCheckStack::Complete(id));
321            if let Some((_, child)) = (&self.ctx, bound.shallow_clone()).right_child() {
322                stack.push(OccursCheckStack::Iterate(child));
323            }
324            if let Some((_, child)) = (&self.ctx, bound).left_child() {
325                stack.push(OccursCheckStack::Iterate(child));
326            }
327        }
328
329        // Now that we know our types have finite size, we can safely use a
330        // post-order iterator to finalize them.
331        let mut finalized = vec![];
332        for data in (&self.ctx, self.inner.bound.root()).post_order_iter::<NoSharing>() {
333            let bound_get = data.node.0.get(&data.node.1);
334            let final_data = match bound_get {
335                Bound::Free(_) => Final::unit(),
336                Bound::Complete(ref arc) => Arc::clone(arc),
337                Bound::Sum(..) => Final::sum(
338                    Arc::clone(&finalized[data.left_index.unwrap()]),
339                    Arc::clone(&finalized[data.right_index.unwrap()]),
340                ),
341                Bound::Product(..) => Final::product(
342                    Arc::clone(&finalized[data.left_index.unwrap()]),
343                    Arc::clone(&finalized[data.right_index.unwrap()]),
344                ),
345            };
346
347            if !matches!(bound_get, Bound::Complete(..)) {
348                self.ctx
349                    .reassign_non_complete(data.node.1, Bound::Complete(Arc::clone(&final_data)));
350            }
351            finalized.push(final_data);
352        }
353        Ok(finalized.pop().unwrap())
354    }
355
356    /// Return a vector containing the types 2^(2^i) for i from 0 to n-1.
357    pub fn powers_of_two(ctx: &Context, n: usize) -> Vec<Self> {
358        let mut ret = Vec::with_capacity(n);
359
360        let unit = Type::unit(ctx);
361        let mut two = Type::sum(ctx, unit.shallow_clone(), unit);
362        for _ in 0..n {
363            ret.push(two.shallow_clone());
364            two = Type::product(ctx, two.shallow_clone(), two);
365        }
366        ret
367    }
368}
369
370const MAX_DISPLAY_DEPTH: usize = 64;
371const MAX_DISPLAY_LENGTH: usize = 10000;
372
373impl fmt::Debug for Type {
374    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
375        for data in (&self.ctx, self.inner.bound.root())
376            .verbose_pre_order_iter::<NoSharing>(Some(MAX_DISPLAY_DEPTH))
377        {
378            if data.index > MAX_DISPLAY_LENGTH {
379                write!(f, "... [truncated type after {} nodes]", MAX_DISPLAY_LENGTH)?;
380                return Ok(());
381            }
382            if data.depth == MAX_DISPLAY_DEPTH {
383                if data.n_children_yielded == 0 {
384                    f.write_str("...")?;
385                }
386                continue;
387            }
388            let bound = data.node.0.get(&data.node.1);
389            match (bound, data.n_children_yielded) {
390                (Bound::Free(ref s), _) => f.write_str(s)?,
391                (Bound::Complete(ref comp), _) => fmt::Debug::fmt(comp, f)?,
392                (Bound::Sum(..), 0) | (Bound::Product(..), 0) => {
393                    if data.index > 0 {
394                        f.write_str("(")?;
395                    }
396                }
397                (Bound::Sum(..), 2) | (Bound::Product(..), 2) => {
398                    if data.index > 0 {
399                        f.write_str(")")?
400                    }
401                }
402                (Bound::Sum(..), _) => f.write_str(" + ")?,
403                (Bound::Product(..), _) => f.write_str(" × ")?,
404            }
405        }
406        Ok(())
407    }
408}
409
410impl fmt::Display for Type {
411    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
412        for data in (&self.ctx, self.inner.bound.root())
413            .verbose_pre_order_iter::<NoSharing>(Some(MAX_DISPLAY_DEPTH))
414        {
415            if data.index > MAX_DISPLAY_LENGTH {
416                write!(f, "... [truncated type after {} nodes]", MAX_DISPLAY_LENGTH)?;
417                return Ok(());
418            }
419            if data.depth == MAX_DISPLAY_DEPTH {
420                if data.n_children_yielded == 0 {
421                    f.write_str("...")?;
422                }
423                continue;
424            }
425            let bound = data.node.0.get(&data.node.1);
426            match (bound, data.n_children_yielded) {
427                (Bound::Free(ref s), _) => f.write_str(s)?,
428                (Bound::Complete(ref comp), _) => fmt::Display::fmt(comp, f)?,
429                (Bound::Sum(..), 0) | (Bound::Product(..), 0) => {
430                    if data.index > 0 {
431                        f.write_str("(")?;
432                    }
433                }
434                (Bound::Sum(..), 2) | (Bound::Product(..), 2) => {
435                    if data.index > 0 {
436                        f.write_str(")")?
437                    }
438                }
439                (Bound::Sum(..), _) => f.write_str(" + ")?,
440                (Bound::Product(..), _) => f.write_str(" × ")?,
441            }
442        }
443        Ok(())
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    use crate::jet::Core;
452    use crate::node::{ConstructNode, CoreConstructible};
453
454    #[test]
455    fn inference_failure() {
456        let ctx = Context::new();
457
458        // unit: A -> 1
459        let unit = Arc::<ConstructNode<Core>>::unit(&ctx); // 1 -> 1
460
461        // Force unit to be 1->1
462        Arc::<ConstructNode<Core>>::comp(&unit, &unit).unwrap();
463
464        // take unit: 1 * B -> 1
465        let take_unit = Arc::<ConstructNode<Core>>::take(&unit); // 1*1 -> 1
466
467        // Pair will try to unify 1 and 1*B
468        Arc::<ConstructNode<Core>>::pair(&unit, &take_unit).unwrap_err();
469        // Trying to do it again should not work.
470        Arc::<ConstructNode<Core>>::pair(&unit, &take_unit).unwrap_err();
471    }
472
473    #[test]
474    fn memory_leak() {
475        let ctx = Context::new();
476        let iden = Arc::<ConstructNode<Core>>::iden(&ctx);
477        let drop = Arc::<ConstructNode<Core>>::drop_(&iden);
478        let case = Arc::<ConstructNode<Core>>::case(&iden, &drop).unwrap();
479
480        let _ = format!("{:?}", case.arrow().source);
481    }
482}