1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
// SPDX-License-Identifier: CC0-1.0

//! Types and Type Inference
//!
//! Every Simplicity expression has two types associated with it: a source and
//! a target. We refer to this pair of types as an "arrow". The types are
//! inferred from the structure of the program.
//!
//! Simplicity types are one of three things
//!   * A unit type, which has one value
//!   * A sum of two other types
//!   * A product of two other types
//!
//! During type inference, types are initially "free", meaning that there are
//! no constraints on what they will eventually be. The program structure then
//! imposes additional constraints (for example, the `comp` combinator requires
//! that its left child's target type be the same as its right child's source
//! type), and by unifying all these constraints, all types can be inferred.
//!
//! In this module, during inference types are characterized by their [`Bound`],
//! which describes the constraints on the type. The bound of a type can be
//! obtained by the [`Type::bound`] method, and is an enum with four variants:
//!
//! * [`Bound::Free`] means that the type has no constraints; it is a free
//!   variable. The type has a name which can be used to identify it in error
//!   messages.
//! * [`Bound::Sum`] and [`Bound::Product`] means that the the type is a sum
//!   (resp. product) of two other types, which are characterized by their
//!   own bounds.
//! * [`Bound::Complete`] means that the type has no free variables at all,
//!   and has an already-computed [`Final`] structure suitable for use in
//!   contexts that require complete types. (Unit types are always complete,
//!   and therefore use this variant rather than getting their own.)
//!
//! During inference, it is possible for a type to be complete, in the sense
//! of having no free variables, without its bound being [`Bound::Complete`].
//! This occurs, for example, if a type is a sum of two incomplete types, then
//! the child types are completed during type inference on an unrelated part
//! of the type hierarchy. The type would then have a [`Bound::Sum`] with two
//! children, both of which are complete.
//!
//! The inference engine makes an effort to notice when this happens and set
//! the bound of complete types to [`Bound::Complete`], but since type inference
//! is inherently non-local this cannot always be done.
//!
//! When the distinction matters, we say a type is "finalized" only if its bound
//! is `Complete` and "complete" if it has no free variables. But the distinction
//! usually does not matter, so we prefer to use the word "complete".
//!
//! Type inference is done progressively during construction of Simplicity
//! expressions. It is completed by the [`Type::finalize`] method, which
//! recursively completes types by setting any remaining free variables to unit.
//! If any type constraints are incompatible with each other (e.g. a type is
//! bound to be both a product and a sum type) then inference fails at that point
//! by returning an error.
//!
//! In addition to completing types [`Type::finalize`], does one additional
//! check, the "occurs check", to ensures that there are no infinitely-sized
//! types. Such types occur when a type has itself as a child, are illegal in
//! Simplicity, and could not be represented by our data structures.
//!
//! There are three main types in this module:
//!   * [`Type`] is the main type representing a Simplicity type, whether it is
//!     complete or not. Its main methods are [`Type::bound`] which returns the
//!     current state of the type and [`Type::bind`] which adds a new constraint
//!     to the type.
//!   * `Final` is a mutex-free structure that can be obtained from a complete
//!     type. It includes the TMR and the complete bound describing the type.
//!   * `Bound` defines the structure of a type: whether it is free, complete,
//!     or a sum or product of other types.
//!

use self::union_bound::UbElement;
use crate::dag::{Dag, DagLike, NoSharing};
use crate::Tmr;

use std::collections::HashSet;
use std::fmt;
use std::sync::Arc;

pub mod arrow;
mod final_data;
mod precomputed;
mod union_bound;
mod variable;

pub use final_data::{CompleteBound, Final};

/// Error type for simplicity
#[non_exhaustive]
#[derive(Clone, Debug)]
pub enum Error {
    /// An attempt to bind a type conflicted with an existing bound on the type
    Bind {
        existing_bound: Bound,
        new_bound: Bound,
        hint: &'static str,
    },
    /// Two unequal complete types were attempted to be unified
    CompleteTypeMismatch {
        type1: Arc<Final>,
        type2: Arc<Final>,
        hint: &'static str,
    },
    /// A type is recursive (i.e., occurs within itself), violating the "occurs check"
    OccursCheck,
}

impl fmt::Display for Error {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            Error::Bind {
                ref existing_bound,
                ref new_bound,
                hint,
            } => {
                write!(
                    f,
                    "failed to apply bound `{}` to existing bound `{}`: {}",
                    new_bound, existing_bound, hint,
                )
            }
            Error::CompleteTypeMismatch {
                ref type1,
                ref type2,
                hint,
            } => {
                write!(
                    f,
                    "attempted to unify unequal types `{}` and `{}`: {}",
                    type1, type2, hint,
                )
            }
            Error::OccursCheck => f.write_str("detected infinitely-sized type"),
        }
    }
}

impl std::error::Error for Error {}

mod bound_mutex {
    use super::{Bound, CompleteBound, Error, Final};
    use std::sync::{Arc, Mutex};

    /// Source or target type of a Simplicity expression
    #[derive(Debug)]
    pub struct BoundMutex {
        /// The type's status according to the union-bound algorithm.
        inner: Mutex<Arc<Bound>>,
    }

    impl BoundMutex {
        pub fn new(bound: Bound) -> Self {
            BoundMutex {
                inner: Mutex::new(Arc::new(bound)),
            }
        }

        pub fn get(&self) -> Arc<Bound> {
            Arc::clone(&self.inner.lock().unwrap())
        }

        pub fn set(&self, new: Arc<Bound>) {
            let mut lock = self.inner.lock().unwrap();
            assert!(
                !matches!(**lock, Bound::Complete(..)),
                "tried to modify finalized type",
            );
            *lock = new;
        }

        pub fn bind(&self, bound: Arc<Bound>, hint: &'static str) -> Result<(), Error> {
            let existing_bound = self.get();
            let bind_error = || Error::Bind {
                existing_bound: existing_bound.shallow_clone(),
                new_bound: bound.shallow_clone(),
                hint,
            };

            match (existing_bound.as_ref(), bound.as_ref()) {
                // Binding a free type to anything is a no-op
                (_, Bound::Free(_)) => Ok(()),
                // Free types are simply dropped and replaced by the new bound
                (Bound::Free(_), _) => {
                    // Free means non-finalized, so set() is ok.
                    self.set(bound);
                    Ok(())
                }
                // Binding complete->complete shouldn't ever happen, but if so, we just
                // compare the two types and return a pass/fail
                (Bound::Complete(ref existing_final), Bound::Complete(ref new_final)) => {
                    if existing_final == new_final {
                        Ok(())
                    } else {
                        Err(bind_error())
                    }
                }
                // Binding an incomplete to a complete type requires recursion.
                (Bound::Complete(complete), incomplete)
                | (incomplete, Bound::Complete(complete)) => {
                    match (complete.bound(), incomplete) {
                        // A unit might match a Bound::Free(..) or a Bound::Complete(..),
                        // and both cases were handled above. So this is an error.
                        (CompleteBound::Unit, _) => Err(bind_error()),
                        (
                            CompleteBound::Product(ref comp1, ref comp2),
                            Bound::Product(ref ty1, ref ty2),
                        )
                        | (
                            CompleteBound::Sum(ref comp1, ref comp2),
                            Bound::Sum(ref ty1, ref ty2),
                        ) => {
                            ty1.bind(Arc::new(Bound::Complete(Arc::clone(comp1))), hint)?;
                            ty2.bind(Arc::new(Bound::Complete(Arc::clone(comp2))), hint)
                        }
                        _ => Err(bind_error()),
                    }
                }
                (Bound::Sum(ref x1, ref x2), Bound::Sum(ref y1, ref y2))
                | (Bound::Product(ref x1, ref x2), Bound::Product(ref y1, ref y2)) => {
                    x1.unify(y1, hint)?;
                    x2.unify(y2, hint)?;
                    // This type was not complete, but it may be after unification, giving us
                    // an opportunity to finaliize it. We do this eagerly to make sure that
                    // "complete" (no free children) is always equivalent to "finalized" (the
                    // bound field having variant Bound::Complete(..)), even during inference.
                    //
                    // It also gives the user access to more information about the type,
                    // prior to finalization.
                    if let (Some(data1), Some(data2)) = (y1.final_data(), y2.final_data()) {
                        self.set(Arc::new(Bound::Complete(Arc::new(
                            if let Bound::Sum(..) = *bound {
                                Final::sum(data1, data2)
                            } else {
                                Final::product(data1, data2)
                            },
                        ))));
                    }
                    Ok(())
                }
                (x, y) => Err(Error::Bind {
                    existing_bound: x.shallow_clone(),
                    new_bound: y.shallow_clone(),
                    hint,
                }),
            }
        }
    }
}

/// The state of a [`Type`] based on all constraints currently imposed on it.
#[derive(Clone)]
pub enum Bound {
    /// Fully-unconstrained type
    Free(String),
    /// Fully-constrained (i.e. complete) type, which has no free variables.
    Complete(Arc<Final>),
    /// A sum of two other types
    Sum(Type, Type),
    /// A product of two other types
    Product(Type, Type),
}

impl Bound {
    /// Clones the `Bound`.
    ///
    /// This is the same as just calling `.clone()` but has a different name to
    /// emphasize that what's being cloned is (at most) a pair of ref-counted
    /// pointers.
    pub fn shallow_clone(&self) -> Bound {
        self.clone()
    }

    fn free(name: String) -> Self {
        Bound::Free(name)
    }

    fn unit() -> Self {
        Bound::Complete(Arc::new(Final::unit()))
    }

    fn sum(a: Type, b: Type) -> Self {
        if let (Some(adata), Some(bdata)) = (a.final_data(), b.final_data()) {
            Bound::Complete(Arc::new(Final::sum(adata, bdata)))
        } else {
            Bound::Sum(a, b)
        }
    }

    fn product(a: Type, b: Type) -> Self {
        if let (Some(adata), Some(bdata)) = (a.final_data(), b.final_data()) {
            Bound::Complete(Arc::new(Final::product(adata, bdata)))
        } else {
            Bound::Product(a, b)
        }
    }
}

impl fmt::Debug for Bound {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        let arc = Arc::new(self.shallow_clone());
        for data in arc.verbose_pre_order_iter::<NoSharing>() {
            match (&*data.node, data.n_children_yielded) {
                (Bound::Free(ref s), _) => f.write_str(s)?,
                (Bound::Complete(ref comp), _) => fmt::Debug::fmt(comp, f)?,
                (Bound::Sum(..), 0) | (Bound::Product(..), 0) => f.write_str("(")?,
                (Bound::Sum(..), 2) | (Bound::Product(..), 2) => f.write_str(")")?,
                (Bound::Sum(..), _) => f.write_str(" + ")?,
                (Bound::Product(..), _) => f.write_str(" × ")?,
            }
        }
        Ok(())
    }
}

impl fmt::Display for Bound {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        let arc = Arc::new(self.shallow_clone());
        for data in arc.verbose_pre_order_iter::<NoSharing>() {
            match (&*data.node, data.n_children_yielded) {
                (Bound::Free(ref s), _) => f.write_str(s)?,
                (Bound::Complete(ref comp), _) => fmt::Display::fmt(comp, f)?,
                (Bound::Sum(..), 0) | (Bound::Product(..), 0) => {
                    if data.index > 0 {
                        f.write_str("(")?
                    }
                }
                (Bound::Sum(..), 2) | (Bound::Product(..), 2) => {
                    if data.index > 0 {
                        f.write_str(")")?
                    }
                }
                (Bound::Sum(..), _) => f.write_str(" + ")?,
                (Bound::Product(..), _) => f.write_str(" × ")?,
            }
        }
        Ok(())
    }
}

impl DagLike for Arc<Bound> {
    type Node = Bound;
    fn data(&self) -> &Bound {
        self
    }

    fn as_dag_node(&self) -> Dag<Self> {
        match **self {
            Bound::Free(..) | Bound::Complete(..) => Dag::Nullary,
            Bound::Sum(ref ty1, ref ty2) | Bound::Product(ref ty1, ref ty2) => {
                Dag::Binary(ty1.bound.root().get(), ty2.bound.root().get())
            }
        }
    }
}

/// Source or target type of a Simplicity expression.
///
/// Internally this type is essentially just a refcounted pointer; it is
/// therefore quite cheap to clone, but be aware that cloning will not
/// actually create a new independent type, just a second pointer to the
/// first one.
#[derive(Clone, Debug)]
pub struct Type {
    /// A set of constraints, which maintained by the union-bound algorithm and
    /// is progressively tightened as type inference proceeds.
    bound: UbElement<bound_mutex::BoundMutex>,
}

impl Type {
    /// Return an unbound type with the given name
    pub fn free(name: String) -> Self {
        Type::from(Bound::free(name))
    }

    /// Return a unit type.
    pub fn unit() -> Self {
        Type::from(Bound::unit())
    }

    /// Return a precomputed copy of 2^(2^n), for given n.
    pub fn two_two_n(n: usize) -> Self {
        precomputed::nth_power_of_2(n)
    }

    /// Return the sum of the given two types.
    pub fn sum(a: Self, b: Self) -> Self {
        Type::from(Bound::sum(a, b))
    }

    /// Return the product of the given two types.
    pub fn product(a: Self, b: Self) -> Self {
        Type::from(Bound::product(a, b))
    }

    /// Clones the `Type`.
    ///
    /// This is the same as just calling `.clone()` but has a different name to
    /// emphasize that what's being cloned is merely a ref-counted pointer.
    pub fn shallow_clone(&self) -> Type {
        self.clone()
    }

    /// Binds the type to a given bound. If this fails, attach the provided
    /// hint to the error.
    ///
    /// Fails if the type has an existing incompatible bound.
    pub fn bind(&self, bound: Arc<Bound>, hint: &'static str) -> Result<(), Error> {
        let root = self.bound.root();
        root.bind(bound, hint)
    }

    /// Unify the type with another one.
    ///
    /// Fails if the bounds on the two types are incompatible
    pub fn unify(&self, other: &Self, hint: &'static str) -> Result<(), Error> {
        self.bound.unify(&other.bound, |x_bound, y_bound| {
            x_bound.bind(y_bound.get(), hint)
        })
    }

    /// Accessor for this type's bound
    pub fn bound(&self) -> Arc<Bound> {
        self.bound.root().get()
    }

    /// Accessor for the TMR of this type, if it is final
    pub fn tmr(&self) -> Option<Tmr> {
        self.final_data().map(|data| data.tmr())
    }

    /// Accessor for the data of this type, if it is complete
    pub fn final_data(&self) -> Option<Arc<Final>> {
        if let Bound::Complete(ref data) = *self.bound.root().get() {
            Some(Arc::clone(data))
        } else {
            None
        }
    }

    /// Whether this type is known to be final
    ///
    /// During type inference this may be false even though the type is, in fact,
    /// complete, since its children may have been unified to a complete type. To
    /// ensure a type is complete, call [`Type::finalize`].
    pub fn is_final(&self) -> bool {
        matches!(*self.bound.root().get(), Bound::Complete(..))
    }

    /// Attempts to finalize the type. Returns its TMR on success.
    pub fn finalize(&self) -> Result<Arc<Final>, Error> {
        let root = self.bound.root();
        let bound = root.get();
        if let Bound::Complete(ref data) = *bound {
            return Ok(Arc::clone(data));
        }

        // First, do occurs-check to ensure that we have no infinitely sized types.
        let mut occurs_check = HashSet::new();
        for data in bound.verbose_pre_order_iter::<NoSharing>() {
            if data.is_complete {
                occurs_check.remove(&(data.node.as_ref() as *const _));
            } else if data.n_children_yielded == 0
                && !occurs_check.insert(data.node.as_ref() as *const _)
            {
                return Err(Error::OccursCheck);
            }
        }

        // Now that we know our types have finite size, we can safely use a
        // post-order iterator to finalize them.
        let mut finalized = vec![];
        for data in self.shallow_clone().post_order_iter::<NoSharing>() {
            let bound = data.node.bound.root();
            let bound_get = bound.get();
            let final_data = match *bound_get {
                Bound::Free(_) => Arc::new(Final::unit()),
                Bound::Complete(ref arc) => Arc::clone(arc),
                Bound::Sum(..) => Arc::new(Final::sum(
                    Arc::clone(&finalized[data.left_index.unwrap()]),
                    Arc::clone(&finalized[data.right_index.unwrap()]),
                )),
                Bound::Product(..) => Arc::new(Final::product(
                    Arc::clone(&finalized[data.left_index.unwrap()]),
                    Arc::clone(&finalized[data.right_index.unwrap()]),
                )),
            };

            if !matches!(*bound_get, Bound::Complete(..)) {
                // set() ok because we are if-guarded on this variable not being complete
                bound.set(Arc::new(Bound::Complete(Arc::clone(&final_data))));
            }
            finalized.push(final_data);
        }
        Ok(finalized.pop().unwrap())
    }

    /// Return a vector containing the types 2^(2^i) for i from 0 to n-1.
    pub fn powers_of_two(n: usize) -> Vec<Self> {
        let mut ret = Vec::with_capacity(n);

        let unit = Type::unit();
        let mut two = Type::sum(unit.shallow_clone(), unit);
        for _ in 0..n {
            ret.push(two.shallow_clone());
            two = Type::product(two.shallow_clone(), two);
        }
        ret
    }
}

impl fmt::Display for Type {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        fmt::Display::fmt(&self.bound.root().get(), f)
    }
}

impl From<Bound> for Type {
    /// Promotes a `Bound` to a type defined by that constraint
    fn from(bound: Bound) -> Type {
        Type {
            bound: UbElement::new(Arc::new(bound_mutex::BoundMutex::new(bound))),
        }
    }
}

impl DagLike for Type {
    type Node = Type;
    fn data(&self) -> &Type {
        self
    }

    fn as_dag_node(&self) -> Dag<Self> {
        match *self.bound.root().get() {
            Bound::Free(..) | Bound::Complete(..) => Dag::Nullary,
            Bound::Sum(ref ty1, ref ty2) | Bound::Product(ref ty1, ref ty2) => {
                Dag::Binary(ty1.shallow_clone(), ty2.shallow_clone())
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use crate::jet::Core;
    use crate::node::{ConstructNode, CoreConstructible};

    #[test]
    fn inference_failure() {
        // unit: A -> 1
        let unit = Arc::<ConstructNode<Core>>::unit(); // 1 -> 1

        // Force unit to be 1->1
        Arc::<ConstructNode<Core>>::comp(&unit, &unit).unwrap();

        // take unit: 1 * B -> 1
        let take_unit = Arc::<ConstructNode<Core>>::take(&unit); // 1*1 -> 1

        // Pair will try to unify 1 and 1*B
        Arc::<ConstructNode<Core>>::pair(&unit, &take_unit).unwrap_err();
        // Trying to do it again should not work.
        Arc::<ConstructNode<Core>>::pair(&unit, &take_unit).unwrap_err();
    }
}