Skip to main content

scan_core/grammar/
boolean.rs

1use std::ops::{BitAnd, BitOr, Not};
2
3use rand::{Rng, RngExt};
4
5use crate::{
6    Expression, Type, TypeError, Val,
7    grammar::{FloatExpr, IntegerExpr, NaturalExpr},
8};
9
10/// Boolean expressions.
11#[derive(Debug, Clone)]
12pub enum BooleanExpr<V>
13where
14    V: Clone,
15{
16    /// A Boolean constant (`true` or `false`).
17    Const(bool),
18    /// A Boolean variable.
19    Var(V),
20    /// A Bernoulli distribution with the given probability.
21    Rand(FloatExpr<V>),
22    // -----------------
23    // Logical operators
24    // -----------------
25    /// n-ary logical conjunction.
26    And(Vec<BooleanExpr<V>>),
27    /// n-ary logical disjunction.
28    Or(Vec<BooleanExpr<V>>),
29    /// Logical implication.
30    Implies(Box<(BooleanExpr<V>, BooleanExpr<V>)>),
31    /// Logical negation.
32    Not(Box<BooleanExpr<V>>),
33    // ------------
34    // (In)Equality
35    // ------------
36    /// Equality of Natural expressions.
37    NatEqual(NaturalExpr<V>, NaturalExpr<V>),
38    /// Equality of Integer expressions.
39    IntEqual(IntegerExpr<V>, IntegerExpr<V>),
40    /// Equality of Float expressions.
41    FloatEqual(FloatExpr<V>, FloatExpr<V>),
42    /// Inequality of Natural expressions: LHS greater than RHS.
43    NatGreater(NaturalExpr<V>, NaturalExpr<V>),
44    /// Inequality of Integer expressions: LHS greater than RHS.
45    IntGreater(IntegerExpr<V>, IntegerExpr<V>),
46    /// Inequality of Float expressions: LHS greater than RHS.
47    FloatGreater(FloatExpr<V>, FloatExpr<V>),
48    /// Inequality of Natural expressions: LHS greater than, or equal to,  RHS.
49    NatGreaterEq(NaturalExpr<V>, NaturalExpr<V>),
50    /// Inequality of Integer expressions: LHS greater than, or equal to,  RHS.
51    IntGreaterEq(IntegerExpr<V>, IntegerExpr<V>),
52    /// Inequality of Float expressions: LHS greater than, or equal to,  RHS.
53    FloatGreaterEq(FloatExpr<V>, FloatExpr<V>),
54    /// Inequality of Natural expressions: LHS less than RHS.
55    NatLess(NaturalExpr<V>, NaturalExpr<V>),
56    /// Inequality of Integer expressions: LHS less than RHS.
57    IntLess(IntegerExpr<V>, IntegerExpr<V>),
58    /// Inequality of Float expressions: LHS less than RHS.
59    FloatLess(FloatExpr<V>, FloatExpr<V>),
60    /// Inequality of Natural expressions: LHS less than, or equal to, RHS.
61    NatLessEq(NaturalExpr<V>, NaturalExpr<V>),
62    /// Inequality of Integer expressions: LHS less than, or equal to, RHS.
63    IntLessEq(IntegerExpr<V>, IntegerExpr<V>),
64    /// Inequality of Float expressions: LHS less than, or equal to, RHS.
65    FloatLessEq(FloatExpr<V>, FloatExpr<V>),
66    // -----
67    // Flow
68    // -----
69    /// If-Then-Else construct, where If must be a boolean expression,
70    /// Then and Else must have the same type,
71    /// and this is also the type of the whole expression.
72    Ite(Box<(BooleanExpr<V>, BooleanExpr<V>, BooleanExpr<V>)>),
73}
74
75impl<V> BooleanExpr<V>
76where
77    V: Clone,
78{
79    /// Returns `true` if the expression is constant, i.e., it contains no variables, and `false` otherwise.
80    pub fn is_constant(&self) -> bool {
81        match self {
82            BooleanExpr::Const(_) => true,
83            BooleanExpr::Var(_) => false,
84            BooleanExpr::Rand(_float_expr) => false,
85            BooleanExpr::And(boolean_exprs) | BooleanExpr::Or(boolean_exprs) => {
86                boolean_exprs.iter().all(Self::is_constant)
87            }
88            BooleanExpr::Implies(args) => {
89                let (lhs, rhs) = args.as_ref();
90                lhs.is_constant() && rhs.is_constant()
91            }
92            BooleanExpr::Not(boolean_expr) => boolean_expr.is_constant(),
93            BooleanExpr::NatEqual(natural_expr_lhs, natural_expr_rhs)
94            | BooleanExpr::NatGreater(natural_expr_lhs, natural_expr_rhs)
95            | BooleanExpr::NatGreaterEq(natural_expr_lhs, natural_expr_rhs)
96            | BooleanExpr::NatLess(natural_expr_lhs, natural_expr_rhs)
97            | BooleanExpr::NatLessEq(natural_expr_lhs, natural_expr_rhs) => {
98                natural_expr_lhs.is_constant() && natural_expr_rhs.is_constant()
99            }
100            BooleanExpr::IntEqual(integer_expr, integer_expr1)
101            | BooleanExpr::IntGreater(integer_expr, integer_expr1)
102            | BooleanExpr::IntGreaterEq(integer_expr, integer_expr1)
103            | BooleanExpr::IntLess(integer_expr, integer_expr1)
104            | BooleanExpr::IntLessEq(integer_expr, integer_expr1) => {
105                integer_expr.is_constant() && integer_expr1.is_constant()
106            }
107            BooleanExpr::FloatEqual(float_expr, float_expr1)
108            | BooleanExpr::FloatLess(float_expr, float_expr1)
109            | BooleanExpr::FloatLessEq(float_expr, float_expr1)
110            | BooleanExpr::FloatGreater(float_expr, float_expr1)
111            | BooleanExpr::FloatGreaterEq(float_expr, float_expr1) => {
112                float_expr.is_constant() && float_expr1.is_constant()
113            }
114            BooleanExpr::Ite(args) => {
115                let (ite, lhs, rhs) = args.as_ref();
116                ite.is_constant() && lhs.is_constant() && rhs.is_constant()
117            }
118        }
119    }
120
121    /// Returns the Boolean value computed from the expression,
122    /// given the variable evaluation.
123    /// It panics if the evaluation is not possible, including:
124    ///
125    /// - If a variable is not included in the evaluation;
126    /// - If a variable included in the evaluation is not of Boolean type.
127    pub fn eval<R: Rng>(&self, vars: &dyn Fn(&V) -> Val, rng: &mut R) -> bool {
128        match self {
129            BooleanExpr::Const(b) => *b,
130            BooleanExpr::Var(var) => {
131                if let Val::Boolean(b) = vars(var) {
132                    b
133                } else {
134                    panic!("type mismatch: expected boolean variable")
135                }
136            }
137            BooleanExpr::Rand(float_expr) => {
138                let bernoulli = float_expr.eval(vars, rng);
139                rng.random_bool(bernoulli)
140            }
141            BooleanExpr::And(boolean_exprs) => boolean_exprs
142                .iter()
143                .all(|boolean_expr| boolean_expr.eval(vars, rng)),
144            BooleanExpr::Or(boolean_exprs) => boolean_exprs
145                .iter()
146                .any(|boolean_expr| boolean_expr.eval(vars, rng)),
147            BooleanExpr::Implies(boolean_exprs) => {
148                let (lhs, rhs) = boolean_exprs.as_ref();
149                rhs.eval(vars, rng) || !lhs.eval(vars, rng)
150            }
151            BooleanExpr::Not(boolean_expr) => !&boolean_expr.eval(vars, rng),
152            BooleanExpr::NatEqual(natural_expr_lhs, natural_expr_rhs) => {
153                natural_expr_lhs.eval(vars, rng) == natural_expr_rhs.eval(vars, rng)
154            }
155            BooleanExpr::IntEqual(integer_expr_lhs, integer_expr_rhs) => {
156                integer_expr_lhs.eval(vars, rng) == integer_expr_rhs.eval(vars, rng)
157            }
158            BooleanExpr::FloatEqual(float_expr_lhs, float_expr_rhs) => {
159                float_expr_lhs.eval(vars, rng) == float_expr_rhs.eval(vars, rng)
160            }
161            BooleanExpr::NatGreater(natural_expr_lhs, natural_expr_rhs) => {
162                natural_expr_lhs.eval(vars, rng) > natural_expr_rhs.eval(vars, rng)
163            }
164            BooleanExpr::IntGreater(integer_expr_lhs, integer_expr_rhs) => {
165                integer_expr_lhs.eval(vars, rng) > integer_expr_rhs.eval(vars, rng)
166            }
167            BooleanExpr::FloatGreater(float_expr_lhs, float_expr_rhs) => {
168                float_expr_lhs.eval(vars, rng) > float_expr_rhs.eval(vars, rng)
169            }
170            BooleanExpr::NatGreaterEq(natural_expr_lhs, natural_expr_rhs) => {
171                natural_expr_lhs.eval(vars, rng) >= natural_expr_rhs.eval(vars, rng)
172            }
173            BooleanExpr::IntGreaterEq(integer_expr_lhs, integer_expr_rhs) => {
174                integer_expr_lhs.eval(vars, rng) >= integer_expr_rhs.eval(vars, rng)
175            }
176            BooleanExpr::FloatGreaterEq(float_expr_lhs, float_expr_rhs) => {
177                float_expr_lhs.eval(vars, rng) >= float_expr_rhs.eval(vars, rng)
178            }
179            BooleanExpr::NatLess(natural_expr_lhs, natural_expr_rhs) => {
180                natural_expr_lhs.eval(vars, rng) < natural_expr_rhs.eval(vars, rng)
181            }
182            BooleanExpr::IntLess(integer_expr_lhs, integer_expr_rhs) => {
183                integer_expr_lhs.eval(vars, rng) < integer_expr_rhs.eval(vars, rng)
184            }
185            BooleanExpr::FloatLess(float_expr_lhs, float_expr_rhs) => {
186                float_expr_lhs.eval(vars, rng) < float_expr_rhs.eval(vars, rng)
187            }
188            BooleanExpr::NatLessEq(natural_expr_lhs, natural_expr_rhs) => {
189                natural_expr_lhs.eval(vars, rng) <= natural_expr_rhs.eval(vars, rng)
190            }
191            BooleanExpr::IntLessEq(integer_expr_lhs, integer_expr_rhs) => {
192                integer_expr_lhs.eval(vars, rng) <= integer_expr_rhs.eval(vars, rng)
193            }
194            BooleanExpr::FloatLessEq(float_expr_lhs, float_expr_rhs) => {
195                float_expr_lhs.eval(vars, rng) <= float_expr_rhs.eval(vars, rng)
196            }
197            BooleanExpr::Ite(args) => {
198                let (ite, lhs, rhs) = args.as_ref();
199                if ite.eval(vars, rng) {
200                    lhs.eval(vars, rng)
201                } else {
202                    rhs.eval(vars, rng)
203                }
204            }
205        }
206    }
207
208    pub(crate) fn map<W: Clone>(self, map: &dyn Fn(V) -> W) -> BooleanExpr<W> {
209        match self {
210            BooleanExpr::Const(b) => BooleanExpr::Const(b),
211            BooleanExpr::Var(var) => BooleanExpr::Var(map(var)),
212            BooleanExpr::Rand(float_expr) => BooleanExpr::Rand(float_expr.map(map)),
213            BooleanExpr::And(boolean_exprs) => BooleanExpr::And(
214                boolean_exprs
215                    .into_iter()
216                    .map(|expr| expr.map(map))
217                    .collect(),
218            ),
219            BooleanExpr::Or(boolean_exprs) => BooleanExpr::Or(
220                boolean_exprs
221                    .into_iter()
222                    .map(|expr| expr.map(map))
223                    .collect(),
224            ),
225            BooleanExpr::Implies(args) => {
226                let (lhs, rhs) = *args;
227                BooleanExpr::Implies(Box::new((lhs.map(map), rhs.map(map))))
228            }
229            BooleanExpr::Not(boolean_expr) => BooleanExpr::Not(Box::new(boolean_expr.map(map))),
230            BooleanExpr::NatEqual(natural_expr_lhs, natural_expr_rhs) => {
231                BooleanExpr::NatEqual(natural_expr_lhs.map(map), natural_expr_rhs.map(map))
232            }
233            BooleanExpr::IntEqual(integer_expr_lhs, integer_expr_rhs) => {
234                BooleanExpr::IntEqual(integer_expr_lhs.map(map), integer_expr_rhs.map(map))
235            }
236            BooleanExpr::FloatEqual(float_expr_lhs, float_expr_rhs) => {
237                BooleanExpr::FloatEqual(float_expr_lhs.map(map), float_expr_rhs.map(map))
238            }
239            BooleanExpr::NatGreater(natural_expr_lhs, natural_expr_rhs) => {
240                BooleanExpr::NatGreater(natural_expr_lhs.map(map), natural_expr_rhs.map(map))
241            }
242            BooleanExpr::IntGreater(integer_expr_lhs, integer_expr_rhs) => {
243                BooleanExpr::IntGreater(integer_expr_lhs.map(map), integer_expr_rhs.map(map))
244            }
245            BooleanExpr::FloatGreater(float_expr_lhs, float_expr_rhs) => {
246                BooleanExpr::FloatGreater(float_expr_lhs.map(map), float_expr_rhs.map(map))
247            }
248            BooleanExpr::NatGreaterEq(natural_expr_lhs, natural_expr_rhs) => {
249                BooleanExpr::NatGreaterEq(natural_expr_lhs.map(map), natural_expr_rhs.map(map))
250            }
251            BooleanExpr::IntGreaterEq(integer_expr_lhs, integer_expr_rhs) => {
252                BooleanExpr::IntGreaterEq(integer_expr_lhs.map(map), integer_expr_rhs.map(map))
253            }
254            BooleanExpr::FloatGreaterEq(float_expr_lhs, float_expr_rhs) => {
255                BooleanExpr::FloatGreaterEq(float_expr_lhs.map(map), float_expr_rhs.map(map))
256            }
257            BooleanExpr::NatLess(natural_expr_lhs, natural_expr_rhs) => {
258                BooleanExpr::NatLess(natural_expr_lhs.map(map), natural_expr_rhs.map(map))
259            }
260            BooleanExpr::IntLess(integer_expr_lhs, integer_expr_rhs) => {
261                BooleanExpr::IntLess(integer_expr_lhs.map(map), integer_expr_rhs.map(map))
262            }
263            BooleanExpr::FloatLess(float_expr_lhs, float_expr_rhs) => {
264                BooleanExpr::FloatLess(float_expr_lhs.map(map), float_expr_rhs.map(map))
265            }
266            BooleanExpr::NatLessEq(natural_expr_lhs, natural_expr_rhs) => {
267                BooleanExpr::NatLessEq(natural_expr_lhs.map(map), natural_expr_rhs.map(map))
268            }
269            BooleanExpr::IntLessEq(integer_expr_lhs, integer_expr_rhs) => {
270                BooleanExpr::IntLessEq(integer_expr_lhs.map(map), integer_expr_rhs.map(map))
271            }
272            BooleanExpr::FloatLessEq(float_expr_lhs, float_expr_rhs) => {
273                BooleanExpr::FloatLessEq(float_expr_lhs.map(map), float_expr_rhs.map(map))
274            }
275            BooleanExpr::Ite(args) => {
276                let (r#if, then, r#else) = *args;
277                BooleanExpr::Ite(Box::new((r#if.map(map), then.map(map), r#else.map(map))))
278            }
279        }
280    }
281
282    pub(crate) fn context(&self, vars: &dyn Fn(V) -> Option<Type>) -> Result<(), TypeError> {
283        match self {
284            BooleanExpr::Const(_) => Ok(()),
285            BooleanExpr::Var(v) => matches!(vars(v.clone()), Some(Type::Boolean))
286                .then_some(())
287                .ok_or(TypeError::TypeMismatch),
288            BooleanExpr::Rand(float_expr) => float_expr.context(vars),
289            BooleanExpr::And(boolean_exprs) | BooleanExpr::Or(boolean_exprs) => {
290                boolean_exprs.iter().try_for_each(|expr| expr.context(vars))
291            }
292            BooleanExpr::Implies(exprs) => {
293                exprs.0.context(vars).and_then(|()| exprs.1.context(vars))
294            }
295            BooleanExpr::Not(boolean_expr) => boolean_expr.context(vars),
296            BooleanExpr::NatEqual(natural_expr_lhs, natural_expr_rhs)
297            | BooleanExpr::NatGreater(natural_expr_lhs, natural_expr_rhs)
298            | BooleanExpr::NatGreaterEq(natural_expr_lhs, natural_expr_rhs)
299            | BooleanExpr::NatLess(natural_expr_lhs, natural_expr_rhs)
300            | BooleanExpr::NatLessEq(natural_expr_lhs, natural_expr_rhs) => natural_expr_lhs
301                .context(vars)
302                .and_then(|()| natural_expr_rhs.context(vars)),
303            BooleanExpr::IntEqual(integer_expr_lhs, integer_expr_rhs)
304            | BooleanExpr::IntGreater(integer_expr_lhs, integer_expr_rhs)
305            | BooleanExpr::IntGreaterEq(integer_expr_lhs, integer_expr_rhs)
306            | BooleanExpr::IntLess(integer_expr_lhs, integer_expr_rhs)
307            | BooleanExpr::IntLessEq(integer_expr_lhs, integer_expr_rhs) => integer_expr_lhs
308                .context(vars)
309                .and_then(|()| integer_expr_rhs.context(vars)),
310            BooleanExpr::FloatGreater(float_expr_lhs, float_expr_rhs)
311            | BooleanExpr::FloatLess(float_expr_lhs, float_expr_rhs)
312            | BooleanExpr::FloatEqual(float_expr_lhs, float_expr_rhs)
313            | BooleanExpr::FloatGreaterEq(float_expr_lhs, float_expr_rhs)
314            | BooleanExpr::FloatLessEq(float_expr_lhs, float_expr_rhs) => float_expr_lhs
315                .context(vars)
316                .and_then(|()| float_expr_rhs.context(vars)),
317            BooleanExpr::Ite(exprs) => exprs
318                .0
319                .context(vars)
320                .and_then(|()| exprs.1.context(vars))
321                .and_then(|()| exprs.2.context(vars)),
322        }
323    }
324}
325
326impl<V> From<bool> for BooleanExpr<V>
327where
328    V: Clone,
329{
330    fn from(value: bool) -> Self {
331        Self::Const(value)
332    }
333}
334
335impl<V> TryFrom<Expression<V>> for BooleanExpr<V>
336where
337    V: Clone,
338{
339    type Error = TypeError;
340
341    fn try_from(value: Expression<V>) -> Result<Self, Self::Error> {
342        if let Expression::Boolean(bool_expr) = value {
343            Ok(bool_expr)
344        } else {
345            Err(TypeError::TypeMismatch)
346        }
347    }
348}
349
350impl<V> Not for BooleanExpr<V>
351where
352    V: Clone,
353{
354    type Output = Self;
355
356    fn not(self) -> Self::Output {
357        if let Self::Not(expr) = self {
358            *expr
359        } else {
360            Self::Not(Box::new(self))
361        }
362    }
363}
364
365impl<V> BitAnd for BooleanExpr<V>
366where
367    V: Clone,
368{
369    type Output = Self;
370
371    fn bitand(mut self, mut rhs: Self) -> Self::Output {
372        if let BooleanExpr::And(ref mut exprs) = self {
373            if let BooleanExpr::And(rhs_exprs) = rhs {
374                exprs.extend(rhs_exprs);
375            } else {
376                exprs.push(rhs);
377            }
378            self
379        } else if let BooleanExpr::And(ref mut rhs_exprs) = rhs {
380            rhs_exprs.push(self);
381            rhs
382        } else {
383            BooleanExpr::And(vec![self, rhs])
384        }
385    }
386}
387
388impl<V> BitOr for BooleanExpr<V>
389where
390    V: Clone,
391{
392    type Output = Self;
393
394    fn bitor(mut self, mut rhs: Self) -> Self::Output {
395        if let BooleanExpr::And(ref mut exprs) = self {
396            if let BooleanExpr::Or(rhs_exprs) = rhs {
397                exprs.extend(rhs_exprs);
398            } else {
399                exprs.push(rhs);
400            }
401            self
402        } else if let BooleanExpr::Or(ref mut rhs_exprs) = rhs {
403            rhs_exprs.push(self);
404            rhs
405        } else {
406            BooleanExpr::Or(vec![self, rhs])
407        }
408    }
409}