simplicity/types/
arrow.rs

1// SPDX-License-Identifier: CC0-1.0
2
3//! Types Arrows
4//!
5//! Every Simplicity expression has two types associated with it: a source and
6//! a target. See the `types` module for more information. We refer to this
7//! pair of types as an "arrow", since the expression can be thought of as
8//! mapping a value of the source type to a value of the target type.
9//!
10//! This module defines the specific arrows associated with each kind of node.
11//!
12//! See the `types` module above this one for more information.
13
14use std::fmt;
15use std::sync::Arc;
16
17use crate::jet::Jet;
18use crate::node::{
19    CoreConstructible, DisconnectConstructible, JetConstructible, NoDisconnect,
20    WitnessConstructible,
21};
22use crate::types::{Context, Error, Final, Type};
23use crate::value::Word;
24
25use super::variable::new_name;
26
27/// A container for an expression's source and target types, whether or not
28/// these types are complete.
29#[derive(Debug)]
30pub struct Arrow<'brand> {
31    /// The source type
32    pub source: Type<'brand>,
33    /// The target type
34    pub target: Type<'brand>,
35    /// Type inference context for both types.
36    pub inference_context: Context<'brand>,
37}
38
39// Having `Clone` makes it easier to derive Clone on structures
40// that contain Arrow, even though it is potentially confusing
41// to use `.clone` to mean a shallow clone.
42impl Clone for Arrow<'_> {
43    fn clone(&self) -> Self {
44        self.shallow_clone()
45    }
46}
47
48impl fmt::Display for Arrow<'_> {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        write!(f, "{} → {}", self.source, self.target)
51    }
52}
53
54/// A container for the type data associated with an expression's source and
55/// target types, if both types are complete.
56#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
57pub struct FinalArrow {
58    /// The source type
59    pub source: Arc<Final>,
60    /// The target type
61    pub target: Arc<Final>,
62}
63
64impl fmt::Display for FinalArrow {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        write!(f, "{} → {}", self.source, self.target)
67    }
68}
69
70impl FinalArrow {
71    /// Same as [`Self::clone`] but named to make it clearer that this is cheap
72    pub fn shallow_clone(&self) -> Self {
73        FinalArrow {
74            source: Arc::clone(&self.source),
75            target: Arc::clone(&self.target),
76        }
77    }
78}
79
80impl<'brand> Arrow<'brand> {
81    /// Finalize the source and target types in the arrow
82    pub fn finalize(&self) -> Result<FinalArrow, Error> {
83        Ok(FinalArrow {
84            source: self.source.finalize()?,
85            target: self.target.finalize()?,
86        })
87    }
88
89    /// Same as [`Self::clone`] but named to make it clearer that this is cheap
90    pub fn shallow_clone(&self) -> Self {
91        Arrow {
92            source: self.source.shallow_clone(),
93            target: self.target.shallow_clone(),
94            inference_context: self.inference_context.shallow_clone(),
95        }
96    }
97
98    /// Create a unification arrow for a fresh `case` combinator
99    ///
100    /// Either child may be `None`, in which case the combinator is assumed to be
101    /// an assertion, which for type-inference purposes means there are no bounds
102    /// on the missing child.
103    ///
104    /// # Panics
105    ///
106    /// If neither child is provided, this function will panic.
107    fn for_case(lchild_arrow: Option<&Self>, rchild_arrow: Option<&Self>) -> Result<Self, Error> {
108        if let (Some(left), Some(right)) = (lchild_arrow, rchild_arrow) {
109            left.inference_context.check_eq(&right.inference_context)?;
110        }
111
112        let ctx = match (lchild_arrow, rchild_arrow) {
113            (Some(left), _) => left.inference_context.shallow_clone(),
114            (_, Some(right)) => right.inference_context.shallow_clone(),
115            (None, None) => panic!("called `for_case` with no children"),
116        };
117
118        let a = Type::free(&ctx, new_name("case_a_"));
119        let b = Type::free(&ctx, new_name("case_b_"));
120        let c = Type::free(&ctx, new_name("case_c_"));
121
122        let sum_a_b = Type::sum(&ctx, a.shallow_clone(), b.shallow_clone());
123        let prod_sum_a_b_c = Type::product(&ctx, sum_a_b, c.shallow_clone());
124
125        let target = Type::free(&ctx, String::new());
126        if let Some(lchild_arrow) = lchild_arrow {
127            ctx.bind_product(
128                &lchild_arrow.source,
129                &a,
130                &c,
131                "case combinator: left source = A × C",
132            )?;
133            ctx.unify(&target, &lchild_arrow.target, "").unwrap();
134        }
135        if let Some(rchild_arrow) = rchild_arrow {
136            ctx.bind_product(
137                &rchild_arrow.source,
138                &b,
139                &c,
140                "case combinator: left source = B × C",
141            )?;
142            ctx.unify(
143                &target,
144                &rchild_arrow.target,
145                "case combinator: left target = right target",
146            )?;
147        }
148
149        Ok(Arrow {
150            source: prod_sum_a_b_c,
151            target,
152            inference_context: ctx,
153        })
154    }
155
156    /// Helper function to combine code for the two `DisconnectConstructible` impls for [`Arrow`].
157    fn for_disconnect(lchild_arrow: &Self, rchild_arrow: &Self) -> Result<Self, Error> {
158        lchild_arrow
159            .inference_context
160            .check_eq(&rchild_arrow.inference_context)?;
161
162        let ctx = lchild_arrow.inference_context();
163        let a = Type::free(ctx, new_name("disconnect_a_"));
164        let b = Type::free(ctx, new_name("disconnect_b_"));
165        let c = rchild_arrow.source.shallow_clone();
166        let d = rchild_arrow.target.shallow_clone();
167
168        ctx.bind_product(
169            &lchild_arrow.source,
170            &Type::two_two_n(ctx, 8),
171            &a,
172            "disconnect combinator: left source = 2^256 × A",
173        )?;
174        ctx.bind_product(
175            &lchild_arrow.target,
176            &b,
177            &c,
178            "disconnect combinator: left target = B × C",
179        )?;
180
181        let prod_b_d = Type::product(ctx, b, d);
182
183        Ok(Arrow {
184            source: a,
185            target: prod_b_d,
186            inference_context: lchild_arrow.inference_context.shallow_clone(),
187        })
188    }
189}
190
191impl<'brand> CoreConstructible<'brand> for Arrow<'brand> {
192    fn iden(inference_context: &Context<'brand>) -> Self {
193        // Throughout this module, when two types are the same, we reuse a
194        // pointer to them rather than creating distinct types and unifying
195        // them. This theoretically could lead to more confusing errors for
196        // the user during type inference, but in practice type inference
197        // is completely opaque and there's no harm in making it moreso.
198        let new = Type::free(inference_context, new_name("iden_src_"));
199        Arrow {
200            source: new.shallow_clone(),
201            target: new,
202            inference_context: inference_context.shallow_clone(),
203        }
204    }
205
206    fn unit(inference_context: &Context<'brand>) -> Self {
207        Arrow {
208            source: Type::free(inference_context, new_name("unit_src_")),
209            target: Type::unit(inference_context),
210            inference_context: inference_context.shallow_clone(),
211        }
212    }
213
214    fn injl(child: &Self) -> Self {
215        Arrow {
216            source: child.source.shallow_clone(),
217            target: Type::sum(
218                &child.inference_context,
219                child.target.shallow_clone(),
220                Type::free(&child.inference_context, new_name("injl_tgt_")),
221            ),
222            inference_context: child.inference_context.shallow_clone(),
223        }
224    }
225
226    fn injr(child: &Self) -> Self {
227        Arrow {
228            source: child.source.shallow_clone(),
229            target: Type::sum(
230                &child.inference_context,
231                Type::free(&child.inference_context, new_name("injr_tgt_")),
232                child.target.shallow_clone(),
233            ),
234            inference_context: child.inference_context.shallow_clone(),
235        }
236    }
237
238    fn take(child: &Self) -> Self {
239        Arrow {
240            source: Type::product(
241                &child.inference_context,
242                child.source.shallow_clone(),
243                Type::free(&child.inference_context, new_name("take_src_")),
244            ),
245            target: child.target.shallow_clone(),
246            inference_context: child.inference_context.shallow_clone(),
247        }
248    }
249
250    fn drop_(child: &Self) -> Self {
251        Arrow {
252            source: Type::product(
253                &child.inference_context,
254                Type::free(&child.inference_context, new_name("drop_src_")),
255                child.source.shallow_clone(),
256            ),
257            target: child.target.shallow_clone(),
258            inference_context: child.inference_context.shallow_clone(),
259        }
260    }
261
262    fn comp(left: &Self, right: &Self) -> Result<Self, Error> {
263        left.inference_context.check_eq(&right.inference_context)?;
264        left.inference_context.unify(
265            &left.target,
266            &right.source,
267            "comp combinator: left target = right source",
268        )?;
269        Ok(Arrow {
270            source: left.source.shallow_clone(),
271            target: right.target.shallow_clone(),
272            inference_context: left.inference_context.shallow_clone(),
273        })
274    }
275
276    fn case(left: &Self, right: &Self) -> Result<Self, Error> {
277        Self::for_case(Some(left), Some(right))
278    }
279
280    fn assertl(left: &Self, _: crate::Cmr) -> Result<Self, Error> {
281        Self::for_case(Some(left), None)
282    }
283
284    fn assertr(_: crate::Cmr, right: &Self) -> Result<Self, Error> {
285        Self::for_case(None, Some(right))
286    }
287
288    fn pair(left: &Self, right: &Self) -> Result<Self, Error> {
289        left.inference_context.check_eq(&right.inference_context)?;
290        left.inference_context.unify(
291            &left.source,
292            &right.source,
293            "pair combinator: left source = right source",
294        )?;
295        Ok(Arrow {
296            source: left.source.shallow_clone(),
297            target: Type::product(
298                &left.inference_context,
299                left.target.shallow_clone(),
300                right.target.shallow_clone(),
301            ),
302            inference_context: left.inference_context.shallow_clone(),
303        })
304    }
305
306    fn fail(inference_context: &Context<'brand>, _: crate::FailEntropy) -> Self {
307        Arrow {
308            source: Type::free(inference_context, new_name("fail_src_")),
309            target: Type::free(inference_context, new_name("fail_tgt_")),
310            inference_context: inference_context.shallow_clone(),
311        }
312    }
313
314    fn const_word(inference_context: &Context<'brand>, word: Word) -> Self {
315        Arrow {
316            source: Type::unit(inference_context),
317            target: Type::two_two_n(inference_context, word.n() as usize), // cast safety: 32-bit machine or higher
318            inference_context: inference_context.shallow_clone(),
319        }
320    }
321
322    fn inference_context(&self) -> &Context<'brand> {
323        &self.inference_context
324    }
325}
326
327impl<'brand> DisconnectConstructible<'brand, Arrow<'brand>> for Arrow<'brand> {
328    fn disconnect(left: &Self, right: &Self) -> Result<Self, Error> {
329        Self::for_disconnect(left, right)
330    }
331}
332
333impl<'brand> DisconnectConstructible<'brand, NoDisconnect> for Arrow<'brand> {
334    fn disconnect(left: &Self, _: &NoDisconnect) -> Result<Self, Error> {
335        let source = Type::free(&left.inference_context, "disc_src".into());
336        let target = Type::free(&left.inference_context, "disc_tgt".into());
337        Self::for_disconnect(
338            left,
339            &Arrow {
340                source,
341                target,
342                inference_context: left.inference_context.shallow_clone(),
343            },
344        )
345    }
346}
347
348impl<'brand> DisconnectConstructible<'brand, Option<&Arrow<'brand>>> for Arrow<'brand> {
349    fn disconnect(left: &Self, right: &Option<&Self>) -> Result<Self, Error> {
350        match *right {
351            Some(right) => Self::disconnect(left, right),
352            None => Self::disconnect(left, &NoDisconnect),
353        }
354    }
355}
356
357impl<'brand, J: Jet> JetConstructible<'brand, J> for Arrow<'brand> {
358    fn jet(inference_context: &Context<'brand>, jet: J) -> Self {
359        Arrow {
360            source: jet.source_ty().to_type(inference_context),
361            target: jet.target_ty().to_type(inference_context),
362            inference_context: inference_context.shallow_clone(),
363        }
364    }
365}
366
367impl<'brand, W> WitnessConstructible<'brand, W> for Arrow<'brand> {
368    fn witness(inference_context: &Context<'brand>, _: W) -> Self {
369        Arrow {
370            source: Type::free(inference_context, new_name("witness_src_")),
371            target: Type::free(inference_context, new_name("witness_tgt_")),
372            inference_context: inference_context.shallow_clone(),
373        }
374    }
375}