sigma_compiler_core/sigma/combiners.rs
1//! This module creates and manipulates trees of basic statements
2//! combined with `AND`, `OR`, and `THRESH`.
3
4use super::types::*;
5use quote::quote;
6use std::collections::{HashMap, HashSet};
7use syn::parse::Result;
8use syn::visit::Visit;
9use syn::{parse_quote, Expr, Ident};
10
11/// For each [`Ident`](struct@syn::Ident) representing a private
12/// `Scalar` (as listed in a [`VarDict`]) that appears in an [`Expr`],
13/// call a given closure.
14pub struct PrivScalarMap<'a> {
15 /// The [`VarDict`] that maps variable names to their types
16 pub vars: &'a VarDict,
17
18 /// The closure that is called for each [`Ident`](struct@syn::Ident)
19 /// found in the [`Expr`] (provided in the call to
20 /// [`visit_expr`](PrivScalarMap::visit_expr)) that represents a
21 /// private `Scalar`
22 pub closure: &'a mut dyn FnMut(&syn::Ident) -> Result<()>,
23
24 /// The accumulated result. This will be the first
25 /// [`Err`](Result::Err) returned from the closure, or
26 /// [`Ok(())`](Result::Ok) if all calls to the closure succeeded.
27 pub result: Result<()>,
28}
29
30impl<'a> Visit<'a> for PrivScalarMap<'a> {
31 fn visit_path(&mut self, path: &'a syn::Path) {
32 // Whenever we see a `Path`, check first if it's just a bare
33 // `Ident`
34 let Some(id) = path.get_ident() else {
35 return;
36 };
37 // Then check if that `Ident` appears in the `VarDict`
38 let Some(vartype) = self.vars.get(&id.to_string()) else {
39 return;
40 };
41 // If so, and the `Ident` represents a private Scalar,
42 // call the closure if we haven't seen an `Err` returned from
43 // the closure yet.
44 if let AExprType::Scalar { is_pub: false, .. } = vartype {
45 if self.result.is_ok() {
46 self.result = (self.closure)(id);
47 }
48 }
49 }
50}
51
52/// The statements in the ZKP form a tree. The leaves are basic
53/// statements of various kinds; for example, equations or inequalities
54/// about Scalars and Points. The interior nodes are combiners: `And`,
55/// `Or`, or `Thresh` (with a given constant threshold). A leaf is true
56/// if the basic statement it contains is true. An `And` node is true
57/// if all of its children are true. An `Or` node is true if at least
58/// one of its children is true. A `Thresh` node (with threshold `k`) is
59/// true if at least `k` of its children are true.
60
61#[derive(Clone, Debug, Eq, PartialEq)]
62pub enum StatementTree {
63 Leaf(Expr),
64 And(Vec<StatementTree>),
65 Or(Vec<StatementTree>),
66 Thresh(usize, Vec<StatementTree>),
67}
68
69impl StatementTree {
70 #[cfg(not(doctest))]
71 /// Parse an [`Expr`] (which may contain nested `AND`, `OR`, or
72 /// `THRESH`) into a [`StatementTree`]. For example, the
73 /// [`Expr`] obtained from:
74 /// ```
75 /// parse_quote! {
76 /// AND (
77 /// C = c*B + r*A,
78 /// D = d*B + s*A,
79 /// OR (
80 /// AND (
81 /// C = c0*B + r0*A,
82 /// D = d0*B + s0*A,
83 /// c0 = d0,
84 /// ),
85 /// AND (
86 /// C = c1*B + r1*A,
87 /// D = d1*B + s1*A,
88 /// c1 = d1 + 1,
89 /// ),
90 /// )
91 /// )
92 /// }
93 /// ```
94 ///
95 /// would yield a [`StatementTree::And`] containing a 3-element
96 /// vector. The first two elements are [`StatementTree::Leaf`], and
97 /// the third is [`StatementTree::Or`] containing a 2-element
98 /// vector. Each element is an [`StatementTree::And`] with a vector
99 /// containing 3 [`StatementTree::Leaf`]s.
100 ///
101 /// Note that `AND`, `OR`, and `THRESH` in the expression are
102 /// case-insensitive.
103 pub fn parse(expr: &Expr) -> Result<Self> {
104 // See if the expression describes a combiner
105 if let Expr::Call(syn::ExprCall { func, args, .. }) = expr {
106 if let Expr::Path(syn::ExprPath { path, .. }) = func.as_ref() {
107 if let Some(funcname) = path.get_ident() {
108 match funcname.to_string().to_lowercase().as_str() {
109 "and" => {
110 let children: Result<Vec<StatementTree>> =
111 args.iter().map(Self::parse).collect();
112 return Ok(Self::And(children?));
113 }
114 "or" => {
115 let children: Result<Vec<StatementTree>> =
116 args.iter().map(Self::parse).collect();
117 return Ok(Self::Or(children?));
118 }
119 "thresh" => {
120 if let Some(Expr::Lit(syn::ExprLit {
121 lit: syn::Lit::Int(litint),
122 ..
123 })) = args.first()
124 {
125 let thresh = litint.base10_parse::<usize>()?;
126 // Remember that args.len() is one more
127 // than the number of expressions,
128 // because the first arg is the
129 // threshold
130 if thresh < 1 || thresh >= args.len() {
131 return Err(syn::Error::new(
132 litint.span(),
133 "threshold out of range",
134 ));
135 }
136 let children: Result<Vec<StatementTree>> =
137 args.iter().skip(1).map(Self::parse).collect();
138 return Ok(Self::Thresh(thresh, children?));
139 }
140 }
141 _ => {}
142 }
143 }
144 }
145 }
146 Ok(StatementTree::Leaf(expr.clone()))
147 }
148
149 /// A convenience function that takes a list of [`Expr`]s, and
150 /// returns the [`StatementTree`] that implicitly puts `AND` around
151 /// the [`Expr`]s. This is useful because a common thing to do is
152 /// to just write a list of [`Expr`]s in the top-level macro
153 /// invocation, having the semantics of "all of these must be true".
154 pub fn parse_andlist(exprlist: &[Expr]) -> Result<Self> {
155 let children: Result<Vec<StatementTree>> = exprlist.iter().map(Self::parse).collect();
156 Ok(StatementTree::And(children?))
157 }
158
159 /// Return a vector of references to all of the leaf expressions in
160 /// the [`StatementTree`]
161 pub fn leaves(&self) -> Vec<&Expr> {
162 match self {
163 StatementTree::Leaf(ref e) => vec![e],
164 StatementTree::And(v) | StatementTree::Or(v) | StatementTree::Thresh(_, v) => {
165 v.iter().fold(Vec::<&Expr>::new(), |mut b, st| {
166 b.extend(st.leaves());
167 b
168 })
169 }
170 }
171 }
172
173 /// Return a vector of mutable references to all of the leaf
174 /// expressions in the [`StatementTree`]
175 pub fn leaves_mut(&mut self) -> Vec<&mut Expr> {
176 match self {
177 StatementTree::Leaf(ref mut e) => vec![e],
178 StatementTree::And(v) | StatementTree::Or(v) | StatementTree::Thresh(_, v) => {
179 v.iter_mut().fold(Vec::<&mut Expr>::new(), |mut b, st| {
180 b.extend(st.leaves_mut());
181 b
182 })
183 }
184 }
185 }
186
187 /// Return a vector of mutable references to all of the leaves in
188 /// the [`StatementTree`]
189 pub fn leaves_st_mut(&mut self) -> Vec<&mut StatementTree> {
190 match self {
191 StatementTree::Leaf(_) => vec![self],
192 StatementTree::And(v) | StatementTree::Or(v) | StatementTree::Thresh(_, v) => v
193 .iter_mut()
194 .fold(Vec::<&mut StatementTree>::new(), |mut b, st| {
195 b.extend(st.leaves_st_mut());
196 b
197 }),
198 }
199 }
200
201 #[cfg(not(doctest))]
202 /// Verify whether the [`StatementTree`] satisfies the disjunction
203 /// invariant.
204 ///
205 /// A _disjunction node_ is an [`Or`](StatementTree::Or) or
206 /// [`Thresh`](StatementTree::Thresh) node in the [`StatementTree`].
207 ///
208 /// A _disjunction branch_ is a subtree rooted at a non-disjunction
209 /// node that is the child of a disjunction node or at the root of
210 /// the [`StatementTree`].
211 ///
212 /// The _disjunction invariant_ is that a private variable (which is
213 /// necessarily a `Scalar` since there are no private `Point`
214 /// variables) that appears in a disjunction branch cannot also
215 /// appear outside of that disjunction branch.
216 ///
217 /// For example, if all of the lowercase variables are private
218 /// `Scalar`s, the [`StatementTree`] created from:
219 ///
220 /// ```
221 /// AND (
222 /// C = c*B + r*A,
223 /// D = d*B + s*A,
224 /// OR (
225 /// AND (
226 /// C = c0*B + r0*A,
227 /// D = d0*B + s0*A,
228 /// c0 = d0,
229 /// ),
230 /// AND (
231 /// C = c1*B + r1*A,
232 /// D = d1*B + s1*A,
233 /// c1 = d1 + 1,
234 /// ),
235 /// )
236 /// )
237 /// ```
238 ///
239 /// satisfies the disjunction invariant, but
240 ///
241 /// ```
242 /// AND (
243 /// C = c*B + r*A,
244 /// D = d*B + s*A,
245 /// OR (
246 /// AND (
247 /// D = d0*B + s0*A,
248 /// c = d0,
249 /// ),
250 /// AND (
251 /// C = c1*B + r1*A,
252 /// D = d1*B + s1*A,
253 /// c1 = d1 + 1,
254 /// ),
255 /// )
256 /// )
257 /// ```
258 ///
259 /// does not, because `c` appears in the first child of the `OR` and
260 /// also outside of the `OR` entirely. Indeed, the reason to write
261 /// the first expression above rather than the more natural
262 ///
263 /// ```
264 /// AND (
265 /// C = c*B + r*A,
266 /// D = d*B + s*A,
267 /// OR (
268 /// c = d,
269 /// c = d + 1,
270 /// )
271 /// )
272 /// ```
273 ///
274 /// is exactly that the invariant must be satisfied.
275 ///
276 /// If you don't know that your [`StatementTree`] already satisfies
277 /// the invariant, call
278 /// [`enforce_disjunction_invariant`](super::super::enforce_disjunction_invariant),
279 /// which will transform the [`StatementTree`] so that it does (and
280 /// also call this
281 /// [`check_disjunction_invariant`](StatementTree::check_disjunction_invariant)
282 /// function as a sanity check).
283 pub fn check_disjunction_invariant(&self, vars: &VarDict) -> Result<()> {
284 let mut disjunct_map: HashMap<String, usize> = HashMap::new();
285
286 // If the recursive call returns Err, return that Err.
287 // Otherwise, we don't care about the Ok(usize) returned, so
288 // just return Ok(())
289 self.check_disjunction_invariant_rec(vars, &mut disjunct_map, 0, 0)?;
290 Ok(())
291 }
292
293 /// Internal recursive helper for
294 /// [`check_disjunction_invariant`](StatementTree::check_disjunction_invariant).
295 ///
296 /// The `disjunct_map` is a [`HashMap`] that maps the names of
297 /// variables to an identifier of which child of a disjunction node
298 /// the variable appears in (or the root if none). In the case of
299 /// nested disjunction node, the closest one to the leaf is what
300 /// matters. Nodes are numbered in pre-order fashion, starting at 0
301 /// for the root, 1 for the first child of the root, 2 for the first
302 /// child of node 1, etc. `cur_node` is the node id of `self`, and
303 /// `cur_disjunct_child` is the node id of the closest child of a
304 /// disjunction node (or 0 for the root if none). Returns the next
305 /// node id to use in the preorder traversal.
306 fn check_disjunction_invariant_rec(
307 &self,
308 vars: &VarDict,
309 disjunct_map: &mut HashMap<String, usize>,
310 cur_node: usize,
311 cur_disjunct_child: usize,
312 ) -> Result<usize> {
313 let mut next_node = cur_node;
314 match self {
315 Self::And(v) => {
316 for st in v {
317 next_node = st.check_disjunction_invariant_rec(
318 vars,
319 disjunct_map,
320 next_node + 1,
321 cur_disjunct_child,
322 )?;
323 }
324 }
325 Self::Or(v) | Self::Thresh(_, v) => {
326 for st in v {
327 next_node = st.check_disjunction_invariant_rec(
328 vars,
329 disjunct_map,
330 next_node + 1,
331 next_node + 1,
332 )?;
333 }
334 }
335 Self::Leaf(e) => {
336 let mut psmap = PrivScalarMap {
337 vars,
338 closure: &mut |ident| {
339 let varname = ident.to_string();
340 if let Some(dis_id) = disjunct_map.get(&varname) {
341 if *dis_id != cur_disjunct_child {
342 return Err(syn::Error::new(
343 ident.span(),
344 "Disjunction invariant violation: a private variable cannot appear both inside and outside a single term of an OR or THRESH"));
345 }
346 } else {
347 disjunct_map.insert(varname, cur_disjunct_child);
348 }
349 Ok(())
350 },
351 result: Ok(()),
352 };
353 psmap.visit_expr(e);
354 psmap.result?;
355 }
356 }
357 Ok(next_node)
358 }
359
360 /// Call the supplied closure for each [disjunction branch] of the
361 /// given [`StatementTree`] (including the root, if the root is a
362 /// non-disjunction node).
363 ///
364 /// The calls are in preorder traversal (parents before children).
365 /// The given `closure` will be called with the root of each
366 /// [disjunction branch] as well as a slice of [`usize`] indicating
367 /// the path through the [`StatementTree`] to that disjunction
368 /// branch. The disjunction branch at the root has path `[]`.
369 /// The disjunction branch rooted at, say, the 2nd child of an `Or`
370 /// node in the root disjunction branch will have path `[2]`. The
371 /// disjunction branch rooted at the 1st child of an `Or` node in
372 /// that disjunction branch will have path `[2,1]`, and so on.
373 ///
374 /// Abort and return `Err` if any call to the closure returns `Err`.
375 ///
376 /// [disjunction branch]: StatementTree::check_disjunction_invariant
377 pub fn for_each_disjunction_branch(
378 &mut self,
379 closure: &mut dyn FnMut(&mut StatementTree, &[usize]) -> Result<()>,
380 ) -> Result<()> {
381 let mut path: Vec<usize> = Vec::new();
382 self.for_each_disjunction_branch_rec(closure, &mut path, 0, true)?;
383 Ok(())
384 }
385
386 /// Internal recursive helper for
387 /// [`for_each_disjunction_branch`](StatementTree::for_each_disjunction_branch).
388 ///
389 /// - `path` is the path to this disjunction branch
390 /// - `last_index` is the last index used for a child of this
391 /// disjunction branch
392 /// - `is_new_branch` is `true` if this node is the start of a new
393 /// disjunction branch
394 ///
395 /// The return value (if `Ok`) is the updated value of `last_index`.
396 fn for_each_disjunction_branch_rec(
397 &mut self,
398 closure: &mut dyn FnMut(&mut StatementTree, &[usize]) -> Result<()>,
399 path: &mut Vec<usize>,
400 mut last_index: usize,
401 is_new_branch: bool,
402 ) -> Result<usize> {
403 // We're starting a new branch (and should call the closure) if
404 // and only if both is_new_branch is true, and also we're at a
405 // non-disjunction node
406 match self {
407 StatementTree::Leaf(_) | StatementTree::And(_) => {
408 if is_new_branch {
409 (closure)(self, path)?;
410 }
411 }
412 _ => {}
413 }
414 match self {
415 StatementTree::Leaf(_) => {}
416 StatementTree::And(stvec) => {
417 stvec.iter_mut().try_for_each(|st| -> Result<()> {
418 last_index =
419 st.for_each_disjunction_branch_rec(closure, path, last_index, false)?;
420 Ok(())
421 })?;
422 }
423 StatementTree::Or(stvec) | StatementTree::Thresh(_, stvec) => {
424 path.push(last_index);
425 let pathlen = path.len();
426 stvec.iter_mut().try_for_each(|st| -> Result<()> {
427 last_index += 1;
428 path[pathlen - 1] = last_index;
429 st.for_each_disjunction_branch_rec(closure, path, 0, true)?;
430 Ok(())
431 })?;
432 path.pop();
433 }
434 }
435 Ok(last_index)
436 }
437
438 /// Call the supplied closure for each [`StatementTree::Leaf`] of
439 /// the given [disjunction branch].
440 ///
441 /// Abort and return `Err` if any call to the closure returns `Err`.
442 ///
443 /// [disjunction branch]: StatementTree::check_disjunction_invariant
444 pub fn for_each_disjunction_branch_leaf(
445 &mut self,
446 closure: &mut dyn FnMut(&mut StatementTree) -> Result<()>,
447 ) -> Result<()> {
448 match self {
449 StatementTree::Leaf(_) => {
450 (closure)(self)?;
451 }
452 StatementTree::And(stvec) => {
453 stvec
454 .iter_mut()
455 .try_for_each(|st| st.for_each_disjunction_branch_leaf(closure))?;
456 }
457 StatementTree::Or(_) | StatementTree::Thresh(_, _) => {
458 // Don't recurse into Or or Thresh nodes, since the
459 // children of those nodes are in different disjunction
460 // branches.
461 }
462 }
463 Ok(())
464 }
465
466 /// Produce a [`HashSet`] of the private Scalars that appear in any
467 /// leaf of the given [disjunction branch].
468 ///
469 /// [disjunction branch]: StatementTree::check_disjunction_invariant
470 pub fn disjunction_branch_priv_scalars(&mut self, vars: &VarDict) -> HashSet<Ident> {
471 let mut priv_scalars: HashSet<Ident> = HashSet::new();
472 self.for_each_disjunction_branch_leaf(&mut |leaf| {
473 if let StatementTree::Leaf(leafexpr) = leaf {
474 let mut psmap = PrivScalarMap {
475 vars,
476 closure: &mut |ident| {
477 priv_scalars.insert(ident.clone());
478 Ok(())
479 },
480 result: Ok(()),
481 };
482 psmap.visit_expr(leafexpr);
483 }
484 Ok(())
485 })
486 .unwrap();
487 priv_scalars
488 }
489
490 #[cfg(not(doctest))]
491 /// Flatten nested `And` nodes in a [`StatementTree`].
492 ///
493 /// The underlying `sigma-proofs` crate can share `Scalars` across
494 /// statements that are direct children of the same `And` node, but
495 /// not in nested `And` nodes.
496 ///
497 /// So a [`StatementTree`] like this:
498 ///
499 /// ```
500 /// AND (
501 /// C = x*B + r*A,
502 /// AND (
503 /// D = x*B + s*A,
504 /// E = x*B + t*A,
505 /// ),
506 /// )
507 /// ```
508 ///
509 /// Needs to be flattened to:
510 ///
511 /// ```
512 /// AND (
513 /// C = x*B + r*A,
514 /// D = x*B + s*A,
515 /// E = x*B + t*A,
516 /// )
517 /// ```
518 pub fn flatten_ands(&mut self) {
519 match self {
520 StatementTree::Leaf(_) => {}
521 StatementTree::Or(svec) | StatementTree::Thresh(_, svec) => {
522 // Flatten each child
523 svec.iter_mut().for_each(|st| st.flatten_ands());
524 }
525 StatementTree::And(svec) => {
526 // Flatten each child, and if any of the children are
527 // `And`s, replace that child with the list of its
528 // children
529 let old_svec = std::mem::take(svec);
530 let mut new_svec: Vec<StatementTree> = Vec::new();
531 for mut st in old_svec {
532 st.flatten_ands();
533 match st {
534 StatementTree::And(mut child_svec) => {
535 new_svec.append(&mut child_svec);
536 }
537 _ => {
538 new_svec.push(st);
539 }
540 }
541 }
542 *self = StatementTree::And(new_svec);
543 }
544 }
545 }
546
547 /// Produce a [`StatementTree`] that represents the constant `true`
548 pub fn leaf_true() -> StatementTree {
549 StatementTree::Leaf(parse_quote! { true })
550 }
551
552 /// Test if the given [`StatementTree`] represents the constant `true`
553 pub fn is_leaf_true(&self) -> bool {
554 if let StatementTree::Leaf(Expr::Lit(exprlit)) = self {
555 if let syn::Lit::Bool(syn::LitBool { value: true, .. }) = exprlit.lit {
556 return true;
557 }
558 }
559 false
560 }
561
562 fn dump_int(&self, depth: usize) {
563 match self {
564 StatementTree::Leaf(e) => {
565 println!(
566 "{:1$}{2},",
567 "",
568 depth * 2,
569 quote! { #e }.to_string().replace('\n', " ")
570 )
571 }
572 StatementTree::And(v) => {
573 println!("{:1$}And (", "", depth * 2);
574 v.iter().for_each(|n| n.dump_int(depth + 1));
575 println!("{:1$})", "", depth * 2);
576 }
577 StatementTree::Or(v) => {
578 println!("{:1$}Or (", "", depth * 2);
579 v.iter().for_each(|n| n.dump_int(depth + 1));
580 println!("{:1$})", "", depth * 2);
581 }
582 StatementTree::Thresh(thresh, v) => {
583 println!("{:1$}Thresh ({2}", "", depth * 2, thresh);
584 v.iter().for_each(|n| n.dump_int(depth + 1));
585 println!("{:1$})", "", depth * 2);
586 }
587 }
588 }
589
590 pub fn dump(&self) {
591 self.dump_int(0);
592 }
593}
594
595#[cfg(test)]
596mod test {
597 use super::StatementTree::*;
598 use super::*;
599 use quote::quote;
600
601 #[test]
602 fn leaf_true_test() {
603 assert!(StatementTree::leaf_true().is_leaf_true());
604 assert!(!StatementTree::Leaf(parse_quote! { false }).is_leaf_true());
605 assert!(!StatementTree::Leaf(parse_quote! { 1 }).is_leaf_true());
606 assert!(!StatementTree::parse(&parse_quote! {
607 OR(1=1, a=b)
608 })
609 .unwrap()
610 .is_leaf_true());
611 }
612
613 #[test]
614 fn combiners_simple_test() {
615 let exprlist: Vec<Expr> = vec![
616 parse_quote! { C = c*B + r*A },
617 parse_quote! { D = d*B + s*A },
618 parse_quote! { c = d },
619 ];
620
621 let statementtree = StatementTree::parse_andlist(&exprlist).unwrap();
622 let And(v) = statementtree else {
623 panic!("Incorrect result");
624 };
625 let [Leaf(l0), Leaf(l1), Leaf(l2)] = v.as_slice() else {
626 panic!("Incorrect result");
627 };
628 assert_eq!(quote! {#l0}.to_string(), "C = c * B + r * A");
629 assert_eq!(quote! {#l1}.to_string(), "D = d * B + s * A");
630 assert_eq!(quote! {#l2}.to_string(), "c = d");
631 }
632
633 #[test]
634 fn combiners_nested_test() {
635 let exprlist: Vec<Expr> = vec![
636 parse_quote! { C = c*B + r*A },
637 parse_quote! { D = d*B + s*A },
638 parse_quote! {
639 OR (
640 AND (
641 C = c0*B + r0*A,
642 D = d0*B + s0*A,
643 c0 = d0,
644 ),
645 AND (
646 C = c1*B + r1*A,
647 D = d1*B + s1*A,
648 c1 = d1 + 1,
649 ),
650 ) },
651 ];
652
653 let statementtree = StatementTree::parse_andlist(&exprlist).unwrap();
654 let And(v0) = statementtree else {
655 panic!("Incorrect result");
656 };
657 let [Leaf(l0), Leaf(l1), Or(v1)] = v0.as_slice() else {
658 panic!("Incorrect result");
659 };
660 assert_eq!(quote! {#l0}.to_string(), "C = c * B + r * A");
661 assert_eq!(quote! {#l1}.to_string(), "D = d * B + s * A");
662 let [And(v2), And(v3)] = v1.as_slice() else {
663 panic!("Incorrect result");
664 };
665 let [Leaf(l20), Leaf(l21), Leaf(l22)] = v2.as_slice() else {
666 panic!("Incorrect result");
667 };
668 assert_eq!(quote! {#l20}.to_string(), "C = c0 * B + r0 * A");
669 assert_eq!(quote! {#l21}.to_string(), "D = d0 * B + s0 * A");
670 assert_eq!(quote! {#l22}.to_string(), "c0 = d0");
671 let [Leaf(l30), Leaf(l31), Leaf(l32)] = v3.as_slice() else {
672 panic!("Incorrect result");
673 };
674 assert_eq!(quote! {#l30}.to_string(), "C = c1 * B + r1 * A");
675 assert_eq!(quote! {#l31}.to_string(), "D = d1 * B + s1 * A");
676 assert_eq!(quote! {#l32}.to_string(), "c1 = d1 + 1");
677 }
678
679 #[test]
680 fn combiners_thresh_test() {
681 let exprlist: Vec<Expr> = vec![
682 parse_quote! { C = c*B + r*A },
683 parse_quote! { D = d*B + s*A },
684 parse_quote! {
685 THRESH (1,
686 AND (
687 C = c0*B + r0*A,
688 D = d0*B + s0*A,
689 c0 = d0,
690 ),
691 AND (
692 C = c1*B + r1*A,
693 D = d1*B + s1*A,
694 c1 = d1 + 1,
695 ),
696 ) },
697 ];
698
699 let statementtree = StatementTree::parse_andlist(&exprlist).unwrap();
700 let And(v0) = statementtree else {
701 panic!("Incorrect result");
702 };
703 let [Leaf(l0), Leaf(l1), Thresh(thresh, v1)] = v0.as_slice() else {
704 panic!("Incorrect result");
705 };
706 assert_eq!(*thresh, 1);
707 assert_eq!(quote! {#l0}.to_string(), "C = c * B + r * A");
708 assert_eq!(quote! {#l1}.to_string(), "D = d * B + s * A");
709 let [And(v2), And(v3)] = v1.as_slice() else {
710 panic!("Incorrect result");
711 };
712 let [Leaf(l20), Leaf(l21), Leaf(l22)] = v2.as_slice() else {
713 panic!("Incorrect result");
714 };
715 assert_eq!(quote! {#l20}.to_string(), "C = c0 * B + r0 * A");
716 assert_eq!(quote! {#l21}.to_string(), "D = d0 * B + s0 * A");
717 assert_eq!(quote! {#l22}.to_string(), "c0 = d0");
718 let [Leaf(l30), Leaf(l31), Leaf(l32)] = v3.as_slice() else {
719 panic!("Incorrect result");
720 };
721 assert_eq!(quote! {#l30}.to_string(), "C = c1 * B + r1 * A");
722 assert_eq!(quote! {#l31}.to_string(), "D = d1 * B + s1 * A");
723 assert_eq!(quote! {#l32}.to_string(), "c1 = d1 + 1");
724 }
725
726 #[test]
727 #[should_panic]
728 fn combiners_bad_thresh_test() {
729 // The threshold is out of range
730 let exprlist: Vec<Expr> = vec![
731 parse_quote! { C = c*B + r*A },
732 parse_quote! { D = d*B + s*A },
733 parse_quote! {
734 THRESH (3,
735 AND (
736 C = c0*B + r0*A,
737 D = d0*B + s0*A,
738 c0 = d0,
739 ),
740 AND (
741 C = c1*B + r1*A,
742 D = d1*B + s1*A,
743 c1 = d1 + 1,
744 ),
745 ) },
746 ];
747
748 StatementTree::parse_andlist(&exprlist).unwrap();
749 }
750
751 #[test]
752 // Test the disjunction invariant checker
753 fn disjunction_invariant_test() {
754 let vars: VarDict = vardict_from_strs(&[
755 ("c", "S"),
756 ("d", "S"),
757 ("c0", "S"),
758 ("c1", "S"),
759 ("d0", "S"),
760 ("d1", "S"),
761 ("A", "pP"),
762 ("B", "pP"),
763 ("C", "pP"),
764 ("D", "pP"),
765 ]);
766 // This one is OK
767 let st_ok = StatementTree::parse(&parse_quote! {
768 AND (
769 C = c*B + r*A,
770 D = d*B + s*A,
771 OR (
772 AND (
773 C = c0*B + r0*A,
774 D = d0*B + s0*A,
775 c0 = d0,
776 ),
777 AND (
778 C = c1*B + r1*A,
779 D = d1*B + s1*A,
780 c1 = d1 + 1,
781 ),
782 )
783 )
784 })
785 .unwrap();
786 // not OK: c0 appears in two branches of the OR
787 let st_nok1 = StatementTree::parse(&parse_quote! {
788 AND (
789 C = c*B + r*A,
790 D = d*B + s*A,
791 OR (
792 AND (
793 C = c0*B + r0*A,
794 D = d0*B + s0*A,
795 c0 = d0,
796 ),
797 AND (
798 C = c0*B + r0*A,
799 D = d1*B + s1*A,
800 c0 = d1 + 1,
801 ),
802 )
803 )
804 })
805 .unwrap();
806 // not OK: c appears in one branch of the OR and also outside
807 // the OR
808 let st_nok2 = StatementTree::parse(&parse_quote! {
809 AND (
810 C = c*B + r*A,
811 D = d*B + s*A,
812 OR (
813 AND (
814 D = d0*B + s0*A,
815 c = d0,
816 ),
817 AND (
818 C = c1*B + r1*A,
819 D = d1*B + s1*A,
820 c1 = d1 + 1,
821 ),
822 )
823 )
824 })
825 .unwrap();
826 // not OK: c and d appear in both branches of the OR, and also
827 // outside it
828 let st_nok3 = StatementTree::parse(&parse_quote! {
829 AND (
830 C = c*B + r*A,
831 D = d*B + s*A,
832 OR (
833 c = d,
834 c = d + 1,
835 )
836 )
837 })
838 .unwrap();
839 st_ok.check_disjunction_invariant(&vars).unwrap();
840 st_nok1.check_disjunction_invariant(&vars).unwrap_err();
841 st_nok2.check_disjunction_invariant(&vars).unwrap_err();
842 st_nok3.check_disjunction_invariant(&vars).unwrap_err();
843 }
844
845 fn disjunction_branch_tester(e: Expr, expected: Vec<(Vec<usize>, Expr)>) {
846 let mut output: Vec<(Vec<usize>, StatementTree)> = Vec::new();
847 let expected_st: Vec<(Vec<usize>, StatementTree)> = expected
848 .iter()
849 .map(|(path, ex)| (path.clone(), StatementTree::parse(ex).unwrap()))
850 .collect();
851 let mut st = StatementTree::parse(&e).unwrap();
852 st.for_each_disjunction_branch(&mut |db, path| {
853 output.push((path.to_vec(), db.clone()));
854 Ok(())
855 })
856 .unwrap();
857 assert_eq!(output, expected_st);
858 }
859
860 fn disjunction_branch_abort_tester(e: Expr, expected: Vec<(Vec<usize>, Expr)>) {
861 let mut output: Vec<(Vec<usize>, StatementTree)> = Vec::new();
862 let expected_st: Vec<(Vec<usize>, StatementTree)> = expected
863 .iter()
864 .map(|(path, ex)| (path.clone(), StatementTree::parse(ex).unwrap()))
865 .collect();
866 let mut st = StatementTree::parse(&e).unwrap();
867 st.for_each_disjunction_branch(&mut |st, path| {
868 if st.is_leaf_true() {
869 return Err(syn::Error::new(proc_macro2::Span::call_site(), "true leaf"));
870 }
871 output.push((path.to_vec(), st.clone()));
872 Ok(())
873 })
874 .unwrap_err();
875 assert_eq!(output, expected_st);
876 }
877
878 #[test]
879 fn disjunction_branch_test() {
880 disjunction_branch_tester(
881 parse_quote! {
882 C = c*B + r*A
883 },
884 vec![(
885 vec![],
886 parse_quote! {
887 C = c*B + r*A
888 },
889 )],
890 );
891
892 disjunction_branch_tester(
893 parse_quote! {
894 AND (
895 C = c*B + r*A,
896 D = d*B + s*A,
897 OR (
898 c = d,
899 c = d + 1,
900 )
901 )
902 },
903 vec![
904 (
905 vec![],
906 parse_quote! {
907 AND (
908 C = c*B + r*A,
909 D = d*B + s*A,
910 OR (
911 c = d,
912 c = d + 1,
913 )
914 )
915 },
916 ),
917 (
918 vec![1],
919 parse_quote! {
920 c = d
921 },
922 ),
923 (
924 vec![2],
925 parse_quote! {
926 c = d + 1
927 },
928 ),
929 ],
930 );
931
932 disjunction_branch_tester(
933 parse_quote! {
934 OR (
935 C = c*B + r*A,
936 D = c*B + r*A,
937 )
938 },
939 vec![
940 (vec![1], parse_quote! { C = c*B + r*A }),
941 (vec![2], parse_quote! { D = c*B + r*A }),
942 ],
943 );
944
945 disjunction_branch_tester(
946 parse_quote! {
947 AND (
948 C = c*B + r*A,
949 D = d*B + s*A,
950 OR (
951 AND (
952 c = d,
953 D = a*B + b*A,
954 OR (
955 d = 5,
956 d = 6,
957 )
958 ),
959 c = d + 1,
960 )
961 )
962 },
963 vec![
964 (
965 vec![],
966 parse_quote! {
967 AND (
968 C = c*B + r*A,
969 D = d*B + s*A,
970 OR (
971 AND (
972 c = d,
973 D = a*B + b*A,
974 OR (
975 d = 5,
976 d = 6,
977 )
978 ),
979 c = d + 1,
980 )
981 )
982 },
983 ),
984 (
985 vec![1],
986 parse_quote! {
987 AND (
988 c = d,
989 D = a*B + b*A,
990 OR (
991 d = 5,
992 d = 6,
993 )
994 )
995 },
996 ),
997 (
998 vec![1, 1],
999 parse_quote! {
1000 d = 5
1001 },
1002 ),
1003 (
1004 vec![1, 2],
1005 parse_quote! {
1006 d = 6
1007 },
1008 ),
1009 (
1010 vec![2],
1011 parse_quote! {
1012 c = d + 1
1013 },
1014 ),
1015 ],
1016 );
1017
1018 disjunction_branch_tester(
1019 parse_quote! {
1020 AND (
1021 C = c*B + r*A,
1022 D = d*B + s*A,
1023 AND (
1024 c = d + 1,
1025 AND (
1026 s = r,
1027 OR (
1028 d = 1,
1029 AND (
1030 d = 2,
1031 s = 1,
1032 )
1033 )
1034 )
1035 ),
1036 OR (
1037 AND (
1038 c = d,
1039 D = a*B + b*A,
1040 OR (
1041 d = 5,
1042 d = 6,
1043 )
1044 ),
1045 c = d + 1,
1046 )
1047 )
1048 },
1049 vec![
1050 (
1051 vec![],
1052 parse_quote! {
1053 AND (
1054 C = c*B + r*A,
1055 D = d*B + s*A,
1056 AND (
1057 c = d + 1,
1058 AND (
1059 s = r,
1060 OR (
1061 d = 1,
1062 AND (
1063 d = 2,
1064 s = 1,
1065 )
1066 )
1067 )
1068 ),
1069 OR (
1070 AND (
1071 c = d,
1072 D = a*B + b*A,
1073 OR (
1074 d = 5,
1075 d = 6,
1076 )
1077 ),
1078 c = d + 1,
1079 )
1080 )
1081 },
1082 ),
1083 (vec![1], parse_quote! { d = 1 }),
1084 (
1085 vec![2],
1086 parse_quote! {
1087 AND (
1088 d = 2,
1089 s = 1,
1090 )
1091 },
1092 ),
1093 (
1094 vec![3],
1095 parse_quote! {
1096 AND (
1097 c = d,
1098 D = a*B + b*A,
1099 OR (
1100 d = 5,
1101 d = 6,
1102 )
1103 )
1104 },
1105 ),
1106 (
1107 vec![3, 1],
1108 parse_quote! {
1109 d = 5
1110 },
1111 ),
1112 (
1113 vec![3, 2],
1114 parse_quote! {
1115 d = 6
1116 },
1117 ),
1118 (
1119 vec![4],
1120 parse_quote! {
1121 c = d + 1
1122 },
1123 ),
1124 ],
1125 );
1126
1127 disjunction_branch_abort_tester(
1128 parse_quote! {
1129 AND (
1130 C = c*B + r*A,
1131 D = d*B + s*A,
1132 OR (
1133 AND (
1134 c = d,
1135 D = a*B + b*A,
1136 OR (
1137 d = 5,
1138 true,
1139 d = 6,
1140 )
1141 ),
1142 c = d + 1,
1143 )
1144 )
1145 },
1146 vec![
1147 (
1148 vec![],
1149 parse_quote! {
1150 AND (
1151 C = c*B + r*A,
1152 D = d*B + s*A,
1153 OR (
1154 AND (
1155 c = d,
1156 D = a*B + b*A,
1157 OR (
1158 d = 5,
1159 true,
1160 d = 6,
1161 )
1162 ),
1163 c = d + 1,
1164 )
1165 )
1166 },
1167 ),
1168 (
1169 vec![1],
1170 parse_quote! {
1171 AND (
1172 c = d,
1173 D = a*B + b*A,
1174 OR (
1175 d = 5,
1176 true,
1177 d = 6,
1178 )
1179 )
1180 },
1181 ),
1182 (
1183 vec![1, 1],
1184 parse_quote! {
1185 d = 5
1186 },
1187 ),
1188 ],
1189 );
1190 }
1191
1192 fn disjunction_branch_leaf_tester(e: Expr, expected: Vec<(Vec<usize>, Vec<Expr>)>) {
1193 let mut output: Vec<(Vec<usize>, Vec<StatementTree>)> = Vec::new();
1194 let expected_st: Vec<(Vec<usize>, Vec<StatementTree>)> = expected
1195 .iter()
1196 .map(|(path, vex)| {
1197 (
1198 path.clone(),
1199 vex.iter()
1200 .map(|ex| StatementTree::parse(ex).unwrap())
1201 .collect(),
1202 )
1203 })
1204 .collect();
1205 let mut st = StatementTree::parse(&e).unwrap();
1206 st.for_each_disjunction_branch(&mut |db, path| {
1207 let mut dis_branch_output: Vec<StatementTree> = Vec::new();
1208 db.for_each_disjunction_branch_leaf(&mut |leaf| {
1209 dis_branch_output.push(leaf.clone());
1210 Ok(())
1211 })
1212 .unwrap();
1213 output.push((path.to_vec(), dis_branch_output));
1214 Ok(())
1215 })
1216 .unwrap();
1217 assert_eq!(output, expected_st);
1218 }
1219
1220 fn disjunction_branch_leaf_abort_tester(e: Expr, expected: Vec<(Vec<usize>, Vec<Expr>)>) {
1221 let mut output: Vec<(Vec<usize>, Vec<StatementTree>)> = Vec::new();
1222 let expected_st: Vec<(Vec<usize>, Vec<StatementTree>)> = expected
1223 .iter()
1224 .map(|(path, vex)| {
1225 (
1226 path.clone(),
1227 vex.iter()
1228 .map(|ex| StatementTree::parse(ex).unwrap())
1229 .collect(),
1230 )
1231 })
1232 .collect();
1233 let mut st = StatementTree::parse(&e).unwrap();
1234 st.for_each_disjunction_branch(&mut |db, path| {
1235 let mut dis_branch_output: Vec<StatementTree> = Vec::new();
1236 db.for_each_disjunction_branch_leaf(&mut |leaf| {
1237 if leaf.is_leaf_true() {
1238 return Err(syn::Error::new(proc_macro2::Span::call_site(), "true leaf"));
1239 }
1240 dis_branch_output.push(leaf.clone());
1241 Ok(())
1242 })?;
1243 output.push((path.to_vec(), dis_branch_output));
1244 Ok(())
1245 })
1246 .unwrap_err();
1247 assert_eq!(output, expected_st);
1248 }
1249
1250 #[test]
1251 fn disjunction_branch_leaf_test() {
1252 disjunction_branch_leaf_tester(
1253 parse_quote! {
1254 C = c*B + r*A
1255 },
1256 vec![(vec![], vec![parse_quote! { C = c*B + r*A }])],
1257 );
1258
1259 disjunction_branch_leaf_tester(
1260 parse_quote! {
1261 AND (
1262 C = c*B + r*A,
1263 D = d*B + s*A,
1264 OR (
1265 c = d,
1266 c = d + 1,
1267 )
1268 )
1269 },
1270 vec![
1271 (
1272 vec![],
1273 vec![
1274 parse_quote! { C = c*B + r*A },
1275 parse_quote! { D = d*B + s*A },
1276 ],
1277 ),
1278 (vec![1], vec![parse_quote! { c = d }]),
1279 (vec![2], vec![parse_quote! { c = d + 1 }]),
1280 ],
1281 );
1282
1283 disjunction_branch_leaf_tester(
1284 parse_quote! {
1285 AND (
1286 C = c*B + r*A,
1287 D = d*B + s*A,
1288 OR (
1289 c = d,
1290 OR (
1291 c = d + 1,
1292 c = d + 2,
1293 )
1294 )
1295 )
1296 },
1297 vec![
1298 (
1299 vec![],
1300 vec![
1301 parse_quote! { C = c*B + r*A },
1302 parse_quote! { D = d*B + s*A },
1303 ],
1304 ),
1305 (vec![1], vec![parse_quote! { c = d }]),
1306 (vec![2, 1], vec![parse_quote! { c = d + 1 }]),
1307 (vec![2, 2], vec![parse_quote! { c = d + 2 }]),
1308 ],
1309 );
1310
1311 disjunction_branch_leaf_tester(
1312 parse_quote! {
1313 AND (
1314 C = c*B + r*A,
1315 D = d*B + s*A,
1316 OR (
1317 AND (
1318 c = d,
1319 D = a*B + b*A,
1320 OR (
1321 d = 5,
1322 d = 6,
1323 )
1324 ),
1325 c = d + 1,
1326 )
1327 )
1328 },
1329 vec![
1330 (
1331 vec![],
1332 vec![
1333 parse_quote! { C = c*B + r*A },
1334 parse_quote! { D = d*B + s*A },
1335 ],
1336 ),
1337 (
1338 vec![1],
1339 vec![
1340 parse_quote! { c = d },
1341 parse_quote! { D
1342 = a*B + b*A },
1343 ],
1344 ),
1345 (vec![1, 1], vec![parse_quote! { d = 5 }]),
1346 (vec![1, 2], vec![parse_quote! { d = 6 }]),
1347 (vec![2], vec![parse_quote! { c = d + 1 }]),
1348 ],
1349 );
1350
1351 disjunction_branch_leaf_abort_tester(
1352 parse_quote! {
1353 AND (
1354 C = c*B + r*A,
1355 D = d*B + s*A,
1356 OR (
1357 AND (
1358 c = d,
1359 D = a*B + b*A,
1360 OR (
1361 d = 5,
1362 true,
1363 d = 6,
1364 )
1365 ),
1366 c = d + 1,
1367 )
1368 )
1369 },
1370 vec![
1371 (
1372 vec![],
1373 vec![
1374 parse_quote! { C = c*B + r*A },
1375 parse_quote! { D = d*B + s*A },
1376 ],
1377 ),
1378 (
1379 vec![1],
1380 vec![
1381 parse_quote! { c = d },
1382 parse_quote! { D
1383 = a*B + b*A },
1384 ],
1385 ),
1386 (vec![1, 1], vec![parse_quote! { d = 5 }]),
1387 ],
1388 );
1389 }
1390
1391 fn flatten_ands_tester(e: Expr, flattened_e: Expr) {
1392 let mut st = StatementTree::parse(&e).unwrap();
1393 st.flatten_ands();
1394 assert_eq!(st, StatementTree::parse(&flattened_e).unwrap());
1395 }
1396
1397 #[test]
1398 // Test flatten_ands
1399 fn flatten_ands_test() {
1400 flatten_ands_tester(
1401 parse_quote! {
1402 C = x*B + r*A
1403 },
1404 parse_quote! {
1405 C = x*B + r*A
1406 },
1407 );
1408
1409 flatten_ands_tester(
1410 parse_quote! {
1411 AND (
1412 C = x*B + r*A,
1413 AND (
1414 D = x*B + s*A,
1415 E = x*B + t*A,
1416 ),
1417 )
1418 },
1419 parse_quote! {
1420 AND (
1421 C = x*B + r*A,
1422 D = x*B + s*A,
1423 E = x*B + t*A,
1424 )
1425 },
1426 );
1427
1428 flatten_ands_tester(
1429 parse_quote! {
1430 AND (
1431 AND (
1432 OR (
1433 D = B + s*A,
1434 D = s*A,
1435 ),
1436 D = x*B + t*A,
1437 ),
1438 C = x*B + r*A,
1439 )
1440 },
1441 parse_quote! {
1442 AND (
1443 OR (
1444 D = B + s*A,
1445 D = s*A,
1446 ),
1447 D = x*B + t*A,
1448 C = x*B + r*A,
1449 )
1450 },
1451 );
1452
1453 flatten_ands_tester(
1454 parse_quote! {
1455 AND (
1456 AND (
1457 OR (
1458 D = B + s*A,
1459 AND (
1460 D = s*A,
1461 AND (
1462 E = s*B,
1463 F = s*C,
1464 ),
1465 ),
1466 ),
1467 D = x*B + t*A,
1468 ),
1469 C = x*B + r*A,
1470 )
1471 },
1472 parse_quote! {
1473 AND (
1474 OR (
1475 D = B + s*A,
1476 AND (
1477 D = s*A,
1478 E = s*B,
1479 F = s*C,
1480 )
1481 ),
1482 D = x*B + t*A,
1483 C = x*B + r*A,
1484 )
1485 },
1486 );
1487 }
1488}