simplicity/types/mod.rs
1// SPDX-License-Identifier: CC0-1.0
2
3//! Types and Type Inference
4//!
5//! Every Simplicity expression has two types associated with it: a source and
6//! a target. We refer to this pair of types as an "arrow". The types are
7//! inferred from the structure of the program.
8//!
9//! Simplicity types are one of three things
10//! * A unit type, which has one value
11//! * A sum of two other types
12//! * A product of two other types
13//!
14//! During type inference, types are initially "free", meaning that there are
15//! no constraints on what they will eventually be. The program structure then
16//! imposes additional constraints (for example, the `comp` combinator requires
17//! that its left child's target type be the same as its right child's source
18//! type), and by unifying all these constraints, all types can be inferred.
19//!
20//! Type inference is done progressively during construction of Simplicity
21//! expressions. It is completed by the [`Type::finalize`] method, which
22//! recursively completes types by setting any remaining free variables to unit.
23//! If any type constraints are incompatible with each other (e.g. a type is
24//! bound to be both a product and a sum type) then inference fails at that point
25//! by returning an error.
26//!
27//! In addition to completing types [`Type::finalize`], does one additional
28//! check, the "occurs check", to ensures that there are no infinitely-sized
29//! types. Such types occur when a type has itself as a child, are illegal in
30//! Simplicity, and could not be represented by our data structures.
31//!
32
33// In this module, during inference types are characterized by their [`Bound`],
34// which describes the constraints on the type. The bound of a type can be
35// obtained by the [`Type::bound`] method, and is an enum with four variants:
36//
37// * [`Bound::Free`] means that the type has no constraints; it is a free
38// variable. The type has a name which can be used to identify it in error
39// messages.
40// * [`Bound::Sum`] and [`Bound::Product`] means that the the type is a sum
41// (resp. product) of two other types, which are characterized by their
42// own bounds.
43// * [`Bound::Complete`] means that the type has no free variables at all,
44// and has an already-computed [`Final`] structure suitable for use in
45// contexts that require complete types. (Unit types are always complete,
46// and therefore use this variant rather than getting their own.)
47//
48// During inference, it is possible for a type to be complete, in the sense
49// of having no free variables, without its bound being [`Bound::Complete`].
50// This occurs, for example, if a type is a sum of two incomplete types, then
51// the child types are completed during type inference on an unrelated part
52// of the type hierarchy. The type would then have a [`Bound::Sum`] with two
53// children, both of which are complete.
54//
55// The inference engine makes an effort to notice when this happens and set
56// the bound of complete types to [`Bound::Complete`], but since type inference
57// is inherently non-local this cannot always be done.
58//
59// When the distinction matters, we say a type is "finalized" only if its bound
60// is `Complete` and "complete" if it has no free variables. But the distinction
61// usually does not matter, so we prefer to use the word "complete".
62//
63// There are three main types in this module:
64// * [`Type`] is the main type representing a Simplicity type, whether it is
65// complete or not. Its main methods are [`Type::bound`] which returns the
66// current state of the type and [`Type::bind`] which adds a new constraint
67// to the type.
68// * `Final` is a mutex-free structure that can be obtained from a complete
69// type. It includes the TMR and the complete bound describing the type.
70// * `Bound` defines the structure of a type: whether it is free, complete,
71// or a sum or product of other types.
72//
73
74use self::union_bound::{PointerLike, UbElement};
75use crate::dag::{DagLike, NoSharing};
76use crate::Tmr;
77
78use std::collections::HashSet;
79use std::fmt;
80use std::sync::Arc;
81
82pub mod arrow;
83mod context;
84mod final_data;
85mod precomputed;
86mod union_bound;
87mod variable;
88
89pub use context::{BoundRef, Context};
90pub use final_data::{CompleteBound, Final};
91
92/// Error type for simplicity
93#[non_exhaustive]
94#[derive(Clone, Debug)]
95pub enum Error {
96 /// An attempt to bind a type conflicted with an existing bound on the type
97 Bind {
98 existing_bound: Type,
99 new_bound: Type,
100 hint: &'static str,
101 },
102 /// Two unequal complete types were attempted to be unified
103 CompleteTypeMismatch {
104 type1: Arc<Final>,
105 type2: Arc<Final>,
106 hint: &'static str,
107 },
108 /// A type is recursive (i.e., occurs within itself), violating the "occurs check"
109 OccursCheck { infinite_bound: Type },
110 /// Attempted to combine two nodes which had different type inference
111 /// contexts. This is probably a programming error.
112 InferenceContextMismatch,
113}
114
115impl fmt::Display for Error {
116 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
117 match self {
118 Error::Bind {
119 ref existing_bound,
120 ref new_bound,
121 hint,
122 } => {
123 write!(
124 f,
125 "failed to apply bound `{}` to existing bound `{}`: {}",
126 new_bound, existing_bound, hint,
127 )
128 }
129 Error::CompleteTypeMismatch {
130 ref type1,
131 ref type2,
132 hint,
133 } => {
134 write!(
135 f,
136 "attempted to unify unequal types `{}` and `{}`: {}",
137 type1, type2, hint,
138 )
139 }
140 Error::OccursCheck { infinite_bound } => {
141 write!(f, "infinitely-sized type {}", infinite_bound,)
142 }
143 Error::InferenceContextMismatch => {
144 f.write_str("attempted to combine two nodes with different type inference contexts")
145 }
146 }
147 }
148}
149
150impl std::error::Error for Error {}
151
152/// The state of a [`Type`] based on all constraints currently imposed on it.
153#[derive(Clone)]
154enum Bound {
155 /// Fully-unconstrained type
156 Free(String),
157 /// Fully-constrained (i.e. complete) type, which has no free variables.
158 Complete(Arc<Final>),
159 /// A sum of two other types
160 Sum(TypeInner, TypeInner),
161 /// A product of two other types
162 Product(TypeInner, TypeInner),
163}
164
165impl Bound {
166 /// Clones the `Bound`.
167 ///
168 /// This is the same as just calling `.clone()` but has a different name to
169 /// emphasize that what's being cloned is (at most) a pair of ref-counted
170 /// pointers.
171 pub fn shallow_clone(&self) -> Bound {
172 self.clone()
173 }
174}
175
176/// Source or target type of a Simplicity expression.
177///
178/// Internally this type is essentially just a refcounted pointer; it is
179/// therefore quite cheap to clone, but be aware that cloning will not
180/// actually create a new independent type, just a second pointer to the
181/// first one.
182#[derive(Clone)]
183pub struct Type {
184 /// Handle to the type context.
185 ctx: Context,
186 /// The actual contents of the type.
187 inner: TypeInner,
188}
189
190#[derive(Clone)]
191struct TypeInner {
192 /// A set of constraints, which maintained by the union-bound algorithm and
193 /// is progressively tightened as type inference proceeds.
194 bound: UbElement<BoundRef>,
195}
196
197impl TypeInner {
198 fn shallow_clone(&self) -> Self {
199 self.clone()
200 }
201}
202
203impl Type {
204 /// Return an unbound type with the given name
205 pub fn free(ctx: &Context, name: String) -> Self {
206 Self::wrap_bound(ctx, ctx.alloc_free(name))
207 }
208
209 /// Create the unit type.
210 pub fn unit(ctx: &Context) -> Self {
211 Self::wrap_bound(ctx, ctx.alloc_unit())
212 }
213
214 /// Create the type `2^(2^n)` for the given `n`.
215 ///
216 /// The type is precomputed and fast to access.
217 pub fn two_two_n(ctx: &Context, n: usize) -> Self {
218 Self::complete(ctx, precomputed::nth_power_of_2(n))
219 }
220
221 /// Create the sum of the given `left` and `right` types.
222 pub fn sum(ctx: &Context, left: Self, right: Self) -> Self {
223 Self::wrap_bound(ctx, ctx.alloc_sum(left, right))
224 }
225
226 /// Create the product of the given `left` and `right` types.
227 pub fn product(ctx: &Context, left: Self, right: Self) -> Self {
228 Self::wrap_bound(ctx, ctx.alloc_product(left, right))
229 }
230
231 /// Create a complete type.
232 pub fn complete(ctx: &Context, final_data: Arc<Final>) -> Self {
233 Self::wrap_bound(ctx, ctx.alloc_complete(final_data))
234 }
235
236 fn wrap_bound(ctx: &Context, bound: BoundRef) -> Self {
237 bound.assert_matches_context(ctx);
238 Type {
239 ctx: ctx.shallow_clone(),
240 inner: TypeInner {
241 bound: UbElement::new(bound),
242 },
243 }
244 }
245
246 /// Clones the `Type`.
247 ///
248 /// This is the same as just calling `.clone()` but has a different name to
249 /// emphasize that what's being cloned is merely a ref-counted pointer.
250 pub fn shallow_clone(&self) -> Type {
251 self.clone()
252 }
253
254 /// Accessor for the TMR of this type, if it is final
255 pub fn tmr(&self) -> Option<Tmr> {
256 self.final_data().map(|data| data.tmr())
257 }
258
259 /// Accessor for the data of this type, if it is complete
260 pub fn final_data(&self) -> Option<Arc<Final>> {
261 if let Bound::Complete(ref data) = self.ctx.get(&self.inner.bound.root()) {
262 Some(Arc::clone(data))
263 } else {
264 None
265 }
266 }
267
268 /// Whether this type is known to be final
269 ///
270 /// During type inference this may be false even though the type is, in fact,
271 /// complete, since its children may have been unified to a complete type. To
272 /// ensure a type is complete, call [`Type::finalize`].
273 pub fn is_final(&self) -> bool {
274 self.final_data().is_some()
275 }
276
277 /// Attempts to finalize the type. Returns its TMR on success.
278 pub fn finalize(&self) -> Result<Arc<Final>, Error> {
279 use context::OccursCheckId;
280
281 /// Helper type for the occurs-check.
282 enum OccursCheckStack {
283 Iterate(BoundRef),
284 Complete(OccursCheckId),
285 }
286
287 // Done with sharing tracker. Actual algorithm follows.
288 let root = self.inner.bound.root();
289 let bound = self.ctx.get(&root);
290 if let Bound::Complete(ref data) = bound {
291 return Ok(Arc::clone(data));
292 }
293
294 // First, do occurs-check to ensure that we have no infinitely sized types.
295 let mut stack = vec![OccursCheckStack::Iterate(root)];
296 let mut in_progress = HashSet::new();
297 let mut completed = HashSet::new();
298 while let Some(top) = stack.pop() {
299 let bound = match top {
300 OccursCheckStack::Complete(id) => {
301 in_progress.remove(&id);
302 completed.insert(id);
303 continue;
304 }
305 OccursCheckStack::Iterate(b) => b,
306 };
307
308 let id = bound.occurs_check_id();
309 if completed.contains(&id) {
310 // Once we have iterated through a type, we don't need to check it again.
311 // Without this shortcut the occurs-check would take exponential time.
312 continue;
313 }
314 if !in_progress.insert(id) {
315 return Err(Error::OccursCheck {
316 infinite_bound: Type::wrap_bound(&self.ctx, bound),
317 });
318 }
319
320 stack.push(OccursCheckStack::Complete(id));
321 if let Some((_, child)) = (&self.ctx, bound.shallow_clone()).right_child() {
322 stack.push(OccursCheckStack::Iterate(child));
323 }
324 if let Some((_, child)) = (&self.ctx, bound).left_child() {
325 stack.push(OccursCheckStack::Iterate(child));
326 }
327 }
328
329 // Now that we know our types have finite size, we can safely use a
330 // post-order iterator to finalize them.
331 let mut finalized = vec![];
332 for data in (&self.ctx, self.inner.bound.root()).post_order_iter::<NoSharing>() {
333 let bound_get = data.node.0.get(&data.node.1);
334 let final_data = match bound_get {
335 Bound::Free(_) => Final::unit(),
336 Bound::Complete(ref arc) => Arc::clone(arc),
337 Bound::Sum(..) => Final::sum(
338 Arc::clone(&finalized[data.left_index.unwrap()]),
339 Arc::clone(&finalized[data.right_index.unwrap()]),
340 ),
341 Bound::Product(..) => Final::product(
342 Arc::clone(&finalized[data.left_index.unwrap()]),
343 Arc::clone(&finalized[data.right_index.unwrap()]),
344 ),
345 };
346
347 if !matches!(bound_get, Bound::Complete(..)) {
348 self.ctx
349 .reassign_non_complete(data.node.1, Bound::Complete(Arc::clone(&final_data)));
350 }
351 finalized.push(final_data);
352 }
353 Ok(finalized.pop().unwrap())
354 }
355
356 /// Return a vector containing the types 2^(2^i) for i from 0 to n-1.
357 pub fn powers_of_two(ctx: &Context, n: usize) -> Vec<Self> {
358 let mut ret = Vec::with_capacity(n);
359
360 let unit = Type::unit(ctx);
361 let mut two = Type::sum(ctx, unit.shallow_clone(), unit);
362 for _ in 0..n {
363 ret.push(two.shallow_clone());
364 two = Type::product(ctx, two.shallow_clone(), two);
365 }
366 ret
367 }
368}
369
370const MAX_DISPLAY_DEPTH: usize = 64;
371const MAX_DISPLAY_LENGTH: usize = 10000;
372
373impl fmt::Debug for Type {
374 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
375 for data in (&self.ctx, self.inner.bound.root())
376 .verbose_pre_order_iter::<NoSharing>(Some(MAX_DISPLAY_DEPTH))
377 {
378 if data.index > MAX_DISPLAY_LENGTH {
379 write!(f, "... [truncated type after {} nodes]", MAX_DISPLAY_LENGTH)?;
380 return Ok(());
381 }
382 if data.depth == MAX_DISPLAY_DEPTH {
383 if data.n_children_yielded == 0 {
384 f.write_str("...")?;
385 }
386 continue;
387 }
388 let bound = data.node.0.get(&data.node.1);
389 match (bound, data.n_children_yielded) {
390 (Bound::Free(ref s), _) => f.write_str(s)?,
391 (Bound::Complete(ref comp), _) => fmt::Debug::fmt(comp, f)?,
392 (Bound::Sum(..), 0) | (Bound::Product(..), 0) => {
393 if data.index > 0 {
394 f.write_str("(")?;
395 }
396 }
397 (Bound::Sum(..), 2) | (Bound::Product(..), 2) => {
398 if data.index > 0 {
399 f.write_str(")")?
400 }
401 }
402 (Bound::Sum(..), _) => f.write_str(" + ")?,
403 (Bound::Product(..), _) => f.write_str(" × ")?,
404 }
405 }
406 Ok(())
407 }
408}
409
410impl fmt::Display for Type {
411 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
412 for data in (&self.ctx, self.inner.bound.root())
413 .verbose_pre_order_iter::<NoSharing>(Some(MAX_DISPLAY_DEPTH))
414 {
415 if data.index > MAX_DISPLAY_LENGTH {
416 write!(f, "... [truncated type after {} nodes]", MAX_DISPLAY_LENGTH)?;
417 return Ok(());
418 }
419 if data.depth == MAX_DISPLAY_DEPTH {
420 if data.n_children_yielded == 0 {
421 f.write_str("...")?;
422 }
423 continue;
424 }
425 let bound = data.node.0.get(&data.node.1);
426 match (bound, data.n_children_yielded) {
427 (Bound::Free(ref s), _) => f.write_str(s)?,
428 (Bound::Complete(ref comp), _) => fmt::Display::fmt(comp, f)?,
429 (Bound::Sum(..), 0) | (Bound::Product(..), 0) => {
430 if data.index > 0 {
431 f.write_str("(")?;
432 }
433 }
434 (Bound::Sum(..), 2) | (Bound::Product(..), 2) => {
435 if data.index > 0 {
436 f.write_str(")")?
437 }
438 }
439 (Bound::Sum(..), _) => f.write_str(" + ")?,
440 (Bound::Product(..), _) => f.write_str(" × ")?,
441 }
442 }
443 Ok(())
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 use crate::jet::Core;
452 use crate::node::{ConstructNode, CoreConstructible};
453
454 #[test]
455 fn inference_failure() {
456 let ctx = Context::new();
457
458 // unit: A -> 1
459 let unit = Arc::<ConstructNode<Core>>::unit(&ctx); // 1 -> 1
460
461 // Force unit to be 1->1
462 Arc::<ConstructNode<Core>>::comp(&unit, &unit).unwrap();
463
464 // take unit: 1 * B -> 1
465 let take_unit = Arc::<ConstructNode<Core>>::take(&unit); // 1*1 -> 1
466
467 // Pair will try to unify 1 and 1*B
468 Arc::<ConstructNode<Core>>::pair(&unit, &take_unit).unwrap_err();
469 // Trying to do it again should not work.
470 Arc::<ConstructNode<Core>>::pair(&unit, &take_unit).unwrap_err();
471 }
472
473 #[test]
474 fn memory_leak() {
475 let ctx = Context::new();
476 let iden = Arc::<ConstructNode<Core>>::iden(&ctx);
477 let drop = Arc::<ConstructNode<Core>>::drop_(&iden);
478 let case = Arc::<ConstructNode<Core>>::case(&iden, &drop).unwrap();
479
480 let _ = format!("{:?}", case.arrow().source);
481 }
482}