simfony/
compile.rs

1//! Compile the parsed ast into a simplicity program
2
3use std::sync::Arc;
4
5use either::Either;
6use simplicity::jet::Elements;
7use simplicity::node::{CoreConstructible as _, JetConstructible as _};
8use simplicity::{Cmr, FailEntropy};
9
10use crate::array::{BTreeSlice, Partition};
11use crate::ast::{
12    Call, CallName, Expression, ExpressionInner, Match, Program, SingleExpression,
13    SingleExpressionInner, Statement,
14};
15use crate::debug::CallTracker;
16use crate::error::{Error, RichError, Span, WithSpan};
17use crate::named::{CoreExt, PairBuilder};
18use crate::num::{NonZeroPow2Usize, Pow2Usize};
19use crate::pattern::{BasePattern, Pattern};
20use crate::str::WitnessName;
21use crate::types::{StructuralType, TypeDeconstructible};
22use crate::value::StructuralValue;
23use crate::witness::Arguments;
24use crate::{ProgNode, Value};
25
26/// Each Simfony expression expects an _input value_.
27/// A Simfony expression is translated into a Simplicity expression
28/// that similarly expects an _input value_.
29///
30/// Simfony variable names are translated into Simplicity expressions
31/// that extract the seeked value from the _input value_.
32///
33/// Each (nested) block expression introduces a new scope.
34/// Bindings from inner scopes overwrite bindings from outer scopes.
35/// Bindings live as long as their scope.
36#[derive(Debug, Clone)]
37struct Scope {
38    /// For each scope, the set of assigned variables.
39    ///
40    /// A stack of scopes. Each scope is a stack of patterns.
41    /// New patterns are pushed onto the top _(current, innermost)_ scope.
42    ///
43    /// ## Input pattern
44    ///
45    /// The stack of scopes corresponds to an _input pattern_.
46    /// All valid input values match the input pattern.
47    ///
48    /// ## Example
49    ///
50    /// The stack `[[p1], [p2, p3]]` corresponds to a nested product pattern:
51    ///
52    /// ```text
53    ///    .
54    ///   / \
55    /// p3   .
56    ///     / \
57    ///   p2   p1
58    /// ```
59    ///
60    /// Inner scopes occur higher in the tree than outer scopes.
61    /// Later assignments occur higher in the tree than earlier assignments.
62    /// ```
63    variables: Vec<Vec<Pattern>>,
64    ctx: simplicity::types::Context,
65    /// Tracker of function calls.
66    call_tracker: Arc<CallTracker>,
67    /// Values for parameters inside the Simfony program.
68    arguments: Arguments,
69}
70
71impl Scope {
72    /// Create the main scope.
73    ///
74    /// _This function should be called at the start of the compilation and then never again._
75    ///
76    ///  ## Precondition
77    ///
78    /// The supplied `arguments` are consistent with the program's parameters.
79    /// Call [`Arguments::is_consistent`] before calling this method!
80    pub fn new(call_tracker: Arc<CallTracker>, arguments: Arguments) -> Self {
81        Self {
82            variables: vec![vec![Pattern::Ignore]],
83            ctx: simplicity::types::Context::new(),
84            call_tracker,
85            arguments,
86        }
87    }
88
89    /// Create a child scope for a function that takes `input` of the given pattern.
90    pub fn child(&self, input: Pattern) -> Self {
91        Self {
92            variables: vec![vec![input]],
93            ctx: self.ctx.shallow_clone(),
94            call_tracker: Arc::clone(&self.call_tracker),
95            arguments: self.arguments.clone(),
96        }
97    }
98
99    /// Push a new scope onto the stack.
100    pub fn push_scope(&mut self) {
101        self.variables.push(Vec::new());
102    }
103
104    /// Pop the current scope from the stack.
105    ///
106    /// # Panics
107    ///
108    /// The stack is empty.
109    pub fn pop_scope(&mut self) {
110        self.variables.pop().expect("Empty stack");
111    }
112
113    /// Push an assignment to the current scope.
114    ///
115    /// Update the input pattern accordingly:
116    ///
117    /// ```text
118    ///   .
119    ///  / \
120    /// p   previous
121    /// ```
122    ///
123    /// ## Panics
124    ///
125    /// The stack is empty.
126    pub fn insert(&mut self, pattern: Pattern) {
127        self.variables
128            .last_mut()
129            .expect("Empty stack")
130            .push(pattern);
131    }
132
133    /// Get the input pattern.
134    ///
135    /// All valid input values match the input pattern.
136    ///
137    /// ## Panics
138    ///
139    /// The stack is empty.
140    fn get_input_pattern(&self) -> Pattern {
141        let mut it = self.variables.iter().flat_map(|scope| scope.iter());
142        let first = it.next().expect("Empty stack");
143        it.cloned()
144            .fold(first.clone(), |acc, next| Pattern::product(next, acc))
145    }
146
147    /// Compute a Simplicity expression that takes a valid input value (that matches the input pattern)
148    /// and that produces as output a value that matches the `target` pattern.
149    ///
150    /// ## Example
151    ///
152    /// ```
153    /// let a: u8 = 0;
154    /// let b = {
155    ///     let b: u8 = 1;
156    ///     let c: u8 = 2;
157    ///     (a, b)  // here we seek the value of `(a, b)`
158    /// };
159    /// ```
160    ///
161    /// The input pattern looks like this:
162    ///
163    /// ```text
164    ///   .
165    ///  / \
166    /// c   .
167    ///    / \
168    ///   b   .
169    ///      / \
170    ///     a   _
171    /// ```
172    ///
173    /// The expression `drop (IOH & OH)` returns the seeked value.
174    pub fn get(&self, target: &BasePattern) -> Option<PairBuilder<ProgNode>> {
175        BasePattern::from(&self.get_input_pattern()).translate(&self.ctx, target)
176    }
177
178    /// Access the Simplicity type inference context.
179    pub fn ctx(&self) -> &simplicity::types::Context {
180        &self.ctx
181    }
182
183    /// Attach a debug symbol to the function body.
184    /// This debug symbol can be used by the Simplicity runtime to print the call arguments
185    /// during execution.
186    ///
187    /// The debug symbol is attached in such a way that a Simplicity runtime without support
188    /// for debug symbols will simply ignore it. The semantics of the program remain unchanged.
189    pub fn with_debug_symbol<S: AsRef<Span>>(
190        &mut self,
191        args: PairBuilder<ProgNode>,
192        body: &ProgNode,
193        span: &S,
194    ) -> Result<PairBuilder<ProgNode>, RichError> {
195        match self.call_tracker.get_cmr(span.as_ref()) {
196            Some(cmr) => {
197                let false_and_args = ProgNode::bit(self.ctx(), false).pair(args);
198                let nop_assert = ProgNode::assertl_drop(body, cmr);
199                false_and_args.comp(&nop_assert).with_span(span)
200            }
201            None => args.comp(body).with_span(span),
202        }
203    }
204
205    pub fn get_argument(&self, name: &WitnessName) -> &Value {
206        self.arguments
207            .get(name)
208            .expect("Precondition: Arguments are consistent with parameters")
209    }
210}
211
212fn compile_blk(
213    stmts: &[Statement],
214    scope: &mut Scope,
215    index: usize,
216    last_expr: Option<&Expression>,
217) -> Result<PairBuilder<ProgNode>, RichError> {
218    if index >= stmts.len() {
219        return match last_expr {
220            Some(expr) => expr.compile(scope),
221            None => Ok(PairBuilder::unit(scope.ctx())),
222        };
223    }
224    match &stmts[index] {
225        Statement::Assignment(assignment) => {
226            let expr = assignment.expression().compile(scope)?;
227            scope.insert(assignment.pattern().clone());
228            let left = expr.pair(PairBuilder::iden(scope.ctx()));
229            let right = compile_blk(stmts, scope, index + 1, last_expr)?;
230            left.comp(&right).with_span(assignment)
231        }
232        Statement::Expression(expression) => {
233            let left = expression.compile(scope)?;
234            let right = compile_blk(stmts, scope, index + 1, last_expr)?;
235            let pair = left.pair(right);
236            let drop_iden = ProgNode::drop_(&ProgNode::iden(scope.ctx()));
237            pair.comp(&drop_iden).with_span(expression)
238        }
239    }
240}
241
242impl Program {
243    /// Compile the Simfony source code to Simplicity target code.
244    ///
245    /// ## Precondition
246    ///
247    /// The supplied `arguments` are consistent with the program's parameters.
248    /// Call [`Arguments::is_consistent`] before calling this method!
249    pub fn compile(&self, arguments: Arguments) -> Result<ProgNode, RichError> {
250        let mut scope = Scope::new(Arc::clone(self.call_tracker()), arguments);
251        self.main().compile(&mut scope).map(PairBuilder::build)
252    }
253}
254
255impl Expression {
256    fn compile(&self, scope: &mut Scope) -> Result<PairBuilder<ProgNode>, RichError> {
257        match self.inner() {
258            ExpressionInner::Block(stmts, expr) => {
259                scope.push_scope();
260                let res = compile_blk(stmts, scope, 0, expr.as_ref().map(Arc::as_ref));
261                scope.pop_scope();
262                res
263            }
264            ExpressionInner::Single(e) => e.compile(scope),
265        }
266    }
267}
268
269impl SingleExpression {
270    fn compile(&self, scope: &mut Scope) -> Result<PairBuilder<ProgNode>, RichError> {
271        let expr = match self.inner() {
272            SingleExpressionInner::Constant(value) => {
273                let value = StructuralValue::from(value);
274                PairBuilder::unit_scribe(scope.ctx(), value.as_ref())
275            }
276            SingleExpressionInner::Witness(name) => PairBuilder::witness(scope.ctx(), name.clone()),
277            SingleExpressionInner::Parameter(name) => {
278                let value = StructuralValue::from(scope.get_argument(name));
279                PairBuilder::unit_scribe(scope.ctx(), value.as_ref())
280            }
281            SingleExpressionInner::Variable(identifier) => scope
282                .get(&BasePattern::Identifier(identifier.clone()))
283                .ok_or(Error::UndefinedVariable(identifier.clone()))
284                .with_span(self)?,
285            SingleExpressionInner::Expression(expr) => expr.compile(scope)?,
286            SingleExpressionInner::Tuple(elements) | SingleExpressionInner::Array(elements) => {
287                let compiled = elements
288                    .iter()
289                    .map(|e| e.compile(scope))
290                    .collect::<Result<Vec<PairBuilder<ProgNode>>, RichError>>()?;
291                let tree = BTreeSlice::from_slice(&compiled);
292                tree.fold(PairBuilder::pair)
293                    .unwrap_or_else(|| PairBuilder::unit(scope.ctx()))
294            }
295            SingleExpressionInner::List(elements) => {
296                let compiled = elements
297                    .iter()
298                    .map(|e| e.compile(scope))
299                    .collect::<Result<Vec<PairBuilder<ProgNode>>, RichError>>()?;
300                let bound = self.ty().as_list().unwrap().1;
301                let partition = Partition::from_slice(&compiled, bound);
302                partition.fold(
303                    |block, _size: usize| {
304                        let tree = BTreeSlice::from_slice(block);
305                        match tree.fold(PairBuilder::pair) {
306                            None => PairBuilder::unit(scope.ctx()).injl(),
307                            Some(pair) => pair.injr(),
308                        }
309                    },
310                    PairBuilder::pair,
311                )
312            }
313            SingleExpressionInner::Option(None) => PairBuilder::unit(scope.ctx()).injl(),
314            SingleExpressionInner::Either(Either::Left(inner)) => {
315                inner.compile(scope).map(PairBuilder::injl)?
316            }
317            SingleExpressionInner::Either(Either::Right(inner))
318            | SingleExpressionInner::Option(Some(inner)) => {
319                inner.compile(scope).map(PairBuilder::injr)?
320            }
321            SingleExpressionInner::Call(call) => call.compile(scope)?,
322            SingleExpressionInner::Match(match_) => match_.compile(scope)?,
323        };
324
325        scope
326            .ctx()
327            .unify(
328                &expr.as_ref().cached_data().arrow().target,
329                &StructuralType::from(self.ty()).to_unfinalized(scope.ctx()),
330                "",
331            )
332            .with_span(self)?;
333        Ok(expr)
334    }
335}
336
337impl Call {
338    fn compile(&self, scope: &mut Scope) -> Result<PairBuilder<ProgNode>, RichError> {
339        let args_ast = SingleExpression::tuple(self.args().clone(), *self.as_ref());
340        let args = args_ast.compile(scope)?;
341
342        match self.name() {
343            CallName::Jet(name) => {
344                let jet = ProgNode::jet(scope.ctx(), *name);
345                scope.with_debug_symbol(args, &jet, self)
346            }
347            CallName::UnwrapLeft(..) => {
348                let input_and_unit =
349                    PairBuilder::iden(scope.ctx()).pair(PairBuilder::unit(scope.ctx()));
350                let extract_inner = ProgNode::assertl_take(
351                    &ProgNode::iden(scope.ctx()),
352                    Cmr::fail(FailEntropy::ZERO),
353                );
354                let body = input_and_unit.comp(&extract_inner).with_span(self)?;
355                scope.with_debug_symbol(args, body.as_ref(), self)
356            }
357            CallName::UnwrapRight(..) | CallName::Unwrap => {
358                let input_and_unit =
359                    PairBuilder::iden(scope.ctx()).pair(PairBuilder::unit(scope.ctx()));
360                let extract_inner = ProgNode::assertr_take(
361                    Cmr::fail(FailEntropy::ZERO),
362                    &ProgNode::iden(scope.ctx()),
363                );
364                let body = input_and_unit.comp(&extract_inner).with_span(self)?;
365                scope.with_debug_symbol(args, body.as_ref(), self)
366            }
367            CallName::IsNone(..) => {
368                let input_and_unit =
369                    PairBuilder::iden(scope.ctx()).pair(PairBuilder::unit(scope.ctx()));
370                let is_right = ProgNode::case_true_false(scope.ctx());
371                let body = input_and_unit.comp(&is_right).with_span(self)?;
372                args.comp(&body).with_span(self)
373            }
374            CallName::Assert => {
375                let jet = ProgNode::jet(scope.ctx(), Elements::Verify);
376                scope.with_debug_symbol(args, &jet, self)
377            }
378            CallName::Panic => {
379                // panic! ignores its arguments
380                let fail = ProgNode::fail(scope.ctx(), FailEntropy::ZERO);
381                scope.with_debug_symbol(args, &fail, self)
382            }
383            CallName::Debug => {
384                // dbg! computes the identity function
385                let iden = ProgNode::iden(scope.ctx());
386                scope.with_debug_symbol(args, &iden, self)
387            }
388            CallName::TypeCast(..) => {
389                // A cast converts between two structurally equal types.
390                // Structural equality of Simfony types A and B means
391                // exact equality of the underlying Simplicity types of A and of B.
392                // Therefore, a Simfony cast is a NOP in Simplicity.
393                Ok(args)
394            }
395            CallName::Custom(function) => {
396                let mut function_scope = scope.child(function.params_pattern());
397                let body = function.body().compile(&mut function_scope)?;
398                args.comp(&body).with_span(self)
399            }
400            CallName::Fold(function, bound) => {
401                let mut function_scope = scope.child(function.params_pattern());
402                let body = function.body().compile(&mut function_scope)?;
403                let fold_body = list_fold(*bound, body.as_ref()).with_span(self)?;
404                args.comp(&fold_body).with_span(self)
405            }
406            CallName::ForWhile(function, bit_width) => {
407                let mut function_scope = scope.child(function.params_pattern());
408                let body = function.body().compile(&mut function_scope)?;
409                let fold_body = for_while(*bit_width, body).with_span(self)?;
410                args.comp(&fold_body).with_span(self)
411            }
412        }
413    }
414}
415
416/// Fold a list of less than `2^n` elements using function `f`.
417///
418/// Function `f: E × A → A`
419/// takes a list element of type `E` and an accumulator of type `A`,
420/// and it produces an updated accumulator of type `A`.
421///
422/// The fold `(fold f)_n : E^(<2^n) × A → A`
423/// takes the list of type `E^(<2^n)` and an initial accumulator of type `A`,
424/// and it produces the final accumulator of type `A`.
425fn list_fold(bound: NonZeroPow2Usize, f: &ProgNode) -> Result<ProgNode, simplicity::types::Error> {
426    /* f_0 :  E × A → A
427     * f_0 := f
428     */
429    let mut f_array = f.clone();
430
431    /* (fold f)_1 :  E^<2 × A → A
432     * (fold f)_1 := case IH f_0
433     */
434    let ctx = f.inference_context();
435    let ioh = ProgNode::i().h(ctx);
436    let mut f_fold = ProgNode::case(ioh.as_ref(), &f_array)?;
437    let mut i = NonZeroPow2Usize::TWO;
438
439    fn next_f_array(f_array: &ProgNode) -> Result<ProgNode, simplicity::types::Error> {
440        /* f_(n + 1) :  E^(2^(n + 1)) × A → A
441         * f_(n + 1) := OIH ▵ (OOH ▵ IH; f_n); f_n
442         */
443        let ctx = f_array.inference_context();
444        let half1_acc = ProgNode::o().o().h(ctx).pair(ProgNode::i().h(ctx));
445        let updated_acc = half1_acc.comp(f_array)?;
446        let half2_acc = ProgNode::o().i().h(ctx).pair(updated_acc);
447        half2_acc.comp(f_array).map(PairBuilder::build)
448    }
449    fn next_f_fold(
450        f_array: &ProgNode,
451        f_fold: &ProgNode,
452    ) -> Result<ProgNode, simplicity::types::Error> {
453        /* (fold f)_(n + 1) :  E<2^(n + 1) × A → A
454         * (fold f)_(n + 1) := OOH ▵ (OIH ▵ IH);
455         *                     case (drop (fold f)_n)
456         *                          ((IOH ▵ (OH ▵ IIH; f_n)); (fold f)_n)
457         */
458        let ctx = f_array.inference_context();
459        let case_input = ProgNode::o()
460            .o()
461            .h(ctx)
462            .pair(ProgNode::o().i().h(ctx).pair(ProgNode::i().h(ctx)));
463        let case_left = ProgNode::drop_(f_fold);
464
465        let f_n_input = ProgNode::o().h(ctx).pair(ProgNode::i().i().h(ctx));
466        let f_n_output = f_n_input.comp(f_array)?;
467        let fold_n_input = ProgNode::i().o().h(ctx).pair(f_n_output);
468        let case_right = fold_n_input.comp(f_fold)?;
469
470        case_input
471            .comp(&ProgNode::case(&case_left, case_right.as_ref())?)
472            .map(PairBuilder::build)
473    }
474
475    while i < bound {
476        f_array = next_f_array(&f_array)?;
477        f_fold = next_f_fold(&f_array, &f_fold)?;
478        i = i.mul2();
479    }
480
481    Ok(f_fold)
482}
483
484/// Run a function at most `2^(2^n)` times and return the first successful output.
485///
486/// Function `f: A × (C × 2^(2^(2^n))) → B + A`
487/// takes an accumulator of type `A`, a readonly context of type `C`,
488/// and a counter of type `2^(2^(2^n))` (unsigned integer of 2^n bits).
489///
490/// `f` may return a left `B` value, which is a successful output value.
491/// In this case, the loop exists and returns this value.
492///
493/// Otherwise, the `f` returns a right `A` value, which is the updated accumulator.
494/// In this case, the loop continues without returning anything.
495/// The loop returns the final iterator after the final iteration
496/// if `f` never returned a successful output.
497fn for_while(
498    bit_width: Pow2Usize,
499    f: PairBuilder<ProgNode>,
500) -> Result<PairBuilder<ProgNode>, simplicity::types::Error> {
501    /* for_while_0 f :  E × A → A
502     * for_while_0 f := (OH ▵ (IH ▵ false); f) ▵ IH;
503     *                  case (injl OH)
504     *                       (OH ▵ (IH ▵ true); f)
505     */
506    fn for_while_0(f: &ProgNode) -> Result<PairBuilder<ProgNode>, simplicity::types::Error> {
507        let ctx = f.inference_context();
508        let f_output = ProgNode::o()
509            .h(ctx)
510            .pair(ProgNode::i().h(ctx).pair(ProgNode::bit(ctx, false)))
511            .comp(f)?;
512        let case_input = f_output.pair(ProgNode::i().h(ctx));
513
514        let x = ProgNode::injl(ProgNode::o().h(ctx).as_ref());
515        let f_output = ProgNode::o()
516            .h(ctx)
517            .pair(ProgNode::i().h(ctx).pair(ProgNode::bit(ctx, true)))
518            .comp(f)?;
519        let case_output = ProgNode::case(&x, f_output.as_ref())?;
520
521        case_input.comp(&case_output)
522    }
523
524    /* adapt f :  A × ((C × 2^(2^n)) × 2^(2^n)) → B + A
525     * adapt f := OH ▵ (IOOH ▵ (IOIH ▵ IIH)); f
526     * where
527     *       f :  A × (C × 2^(2^(n + 1))) → B + A
528     */
529    fn adapt_f(f: &ProgNode) -> Result<PairBuilder<ProgNode>, simplicity::types::Error> {
530        let ctx = f.inference_context();
531        let f_input = ProgNode::o().h(ctx).pair(
532            ProgNode::i()
533                .o()
534                .o()
535                .h(ctx)
536                .pair(ProgNode::i().o().i().h(ctx).pair(ProgNode::i().i().h(ctx))),
537        );
538        f_input.comp(f)
539    }
540
541    /* for_while_(n + 1) f :  E × A → A
542     * for_while_(n + 1) f := for_while_n $ for_while_n $ adapt $ f
543     *
544     * If we write "0" for "for_while_0" and "1" for "adapt" and "." for function composition,
545     * then the extended pattern looks like this:
546     *
547     * for_while_0 f := 0 . f
548     * for_while_1 f := 0 . 0 . 1 . f
549     * for_while_2 f := 0 . 0 . 1 . 0 . 0 . 1 . 1 . f
550     * for_while_3 f := 0 . 0 . 1 . 0 . 0 . 1 . 1 . 0 . 0 . 1 . 0 . 0 . 1 . 1 . 1 . f
551     *
552     * The sequence of zeroes and ones starts with a single 0.
553     * The next sequence is two copies of the previous sequence plus a final 1.
554     *
555     * The following Rust code implements this behavior:
556     * First, a stack of zeroes is allocated. We know its final size, so we allocate exactly once.
557     * The stack is repeatedly copied into itself to produce the seeked sequence of zeroes and ones.
558     * Finally, "for_while_0" and "adapt" are applied to "f" by popping from the stack.
559     */
560    #[derive(Debug, Copy, Clone)]
561    enum Task {
562        /// "Zero"
563        ForWhile0,
564        /// "One"
565        Adapt,
566    }
567    let max_stack = bit_width.mul2().get() - 1;
568    let mut stack = vec![Task::ForWhile0; max_stack];
569
570    let mut i = Pow2Usize::ONE.mul2();
571    while i <= bit_width {
572        let index = i.get() - 1;
573        let (prefix, tail) = stack.as_mut_slice().split_at_mut(index);
574        let suffix = &mut tail[..index];
575        debug_assert_eq!(prefix.len(), suffix.len());
576        suffix.copy_from_slice(prefix);
577        tail[index] = Task::Adapt;
578        i = i.mul2();
579    }
580
581    let mut for_while_f = f;
582
583    while let Some(task) = stack.pop() {
584        match task {
585            Task::ForWhile0 => {
586                for_while_f = for_while_0(for_while_f.as_ref())?;
587            }
588            Task::Adapt => {
589                for_while_f = adapt_f(for_while_f.as_ref())?;
590            }
591        }
592    }
593
594    Ok(for_while_f)
595}
596
597impl Match {
598    fn compile(&self, scope: &mut Scope) -> Result<PairBuilder<ProgNode>, RichError> {
599        scope.push_scope();
600        scope.insert(
601            self.left()
602                .pattern()
603                .as_variable()
604                .cloned()
605                .map(Pattern::Identifier)
606                .unwrap_or(Pattern::Ignore),
607        );
608        let left = self.left().expression().compile(scope)?;
609        scope.pop_scope();
610
611        scope.push_scope();
612        scope.insert(
613            self.right()
614                .pattern()
615                .as_variable()
616                .cloned()
617                .map(Pattern::Identifier)
618                .unwrap_or(Pattern::Ignore),
619        );
620        let right = self.right().expression().compile(scope)?;
621        scope.pop_scope();
622
623        let scrutinee = self.scrutinee().compile(scope)?;
624        let input = scrutinee.pair(PairBuilder::iden(scope.ctx()));
625        let output = ProgNode::case(left.as_ref(), right.as_ref()).with_span(self)?;
626        input.comp(&output).with_span(self)
627    }
628}