1use std::sync::Arc;
4
5use either::Either;
6use simplicity::jet::Elements;
7use simplicity::node::{CoreConstructible as _, JetConstructible as _};
8use simplicity::{Cmr, FailEntropy};
9
10use crate::array::{BTreeSlice, Partition};
11use crate::ast::{
12 Call, CallName, Expression, ExpressionInner, Match, Program, SingleExpression,
13 SingleExpressionInner, Statement,
14};
15use crate::debug::CallTracker;
16use crate::error::{Error, RichError, Span, WithSpan};
17use crate::named::{CoreExt, PairBuilder};
18use crate::num::{NonZeroPow2Usize, Pow2Usize};
19use crate::pattern::{BasePattern, Pattern};
20use crate::str::WitnessName;
21use crate::types::{StructuralType, TypeDeconstructible};
22use crate::value::StructuralValue;
23use crate::witness::Arguments;
24use crate::{ProgNode, Value};
25
26#[derive(Debug, Clone)]
37struct Scope {
38 variables: Vec<Vec<Pattern>>,
64 ctx: simplicity::types::Context,
65 call_tracker: Arc<CallTracker>,
67 arguments: Arguments,
69}
70
71impl Scope {
72 pub fn new(call_tracker: Arc<CallTracker>, arguments: Arguments) -> Self {
81 Self {
82 variables: vec![vec![Pattern::Ignore]],
83 ctx: simplicity::types::Context::new(),
84 call_tracker,
85 arguments,
86 }
87 }
88
89 pub fn child(&self, input: Pattern) -> Self {
91 Self {
92 variables: vec![vec![input]],
93 ctx: self.ctx.shallow_clone(),
94 call_tracker: Arc::clone(&self.call_tracker),
95 arguments: self.arguments.clone(),
96 }
97 }
98
99 pub fn push_scope(&mut self) {
101 self.variables.push(Vec::new());
102 }
103
104 pub fn pop_scope(&mut self) {
110 self.variables.pop().expect("Empty stack");
111 }
112
113 pub fn insert(&mut self, pattern: Pattern) {
127 self.variables
128 .last_mut()
129 .expect("Empty stack")
130 .push(pattern);
131 }
132
133 fn get_input_pattern(&self) -> Pattern {
141 let mut it = self.variables.iter().flat_map(|scope| scope.iter());
142 let first = it.next().expect("Empty stack");
143 it.cloned()
144 .fold(first.clone(), |acc, next| Pattern::product(next, acc))
145 }
146
147 pub fn get(&self, target: &BasePattern) -> Option<PairBuilder<ProgNode>> {
175 BasePattern::from(&self.get_input_pattern()).translate(&self.ctx, target)
176 }
177
178 pub fn ctx(&self) -> &simplicity::types::Context {
180 &self.ctx
181 }
182
183 pub fn with_debug_symbol<S: AsRef<Span>>(
190 &mut self,
191 args: PairBuilder<ProgNode>,
192 body: &ProgNode,
193 span: &S,
194 ) -> Result<PairBuilder<ProgNode>, RichError> {
195 match self.call_tracker.get_cmr(span.as_ref()) {
196 Some(cmr) => {
197 let false_and_args = ProgNode::bit(self.ctx(), false).pair(args);
198 let nop_assert = ProgNode::assertl_drop(body, cmr);
199 false_and_args.comp(&nop_assert).with_span(span)
200 }
201 None => args.comp(body).with_span(span),
202 }
203 }
204
205 pub fn get_argument(&self, name: &WitnessName) -> &Value {
206 self.arguments
207 .get(name)
208 .expect("Precondition: Arguments are consistent with parameters")
209 }
210}
211
212fn compile_blk(
213 stmts: &[Statement],
214 scope: &mut Scope,
215 index: usize,
216 last_expr: Option<&Expression>,
217) -> Result<PairBuilder<ProgNode>, RichError> {
218 if index >= stmts.len() {
219 return match last_expr {
220 Some(expr) => expr.compile(scope),
221 None => Ok(PairBuilder::unit(scope.ctx())),
222 };
223 }
224 match &stmts[index] {
225 Statement::Assignment(assignment) => {
226 let expr = assignment.expression().compile(scope)?;
227 scope.insert(assignment.pattern().clone());
228 let left = expr.pair(PairBuilder::iden(scope.ctx()));
229 let right = compile_blk(stmts, scope, index + 1, last_expr)?;
230 left.comp(&right).with_span(assignment)
231 }
232 Statement::Expression(expression) => {
233 let left = expression.compile(scope)?;
234 let right = compile_blk(stmts, scope, index + 1, last_expr)?;
235 let pair = left.pair(right);
236 let drop_iden = ProgNode::drop_(&ProgNode::iden(scope.ctx()));
237 pair.comp(&drop_iden).with_span(expression)
238 }
239 }
240}
241
242impl Program {
243 pub fn compile(&self, arguments: Arguments) -> Result<ProgNode, RichError> {
250 let mut scope = Scope::new(Arc::clone(self.call_tracker()), arguments);
251 self.main().compile(&mut scope).map(PairBuilder::build)
252 }
253}
254
255impl Expression {
256 fn compile(&self, scope: &mut Scope) -> Result<PairBuilder<ProgNode>, RichError> {
257 match self.inner() {
258 ExpressionInner::Block(stmts, expr) => {
259 scope.push_scope();
260 let res = compile_blk(stmts, scope, 0, expr.as_ref().map(Arc::as_ref));
261 scope.pop_scope();
262 res
263 }
264 ExpressionInner::Single(e) => e.compile(scope),
265 }
266 }
267}
268
269impl SingleExpression {
270 fn compile(&self, scope: &mut Scope) -> Result<PairBuilder<ProgNode>, RichError> {
271 let expr = match self.inner() {
272 SingleExpressionInner::Constant(value) => {
273 let value = StructuralValue::from(value);
274 PairBuilder::unit_scribe(scope.ctx(), value.as_ref())
275 }
276 SingleExpressionInner::Witness(name) => PairBuilder::witness(scope.ctx(), name.clone()),
277 SingleExpressionInner::Parameter(name) => {
278 let value = StructuralValue::from(scope.get_argument(name));
279 PairBuilder::unit_scribe(scope.ctx(), value.as_ref())
280 }
281 SingleExpressionInner::Variable(identifier) => scope
282 .get(&BasePattern::Identifier(identifier.clone()))
283 .ok_or(Error::UndefinedVariable(identifier.clone()))
284 .with_span(self)?,
285 SingleExpressionInner::Expression(expr) => expr.compile(scope)?,
286 SingleExpressionInner::Tuple(elements) | SingleExpressionInner::Array(elements) => {
287 let compiled = elements
288 .iter()
289 .map(|e| e.compile(scope))
290 .collect::<Result<Vec<PairBuilder<ProgNode>>, RichError>>()?;
291 let tree = BTreeSlice::from_slice(&compiled);
292 tree.fold(PairBuilder::pair)
293 .unwrap_or_else(|| PairBuilder::unit(scope.ctx()))
294 }
295 SingleExpressionInner::List(elements) => {
296 let compiled = elements
297 .iter()
298 .map(|e| e.compile(scope))
299 .collect::<Result<Vec<PairBuilder<ProgNode>>, RichError>>()?;
300 let bound = self.ty().as_list().unwrap().1;
301 let partition = Partition::from_slice(&compiled, bound);
302 partition.fold(
303 |block, _size: usize| {
304 let tree = BTreeSlice::from_slice(block);
305 match tree.fold(PairBuilder::pair) {
306 None => PairBuilder::unit(scope.ctx()).injl(),
307 Some(pair) => pair.injr(),
308 }
309 },
310 PairBuilder::pair,
311 )
312 }
313 SingleExpressionInner::Option(None) => PairBuilder::unit(scope.ctx()).injl(),
314 SingleExpressionInner::Either(Either::Left(inner)) => {
315 inner.compile(scope).map(PairBuilder::injl)?
316 }
317 SingleExpressionInner::Either(Either::Right(inner))
318 | SingleExpressionInner::Option(Some(inner)) => {
319 inner.compile(scope).map(PairBuilder::injr)?
320 }
321 SingleExpressionInner::Call(call) => call.compile(scope)?,
322 SingleExpressionInner::Match(match_) => match_.compile(scope)?,
323 };
324
325 scope
326 .ctx()
327 .unify(
328 &expr.as_ref().cached_data().arrow().target,
329 &StructuralType::from(self.ty()).to_unfinalized(scope.ctx()),
330 "",
331 )
332 .with_span(self)?;
333 Ok(expr)
334 }
335}
336
337impl Call {
338 fn compile(&self, scope: &mut Scope) -> Result<PairBuilder<ProgNode>, RichError> {
339 let args_ast = SingleExpression::tuple(self.args().clone(), *self.as_ref());
340 let args = args_ast.compile(scope)?;
341
342 match self.name() {
343 CallName::Jet(name) => {
344 let jet = ProgNode::jet(scope.ctx(), *name);
345 scope.with_debug_symbol(args, &jet, self)
346 }
347 CallName::UnwrapLeft(..) => {
348 let input_and_unit =
349 PairBuilder::iden(scope.ctx()).pair(PairBuilder::unit(scope.ctx()));
350 let extract_inner = ProgNode::assertl_take(
351 &ProgNode::iden(scope.ctx()),
352 Cmr::fail(FailEntropy::ZERO),
353 );
354 let body = input_and_unit.comp(&extract_inner).with_span(self)?;
355 scope.with_debug_symbol(args, body.as_ref(), self)
356 }
357 CallName::UnwrapRight(..) | CallName::Unwrap => {
358 let input_and_unit =
359 PairBuilder::iden(scope.ctx()).pair(PairBuilder::unit(scope.ctx()));
360 let extract_inner = ProgNode::assertr_take(
361 Cmr::fail(FailEntropy::ZERO),
362 &ProgNode::iden(scope.ctx()),
363 );
364 let body = input_and_unit.comp(&extract_inner).with_span(self)?;
365 scope.with_debug_symbol(args, body.as_ref(), self)
366 }
367 CallName::IsNone(..) => {
368 let input_and_unit =
369 PairBuilder::iden(scope.ctx()).pair(PairBuilder::unit(scope.ctx()));
370 let is_right = ProgNode::case_true_false(scope.ctx());
371 let body = input_and_unit.comp(&is_right).with_span(self)?;
372 args.comp(&body).with_span(self)
373 }
374 CallName::Assert => {
375 let jet = ProgNode::jet(scope.ctx(), Elements::Verify);
376 scope.with_debug_symbol(args, &jet, self)
377 }
378 CallName::Panic => {
379 let fail = ProgNode::fail(scope.ctx(), FailEntropy::ZERO);
381 scope.with_debug_symbol(args, &fail, self)
382 }
383 CallName::Debug => {
384 let iden = ProgNode::iden(scope.ctx());
386 scope.with_debug_symbol(args, &iden, self)
387 }
388 CallName::TypeCast(..) => {
389 Ok(args)
394 }
395 CallName::Custom(function) => {
396 let mut function_scope = scope.child(function.params_pattern());
397 let body = function.body().compile(&mut function_scope)?;
398 args.comp(&body).with_span(self)
399 }
400 CallName::Fold(function, bound) => {
401 let mut function_scope = scope.child(function.params_pattern());
402 let body = function.body().compile(&mut function_scope)?;
403 let fold_body = list_fold(*bound, body.as_ref()).with_span(self)?;
404 args.comp(&fold_body).with_span(self)
405 }
406 CallName::ForWhile(function, bit_width) => {
407 let mut function_scope = scope.child(function.params_pattern());
408 let body = function.body().compile(&mut function_scope)?;
409 let fold_body = for_while(*bit_width, body).with_span(self)?;
410 args.comp(&fold_body).with_span(self)
411 }
412 }
413 }
414}
415
416fn list_fold(bound: NonZeroPow2Usize, f: &ProgNode) -> Result<ProgNode, simplicity::types::Error> {
426 let mut f_array = f.clone();
430
431 let ctx = f.inference_context();
435 let ioh = ProgNode::i().h(ctx);
436 let mut f_fold = ProgNode::case(ioh.as_ref(), &f_array)?;
437 let mut i = NonZeroPow2Usize::TWO;
438
439 fn next_f_array(f_array: &ProgNode) -> Result<ProgNode, simplicity::types::Error> {
440 let ctx = f_array.inference_context();
444 let half1_acc = ProgNode::o().o().h(ctx).pair(ProgNode::i().h(ctx));
445 let updated_acc = half1_acc.comp(f_array)?;
446 let half2_acc = ProgNode::o().i().h(ctx).pair(updated_acc);
447 half2_acc.comp(f_array).map(PairBuilder::build)
448 }
449 fn next_f_fold(
450 f_array: &ProgNode,
451 f_fold: &ProgNode,
452 ) -> Result<ProgNode, simplicity::types::Error> {
453 let ctx = f_array.inference_context();
459 let case_input = ProgNode::o()
460 .o()
461 .h(ctx)
462 .pair(ProgNode::o().i().h(ctx).pair(ProgNode::i().h(ctx)));
463 let case_left = ProgNode::drop_(f_fold);
464
465 let f_n_input = ProgNode::o().h(ctx).pair(ProgNode::i().i().h(ctx));
466 let f_n_output = f_n_input.comp(f_array)?;
467 let fold_n_input = ProgNode::i().o().h(ctx).pair(f_n_output);
468 let case_right = fold_n_input.comp(f_fold)?;
469
470 case_input
471 .comp(&ProgNode::case(&case_left, case_right.as_ref())?)
472 .map(PairBuilder::build)
473 }
474
475 while i < bound {
476 f_array = next_f_array(&f_array)?;
477 f_fold = next_f_fold(&f_array, &f_fold)?;
478 i = i.mul2();
479 }
480
481 Ok(f_fold)
482}
483
484fn for_while(
498 bit_width: Pow2Usize,
499 f: PairBuilder<ProgNode>,
500) -> Result<PairBuilder<ProgNode>, simplicity::types::Error> {
501 fn for_while_0(f: &ProgNode) -> Result<PairBuilder<ProgNode>, simplicity::types::Error> {
507 let ctx = f.inference_context();
508 let f_output = ProgNode::o()
509 .h(ctx)
510 .pair(ProgNode::i().h(ctx).pair(ProgNode::bit(ctx, false)))
511 .comp(f)?;
512 let case_input = f_output.pair(ProgNode::i().h(ctx));
513
514 let x = ProgNode::injl(ProgNode::o().h(ctx).as_ref());
515 let f_output = ProgNode::o()
516 .h(ctx)
517 .pair(ProgNode::i().h(ctx).pair(ProgNode::bit(ctx, true)))
518 .comp(f)?;
519 let case_output = ProgNode::case(&x, f_output.as_ref())?;
520
521 case_input.comp(&case_output)
522 }
523
524 fn adapt_f(f: &ProgNode) -> Result<PairBuilder<ProgNode>, simplicity::types::Error> {
530 let ctx = f.inference_context();
531 let f_input = ProgNode::o().h(ctx).pair(
532 ProgNode::i()
533 .o()
534 .o()
535 .h(ctx)
536 .pair(ProgNode::i().o().i().h(ctx).pair(ProgNode::i().i().h(ctx))),
537 );
538 f_input.comp(f)
539 }
540
541 #[derive(Debug, Copy, Clone)]
561 enum Task {
562 ForWhile0,
564 Adapt,
566 }
567 let max_stack = bit_width.mul2().get() - 1;
568 let mut stack = vec![Task::ForWhile0; max_stack];
569
570 let mut i = Pow2Usize::ONE.mul2();
571 while i <= bit_width {
572 let index = i.get() - 1;
573 let (prefix, tail) = stack.as_mut_slice().split_at_mut(index);
574 let suffix = &mut tail[..index];
575 debug_assert_eq!(prefix.len(), suffix.len());
576 suffix.copy_from_slice(prefix);
577 tail[index] = Task::Adapt;
578 i = i.mul2();
579 }
580
581 let mut for_while_f = f;
582
583 while let Some(task) = stack.pop() {
584 match task {
585 Task::ForWhile0 => {
586 for_while_f = for_while_0(for_while_f.as_ref())?;
587 }
588 Task::Adapt => {
589 for_while_f = adapt_f(for_while_f.as_ref())?;
590 }
591 }
592 }
593
594 Ok(for_while_f)
595}
596
597impl Match {
598 fn compile(&self, scope: &mut Scope) -> Result<PairBuilder<ProgNode>, RichError> {
599 scope.push_scope();
600 scope.insert(
601 self.left()
602 .pattern()
603 .as_variable()
604 .cloned()
605 .map(Pattern::Identifier)
606 .unwrap_or(Pattern::Ignore),
607 );
608 let left = self.left().expression().compile(scope)?;
609 scope.pop_scope();
610
611 scope.push_scope();
612 scope.insert(
613 self.right()
614 .pattern()
615 .as_variable()
616 .cloned()
617 .map(Pattern::Identifier)
618 .unwrap_or(Pattern::Ignore),
619 );
620 let right = self.right().expression().compile(scope)?;
621 scope.pop_scope();
622
623 let scrutinee = self.scrutinee().compile(scope)?;
624 let input = scrutinee.pair(PairBuilder::iden(scope.ctx()));
625 let output = ProgNode::case(left.as_ref(), right.as_ref()).with_span(self)?;
626 input.comp(&output).with_span(self)
627 }
628}