simplicity/types/
context.rs

1// SPDX-License-Identifier: CC0-1.0
2
3//! Type Inference Context
4//!
5//! When constructing a Simplicity program, you must first create a type inference
6//! context, in which type inference occurs incrementally during construction. Each
7//! leaf node (e.g. `unit` and `iden`) must explicitly refer to the type inference
8//! context, while combinator nodes (e.g. `comp`) infer the context from their
9//! children, raising an error if there are multiple children whose contexts don't
10//! match.
11//!
12//! This helps to prevent situations in which users attempt to construct multiple
13//! independent programs, but types in one program accidentally refer to types in
14//! the other.
15//!
16
17use std::fmt;
18use std::sync::{Arc, Mutex, MutexGuard};
19
20use crate::dag::{Dag, DagLike};
21
22use super::{Bound, CompleteBound, Error, Final, Type, TypeInner};
23
24/// Type inference context, or handle to a context.
25///
26/// Can be cheaply cloned with [`Context::shallow_clone`]. These clones will
27/// refer to the same underlying type inference context, and can be used as
28/// handles to each other. The derived [`Context::clone`] has the same effect.
29///
30/// There is currently no way to create an independent context with the same
31/// type inference variables (i.e. a deep clone). If you need this functionality,
32/// please file an issue.
33#[derive(Clone, Default)]
34pub struct Context {
35    slab: Arc<Mutex<Vec<Bound>>>,
36}
37
38impl fmt::Debug for Context {
39    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
40        let id = Arc::as_ptr(&self.slab) as usize;
41        write!(f, "inference_ctx_{:08x}", id)
42    }
43}
44
45impl PartialEq for Context {
46    fn eq(&self, other: &Self) -> bool {
47        Arc::ptr_eq(&self.slab, &other.slab)
48    }
49}
50impl Eq for Context {}
51
52impl Context {
53    /// Creates a new empty type inference context.
54    pub fn new() -> Self {
55        Context {
56            slab: Arc::new(Mutex::new(vec![])),
57        }
58    }
59
60    /// Helper function to allocate a bound and return a reference to it.
61    fn alloc_bound(&self, bound: Bound) -> BoundRef {
62        let mut lock = self.lock();
63        lock.alloc_bound(bound)
64    }
65
66    /// Allocate a new free type bound, and return a reference to it.
67    pub fn alloc_free(&self, name: String) -> BoundRef {
68        self.alloc_bound(Bound::Free(name))
69    }
70
71    /// Allocate a new unit type bound, and return a reference to it.
72    pub fn alloc_unit(&self) -> BoundRef {
73        self.alloc_bound(Bound::Complete(Final::unit()))
74    }
75
76    /// Allocate a new unit type bound, and return a reference to it.
77    pub fn alloc_complete(&self, data: Arc<Final>) -> BoundRef {
78        self.alloc_bound(Bound::Complete(data))
79    }
80
81    /// Allocate a new sum-type bound, and return a reference to it.
82    ///
83    /// # Panics
84    ///
85    /// Panics if either of the child types are from a different inference context.
86    pub fn alloc_sum(&self, left: Type, right: Type) -> BoundRef {
87        assert_eq!(
88            left.ctx, *self,
89            "left type did not match inference context of sum"
90        );
91        assert_eq!(
92            right.ctx, *self,
93            "right type did not match inference context of sum"
94        );
95
96        let mut lock = self.lock();
97        if let Some((data1, data2)) = lock.complete_pair_data(&left.inner, &right.inner) {
98            lock.alloc_bound(Bound::Complete(Final::sum(data1, data2)))
99        } else {
100            lock.alloc_bound(Bound::Sum(left.inner, right.inner))
101        }
102    }
103
104    /// Allocate a new product-type bound, and return a reference to it.
105    ///
106    /// # Panics
107    ///
108    /// Panics if either of the child types are from a different inference context.
109    pub fn alloc_product(&self, left: Type, right: Type) -> BoundRef {
110        assert_eq!(
111            left.ctx, *self,
112            "left type did not match inference context of product"
113        );
114        assert_eq!(
115            right.ctx, *self,
116            "right type did not match inference context of product"
117        );
118
119        let mut lock = self.lock();
120        if let Some((data1, data2)) = lock.complete_pair_data(&left.inner, &right.inner) {
121            lock.alloc_bound(Bound::Complete(Final::product(data1, data2)))
122        } else {
123            lock.alloc_bound(Bound::Product(left.inner, right.inner))
124        }
125    }
126
127    /// Creates a new handle to the context.
128    ///
129    /// This handle holds a reference to the underlying context and will keep
130    /// it alive. The context will only be dropped once all handles, including
131    /// the original context object, are dropped.
132    pub fn shallow_clone(&self) -> Self {
133        Self {
134            slab: Arc::clone(&self.slab),
135        }
136    }
137
138    /// Checks whether two inference contexts are equal, and returns an error if not.
139    pub fn check_eq(&self, other: &Self) -> Result<(), super::Error> {
140        if self == other {
141            Ok(())
142        } else {
143            Err(super::Error::InferenceContextMismatch)
144        }
145    }
146
147    /// Accesses a bound.
148    ///
149    /// # Panics
150    ///
151    /// Panics if passed a `BoundRef` that was not allocated by this context.
152    pub(super) fn get(&self, bound: &BoundRef) -> Bound {
153        bound.assert_matches_context(self);
154        let lock = self.lock();
155        lock.slab[bound.index].shallow_clone()
156    }
157
158    /// Reassigns a bound to a different bound.
159    ///
160    /// # Panics
161    ///
162    /// Panics if called on a complete type. This is a sanity-check to avoid
163    /// replacing already-completed types, which can cause inefficiencies in
164    /// the union-bound algorithm (and if our replacement changes the type,
165    /// this is probably a bug.
166    ///
167    /// Also panics if passed a `BoundRef` that was not allocated by this context.
168    pub(super) fn reassign_non_complete(&self, bound: BoundRef, new: Bound) {
169        let mut lock = self.lock();
170        lock.reassign_non_complete(bound, new);
171    }
172
173    /// Binds the type to a product bound formed by the two inner types. If this
174    /// fails, attach the provided hint to the error.
175    ///
176    /// Fails if the type has an existing incompatible bound.
177    ///
178    /// # Panics
179    ///
180    /// Panics if any of the three types passed in were allocated from a different
181    /// context than this one.
182    pub fn bind_product(
183        &self,
184        existing: &Type,
185        prod_l: &Type,
186        prod_r: &Type,
187        hint: &'static str,
188    ) -> Result<(), Error> {
189        assert_eq!(
190            existing.ctx, *self,
191            "attempted to bind existing type with wrong context",
192        );
193        assert_eq!(
194            prod_l.ctx, *self,
195            "attempted to bind product whose left type had wrong context",
196        );
197        assert_eq!(
198            prod_r.ctx, *self,
199            "attempted to bind product whose right type had wrong context",
200        );
201
202        let existing_root = existing.inner.bound.root();
203        let new_bound = Bound::Product(prod_l.inner.shallow_clone(), prod_r.inner.shallow_clone());
204
205        let mut lock = self.lock();
206        lock.bind(existing_root, new_bound).map_err(|e| {
207            let new_bound = lock.alloc_bound(e.new);
208            Error::Bind {
209                existing_bound: Type::wrap_bound(self, e.existing),
210                new_bound: Type::wrap_bound(self, new_bound),
211                hint,
212            }
213        })
214    }
215
216    /// Unify the type with another one.
217    ///
218    /// Fails if the bounds on the two types are incompatible
219    pub fn unify(&self, ty1: &Type, ty2: &Type, hint: &'static str) -> Result<(), Error> {
220        assert_eq!(ty1.ctx, *self);
221        assert_eq!(ty2.ctx, *self);
222        let mut lock = self.lock();
223        lock.unify(&ty1.inner, &ty2.inner).map_err(|e| {
224            let new_bound = lock.alloc_bound(e.new);
225            Error::Bind {
226                existing_bound: Type::wrap_bound(self, e.existing),
227                new_bound: Type::wrap_bound(self, new_bound),
228                hint,
229            }
230        })
231    }
232
233    /// Locks the underlying slab mutex.
234    fn lock(&self) -> LockedContext {
235        LockedContext {
236            context: Arc::as_ptr(&self.slab),
237            slab: self.slab.lock().unwrap(),
238        }
239    }
240}
241
242#[derive(Debug, Clone)]
243pub struct BoundRef {
244    context: *const Mutex<Vec<Bound>>,
245    index: usize,
246}
247
248// SAFETY: The pointer inside `BoundRef` is always (eventually) constructed from Arc::as_ptr
249// from the slab of a type-inference context.
250//
251// Arc will prevent the pointer from ever changing, except to be deallocated when the last
252// Arc goes away. But this occurs only when the context itself goes away, which in turn
253// happens only when every type bound referring to the context goes away.
254//
255// If this were untrue, our use of `BoundRef` would lead to dereferences of a dangling
256// pointer, and `Send`/`Sync` would be the least of our concerns!
257unsafe impl Send for BoundRef {}
258// SAFETY: see comment on `Send`
259unsafe impl Sync for BoundRef {}
260
261impl BoundRef {
262    pub fn assert_matches_context(&self, ctx: &Context) {
263        assert_eq!(
264            self.context,
265            Arc::as_ptr(&ctx.slab),
266            "bound was accessed from a type inference context that did not create it",
267        );
268    }
269
270    /// Creates an "occurs-check ID" which is just a copy of the [`BoundRef`]
271    /// with `PartialEq` and `Eq` implemented in terms of underlying pointer
272    /// equality.
273    pub fn occurs_check_id(&self) -> OccursCheckId {
274        OccursCheckId {
275            context: self.context,
276            index: self.index,
277        }
278    }
279}
280
281impl super::PointerLike for BoundRef {
282    fn ptr_eq(&self, other: &Self) -> bool {
283        debug_assert_eq!(
284            self.context, other.context,
285            "tried to compare two bounds from different inference contexts"
286        );
287        self.index == other.index
288    }
289
290    fn shallow_clone(&self) -> Self {
291        BoundRef {
292            context: self.context,
293            index: self.index,
294        }
295    }
296}
297
298impl DagLike for (&'_ Context, BoundRef) {
299    type Node = BoundRef;
300    fn data(&self) -> &BoundRef {
301        &self.1
302    }
303
304    fn as_dag_node(&self) -> Dag<Self> {
305        match self.0.get(&self.1) {
306            Bound::Free(..) | Bound::Complete(..) => Dag::Nullary,
307            Bound::Sum(ref ty1, ref ty2) | Bound::Product(ref ty1, ref ty2) => {
308                Dag::Binary((self.0, ty1.bound.root()), (self.0, ty2.bound.root()))
309            }
310        }
311    }
312}
313
314#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
315pub struct OccursCheckId {
316    context: *const Mutex<Vec<Bound>>,
317    index: usize,
318}
319
320struct BindError {
321    existing: BoundRef,
322    new: Bound,
323}
324
325/// Structure representing an inference context with its slab allocator mutex locked.
326///
327/// This type is never exposed outside of this module and should only exist
328/// ephemerally within function calls into this module.
329struct LockedContext<'ctx> {
330    context: *const Mutex<Vec<Bound>>,
331    slab: MutexGuard<'ctx, Vec<Bound>>,
332}
333
334impl LockedContext<'_> {
335    fn alloc_bound(&mut self, bound: Bound) -> BoundRef {
336        self.slab.push(bound);
337        let index = self.slab.len() - 1;
338
339        BoundRef {
340            context: self.context,
341            index,
342        }
343    }
344
345    fn reassign_non_complete(&mut self, bound: BoundRef, new: Bound) {
346        assert!(
347            !matches!(self.slab[bound.index], Bound::Complete(..)),
348            "tried to modify finalized type",
349        );
350        self.slab[bound.index] = new;
351    }
352
353    /// It is a common situation that we are pairing two types, and in the
354    /// case that they are both complete, we want to pair the complete types.
355    ///
356    /// This method deals with all the annoying/complicated member variable
357    /// paths to get the actual complete data out.
358    fn complete_pair_data(
359        &self,
360        inn1: &TypeInner,
361        inn2: &TypeInner,
362    ) -> Option<(Arc<Final>, Arc<Final>)> {
363        let bound1 = &self.slab[inn1.bound.root().index];
364        let bound2 = &self.slab[inn2.bound.root().index];
365        if let (Bound::Complete(ref data1), Bound::Complete(ref data2)) = (bound1, bound2) {
366            Some((Arc::clone(data1), Arc::clone(data2)))
367        } else {
368            None
369        }
370    }
371
372    /// Unify the type with another one.
373    ///
374    /// Fails if the bounds on the two types are incompatible
375    fn unify(&mut self, existing: &TypeInner, other: &TypeInner) -> Result<(), BindError> {
376        existing.bound.unify(&other.bound, |x_bound, y_bound| {
377            self.bind(x_bound, self.slab[y_bound.index].shallow_clone())
378        })
379    }
380
381    fn bind(&mut self, existing: BoundRef, new: Bound) -> Result<(), BindError> {
382        let existing_bound = self.slab[existing.index].shallow_clone();
383        let bind_error = || BindError {
384            existing: existing.clone(),
385            new: new.shallow_clone(),
386        };
387
388        match (&existing_bound, &new) {
389            // Binding a free type to anything is a no-op
390            (_, Bound::Free(_)) => Ok(()),
391            // Free types are simply dropped and replaced by the new bound
392            (Bound::Free(_), _) => {
393                // Free means non-finalized, so set() is ok.
394                self.reassign_non_complete(existing, new);
395                Ok(())
396            }
397            // Binding complete->complete shouldn't ever happen, but if so, we just
398            // compare the two types and return a pass/fail
399            (Bound::Complete(ref existing_final), Bound::Complete(ref new_final)) => {
400                if existing_final == new_final {
401                    Ok(())
402                } else {
403                    Err(bind_error())
404                }
405            }
406            // Binding an incomplete to a complete type requires recursion.
407            (Bound::Complete(complete), incomplete) | (incomplete, Bound::Complete(complete)) => {
408                match (complete.bound(), incomplete) {
409                    // A unit might match a Bound::Free(..) or a Bound::Complete(..),
410                    // and both cases were handled above. So this is an error.
411                    (CompleteBound::Unit, _) => Err(bind_error()),
412                    (
413                        CompleteBound::Product(ref comp1, ref comp2),
414                        Bound::Product(ref ty1, ref ty2),
415                    )
416                    | (CompleteBound::Sum(ref comp1, ref comp2), Bound::Sum(ref ty1, ref ty2)) => {
417                        let bound1 = ty1.bound.root();
418                        let bound2 = ty2.bound.root();
419                        self.bind(bound1, Bound::Complete(Arc::clone(comp1)))?;
420                        self.bind(bound2, Bound::Complete(Arc::clone(comp2)))
421                    }
422                    _ => Err(bind_error()),
423                }
424            }
425            (Bound::Sum(ref x1, ref x2), Bound::Sum(ref y1, ref y2))
426            | (Bound::Product(ref x1, ref x2), Bound::Product(ref y1, ref y2)) => {
427                self.unify(x1, y1)?;
428                self.unify(x2, y2)?;
429                // This type was not complete, but it may be after unification, giving us
430                // an opportunity to finaliize it. We do this eagerly to make sure that
431                // "complete" (no free children) is always equivalent to "finalized" (the
432                // bound field having variant Bound::Complete(..)), even during inference.
433                //
434                // It also gives the user access to more information about the type,
435                // prior to finalization.
436                if let Some((data1, data2)) = self.complete_pair_data(y1, y2) {
437                    self.reassign_non_complete(
438                        existing,
439                        Bound::Complete(if let Bound::Sum(..) = existing_bound {
440                            Final::sum(data1, data2)
441                        } else {
442                            Final::product(data1, data2)
443                        }),
444                    );
445                }
446                Ok(())
447            }
448            (_, _) => Err(bind_error()),
449        }
450    }
451}