sigma_compiler_core/sigma/
combiners.rs

1//! This module creates and manipulates trees of basic statements
2//! combined with `AND`, `OR`, and `THRESH`.
3
4use super::types::*;
5use quote::quote;
6use std::collections::{HashMap, HashSet};
7use syn::parse::Result;
8use syn::visit::Visit;
9use syn::{parse_quote, Expr, Ident};
10
11/// For each [`Ident`](struct@syn::Ident) representing a private
12/// `Scalar` (as listed in a [`VarDict`]) that appears in an [`Expr`],
13/// call a given closure.
14pub struct PrivScalarMap<'a> {
15    /// The [`VarDict`] that maps variable names to their types
16    pub vars: &'a VarDict,
17
18    /// The closure that is called for each [`Ident`](struct@syn::Ident)
19    /// found in the [`Expr`] (provided in the call to
20    /// [`visit_expr`](PrivScalarMap::visit_expr)) that represents a
21    /// private `Scalar`
22    pub closure: &'a mut dyn FnMut(&syn::Ident) -> Result<()>,
23
24    /// The accumulated result.  This will be the first
25    /// [`Err`](Result::Err) returned from the closure, or
26    /// [`Ok(())`](Result::Ok) if all calls to the closure succeeded.
27    pub result: Result<()>,
28}
29
30impl<'a> Visit<'a> for PrivScalarMap<'a> {
31    fn visit_path(&mut self, path: &'a syn::Path) {
32        // Whenever we see a `Path`, check first if it's just a bare
33        // `Ident`
34        let Some(id) = path.get_ident() else {
35            return;
36        };
37        // Then check if that `Ident` appears in the `VarDict`
38        let Some(vartype) = self.vars.get(&id.to_string()) else {
39            return;
40        };
41        // If so, and the `Ident` represents a private Scalar,
42        // call the closure if we haven't seen an `Err` returned from
43        // the closure yet.
44        if let AExprType::Scalar { is_pub: false, .. } = vartype {
45            if self.result.is_ok() {
46                self.result = (self.closure)(id);
47            }
48        }
49    }
50}
51
52/// The statements in the ZKP form a tree.  The leaves are basic
53/// statements of various kinds; for example, equations or inequalities
54/// about Scalars and Points.  The interior nodes are combiners: `And`,
55/// `Or`, or `Thresh` (with a given constant threshold).  A leaf is true
56/// if the basic statement it contains is true.  An `And` node is true
57/// if all of its children are true.  An `Or` node is true if at least
58/// one of its children is true.  A `Thresh` node (with threshold `k`) is
59/// true if at least `k` of its children are true.
60
61#[derive(Clone, Debug, Eq, PartialEq)]
62pub enum StatementTree {
63    Leaf(Expr),
64    And(Vec<StatementTree>),
65    Or(Vec<StatementTree>),
66    Thresh(usize, Vec<StatementTree>),
67}
68
69impl StatementTree {
70    #[cfg(not(doctest))]
71    /// Parse an [`Expr`] (which may contain nested `AND`, `OR`, or
72    /// `THRESH`) into a [`StatementTree`].  For example, the
73    /// [`Expr`] obtained from:
74    /// ```
75    /// parse_quote! {
76    ///    AND (
77    ///        C = c*B + r*A,
78    ///        D = d*B + s*A,
79    ///        OR (
80    ///            AND (
81    ///                C = c0*B + r0*A,
82    ///                D = d0*B + s0*A,
83    ///                c0 = d0,
84    ///            ),
85    ///            AND (
86    ///                C = c1*B + r1*A,
87    ///                D = d1*B + s1*A,
88    ///                c1 = d1 + 1,
89    ///            ),
90    ///        )
91    ///    )
92    /// }
93    /// ```
94    ///
95    /// would yield a [`StatementTree::And`] containing a 3-element
96    /// vector.  The first two elements are [`StatementTree::Leaf`], and
97    /// the third is [`StatementTree::Or`] containing a 2-element
98    /// vector.  Each element is an [`StatementTree::And`] with a vector
99    /// containing 3 [`StatementTree::Leaf`]s.
100    ///
101    /// Note that `AND`, `OR`, and `THRESH` in the expression are
102    /// case-insensitive.
103    pub fn parse(expr: &Expr) -> Result<Self> {
104        // See if the expression describes a combiner
105        if let Expr::Call(syn::ExprCall { func, args, .. }) = expr {
106            if let Expr::Path(syn::ExprPath { path, .. }) = func.as_ref() {
107                if let Some(funcname) = path.get_ident() {
108                    match funcname.to_string().to_lowercase().as_str() {
109                        "and" => {
110                            let children: Result<Vec<StatementTree>> =
111                                args.iter().map(Self::parse).collect();
112                            return Ok(Self::And(children?));
113                        }
114                        "or" => {
115                            let children: Result<Vec<StatementTree>> =
116                                args.iter().map(Self::parse).collect();
117                            return Ok(Self::Or(children?));
118                        }
119                        "thresh" => {
120                            if let Some(Expr::Lit(syn::ExprLit {
121                                lit: syn::Lit::Int(litint),
122                                ..
123                            })) = args.first()
124                            {
125                                let thresh = litint.base10_parse::<usize>()?;
126                                // Remember that args.len() is one more
127                                // than the number of expressions,
128                                // because the first arg is the
129                                // threshold
130                                if thresh < 1 || thresh >= args.len() {
131                                    return Err(syn::Error::new(
132                                        litint.span(),
133                                        "threshold out of range",
134                                    ));
135                                }
136                                let children: Result<Vec<StatementTree>> =
137                                    args.iter().skip(1).map(Self::parse).collect();
138                                return Ok(Self::Thresh(thresh, children?));
139                            }
140                        }
141                        _ => {}
142                    }
143                }
144            }
145        }
146        Ok(StatementTree::Leaf(expr.clone()))
147    }
148
149    /// A convenience function that takes a list of [`Expr`]s, and
150    /// returns the [`StatementTree`] that implicitly puts `AND` around
151    /// the [`Expr`]s.  This is useful because a common thing to do is
152    /// to just write a list of [`Expr`]s in the top-level macro
153    /// invocation, having the semantics of "all of these must be true".
154    pub fn parse_andlist(exprlist: &[Expr]) -> Result<Self> {
155        let children: Result<Vec<StatementTree>> = exprlist.iter().map(Self::parse).collect();
156        Ok(StatementTree::And(children?))
157    }
158
159    /// Return a vector of references to all of the leaf expressions in
160    /// the [`StatementTree`]
161    pub fn leaves(&self) -> Vec<&Expr> {
162        match self {
163            StatementTree::Leaf(ref e) => vec![e],
164            StatementTree::And(v) | StatementTree::Or(v) | StatementTree::Thresh(_, v) => {
165                v.iter().fold(Vec::<&Expr>::new(), |mut b, st| {
166                    b.extend(st.leaves());
167                    b
168                })
169            }
170        }
171    }
172
173    /// Return a vector of mutable references to all of the leaf
174    /// expressions in the [`StatementTree`]
175    pub fn leaves_mut(&mut self) -> Vec<&mut Expr> {
176        match self {
177            StatementTree::Leaf(ref mut e) => vec![e],
178            StatementTree::And(v) | StatementTree::Or(v) | StatementTree::Thresh(_, v) => {
179                v.iter_mut().fold(Vec::<&mut Expr>::new(), |mut b, st| {
180                    b.extend(st.leaves_mut());
181                    b
182                })
183            }
184        }
185    }
186
187    /// Return a vector of mutable references to all of the leaves in
188    /// the [`StatementTree`]
189    pub fn leaves_st_mut(&mut self) -> Vec<&mut StatementTree> {
190        match self {
191            StatementTree::Leaf(_) => vec![self],
192            StatementTree::And(v) | StatementTree::Or(v) | StatementTree::Thresh(_, v) => v
193                .iter_mut()
194                .fold(Vec::<&mut StatementTree>::new(), |mut b, st| {
195                    b.extend(st.leaves_st_mut());
196                    b
197                }),
198        }
199    }
200
201    #[cfg(not(doctest))]
202    /// Verify whether the [`StatementTree`] satisfies the disjunction
203    /// invariant.
204    ///
205    /// A _disjunction node_ is an [`Or`](StatementTree::Or) or
206    /// [`Thresh`](StatementTree::Thresh) node in the [`StatementTree`].
207    ///
208    /// A _disjunction branch_ is a subtree rooted at a non-disjunction
209    /// node that is the child of a disjunction node or at the root of
210    /// the [`StatementTree`].
211    ///
212    /// The _disjunction invariant_ is that a private variable (which is
213    /// necessarily a `Scalar` since there are no private `Point`
214    /// variables) that appears in a disjunction branch cannot also
215    /// appear outside of that disjunction branch.
216    ///
217    /// For example, if all of the lowercase variables are private
218    /// `Scalar`s, the [`StatementTree`] created from:
219    ///
220    /// ```
221    ///    AND (
222    ///        C = c*B + r*A,
223    ///        D = d*B + s*A,
224    ///        OR (
225    ///            AND (
226    ///                C = c0*B + r0*A,
227    ///                D = d0*B + s0*A,
228    ///                c0 = d0,
229    ///            ),
230    ///            AND (
231    ///                C = c1*B + r1*A,
232    ///                D = d1*B + s1*A,
233    ///                c1 = d1 + 1,
234    ///            ),
235    ///        )
236    ///    )
237    /// ```
238    ///
239    /// satisfies the disjunction invariant, but
240    ///
241    /// ```
242    ///    AND (
243    ///        C = c*B + r*A,
244    ///        D = d*B + s*A,
245    ///        OR (
246    ///            AND (
247    ///                D = d0*B + s0*A,
248    ///                c = d0,
249    ///            ),
250    ///            AND (
251    ///                C = c1*B + r1*A,
252    ///                D = d1*B + s1*A,
253    ///                c1 = d1 + 1,
254    ///            ),
255    ///        )
256    ///    )
257    /// ```
258    ///
259    /// does not, because `c` appears in the first child of the `OR` and
260    /// also outside of the `OR` entirely.  Indeed, the reason to write
261    /// the first expression above rather than the more natural
262    ///
263    /// ```
264    ///    AND (
265    ///        C = c*B + r*A,
266    ///        D = d*B + s*A,
267    ///        OR (
268    ///            c = d,
269    ///            c = d + 1,
270    ///        )
271    ///    )
272    /// ```
273    ///
274    /// is exactly that the invariant must be satisfied.
275    ///
276    /// If you don't know that your [`StatementTree`] already satisfies
277    /// the invariant, call
278    /// [`enforce_disjunction_invariant`](super::super::enforce_disjunction_invariant),
279    /// which will transform the [`StatementTree`] so that it does (and
280    /// also call this
281    /// [`check_disjunction_invariant`](StatementTree::check_disjunction_invariant)
282    /// function as a sanity check).
283    pub fn check_disjunction_invariant(&self, vars: &VarDict) -> Result<()> {
284        let mut disjunct_map: HashMap<String, usize> = HashMap::new();
285
286        // If the recursive call returns Err, return that Err.
287        // Otherwise, we don't care about the Ok(usize) returned, so
288        // just return Ok(())
289        self.check_disjunction_invariant_rec(vars, &mut disjunct_map, 0, 0)?;
290        Ok(())
291    }
292
293    /// Internal recursive helper for
294    /// [`check_disjunction_invariant`](StatementTree::check_disjunction_invariant).
295    ///
296    /// The `disjunct_map` is a [`HashMap`] that maps the names of
297    /// variables to an identifier of which child of a disjunction node
298    /// the variable appears in (or the root if none).  In the case of
299    /// nested disjunction node, the closest one to the leaf is what
300    /// matters.  Nodes are numbered in pre-order fashion, starting at 0
301    /// for the root, 1 for the first child of the root, 2 for the first
302    /// child of node 1, etc.  `cur_node` is the node id of `self`, and
303    /// `cur_disjunct_child` is the node id of the closest child of a
304    /// disjunction node (or 0 for the root if none).  Returns the next
305    /// node id to use in the preorder traversal.
306    fn check_disjunction_invariant_rec(
307        &self,
308        vars: &VarDict,
309        disjunct_map: &mut HashMap<String, usize>,
310        cur_node: usize,
311        cur_disjunct_child: usize,
312    ) -> Result<usize> {
313        let mut next_node = cur_node;
314        match self {
315            Self::And(v) => {
316                for st in v {
317                    next_node = st.check_disjunction_invariant_rec(
318                        vars,
319                        disjunct_map,
320                        next_node + 1,
321                        cur_disjunct_child,
322                    )?;
323                }
324            }
325            Self::Or(v) | Self::Thresh(_, v) => {
326                for st in v {
327                    next_node = st.check_disjunction_invariant_rec(
328                        vars,
329                        disjunct_map,
330                        next_node + 1,
331                        next_node + 1,
332                    )?;
333                }
334            }
335            Self::Leaf(e) => {
336                let mut psmap = PrivScalarMap {
337                    vars,
338                    closure: &mut |ident| {
339                        let varname = ident.to_string();
340                        if let Some(dis_id) = disjunct_map.get(&varname) {
341                            if *dis_id != cur_disjunct_child {
342                                return Err(syn::Error::new(
343                                    ident.span(),
344                                    "Disjunction invariant violation: a private variable cannot appear both inside and outside a single term of an OR or THRESH"));
345                            }
346                        } else {
347                            disjunct_map.insert(varname, cur_disjunct_child);
348                        }
349                        Ok(())
350                    },
351                    result: Ok(()),
352                };
353                psmap.visit_expr(e);
354                psmap.result?;
355            }
356        }
357        Ok(next_node)
358    }
359
360    /// Call the supplied closure for each [disjunction branch] of the
361    /// given [`StatementTree`] (including the root, if the root is a
362    /// non-disjunction node).
363    ///
364    /// The calls are in preorder traversal (parents before children).
365    /// The given `closure` will be called with the root of each
366    /// [disjunction branch] as well as a slice of [`usize`] indicating
367    /// the path through the [`StatementTree`] to that disjunction
368    /// branch.  The disjunction branch at the root has path `[]`.
369    /// The disjunction branch rooted at, say, the 2nd child of an `Or`
370    /// node in the root disjunction branch will have path `[2]`.  The
371    /// disjunction branch rooted at the 1st child of an `Or` node in
372    /// that disjunction branch will have path `[2,1]`, and so on.
373    ///
374    /// Abort and return `Err` if any call to the closure returns `Err`.
375    ///
376    /// [disjunction branch]: StatementTree::check_disjunction_invariant
377    pub fn for_each_disjunction_branch(
378        &mut self,
379        closure: &mut dyn FnMut(&mut StatementTree, &[usize]) -> Result<()>,
380    ) -> Result<()> {
381        let mut path: Vec<usize> = Vec::new();
382        self.for_each_disjunction_branch_rec(closure, &mut path, 0, true)?;
383        Ok(())
384    }
385
386    /// Internal recursive helper for
387    /// [`for_each_disjunction_branch`](StatementTree::for_each_disjunction_branch).
388    ///
389    ///   - `path` is the path to this disjunction branch
390    ///   - `last_index` is the last index used for a child of this
391    ///     disjunction branch
392    ///   - `is_new_branch` is `true` if this node is the start of a new
393    ///     disjunction branch
394    ///
395    /// The return value (if `Ok`) is the updated value of `last_index`.
396    fn for_each_disjunction_branch_rec(
397        &mut self,
398        closure: &mut dyn FnMut(&mut StatementTree, &[usize]) -> Result<()>,
399        path: &mut Vec<usize>,
400        mut last_index: usize,
401        is_new_branch: bool,
402    ) -> Result<usize> {
403        // We're starting a new branch (and should call the closure) if
404        // and only if both is_new_branch is true, and also we're at a
405        // non-disjunction node
406        match self {
407            StatementTree::Leaf(_) | StatementTree::And(_) => {
408                if is_new_branch {
409                    (closure)(self, path)?;
410                }
411            }
412            _ => {}
413        }
414        match self {
415            StatementTree::Leaf(_) => {}
416            StatementTree::And(stvec) => {
417                stvec.iter_mut().try_for_each(|st| -> Result<()> {
418                    last_index =
419                        st.for_each_disjunction_branch_rec(closure, path, last_index, false)?;
420                    Ok(())
421                })?;
422            }
423            StatementTree::Or(stvec) | StatementTree::Thresh(_, stvec) => {
424                path.push(last_index);
425                let pathlen = path.len();
426                stvec.iter_mut().try_for_each(|st| -> Result<()> {
427                    last_index += 1;
428                    path[pathlen - 1] = last_index;
429                    st.for_each_disjunction_branch_rec(closure, path, 0, true)?;
430                    Ok(())
431                })?;
432                path.pop();
433            }
434        }
435        Ok(last_index)
436    }
437
438    /// Call the supplied closure for each [`StatementTree::Leaf`] of
439    /// the given [disjunction branch].
440    ///
441    /// Abort and return `Err` if any call to the closure returns `Err`.
442    ///
443    /// [disjunction branch]: StatementTree::check_disjunction_invariant
444    pub fn for_each_disjunction_branch_leaf(
445        &mut self,
446        closure: &mut dyn FnMut(&mut StatementTree) -> Result<()>,
447    ) -> Result<()> {
448        match self {
449            StatementTree::Leaf(_) => {
450                (closure)(self)?;
451            }
452            StatementTree::And(stvec) => {
453                stvec
454                    .iter_mut()
455                    .try_for_each(|st| st.for_each_disjunction_branch_leaf(closure))?;
456            }
457            StatementTree::Or(_) | StatementTree::Thresh(_, _) => {
458                // Don't recurse into Or or Thresh nodes, since the
459                // children of those nodes are in different disjunction
460                // branches.
461            }
462        }
463        Ok(())
464    }
465
466    /// Produce a [`HashSet`] of the private Scalars that appear in any
467    /// leaf of the given [disjunction branch].
468    ///
469    /// [disjunction branch]: StatementTree::check_disjunction_invariant
470    pub fn disjunction_branch_priv_scalars(&mut self, vars: &VarDict) -> HashSet<Ident> {
471        let mut priv_scalars: HashSet<Ident> = HashSet::new();
472        self.for_each_disjunction_branch_leaf(&mut |leaf| {
473            if let StatementTree::Leaf(leafexpr) = leaf {
474                let mut psmap = PrivScalarMap {
475                    vars,
476                    closure: &mut |ident| {
477                        priv_scalars.insert(ident.clone());
478                        Ok(())
479                    },
480                    result: Ok(()),
481                };
482                psmap.visit_expr(leafexpr);
483            }
484            Ok(())
485        })
486        .unwrap();
487        priv_scalars
488    }
489
490    #[cfg(not(doctest))]
491    /// Flatten nested `And` nodes in a [`StatementTree`].
492    ///
493    /// The underlying `sigma-proofs` crate can share `Scalars` across
494    /// statements that are direct children of the same `And` node, but
495    /// not in nested `And` nodes.
496    ///
497    /// So a [`StatementTree`] like this:
498    ///
499    /// ```
500    ///    AND (
501    ///        C = x*B + r*A,
502    ///        AND (
503    ///            D = x*B + s*A,
504    ///            E = x*B + t*A,
505    ///        ),
506    ///    )
507    /// ```
508    ///
509    /// Needs to be flattened to:
510    ///
511    /// ```
512    ///    AND (
513    ///        C = x*B + r*A,
514    ///        D = x*B + s*A,
515    ///        E = x*B + t*A,
516    ///    )
517    /// ```
518    pub fn flatten_ands(&mut self) {
519        match self {
520            StatementTree::Leaf(_) => {}
521            StatementTree::Or(svec) | StatementTree::Thresh(_, svec) => {
522                // Flatten each child
523                svec.iter_mut().for_each(|st| st.flatten_ands());
524            }
525            StatementTree::And(svec) => {
526                // Flatten each child, and if any of the children are
527                // `And`s, replace that child with the list of its
528                // children
529                let old_svec = std::mem::take(svec);
530                let mut new_svec: Vec<StatementTree> = Vec::new();
531                for mut st in old_svec {
532                    st.flatten_ands();
533                    match st {
534                        StatementTree::And(mut child_svec) => {
535                            new_svec.append(&mut child_svec);
536                        }
537                        _ => {
538                            new_svec.push(st);
539                        }
540                    }
541                }
542                *self = StatementTree::And(new_svec);
543            }
544        }
545    }
546
547    /// Produce a [`StatementTree`] that represents the constant `true`
548    pub fn leaf_true() -> StatementTree {
549        StatementTree::Leaf(parse_quote! { true })
550    }
551
552    /// Test if the given [`StatementTree`] represents the constant `true`
553    pub fn is_leaf_true(&self) -> bool {
554        if let StatementTree::Leaf(Expr::Lit(exprlit)) = self {
555            if let syn::Lit::Bool(syn::LitBool { value: true, .. }) = exprlit.lit {
556                return true;
557            }
558        }
559        false
560    }
561
562    fn dump_int(&self, depth: usize) {
563        match self {
564            StatementTree::Leaf(e) => {
565                println!(
566                    "{:1$}{2},",
567                    "",
568                    depth * 2,
569                    quote! { #e }.to_string().replace('\n', " ")
570                )
571            }
572            StatementTree::And(v) => {
573                println!("{:1$}And (", "", depth * 2);
574                v.iter().for_each(|n| n.dump_int(depth + 1));
575                println!("{:1$})", "", depth * 2);
576            }
577            StatementTree::Or(v) => {
578                println!("{:1$}Or (", "", depth * 2);
579                v.iter().for_each(|n| n.dump_int(depth + 1));
580                println!("{:1$})", "", depth * 2);
581            }
582            StatementTree::Thresh(thresh, v) => {
583                println!("{:1$}Thresh ({2}", "", depth * 2, thresh);
584                v.iter().for_each(|n| n.dump_int(depth + 1));
585                println!("{:1$})", "", depth * 2);
586            }
587        }
588    }
589
590    pub fn dump(&self) {
591        self.dump_int(0);
592    }
593}
594
595#[cfg(test)]
596mod test {
597    use super::StatementTree::*;
598    use super::*;
599    use quote::quote;
600
601    #[test]
602    fn leaf_true_test() {
603        assert!(StatementTree::leaf_true().is_leaf_true());
604        assert!(!StatementTree::Leaf(parse_quote! { false }).is_leaf_true());
605        assert!(!StatementTree::Leaf(parse_quote! { 1 }).is_leaf_true());
606        assert!(!StatementTree::parse(&parse_quote! {
607            OR(1=1, a=b)
608        })
609        .unwrap()
610        .is_leaf_true());
611    }
612
613    #[test]
614    fn combiners_simple_test() {
615        let exprlist: Vec<Expr> = vec![
616            parse_quote! { C = c*B + r*A },
617            parse_quote! { D = d*B + s*A },
618            parse_quote! { c = d },
619        ];
620
621        let statementtree = StatementTree::parse_andlist(&exprlist).unwrap();
622        let And(v) = statementtree else {
623            panic!("Incorrect result");
624        };
625        let [Leaf(l0), Leaf(l1), Leaf(l2)] = v.as_slice() else {
626            panic!("Incorrect result");
627        };
628        assert_eq!(quote! {#l0}.to_string(), "C = c * B + r * A");
629        assert_eq!(quote! {#l1}.to_string(), "D = d * B + s * A");
630        assert_eq!(quote! {#l2}.to_string(), "c = d");
631    }
632
633    #[test]
634    fn combiners_nested_test() {
635        let exprlist: Vec<Expr> = vec![
636            parse_quote! { C = c*B + r*A },
637            parse_quote! { D = d*B + s*A },
638            parse_quote! {
639            OR (
640                AND (
641                    C = c0*B + r0*A,
642                    D = d0*B + s0*A,
643                    c0 = d0,
644                ),
645                AND (
646                    C = c1*B + r1*A,
647                    D = d1*B + s1*A,
648                    c1 = d1 + 1,
649                ),
650            ) },
651        ];
652
653        let statementtree = StatementTree::parse_andlist(&exprlist).unwrap();
654        let And(v0) = statementtree else {
655            panic!("Incorrect result");
656        };
657        let [Leaf(l0), Leaf(l1), Or(v1)] = v0.as_slice() else {
658            panic!("Incorrect result");
659        };
660        assert_eq!(quote! {#l0}.to_string(), "C = c * B + r * A");
661        assert_eq!(quote! {#l1}.to_string(), "D = d * B + s * A");
662        let [And(v2), And(v3)] = v1.as_slice() else {
663            panic!("Incorrect result");
664        };
665        let [Leaf(l20), Leaf(l21), Leaf(l22)] = v2.as_slice() else {
666            panic!("Incorrect result");
667        };
668        assert_eq!(quote! {#l20}.to_string(), "C = c0 * B + r0 * A");
669        assert_eq!(quote! {#l21}.to_string(), "D = d0 * B + s0 * A");
670        assert_eq!(quote! {#l22}.to_string(), "c0 = d0");
671        let [Leaf(l30), Leaf(l31), Leaf(l32)] = v3.as_slice() else {
672            panic!("Incorrect result");
673        };
674        assert_eq!(quote! {#l30}.to_string(), "C = c1 * B + r1 * A");
675        assert_eq!(quote! {#l31}.to_string(), "D = d1 * B + s1 * A");
676        assert_eq!(quote! {#l32}.to_string(), "c1 = d1 + 1");
677    }
678
679    #[test]
680    fn combiners_thresh_test() {
681        let exprlist: Vec<Expr> = vec![
682            parse_quote! { C = c*B + r*A },
683            parse_quote! { D = d*B + s*A },
684            parse_quote! {
685            THRESH (1,
686                AND (
687                    C = c0*B + r0*A,
688                    D = d0*B + s0*A,
689                    c0 = d0,
690                ),
691                AND (
692                    C = c1*B + r1*A,
693                    D = d1*B + s1*A,
694                    c1 = d1 + 1,
695                ),
696            ) },
697        ];
698
699        let statementtree = StatementTree::parse_andlist(&exprlist).unwrap();
700        let And(v0) = statementtree else {
701            panic!("Incorrect result");
702        };
703        let [Leaf(l0), Leaf(l1), Thresh(thresh, v1)] = v0.as_slice() else {
704            panic!("Incorrect result");
705        };
706        assert_eq!(*thresh, 1);
707        assert_eq!(quote! {#l0}.to_string(), "C = c * B + r * A");
708        assert_eq!(quote! {#l1}.to_string(), "D = d * B + s * A");
709        let [And(v2), And(v3)] = v1.as_slice() else {
710            panic!("Incorrect result");
711        };
712        let [Leaf(l20), Leaf(l21), Leaf(l22)] = v2.as_slice() else {
713            panic!("Incorrect result");
714        };
715        assert_eq!(quote! {#l20}.to_string(), "C = c0 * B + r0 * A");
716        assert_eq!(quote! {#l21}.to_string(), "D = d0 * B + s0 * A");
717        assert_eq!(quote! {#l22}.to_string(), "c0 = d0");
718        let [Leaf(l30), Leaf(l31), Leaf(l32)] = v3.as_slice() else {
719            panic!("Incorrect result");
720        };
721        assert_eq!(quote! {#l30}.to_string(), "C = c1 * B + r1 * A");
722        assert_eq!(quote! {#l31}.to_string(), "D = d1 * B + s1 * A");
723        assert_eq!(quote! {#l32}.to_string(), "c1 = d1 + 1");
724    }
725
726    #[test]
727    #[should_panic]
728    fn combiners_bad_thresh_test() {
729        // The threshold is out of range
730        let exprlist: Vec<Expr> = vec![
731            parse_quote! { C = c*B + r*A },
732            parse_quote! { D = d*B + s*A },
733            parse_quote! {
734            THRESH (3,
735                AND (
736                    C = c0*B + r0*A,
737                    D = d0*B + s0*A,
738                    c0 = d0,
739                ),
740                AND (
741                    C = c1*B + r1*A,
742                    D = d1*B + s1*A,
743                    c1 = d1 + 1,
744                ),
745            ) },
746        ];
747
748        StatementTree::parse_andlist(&exprlist).unwrap();
749    }
750
751    #[test]
752    // Test the disjunction invariant checker
753    fn disjunction_invariant_test() {
754        let vars: VarDict = vardict_from_strs(&[
755            ("c", "S"),
756            ("d", "S"),
757            ("c0", "S"),
758            ("c1", "S"),
759            ("d0", "S"),
760            ("d1", "S"),
761            ("A", "pP"),
762            ("B", "pP"),
763            ("C", "pP"),
764            ("D", "pP"),
765        ]);
766        // This one is OK
767        let st_ok = StatementTree::parse(&parse_quote! {
768           AND (
769               C = c*B + r*A,
770               D = d*B + s*A,
771               OR (
772                   AND (
773                       C = c0*B + r0*A,
774                       D = d0*B + s0*A,
775                       c0 = d0,
776                   ),
777                   AND (
778                       C = c1*B + r1*A,
779                       D = d1*B + s1*A,
780                       c1 = d1 + 1,
781                   ),
782               )
783           )
784        })
785        .unwrap();
786        // not OK: c0 appears in two branches of the OR
787        let st_nok1 = StatementTree::parse(&parse_quote! {
788           AND (
789               C = c*B + r*A,
790               D = d*B + s*A,
791               OR (
792                   AND (
793                       C = c0*B + r0*A,
794                       D = d0*B + s0*A,
795                       c0 = d0,
796                   ),
797                   AND (
798                       C = c0*B + r0*A,
799                       D = d1*B + s1*A,
800                       c0 = d1 + 1,
801                   ),
802               )
803           )
804        })
805        .unwrap();
806        // not OK: c appears in one branch of the OR and also outside
807        // the OR
808        let st_nok2 = StatementTree::parse(&parse_quote! {
809           AND (
810               C = c*B + r*A,
811               D = d*B + s*A,
812               OR (
813                   AND (
814                       D = d0*B + s0*A,
815                       c = d0,
816                   ),
817                   AND (
818                       C = c1*B + r1*A,
819                       D = d1*B + s1*A,
820                       c1 = d1 + 1,
821                   ),
822               )
823           )
824        })
825        .unwrap();
826        // not OK: c and d appear in both branches of the OR, and also
827        // outside it
828        let st_nok3 = StatementTree::parse(&parse_quote! {
829           AND (
830               C = c*B + r*A,
831               D = d*B + s*A,
832               OR (
833                   c = d,
834                   c = d + 1,
835               )
836           )
837        })
838        .unwrap();
839        st_ok.check_disjunction_invariant(&vars).unwrap();
840        st_nok1.check_disjunction_invariant(&vars).unwrap_err();
841        st_nok2.check_disjunction_invariant(&vars).unwrap_err();
842        st_nok3.check_disjunction_invariant(&vars).unwrap_err();
843    }
844
845    fn disjunction_branch_tester(e: Expr, expected: Vec<(Vec<usize>, Expr)>) {
846        let mut output: Vec<(Vec<usize>, StatementTree)> = Vec::new();
847        let expected_st: Vec<(Vec<usize>, StatementTree)> = expected
848            .iter()
849            .map(|(path, ex)| (path.clone(), StatementTree::parse(ex).unwrap()))
850            .collect();
851        let mut st = StatementTree::parse(&e).unwrap();
852        st.for_each_disjunction_branch(&mut |db, path| {
853            output.push((path.to_vec(), db.clone()));
854            Ok(())
855        })
856        .unwrap();
857        assert_eq!(output, expected_st);
858    }
859
860    fn disjunction_branch_abort_tester(e: Expr, expected: Vec<(Vec<usize>, Expr)>) {
861        let mut output: Vec<(Vec<usize>, StatementTree)> = Vec::new();
862        let expected_st: Vec<(Vec<usize>, StatementTree)> = expected
863            .iter()
864            .map(|(path, ex)| (path.clone(), StatementTree::parse(ex).unwrap()))
865            .collect();
866        let mut st = StatementTree::parse(&e).unwrap();
867        st.for_each_disjunction_branch(&mut |st, path| {
868            if st.is_leaf_true() {
869                return Err(syn::Error::new(proc_macro2::Span::call_site(), "true leaf"));
870            }
871            output.push((path.to_vec(), st.clone()));
872            Ok(())
873        })
874        .unwrap_err();
875        assert_eq!(output, expected_st);
876    }
877
878    #[test]
879    fn disjunction_branch_test() {
880        disjunction_branch_tester(
881            parse_quote! {
882                C = c*B + r*A
883            },
884            vec![(
885                vec![],
886                parse_quote! {
887                    C = c*B + r*A
888                },
889            )],
890        );
891
892        disjunction_branch_tester(
893            parse_quote! {
894               AND (
895                   C = c*B + r*A,
896                   D = d*B + s*A,
897                   OR (
898                       c = d,
899                       c = d + 1,
900                   )
901               )
902            },
903            vec![
904                (
905                    vec![],
906                    parse_quote! {
907                       AND (
908                           C = c*B + r*A,
909                           D = d*B + s*A,
910                           OR (
911                               c = d,
912                               c = d + 1,
913                           )
914                       )
915                    },
916                ),
917                (
918                    vec![1],
919                    parse_quote! {
920                        c = d
921                    },
922                ),
923                (
924                    vec![2],
925                    parse_quote! {
926                        c = d + 1
927                    },
928                ),
929            ],
930        );
931
932        disjunction_branch_tester(
933            parse_quote! {
934                OR (
935                    C = c*B + r*A,
936                    D = c*B + r*A,
937                )
938            },
939            vec![
940                (vec![1], parse_quote! { C = c*B + r*A }),
941                (vec![2], parse_quote! { D = c*B + r*A }),
942            ],
943        );
944
945        disjunction_branch_tester(
946            parse_quote! {
947                AND (
948                    C = c*B + r*A,
949                    D = d*B + s*A,
950                    OR (
951                        AND (
952                            c = d,
953                            D = a*B + b*A,
954                            OR (
955                                d = 5,
956                                d = 6,
957                            )
958                        ),
959                        c = d + 1,
960                    )
961                )
962            },
963            vec![
964                (
965                    vec![],
966                    parse_quote! {
967                        AND (
968                            C = c*B + r*A,
969                            D = d*B + s*A,
970                            OR (
971                                AND (
972                                    c = d,
973                                    D = a*B + b*A,
974                                    OR (
975                                        d = 5,
976                                        d = 6,
977                                    )
978                                ),
979                                c = d + 1,
980                            )
981                        )
982                    },
983                ),
984                (
985                    vec![1],
986                    parse_quote! {
987                        AND (
988                            c = d,
989                            D = a*B + b*A,
990                            OR (
991                                d = 5,
992                                d = 6,
993                            )
994                        )
995                    },
996                ),
997                (
998                    vec![1, 1],
999                    parse_quote! {
1000                        d = 5
1001                    },
1002                ),
1003                (
1004                    vec![1, 2],
1005                    parse_quote! {
1006                        d = 6
1007                    },
1008                ),
1009                (
1010                    vec![2],
1011                    parse_quote! {
1012                        c = d + 1
1013                    },
1014                ),
1015            ],
1016        );
1017
1018        disjunction_branch_tester(
1019            parse_quote! {
1020                AND (
1021                    C = c*B + r*A,
1022                    D = d*B + s*A,
1023                    AND (
1024                        c = d + 1,
1025                        AND (
1026                            s = r,
1027                            OR (
1028                                d = 1,
1029                                AND (
1030                                    d = 2,
1031                                    s = 1,
1032                                )
1033                            )
1034                        )
1035                    ),
1036                    OR (
1037                        AND (
1038                            c = d,
1039                            D = a*B + b*A,
1040                            OR (
1041                                d = 5,
1042                                d = 6,
1043                            )
1044                        ),
1045                        c = d + 1,
1046                    )
1047                )
1048            },
1049            vec![
1050                (
1051                    vec![],
1052                    parse_quote! {
1053                        AND (
1054                            C = c*B + r*A,
1055                            D = d*B + s*A,
1056                            AND (
1057                                c = d + 1,
1058                                AND (
1059                                    s = r,
1060                                    OR (
1061                                        d = 1,
1062                                        AND (
1063                                            d = 2,
1064                                            s = 1,
1065                                        )
1066                                    )
1067                                )
1068                            ),
1069                            OR (
1070                                AND (
1071                                    c = d,
1072                                    D = a*B + b*A,
1073                                    OR (
1074                                        d = 5,
1075                                        d = 6,
1076                                    )
1077                                ),
1078                                c = d + 1,
1079                            )
1080                        )
1081                    },
1082                ),
1083                (vec![1], parse_quote! { d = 1 }),
1084                (
1085                    vec![2],
1086                    parse_quote! {
1087                        AND (
1088                            d = 2,
1089                            s = 1,
1090                        )
1091                    },
1092                ),
1093                (
1094                    vec![3],
1095                    parse_quote! {
1096                        AND (
1097                            c = d,
1098                            D = a*B + b*A,
1099                            OR (
1100                                d = 5,
1101                                d = 6,
1102                            )
1103                        )
1104                    },
1105                ),
1106                (
1107                    vec![3, 1],
1108                    parse_quote! {
1109                        d = 5
1110                    },
1111                ),
1112                (
1113                    vec![3, 2],
1114                    parse_quote! {
1115                        d = 6
1116                    },
1117                ),
1118                (
1119                    vec![4],
1120                    parse_quote! {
1121                        c = d + 1
1122                    },
1123                ),
1124            ],
1125        );
1126
1127        disjunction_branch_abort_tester(
1128            parse_quote! {
1129                AND (
1130                    C = c*B + r*A,
1131                    D = d*B + s*A,
1132                    OR (
1133                        AND (
1134                            c = d,
1135                            D = a*B + b*A,
1136                            OR (
1137                                d = 5,
1138                                true,
1139                                d = 6,
1140                            )
1141                        ),
1142                        c = d + 1,
1143                    )
1144                )
1145            },
1146            vec![
1147                (
1148                    vec![],
1149                    parse_quote! {
1150                        AND (
1151                            C = c*B + r*A,
1152                            D = d*B + s*A,
1153                            OR (
1154                                AND (
1155                                    c = d,
1156                                    D = a*B + b*A,
1157                                    OR (
1158                                        d = 5,
1159                                        true,
1160                                        d = 6,
1161                                    )
1162                                ),
1163                                c = d + 1,
1164                            )
1165                        )
1166                    },
1167                ),
1168                (
1169                    vec![1],
1170                    parse_quote! {
1171                        AND (
1172                            c = d,
1173                            D = a*B + b*A,
1174                            OR (
1175                                d = 5,
1176                                true,
1177                                d = 6,
1178                            )
1179                        )
1180                    },
1181                ),
1182                (
1183                    vec![1, 1],
1184                    parse_quote! {
1185                        d = 5
1186                    },
1187                ),
1188            ],
1189        );
1190    }
1191
1192    fn disjunction_branch_leaf_tester(e: Expr, expected: Vec<(Vec<usize>, Vec<Expr>)>) {
1193        let mut output: Vec<(Vec<usize>, Vec<StatementTree>)> = Vec::new();
1194        let expected_st: Vec<(Vec<usize>, Vec<StatementTree>)> = expected
1195            .iter()
1196            .map(|(path, vex)| {
1197                (
1198                    path.clone(),
1199                    vex.iter()
1200                        .map(|ex| StatementTree::parse(ex).unwrap())
1201                        .collect(),
1202                )
1203            })
1204            .collect();
1205        let mut st = StatementTree::parse(&e).unwrap();
1206        st.for_each_disjunction_branch(&mut |db, path| {
1207            let mut dis_branch_output: Vec<StatementTree> = Vec::new();
1208            db.for_each_disjunction_branch_leaf(&mut |leaf| {
1209                dis_branch_output.push(leaf.clone());
1210                Ok(())
1211            })
1212            .unwrap();
1213            output.push((path.to_vec(), dis_branch_output));
1214            Ok(())
1215        })
1216        .unwrap();
1217        assert_eq!(output, expected_st);
1218    }
1219
1220    fn disjunction_branch_leaf_abort_tester(e: Expr, expected: Vec<(Vec<usize>, Vec<Expr>)>) {
1221        let mut output: Vec<(Vec<usize>, Vec<StatementTree>)> = Vec::new();
1222        let expected_st: Vec<(Vec<usize>, Vec<StatementTree>)> = expected
1223            .iter()
1224            .map(|(path, vex)| {
1225                (
1226                    path.clone(),
1227                    vex.iter()
1228                        .map(|ex| StatementTree::parse(ex).unwrap())
1229                        .collect(),
1230                )
1231            })
1232            .collect();
1233        let mut st = StatementTree::parse(&e).unwrap();
1234        st.for_each_disjunction_branch(&mut |db, path| {
1235            let mut dis_branch_output: Vec<StatementTree> = Vec::new();
1236            db.for_each_disjunction_branch_leaf(&mut |leaf| {
1237                if leaf.is_leaf_true() {
1238                    return Err(syn::Error::new(proc_macro2::Span::call_site(), "true leaf"));
1239                }
1240                dis_branch_output.push(leaf.clone());
1241                Ok(())
1242            })?;
1243            output.push((path.to_vec(), dis_branch_output));
1244            Ok(())
1245        })
1246        .unwrap_err();
1247        assert_eq!(output, expected_st);
1248    }
1249
1250    #[test]
1251    fn disjunction_branch_leaf_test() {
1252        disjunction_branch_leaf_tester(
1253            parse_quote! {
1254                C = c*B + r*A
1255            },
1256            vec![(vec![], vec![parse_quote! { C = c*B + r*A }])],
1257        );
1258
1259        disjunction_branch_leaf_tester(
1260            parse_quote! {
1261               AND (
1262                   C = c*B + r*A,
1263                   D = d*B + s*A,
1264                   OR (
1265                       c = d,
1266                       c = d + 1,
1267                   )
1268               )
1269            },
1270            vec![
1271                (
1272                    vec![],
1273                    vec![
1274                        parse_quote! { C = c*B + r*A },
1275                        parse_quote! { D = d*B + s*A },
1276                    ],
1277                ),
1278                (vec![1], vec![parse_quote! { c = d }]),
1279                (vec![2], vec![parse_quote! { c = d + 1 }]),
1280            ],
1281        );
1282
1283        disjunction_branch_leaf_tester(
1284            parse_quote! {
1285               AND (
1286                   C = c*B + r*A,
1287                   D = d*B + s*A,
1288                   OR (
1289                       c = d,
1290                       OR (
1291                           c = d + 1,
1292                           c = d + 2,
1293                        )
1294                   )
1295               )
1296            },
1297            vec![
1298                (
1299                    vec![],
1300                    vec![
1301                        parse_quote! { C = c*B + r*A },
1302                        parse_quote! { D = d*B + s*A },
1303                    ],
1304                ),
1305                (vec![1], vec![parse_quote! { c = d }]),
1306                (vec![2, 1], vec![parse_quote! { c = d + 1 }]),
1307                (vec![2, 2], vec![parse_quote! { c = d + 2 }]),
1308            ],
1309        );
1310
1311        disjunction_branch_leaf_tester(
1312            parse_quote! {
1313                AND (
1314                    C = c*B + r*A,
1315                    D = d*B + s*A,
1316                    OR (
1317                        AND (
1318                            c = d,
1319                            D = a*B + b*A,
1320                            OR (
1321                                d = 5,
1322                                d = 6,
1323                            )
1324                        ),
1325                        c = d + 1,
1326                    )
1327                )
1328            },
1329            vec![
1330                (
1331                    vec![],
1332                    vec![
1333                        parse_quote! { C = c*B + r*A },
1334                        parse_quote! { D = d*B + s*A },
1335                    ],
1336                ),
1337                (
1338                    vec![1],
1339                    vec![
1340                        parse_quote! { c = d },
1341                        parse_quote! { D
1342                        = a*B + b*A },
1343                    ],
1344                ),
1345                (vec![1, 1], vec![parse_quote! { d = 5 }]),
1346                (vec![1, 2], vec![parse_quote! { d = 6 }]),
1347                (vec![2], vec![parse_quote! { c = d + 1 }]),
1348            ],
1349        );
1350
1351        disjunction_branch_leaf_abort_tester(
1352            parse_quote! {
1353                AND (
1354                    C = c*B + r*A,
1355                    D = d*B + s*A,
1356                    OR (
1357                        AND (
1358                            c = d,
1359                            D = a*B + b*A,
1360                            OR (
1361                                d = 5,
1362                                true,
1363                                d = 6,
1364                            )
1365                        ),
1366                        c = d + 1,
1367                    )
1368                )
1369            },
1370            vec![
1371                (
1372                    vec![],
1373                    vec![
1374                        parse_quote! { C = c*B + r*A },
1375                        parse_quote! { D = d*B + s*A },
1376                    ],
1377                ),
1378                (
1379                    vec![1],
1380                    vec![
1381                        parse_quote! { c = d },
1382                        parse_quote! { D
1383                        = a*B + b*A },
1384                    ],
1385                ),
1386                (vec![1, 1], vec![parse_quote! { d = 5 }]),
1387            ],
1388        );
1389    }
1390
1391    fn flatten_ands_tester(e: Expr, flattened_e: Expr) {
1392        let mut st = StatementTree::parse(&e).unwrap();
1393        st.flatten_ands();
1394        assert_eq!(st, StatementTree::parse(&flattened_e).unwrap());
1395    }
1396
1397    #[test]
1398    // Test flatten_ands
1399    fn flatten_ands_test() {
1400        flatten_ands_tester(
1401            parse_quote! {
1402                C = x*B + r*A
1403            },
1404            parse_quote! {
1405                C = x*B + r*A
1406            },
1407        );
1408
1409        flatten_ands_tester(
1410            parse_quote! {
1411                AND (
1412                    C = x*B + r*A,
1413                    AND (
1414                        D = x*B + s*A,
1415                        E = x*B + t*A,
1416                    ),
1417                )
1418            },
1419            parse_quote! {
1420                AND (
1421                    C = x*B + r*A,
1422                    D = x*B + s*A,
1423                    E = x*B + t*A,
1424                )
1425            },
1426        );
1427
1428        flatten_ands_tester(
1429            parse_quote! {
1430                AND (
1431                    AND (
1432                        OR (
1433                            D = B + s*A,
1434                            D = s*A,
1435                        ),
1436                        D = x*B + t*A,
1437                    ),
1438                    C = x*B + r*A,
1439                )
1440            },
1441            parse_quote! {
1442                AND (
1443                    OR (
1444                        D = B + s*A,
1445                        D = s*A,
1446                    ),
1447                    D = x*B + t*A,
1448                    C = x*B + r*A,
1449                )
1450            },
1451        );
1452
1453        flatten_ands_tester(
1454            parse_quote! {
1455                AND (
1456                    AND (
1457                        OR (
1458                            D = B + s*A,
1459                            AND (
1460                                D = s*A,
1461                                AND (
1462                                    E = s*B,
1463                                    F = s*C,
1464                                ),
1465                            ),
1466                        ),
1467                        D = x*B + t*A,
1468                    ),
1469                    C = x*B + r*A,
1470                )
1471            },
1472            parse_quote! {
1473                AND (
1474                    OR (
1475                        D = B + s*A,
1476                        AND (
1477                            D = s*A,
1478                            E = s*B,
1479                            F = s*C,
1480                        )
1481                    ),
1482                    D = x*B + t*A,
1483                    C = x*B + r*A,
1484                )
1485            },
1486        );
1487    }
1488}