Skip to main content

simplicity/types/
context.rs

1// SPDX-License-Identifier: CC0-1.0
2
3//! Type Inference Context
4//!
5//! When constructing a Simplicity program, you must first create a type inference
6//! context, in which type inference occurs incrementally during construction. Each
7//! leaf node (e.g. `unit` and `iden`) must explicitly refer to the type inference
8//! context, while combinator nodes (e.g. `comp`) infer the context from their
9//! children, raising an error if there are multiple children whose contexts don't
10//! match.
11//!
12//! This helps to prevent situations in which users attempt to construct multiple
13//! independent programs, but types in one program accidentally refer to types in
14//! the other.
15//!
16
17use std::any::TypeId;
18use std::fmt;
19use std::marker::PhantomData;
20use std::sync::{Arc, Mutex, MutexGuard};
21
22use ghost_cell::GhostToken;
23
24use crate::dag::{Dag, DagLike};
25use crate::jet::Jet;
26
27use super::{
28    Bound, CompleteBound, Error, Final, Incomplete, Type, TypeInner, UbElement, WithGhostToken,
29};
30
31// Copied from ghost_cell source. See
32//     https://arhan.sh/blog/the-generativity-pattern-in-rust/
33// in particular the box labeled "Throughout the years lifetime invariance has
34// been achieved in several other ways." for some detail about this.
35type InvariantLifetime<'brand> = PhantomData<fn(&'brand ()) -> &'brand ()>;
36
37/// Type inference context, or handle to a context.
38///
39/// Can be cheaply cloned with [`Context::shallow_clone`]. These clones will
40/// refer to the same underlying type inference context, and can be used as
41/// handles to each other. The derived [`Context::clone`] has the same effect.
42///
43/// There is currently no way to create an independent context with the same
44/// type inference variables (i.e. a deep clone). If you need this functionality,
45/// please file an issue.
46#[derive(Clone)]
47pub struct Context<'brand> {
48    inner: Arc<Mutex<WithGhostToken<'brand, ContextInner<'brand>>>>,
49}
50
51struct ContextInner<'brand> {
52    slab: Vec<Bound<'brand>>,
53    /// Concrete jet type registered in this context, if any.
54    jet_type: Option<TypeId>,
55}
56
57impl fmt::Debug for Context<'_> {
58    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
59        let id = Arc::as_ptr(&self.inner) as usize;
60        write!(f, "inference_ctx_{:08x}", id)
61    }
62}
63
64impl PartialEq for Context<'_> {
65    fn eq(&self, other: &Self) -> bool {
66        Arc::ptr_eq(&self.inner, &other.inner)
67    }
68}
69impl Eq for Context<'_> {}
70
71impl<'brand> Context<'brand> {
72    /// Creates a scope with a new empty type inference context.
73    pub fn with_context<R, F>(fun: F) -> R
74    where
75        F: for<'new_brand> FnOnce(Context<'new_brand>) -> R,
76    {
77        GhostToken::new(|token| {
78            let ctx = Context::new(token);
79            fun(ctx)
80        })
81    }
82
83    /// Creates a new empty type inference context.
84    pub fn new(token: GhostToken<'brand>) -> Self {
85        Context {
86            inner: Arc::new(Mutex::new(WithGhostToken {
87                token,
88                inner: ContextInner {
89                    slab: vec![],
90                    jet_type: None,
91                },
92            })),
93        }
94    }
95
96    /// Helper function to allocate a bound and return a reference to it.
97    fn alloc_bound(&self, bound: Bound<'brand>) -> BoundRef<'brand> {
98        let mut lock = self.lock();
99        lock.alloc_bound(bound)
100    }
101
102    /// Allocate a new free type bound, and return a reference to it.
103    pub fn alloc_free(&self, name: String) -> BoundRef<'brand> {
104        self.alloc_bound(Bound::Free(name))
105    }
106
107    /// Allocate a new unit type bound, and return a reference to it.
108    pub fn alloc_unit(&self) -> BoundRef<'brand> {
109        self.alloc_bound(Bound::Complete(Final::unit()))
110    }
111
112    /// Allocate a new unit type bound, and return a reference to it.
113    pub fn alloc_complete(&self, data: Arc<Final>) -> BoundRef<'brand> {
114        self.alloc_bound(Bound::Complete(data))
115    }
116
117    /// Allocate a new sum-type bound, and return a reference to it.
118    pub fn alloc_sum(&self, left: Type<'brand>, right: Type<'brand>) -> BoundRef<'brand> {
119        let mut lock = self.lock();
120        if let Some((data1, data2)) = lock.complete_pair_data(&left.inner, &right.inner) {
121            lock.alloc_bound(Bound::Complete(Final::sum(data1, data2)))
122        } else {
123            lock.alloc_bound(Bound::Sum(left.inner, right.inner))
124        }
125    }
126
127    /// Allocate a new product-type bound, and return a reference to it.
128    pub fn alloc_product(&self, left: Type<'brand>, right: Type<'brand>) -> BoundRef<'brand> {
129        let mut lock = self.lock();
130        if let Some((data1, data2)) = lock.complete_pair_data(&left.inner, &right.inner) {
131            lock.alloc_bound(Bound::Complete(Final::product(data1, data2)))
132        } else {
133            lock.alloc_bound(Bound::Product(left.inner, right.inner))
134        }
135    }
136
137    /// Creates a new handle to the context.
138    ///
139    /// This handle holds a reference to the underlying context and will keep
140    /// it alive. The context will only be dropped once all handles, including
141    /// the original context object, are dropped.
142    pub fn shallow_clone(&self) -> Self {
143        Self {
144            inner: Arc::clone(&self.inner),
145        }
146    }
147
148    /// Checks whether two inference contexts are equal, and returns an error if not.
149    pub fn check_eq(&self, other: &Self) -> Result<(), super::Error> {
150        if self == other {
151            Ok(())
152        } else {
153            Err(super::Error::InferenceContextMismatch)
154        }
155    }
156
157    /// Asserts that all jets in this context have the same concrete type.
158    ///
159    /// Records the jet's type on first call, panics on subsequent calls with
160    /// a different concrete type.
161    pub fn check_jet(&self, jet: &dyn Jet) {
162        let new_id = jet.as_any().type_id();
163        let mut lock = self.lock();
164
165        if let Some(existing_id) = lock.inner.jet_type {
166            assert!(existing_id == new_id, "mixed jet types in context");
167
168            return;
169        }
170
171        lock.inner.jet_type = Some(new_id);
172    }
173
174    /// Accesses a bound.
175    pub(super) fn get(&self, bound: &BoundRef<'brand>) -> Bound<'brand> {
176        let lock = self.lock();
177        lock.inner.slab[bound.index].shallow_clone()
178    }
179
180    /// Accesses a bound through a union-bound element.
181    pub(super) fn get_root_ref(
182        &self,
183        bound: &UbElement<'brand, BoundRef<'brand>>,
184    ) -> BoundRef<'brand> {
185        let mut lock = self.lock();
186        bound.root(&mut lock.token)
187    }
188
189    /// Reassigns a bound to a different bound.
190    ///
191    /// # Panics
192    ///
193    /// Panics if called on a complete type. This is a sanity-check to avoid
194    /// replacing already-completed types, which can cause inefficiencies in
195    /// the union-bound algorithm (and if our replacement changes the type,
196    /// this is probably a bug.
197    ///
198    /// Also panics if passed a `BoundRef` that was not allocated by this context.
199    pub(super) fn reassign_non_complete(&self, bound: BoundRef<'brand>, new: Bound<'brand>) {
200        let mut lock = self.lock();
201        lock.reassign_non_complete(bound, new);
202    }
203
204    /// Binds the type to a product bound formed by the two inner types. If this
205    /// fails, attach the provided hint to the error.
206    ///
207    /// Fails if the type has an existing incompatible bound.
208    pub fn bind_product(
209        &self,
210        existing: &Type<'brand>,
211        prod_l: &Type<'brand>,
212        prod_r: &Type<'brand>,
213        hint: &'static str,
214    ) -> Result<(), Error> {
215        let mut lock = self.lock();
216        let existing_root = existing.inner.bound.root(&mut lock.token);
217        let new_bound = Bound::Product(prod_l.inner.shallow_clone(), prod_r.inner.shallow_clone());
218
219        lock.bind(existing_root, new_bound).map_err(|e| {
220            let new_bound = lock.alloc_bound(e.new);
221            drop(lock);
222            Error::Bind {
223                existing_bound: Incomplete::from_bound_ref(self, e.existing),
224                new_bound: Incomplete::from_bound_ref(self, new_bound),
225                hint,
226            }
227        })
228    }
229
230    /// Unify the type with another one.
231    ///
232    /// Fails if the bounds on the two types are incompatible
233    pub fn unify(
234        &self,
235        ty1: &Type<'brand>,
236        ty2: &Type<'brand>,
237        hint: &'static str,
238    ) -> Result<(), Error> {
239        let mut lock = self.lock();
240        lock.unify(&ty1.inner, &ty2.inner).map_err(|e| {
241            let new_bound = lock.alloc_bound(e.new);
242            drop(lock);
243            Error::Bind {
244                existing_bound: Incomplete::from_bound_ref(self, e.existing),
245                new_bound: Incomplete::from_bound_ref(self, new_bound),
246                hint,
247            }
248        })
249    }
250
251    /// Locks the underlying slab mutex.
252    fn lock(&self) -> MutexGuard<'_, WithGhostToken<'brand, ContextInner<'brand>>> {
253        self.inner.lock().unwrap()
254    }
255}
256
257#[derive(Debug, Clone)]
258pub struct BoundRef<'brand> {
259    phantom: InvariantLifetime<'brand>,
260    index: usize,
261}
262
263impl<'brand> BoundRef<'brand> {
264    /// Creates an "occurs-check ID" which is just a copy of the [`BoundRef`]
265    /// with `PartialEq` and `Eq` implemented in terms of underlying pointer
266    /// equality.
267    pub fn occurs_check_id(&self) -> OccursCheckId<'brand> {
268        OccursCheckId {
269            phantom: InvariantLifetime::default(),
270            index: self.index,
271        }
272    }
273}
274
275impl super::PointerLike for BoundRef<'_> {
276    fn ptr_eq(&self, other: &Self) -> bool {
277        self.index == other.index
278    }
279
280    fn shallow_clone(&self) -> Self {
281        BoundRef {
282            phantom: InvariantLifetime::default(),
283            index: self.index,
284        }
285    }
286}
287
288impl<'brand> DagLike for (&'_ Context<'brand>, BoundRef<'brand>) {
289    type Node = BoundRef<'brand>;
290    fn data(&self) -> &BoundRef<'brand> {
291        &self.1
292    }
293
294    fn as_dag_node(&self) -> Dag<Self> {
295        match self.0.get(&self.1) {
296            Bound::Free(..) | Bound::Complete(..) => Dag::Nullary,
297            Bound::Sum(ref ty1, ref ty2) | Bound::Product(ref ty1, ref ty2) => {
298                let root1 = self.0.get_root_ref(&ty1.bound);
299                let root2 = self.0.get_root_ref(&ty2.bound);
300                Dag::Binary((self.0, root1), (self.0, root2))
301            }
302        }
303    }
304}
305
306#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
307pub struct OccursCheckId<'brand> {
308    phantom: InvariantLifetime<'brand>,
309    index: usize,
310}
311
312struct BindError<'brand> {
313    existing: BoundRef<'brand>,
314    new: Bound<'brand>,
315}
316
317impl<'brand> ContextInner<'brand> {
318    fn alloc_bound(&mut self, bound: Bound<'brand>) -> BoundRef<'brand> {
319        self.slab.push(bound);
320        let index = self.slab.len() - 1;
321
322        BoundRef {
323            phantom: InvariantLifetime::default(),
324            index,
325        }
326    }
327
328    fn reassign_non_complete(&mut self, bound: BoundRef<'brand>, new: Bound<'brand>) {
329        assert!(
330            !matches!(self.slab[bound.index], Bound::Complete(..)),
331            "tried to modify finalized type",
332        );
333        self.slab[bound.index] = new;
334    }
335}
336
337impl<'brand> WithGhostToken<'brand, ContextInner<'brand>> {
338    /// It is a common situation that we are pairing two types, and in the
339    /// case that they are both complete, we want to pair the complete types.
340    ///
341    /// This method deals with all the annoying/complicated member variable
342    /// paths to get the actual complete data out.
343    fn complete_pair_data(
344        &mut self,
345        inn1: &TypeInner<'brand>,
346        inn2: &TypeInner<'brand>,
347    ) -> Option<(Arc<Final>, Arc<Final>)> {
348        let idx1 = inn1.bound.root(&mut self.token).index;
349        let idx2 = inn2.bound.root(&mut self.token).index;
350        let bound1 = &self.slab[idx1];
351        let bound2 = &self.slab[idx2];
352        if let (Bound::Complete(ref data1), Bound::Complete(ref data2)) = (bound1, bound2) {
353            Some((Arc::clone(data1), Arc::clone(data2)))
354        } else {
355            None
356        }
357    }
358    /// Unify the type with another one.
359    ///
360    /// Fails if the bounds on the two types are incompatible
361    fn unify(
362        &mut self,
363        existing: &TypeInner<'brand>,
364        other: &TypeInner<'brand>,
365    ) -> Result<(), BindError<'brand>> {
366        existing
367            .bound
368            .unify(self, &other.bound, |self_, x_bound, y_bound| {
369                self_.bind(x_bound, self_.slab[y_bound.index].shallow_clone())
370            })
371    }
372
373    fn bind(
374        &mut self,
375        existing: BoundRef<'brand>,
376        new: Bound<'brand>,
377    ) -> Result<(), BindError<'brand>> {
378        let existing_bound = self.slab[existing.index].shallow_clone();
379        let bind_error = || BindError {
380            existing: existing.clone(),
381            new: new.shallow_clone(),
382        };
383
384        match (&existing_bound, &new) {
385            // Binding a free type to anything is a no-op
386            (_, Bound::Free(_)) => Ok(()),
387            // Free types are simply dropped and replaced by the new bound
388            (Bound::Free(_), _) => {
389                // Free means non-finalized, so set() is ok.
390                self.reassign_non_complete(existing, new);
391                Ok(())
392            }
393            // Binding complete->complete shouldn't ever happen, but if so, we just
394            // compare the two types and return a pass/fail
395            (Bound::Complete(ref existing_final), Bound::Complete(ref new_final)) => {
396                if existing_final == new_final {
397                    Ok(())
398                } else {
399                    Err(bind_error())
400                }
401            }
402            // Binding an incomplete to a complete type requires recursion.
403            (Bound::Complete(complete), incomplete) | (incomplete, Bound::Complete(complete)) => {
404                match (complete.bound(), incomplete) {
405                    // A unit might match a Bound::Free(..) or a Bound::Complete(..),
406                    // and both cases were handled above. So this is an error.
407                    (CompleteBound::Unit, _) => Err(bind_error()),
408                    (
409                        CompleteBound::Product(ref comp1, ref comp2),
410                        Bound::Product(ref ty1, ref ty2),
411                    )
412                    | (CompleteBound::Sum(ref comp1, ref comp2), Bound::Sum(ref ty1, ref ty2)) => {
413                        let bound1 = ty1.bound.root(&mut self.token);
414                        let bound2 = ty2.bound.root(&mut self.token);
415                        self.bind(bound1, Bound::Complete(Arc::clone(comp1)))?;
416                        self.bind(bound2, Bound::Complete(Arc::clone(comp2)))
417                    }
418                    _ => Err(bind_error()),
419                }
420            }
421            (Bound::Sum(ref x1, ref x2), Bound::Sum(ref y1, ref y2))
422            | (Bound::Product(ref x1, ref x2), Bound::Product(ref y1, ref y2)) => {
423                self.unify(x1, y1)?;
424                self.unify(x2, y2)?;
425                // This type was not complete, but it may be after unification, giving us
426                // an opportunity to finaliize it. We do this eagerly to make sure that
427                // "complete" (no free children) is always equivalent to "finalized" (the
428                // bound field having variant Bound::Complete(..)), even during inference.
429                //
430                // It also gives the user access to more information about the type,
431                // prior to finalization.
432                if let Some((data1, data2)) = self.complete_pair_data(y1, y2) {
433                    self.reassign_non_complete(
434                        existing,
435                        Bound::Complete(if let Bound::Sum(..) = existing_bound {
436                            Final::sum(data1, data2)
437                        } else {
438                            Final::product(data1, data2)
439                        }),
440                    );
441                }
442                Ok(())
443            }
444            (_, _) => Err(bind_error()),
445        }
446    }
447}