simplicity/types/
context.rs

1// SPDX-License-Identifier: CC0-1.0
2
3//! Type Inference Context
4//!
5//! When constructing a Simplicity program, you must first create a type inference
6//! context, in which type inference occurs incrementally during construction. Each
7//! leaf node (e.g. `unit` and `iden`) must explicitly refer to the type inference
8//! context, while combinator nodes (e.g. `comp`) infer the context from their
9//! children, raising an error if there are multiple children whose contexts don't
10//! match.
11//!
12//! This helps to prevent situations in which users attempt to construct multiple
13//! independent programs, but types in one program accidentally refer to types in
14//! the other.
15//!
16
17use std::fmt;
18use std::marker::PhantomData;
19use std::sync::{Arc, Mutex, MutexGuard};
20
21use ghost_cell::GhostToken;
22
23use crate::dag::{Dag, DagLike};
24
25use super::{
26    Bound, CompleteBound, Error, Final, Incomplete, Type, TypeInner, UbElement, WithGhostToken,
27};
28
29// Copied from ghost_cell source. See
30//     https://arhan.sh/blog/the-generativity-pattern-in-rust/
31// in particular the box labeled "Throughout the years lifetime invariance has
32// been achieved in several other ways." for some detail about this.
33type InvariantLifetime<'brand> = PhantomData<fn(&'brand ()) -> &'brand ()>;
34
35/// Type inference context, or handle to a context.
36///
37/// Can be cheaply cloned with [`Context::shallow_clone`]. These clones will
38/// refer to the same underlying type inference context, and can be used as
39/// handles to each other. The derived [`Context::clone`] has the same effect.
40///
41/// There is currently no way to create an independent context with the same
42/// type inference variables (i.e. a deep clone). If you need this functionality,
43/// please file an issue.
44#[derive(Clone)]
45pub struct Context<'brand> {
46    inner: Arc<Mutex<WithGhostToken<'brand, ContextInner<'brand>>>>,
47}
48
49struct ContextInner<'brand> {
50    slab: Vec<Bound<'brand>>,
51}
52
53impl fmt::Debug for Context<'_> {
54    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
55        let id = Arc::as_ptr(&self.inner) as usize;
56        write!(f, "inference_ctx_{:08x}", id)
57    }
58}
59
60impl PartialEq for Context<'_> {
61    fn eq(&self, other: &Self) -> bool {
62        Arc::ptr_eq(&self.inner, &other.inner)
63    }
64}
65impl Eq for Context<'_> {}
66
67impl<'brand> Context<'brand> {
68    /// Creates a scope with a new empty type inference context.
69    pub fn with_context<R, F>(fun: F) -> R
70    where
71        F: for<'new_brand> FnOnce(Context<'new_brand>) -> R,
72    {
73        GhostToken::new(|token| {
74            let ctx = Context::new(token);
75            fun(ctx)
76        })
77    }
78
79    /// Creates a new empty type inference context.
80    pub fn new(token: GhostToken<'brand>) -> Self {
81        Context {
82            inner: Arc::new(Mutex::new(WithGhostToken {
83                token,
84                inner: ContextInner { slab: vec![] },
85            })),
86        }
87    }
88
89    /// Helper function to allocate a bound and return a reference to it.
90    fn alloc_bound(&self, bound: Bound<'brand>) -> BoundRef<'brand> {
91        let mut lock = self.lock();
92        lock.alloc_bound(bound)
93    }
94
95    /// Allocate a new free type bound, and return a reference to it.
96    pub fn alloc_free(&self, name: String) -> BoundRef<'brand> {
97        self.alloc_bound(Bound::Free(name))
98    }
99
100    /// Allocate a new unit type bound, and return a reference to it.
101    pub fn alloc_unit(&self) -> BoundRef<'brand> {
102        self.alloc_bound(Bound::Complete(Final::unit()))
103    }
104
105    /// Allocate a new unit type bound, and return a reference to it.
106    pub fn alloc_complete(&self, data: Arc<Final>) -> BoundRef<'brand> {
107        self.alloc_bound(Bound::Complete(data))
108    }
109
110    /// Allocate a new sum-type bound, and return a reference to it.
111    pub fn alloc_sum(&self, left: Type<'brand>, right: Type<'brand>) -> BoundRef<'brand> {
112        let mut lock = self.lock();
113        if let Some((data1, data2)) = lock.complete_pair_data(&left.inner, &right.inner) {
114            lock.alloc_bound(Bound::Complete(Final::sum(data1, data2)))
115        } else {
116            lock.alloc_bound(Bound::Sum(left.inner, right.inner))
117        }
118    }
119
120    /// Allocate a new product-type bound, and return a reference to it.
121    pub fn alloc_product(&self, left: Type<'brand>, right: Type<'brand>) -> BoundRef<'brand> {
122        let mut lock = self.lock();
123        if let Some((data1, data2)) = lock.complete_pair_data(&left.inner, &right.inner) {
124            lock.alloc_bound(Bound::Complete(Final::product(data1, data2)))
125        } else {
126            lock.alloc_bound(Bound::Product(left.inner, right.inner))
127        }
128    }
129
130    /// Creates a new handle to the context.
131    ///
132    /// This handle holds a reference to the underlying context and will keep
133    /// it alive. The context will only be dropped once all handles, including
134    /// the original context object, are dropped.
135    pub fn shallow_clone(&self) -> Self {
136        Self {
137            inner: Arc::clone(&self.inner),
138        }
139    }
140
141    /// Checks whether two inference contexts are equal, and returns an error if not.
142    pub fn check_eq(&self, other: &Self) -> Result<(), super::Error> {
143        if self == other {
144            Ok(())
145        } else {
146            Err(super::Error::InferenceContextMismatch)
147        }
148    }
149
150    /// Accesses a bound.
151    pub(super) fn get(&self, bound: &BoundRef<'brand>) -> Bound<'brand> {
152        let lock = self.lock();
153        lock.inner.slab[bound.index].shallow_clone()
154    }
155
156    /// Accesses a bound through a union-bound element.
157    pub(super) fn get_root_ref(
158        &self,
159        bound: &UbElement<'brand, BoundRef<'brand>>,
160    ) -> BoundRef<'brand> {
161        let mut lock = self.lock();
162        bound.root(&mut lock.token)
163    }
164
165    /// Reassigns a bound to a different bound.
166    ///
167    /// # Panics
168    ///
169    /// Panics if called on a complete type. This is a sanity-check to avoid
170    /// replacing already-completed types, which can cause inefficiencies in
171    /// the union-bound algorithm (and if our replacement changes the type,
172    /// this is probably a bug.
173    ///
174    /// Also panics if passed a `BoundRef` that was not allocated by this context.
175    pub(super) fn reassign_non_complete(&self, bound: BoundRef<'brand>, new: Bound<'brand>) {
176        let mut lock = self.lock();
177        lock.reassign_non_complete(bound, new);
178    }
179
180    /// Binds the type to a product bound formed by the two inner types. If this
181    /// fails, attach the provided hint to the error.
182    ///
183    /// Fails if the type has an existing incompatible bound.
184    pub fn bind_product(
185        &self,
186        existing: &Type<'brand>,
187        prod_l: &Type<'brand>,
188        prod_r: &Type<'brand>,
189        hint: &'static str,
190    ) -> Result<(), Error> {
191        let mut lock = self.lock();
192        let existing_root = existing.inner.bound.root(&mut lock.token);
193        let new_bound = Bound::Product(prod_l.inner.shallow_clone(), prod_r.inner.shallow_clone());
194
195        lock.bind(existing_root, new_bound).map_err(|e| {
196            let new_bound = lock.alloc_bound(e.new);
197            drop(lock);
198            Error::Bind {
199                existing_bound: Incomplete::from_bound_ref(self, e.existing),
200                new_bound: Incomplete::from_bound_ref(self, new_bound),
201                hint,
202            }
203        })
204    }
205
206    /// Unify the type with another one.
207    ///
208    /// Fails if the bounds on the two types are incompatible
209    pub fn unify(
210        &self,
211        ty1: &Type<'brand>,
212        ty2: &Type<'brand>,
213        hint: &'static str,
214    ) -> Result<(), Error> {
215        let mut lock = self.lock();
216        lock.unify(&ty1.inner, &ty2.inner).map_err(|e| {
217            let new_bound = lock.alloc_bound(e.new);
218            drop(lock);
219            Error::Bind {
220                existing_bound: Incomplete::from_bound_ref(self, e.existing),
221                new_bound: Incomplete::from_bound_ref(self, new_bound),
222                hint,
223            }
224        })
225    }
226
227    /// Locks the underlying slab mutex.
228    fn lock(&self) -> MutexGuard<'_, WithGhostToken<'brand, ContextInner<'brand>>> {
229        self.inner.lock().unwrap()
230    }
231}
232
233#[derive(Debug, Clone)]
234pub struct BoundRef<'brand> {
235    phantom: InvariantLifetime<'brand>,
236    index: usize,
237}
238
239impl<'brand> BoundRef<'brand> {
240    /// Creates an "occurs-check ID" which is just a copy of the [`BoundRef`]
241    /// with `PartialEq` and `Eq` implemented in terms of underlying pointer
242    /// equality.
243    pub fn occurs_check_id(&self) -> OccursCheckId<'brand> {
244        OccursCheckId {
245            phantom: InvariantLifetime::default(),
246            index: self.index,
247        }
248    }
249}
250
251impl super::PointerLike for BoundRef<'_> {
252    fn ptr_eq(&self, other: &Self) -> bool {
253        self.index == other.index
254    }
255
256    fn shallow_clone(&self) -> Self {
257        BoundRef {
258            phantom: InvariantLifetime::default(),
259            index: self.index,
260        }
261    }
262}
263
264impl<'brand> DagLike for (&'_ Context<'brand>, BoundRef<'brand>) {
265    type Node = BoundRef<'brand>;
266    fn data(&self) -> &BoundRef<'brand> {
267        &self.1
268    }
269
270    fn as_dag_node(&self) -> Dag<Self> {
271        match self.0.get(&self.1) {
272            Bound::Free(..) | Bound::Complete(..) => Dag::Nullary,
273            Bound::Sum(ref ty1, ref ty2) | Bound::Product(ref ty1, ref ty2) => {
274                let root1 = self.0.get_root_ref(&ty1.bound);
275                let root2 = self.0.get_root_ref(&ty2.bound);
276                Dag::Binary((self.0, root1), (self.0, root2))
277            }
278        }
279    }
280}
281
282#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
283pub struct OccursCheckId<'brand> {
284    phantom: InvariantLifetime<'brand>,
285    index: usize,
286}
287
288struct BindError<'brand> {
289    existing: BoundRef<'brand>,
290    new: Bound<'brand>,
291}
292
293impl<'brand> ContextInner<'brand> {
294    fn alloc_bound(&mut self, bound: Bound<'brand>) -> BoundRef<'brand> {
295        self.slab.push(bound);
296        let index = self.slab.len() - 1;
297
298        BoundRef {
299            phantom: InvariantLifetime::default(),
300            index,
301        }
302    }
303
304    fn reassign_non_complete(&mut self, bound: BoundRef<'brand>, new: Bound<'brand>) {
305        assert!(
306            !matches!(self.slab[bound.index], Bound::Complete(..)),
307            "tried to modify finalized type",
308        );
309        self.slab[bound.index] = new;
310    }
311}
312
313impl<'brand> WithGhostToken<'brand, ContextInner<'brand>> {
314    /// It is a common situation that we are pairing two types, and in the
315    /// case that they are both complete, we want to pair the complete types.
316    ///
317    /// This method deals with all the annoying/complicated member variable
318    /// paths to get the actual complete data out.
319    fn complete_pair_data(
320        &mut self,
321        inn1: &TypeInner<'brand>,
322        inn2: &TypeInner<'brand>,
323    ) -> Option<(Arc<Final>, Arc<Final>)> {
324        let idx1 = inn1.bound.root(&mut self.token).index;
325        let idx2 = inn2.bound.root(&mut self.token).index;
326        let bound1 = &self.slab[idx1];
327        let bound2 = &self.slab[idx2];
328        if let (Bound::Complete(ref data1), Bound::Complete(ref data2)) = (bound1, bound2) {
329            Some((Arc::clone(data1), Arc::clone(data2)))
330        } else {
331            None
332        }
333    }
334    /// Unify the type with another one.
335    ///
336    /// Fails if the bounds on the two types are incompatible
337    fn unify(
338        &mut self,
339        existing: &TypeInner<'brand>,
340        other: &TypeInner<'brand>,
341    ) -> Result<(), BindError<'brand>> {
342        existing
343            .bound
344            .unify(self, &other.bound, |self_, x_bound, y_bound| {
345                self_.bind(x_bound, self_.slab[y_bound.index].shallow_clone())
346            })
347    }
348
349    fn bind(
350        &mut self,
351        existing: BoundRef<'brand>,
352        new: Bound<'brand>,
353    ) -> Result<(), BindError<'brand>> {
354        let existing_bound = self.slab[existing.index].shallow_clone();
355        let bind_error = || BindError {
356            existing: existing.clone(),
357            new: new.shallow_clone(),
358        };
359
360        match (&existing_bound, &new) {
361            // Binding a free type to anything is a no-op
362            (_, Bound::Free(_)) => Ok(()),
363            // Free types are simply dropped and replaced by the new bound
364            (Bound::Free(_), _) => {
365                // Free means non-finalized, so set() is ok.
366                self.reassign_non_complete(existing, new);
367                Ok(())
368            }
369            // Binding complete->complete shouldn't ever happen, but if so, we just
370            // compare the two types and return a pass/fail
371            (Bound::Complete(ref existing_final), Bound::Complete(ref new_final)) => {
372                if existing_final == new_final {
373                    Ok(())
374                } else {
375                    Err(bind_error())
376                }
377            }
378            // Binding an incomplete to a complete type requires recursion.
379            (Bound::Complete(complete), incomplete) | (incomplete, Bound::Complete(complete)) => {
380                match (complete.bound(), incomplete) {
381                    // A unit might match a Bound::Free(..) or a Bound::Complete(..),
382                    // and both cases were handled above. So this is an error.
383                    (CompleteBound::Unit, _) => Err(bind_error()),
384                    (
385                        CompleteBound::Product(ref comp1, ref comp2),
386                        Bound::Product(ref ty1, ref ty2),
387                    )
388                    | (CompleteBound::Sum(ref comp1, ref comp2), Bound::Sum(ref ty1, ref ty2)) => {
389                        let bound1 = ty1.bound.root(&mut self.token);
390                        let bound2 = ty2.bound.root(&mut self.token);
391                        self.bind(bound1, Bound::Complete(Arc::clone(comp1)))?;
392                        self.bind(bound2, Bound::Complete(Arc::clone(comp2)))
393                    }
394                    _ => Err(bind_error()),
395                }
396            }
397            (Bound::Sum(ref x1, ref x2), Bound::Sum(ref y1, ref y2))
398            | (Bound::Product(ref x1, ref x2), Bound::Product(ref y1, ref y2)) => {
399                self.unify(x1, y1)?;
400                self.unify(x2, y2)?;
401                // This type was not complete, but it may be after unification, giving us
402                // an opportunity to finaliize it. We do this eagerly to make sure that
403                // "complete" (no free children) is always equivalent to "finalized" (the
404                // bound field having variant Bound::Complete(..)), even during inference.
405                //
406                // It also gives the user access to more information about the type,
407                // prior to finalization.
408                if let Some((data1, data2)) = self.complete_pair_data(y1, y2) {
409                    self.reassign_non_complete(
410                        existing,
411                        Bound::Complete(if let Bound::Sum(..) = existing_bound {
412                            Final::sum(data1, data2)
413                        } else {
414                            Final::product(data1, data2)
415                        }),
416                    );
417                }
418                Ok(())
419            }
420            (_, _) => Err(bind_error()),
421        }
422    }
423}