1use std::ops::{BitAnd, BitOr, Not};
2
3use rand::{Rng, RngExt};
4
5use crate::{
6 Expression, Type, TypeError, Val,
7 grammar::{FloatExpr, IntegerExpr, NaturalExpr},
8};
9
10#[derive(Debug, Clone)]
12pub enum BooleanExpr<V>
13where
14 V: Clone,
15{
16 Const(bool),
18 Var(V),
20 Rand(FloatExpr<V>),
22 And(Vec<BooleanExpr<V>>),
27 Or(Vec<BooleanExpr<V>>),
29 Implies(Box<(BooleanExpr<V>, BooleanExpr<V>)>),
31 Not(Box<BooleanExpr<V>>),
33 NatEqual(NaturalExpr<V>, NaturalExpr<V>),
38 IntEqual(IntegerExpr<V>, IntegerExpr<V>),
40 FloatEqual(FloatExpr<V>, FloatExpr<V>),
42 NatGreater(NaturalExpr<V>, NaturalExpr<V>),
44 IntGreater(IntegerExpr<V>, IntegerExpr<V>),
46 FloatGreater(FloatExpr<V>, FloatExpr<V>),
48 NatGreaterEq(NaturalExpr<V>, NaturalExpr<V>),
50 IntGreaterEq(IntegerExpr<V>, IntegerExpr<V>),
52 FloatGreaterEq(FloatExpr<V>, FloatExpr<V>),
54 NatLess(NaturalExpr<V>, NaturalExpr<V>),
56 IntLess(IntegerExpr<V>, IntegerExpr<V>),
58 FloatLess(FloatExpr<V>, FloatExpr<V>),
60 NatLessEq(NaturalExpr<V>, NaturalExpr<V>),
62 IntLessEq(IntegerExpr<V>, IntegerExpr<V>),
64 FloatLessEq(FloatExpr<V>, FloatExpr<V>),
66 Ite(Box<(BooleanExpr<V>, BooleanExpr<V>, BooleanExpr<V>)>),
73}
74
75impl<V> BooleanExpr<V>
76where
77 V: Clone,
78{
79 pub fn is_constant(&self) -> bool {
81 match self {
82 BooleanExpr::Const(_) => true,
83 BooleanExpr::Var(_) => false,
84 BooleanExpr::Rand(_float_expr) => false,
85 BooleanExpr::And(boolean_exprs) | BooleanExpr::Or(boolean_exprs) => {
86 boolean_exprs.iter().all(Self::is_constant)
87 }
88 BooleanExpr::Implies(args) => {
89 let (lhs, rhs) = args.as_ref();
90 lhs.is_constant() && rhs.is_constant()
91 }
92 BooleanExpr::Not(boolean_expr) => boolean_expr.is_constant(),
93 BooleanExpr::NatEqual(natural_expr_lhs, natural_expr_rhs)
94 | BooleanExpr::NatGreater(natural_expr_lhs, natural_expr_rhs)
95 | BooleanExpr::NatGreaterEq(natural_expr_lhs, natural_expr_rhs)
96 | BooleanExpr::NatLess(natural_expr_lhs, natural_expr_rhs)
97 | BooleanExpr::NatLessEq(natural_expr_lhs, natural_expr_rhs) => {
98 natural_expr_lhs.is_constant() && natural_expr_rhs.is_constant()
99 }
100 BooleanExpr::IntEqual(integer_expr, integer_expr1)
101 | BooleanExpr::IntGreater(integer_expr, integer_expr1)
102 | BooleanExpr::IntGreaterEq(integer_expr, integer_expr1)
103 | BooleanExpr::IntLess(integer_expr, integer_expr1)
104 | BooleanExpr::IntLessEq(integer_expr, integer_expr1) => {
105 integer_expr.is_constant() && integer_expr1.is_constant()
106 }
107 BooleanExpr::FloatEqual(float_expr, float_expr1)
108 | BooleanExpr::FloatLess(float_expr, float_expr1)
109 | BooleanExpr::FloatLessEq(float_expr, float_expr1)
110 | BooleanExpr::FloatGreater(float_expr, float_expr1)
111 | BooleanExpr::FloatGreaterEq(float_expr, float_expr1) => {
112 float_expr.is_constant() && float_expr1.is_constant()
113 }
114 BooleanExpr::Ite(args) => {
115 let (ite, lhs, rhs) = args.as_ref();
116 ite.is_constant() && lhs.is_constant() && rhs.is_constant()
117 }
118 }
119 }
120
121 pub fn eval<R: Rng>(&self, vars: &dyn Fn(&V) -> Val, rng: &mut R) -> bool {
128 match self {
129 BooleanExpr::Const(b) => *b,
130 BooleanExpr::Var(var) => {
131 if let Val::Boolean(b) = vars(var) {
132 b
133 } else {
134 panic!("type mismatch: expected boolean variable")
135 }
136 }
137 BooleanExpr::Rand(float_expr) => {
138 let bernoulli = float_expr.eval(vars, rng);
139 rng.random_bool(bernoulli)
140 }
141 BooleanExpr::And(boolean_exprs) => boolean_exprs
142 .iter()
143 .all(|boolean_expr| boolean_expr.eval(vars, rng)),
144 BooleanExpr::Or(boolean_exprs) => boolean_exprs
145 .iter()
146 .any(|boolean_expr| boolean_expr.eval(vars, rng)),
147 BooleanExpr::Implies(boolean_exprs) => {
148 let (lhs, rhs) = boolean_exprs.as_ref();
149 rhs.eval(vars, rng) || !lhs.eval(vars, rng)
150 }
151 BooleanExpr::Not(boolean_expr) => !&boolean_expr.eval(vars, rng),
152 BooleanExpr::NatEqual(natural_expr_lhs, natural_expr_rhs) => {
153 natural_expr_lhs.eval(vars, rng) == natural_expr_rhs.eval(vars, rng)
154 }
155 BooleanExpr::IntEqual(integer_expr_lhs, integer_expr_rhs) => {
156 integer_expr_lhs.eval(vars, rng) == integer_expr_rhs.eval(vars, rng)
157 }
158 BooleanExpr::FloatEqual(float_expr_lhs, float_expr_rhs) => {
159 float_expr_lhs.eval(vars, rng) == float_expr_rhs.eval(vars, rng)
160 }
161 BooleanExpr::NatGreater(natural_expr_lhs, natural_expr_rhs) => {
162 natural_expr_lhs.eval(vars, rng) > natural_expr_rhs.eval(vars, rng)
163 }
164 BooleanExpr::IntGreater(integer_expr_lhs, integer_expr_rhs) => {
165 integer_expr_lhs.eval(vars, rng) > integer_expr_rhs.eval(vars, rng)
166 }
167 BooleanExpr::FloatGreater(float_expr_lhs, float_expr_rhs) => {
168 float_expr_lhs.eval(vars, rng) > float_expr_rhs.eval(vars, rng)
169 }
170 BooleanExpr::NatGreaterEq(natural_expr_lhs, natural_expr_rhs) => {
171 natural_expr_lhs.eval(vars, rng) >= natural_expr_rhs.eval(vars, rng)
172 }
173 BooleanExpr::IntGreaterEq(integer_expr_lhs, integer_expr_rhs) => {
174 integer_expr_lhs.eval(vars, rng) >= integer_expr_rhs.eval(vars, rng)
175 }
176 BooleanExpr::FloatGreaterEq(float_expr_lhs, float_expr_rhs) => {
177 float_expr_lhs.eval(vars, rng) >= float_expr_rhs.eval(vars, rng)
178 }
179 BooleanExpr::NatLess(natural_expr_lhs, natural_expr_rhs) => {
180 natural_expr_lhs.eval(vars, rng) < natural_expr_rhs.eval(vars, rng)
181 }
182 BooleanExpr::IntLess(integer_expr_lhs, integer_expr_rhs) => {
183 integer_expr_lhs.eval(vars, rng) < integer_expr_rhs.eval(vars, rng)
184 }
185 BooleanExpr::FloatLess(float_expr_lhs, float_expr_rhs) => {
186 float_expr_lhs.eval(vars, rng) < float_expr_rhs.eval(vars, rng)
187 }
188 BooleanExpr::NatLessEq(natural_expr_lhs, natural_expr_rhs) => {
189 natural_expr_lhs.eval(vars, rng) <= natural_expr_rhs.eval(vars, rng)
190 }
191 BooleanExpr::IntLessEq(integer_expr_lhs, integer_expr_rhs) => {
192 integer_expr_lhs.eval(vars, rng) <= integer_expr_rhs.eval(vars, rng)
193 }
194 BooleanExpr::FloatLessEq(float_expr_lhs, float_expr_rhs) => {
195 float_expr_lhs.eval(vars, rng) <= float_expr_rhs.eval(vars, rng)
196 }
197 BooleanExpr::Ite(args) => {
198 let (ite, lhs, rhs) = args.as_ref();
199 if ite.eval(vars, rng) {
200 lhs.eval(vars, rng)
201 } else {
202 rhs.eval(vars, rng)
203 }
204 }
205 }
206 }
207
208 pub(crate) fn map<W: Clone>(self, map: &dyn Fn(V) -> W) -> BooleanExpr<W> {
209 match self {
210 BooleanExpr::Const(b) => BooleanExpr::Const(b),
211 BooleanExpr::Var(var) => BooleanExpr::Var(map(var)),
212 BooleanExpr::Rand(float_expr) => BooleanExpr::Rand(float_expr.map(map)),
213 BooleanExpr::And(boolean_exprs) => BooleanExpr::And(
214 boolean_exprs
215 .into_iter()
216 .map(|expr| expr.map(map))
217 .collect(),
218 ),
219 BooleanExpr::Or(boolean_exprs) => BooleanExpr::Or(
220 boolean_exprs
221 .into_iter()
222 .map(|expr| expr.map(map))
223 .collect(),
224 ),
225 BooleanExpr::Implies(args) => {
226 let (lhs, rhs) = *args;
227 BooleanExpr::Implies(Box::new((lhs.map(map), rhs.map(map))))
228 }
229 BooleanExpr::Not(boolean_expr) => BooleanExpr::Not(Box::new(boolean_expr.map(map))),
230 BooleanExpr::NatEqual(natural_expr_lhs, natural_expr_rhs) => {
231 BooleanExpr::NatEqual(natural_expr_lhs.map(map), natural_expr_rhs.map(map))
232 }
233 BooleanExpr::IntEqual(integer_expr_lhs, integer_expr_rhs) => {
234 BooleanExpr::IntEqual(integer_expr_lhs.map(map), integer_expr_rhs.map(map))
235 }
236 BooleanExpr::FloatEqual(float_expr_lhs, float_expr_rhs) => {
237 BooleanExpr::FloatEqual(float_expr_lhs.map(map), float_expr_rhs.map(map))
238 }
239 BooleanExpr::NatGreater(natural_expr_lhs, natural_expr_rhs) => {
240 BooleanExpr::NatGreater(natural_expr_lhs.map(map), natural_expr_rhs.map(map))
241 }
242 BooleanExpr::IntGreater(integer_expr_lhs, integer_expr_rhs) => {
243 BooleanExpr::IntGreater(integer_expr_lhs.map(map), integer_expr_rhs.map(map))
244 }
245 BooleanExpr::FloatGreater(float_expr_lhs, float_expr_rhs) => {
246 BooleanExpr::FloatGreater(float_expr_lhs.map(map), float_expr_rhs.map(map))
247 }
248 BooleanExpr::NatGreaterEq(natural_expr_lhs, natural_expr_rhs) => {
249 BooleanExpr::NatGreaterEq(natural_expr_lhs.map(map), natural_expr_rhs.map(map))
250 }
251 BooleanExpr::IntGreaterEq(integer_expr_lhs, integer_expr_rhs) => {
252 BooleanExpr::IntGreaterEq(integer_expr_lhs.map(map), integer_expr_rhs.map(map))
253 }
254 BooleanExpr::FloatGreaterEq(float_expr_lhs, float_expr_rhs) => {
255 BooleanExpr::FloatGreaterEq(float_expr_lhs.map(map), float_expr_rhs.map(map))
256 }
257 BooleanExpr::NatLess(natural_expr_lhs, natural_expr_rhs) => {
258 BooleanExpr::NatLess(natural_expr_lhs.map(map), natural_expr_rhs.map(map))
259 }
260 BooleanExpr::IntLess(integer_expr_lhs, integer_expr_rhs) => {
261 BooleanExpr::IntLess(integer_expr_lhs.map(map), integer_expr_rhs.map(map))
262 }
263 BooleanExpr::FloatLess(float_expr_lhs, float_expr_rhs) => {
264 BooleanExpr::FloatLess(float_expr_lhs.map(map), float_expr_rhs.map(map))
265 }
266 BooleanExpr::NatLessEq(natural_expr_lhs, natural_expr_rhs) => {
267 BooleanExpr::NatLessEq(natural_expr_lhs.map(map), natural_expr_rhs.map(map))
268 }
269 BooleanExpr::IntLessEq(integer_expr_lhs, integer_expr_rhs) => {
270 BooleanExpr::IntLessEq(integer_expr_lhs.map(map), integer_expr_rhs.map(map))
271 }
272 BooleanExpr::FloatLessEq(float_expr_lhs, float_expr_rhs) => {
273 BooleanExpr::FloatLessEq(float_expr_lhs.map(map), float_expr_rhs.map(map))
274 }
275 BooleanExpr::Ite(args) => {
276 let (r#if, then, r#else) = *args;
277 BooleanExpr::Ite(Box::new((r#if.map(map), then.map(map), r#else.map(map))))
278 }
279 }
280 }
281
282 pub(crate) fn context(&self, vars: &dyn Fn(V) -> Option<Type>) -> Result<(), TypeError> {
283 match self {
284 BooleanExpr::Const(_) => Ok(()),
285 BooleanExpr::Var(v) => matches!(vars(v.clone()), Some(Type::Boolean))
286 .then_some(())
287 .ok_or(TypeError::TypeMismatch),
288 BooleanExpr::Rand(float_expr) => float_expr.context(vars),
289 BooleanExpr::And(boolean_exprs) | BooleanExpr::Or(boolean_exprs) => {
290 boolean_exprs.iter().try_for_each(|expr| expr.context(vars))
291 }
292 BooleanExpr::Implies(exprs) => {
293 exprs.0.context(vars).and_then(|()| exprs.1.context(vars))
294 }
295 BooleanExpr::Not(boolean_expr) => boolean_expr.context(vars),
296 BooleanExpr::NatEqual(natural_expr_lhs, natural_expr_rhs)
297 | BooleanExpr::NatGreater(natural_expr_lhs, natural_expr_rhs)
298 | BooleanExpr::NatGreaterEq(natural_expr_lhs, natural_expr_rhs)
299 | BooleanExpr::NatLess(natural_expr_lhs, natural_expr_rhs)
300 | BooleanExpr::NatLessEq(natural_expr_lhs, natural_expr_rhs) => natural_expr_lhs
301 .context(vars)
302 .and_then(|()| natural_expr_rhs.context(vars)),
303 BooleanExpr::IntEqual(integer_expr_lhs, integer_expr_rhs)
304 | BooleanExpr::IntGreater(integer_expr_lhs, integer_expr_rhs)
305 | BooleanExpr::IntGreaterEq(integer_expr_lhs, integer_expr_rhs)
306 | BooleanExpr::IntLess(integer_expr_lhs, integer_expr_rhs)
307 | BooleanExpr::IntLessEq(integer_expr_lhs, integer_expr_rhs) => integer_expr_lhs
308 .context(vars)
309 .and_then(|()| integer_expr_rhs.context(vars)),
310 BooleanExpr::FloatGreater(float_expr_lhs, float_expr_rhs)
311 | BooleanExpr::FloatLess(float_expr_lhs, float_expr_rhs)
312 | BooleanExpr::FloatEqual(float_expr_lhs, float_expr_rhs)
313 | BooleanExpr::FloatGreaterEq(float_expr_lhs, float_expr_rhs)
314 | BooleanExpr::FloatLessEq(float_expr_lhs, float_expr_rhs) => float_expr_lhs
315 .context(vars)
316 .and_then(|()| float_expr_rhs.context(vars)),
317 BooleanExpr::Ite(exprs) => exprs
318 .0
319 .context(vars)
320 .and_then(|()| exprs.1.context(vars))
321 .and_then(|()| exprs.2.context(vars)),
322 }
323 }
324}
325
326impl<V> From<bool> for BooleanExpr<V>
327where
328 V: Clone,
329{
330 fn from(value: bool) -> Self {
331 Self::Const(value)
332 }
333}
334
335impl<V> TryFrom<Expression<V>> for BooleanExpr<V>
336where
337 V: Clone,
338{
339 type Error = TypeError;
340
341 fn try_from(value: Expression<V>) -> Result<Self, Self::Error> {
342 if let Expression::Boolean(bool_expr) = value {
343 Ok(bool_expr)
344 } else {
345 Err(TypeError::TypeMismatch)
346 }
347 }
348}
349
350impl<V> Not for BooleanExpr<V>
351where
352 V: Clone,
353{
354 type Output = Self;
355
356 fn not(self) -> Self::Output {
357 if let Self::Not(expr) = self {
358 *expr
359 } else {
360 Self::Not(Box::new(self))
361 }
362 }
363}
364
365impl<V> BitAnd for BooleanExpr<V>
366where
367 V: Clone,
368{
369 type Output = Self;
370
371 fn bitand(mut self, mut rhs: Self) -> Self::Output {
372 if let BooleanExpr::And(ref mut exprs) = self {
373 if let BooleanExpr::And(rhs_exprs) = rhs {
374 exprs.extend(rhs_exprs);
375 } else {
376 exprs.push(rhs);
377 }
378 self
379 } else if let BooleanExpr::And(ref mut rhs_exprs) = rhs {
380 rhs_exprs.push(self);
381 rhs
382 } else {
383 BooleanExpr::And(vec![self, rhs])
384 }
385 }
386}
387
388impl<V> BitOr for BooleanExpr<V>
389where
390 V: Clone,
391{
392 type Output = Self;
393
394 fn bitor(mut self, mut rhs: Self) -> Self::Output {
395 if let BooleanExpr::And(ref mut exprs) = self {
396 if let BooleanExpr::Or(rhs_exprs) = rhs {
397 exprs.extend(rhs_exprs);
398 } else {
399 exprs.push(rhs);
400 }
401 self
402 } else if let BooleanExpr::Or(ref mut rhs_exprs) = rhs {
403 rhs_exprs.push(self);
404 rhs
405 } else {
406 BooleanExpr::Or(vec![self, rhs])
407 }
408 }
409}