Skip to main content

scan_core/grammar/
natural.rs

1use std::ops::{Add, Div, Mul, Rem};
2
3use rand::{Rng, RngExt};
4
5use crate::{
6    Expression, Type, TypeError, Val,
7    grammar::{BooleanExpr, IntegerExpr},
8};
9
10/// Natural (unsigned) values.
11pub type Natural = u64;
12
13/// A [`Natural`] number expression
14#[derive(Debug, Clone)]
15pub enum NaturalExpr<V>
16where
17    V: Clone,
18{
19    // -------------------
20    // General expressions
21    // -------------------
22    /// A constant value.
23    Const(Natural),
24    /// A typed variable.
25    Var(V),
26    // -------------
27    // Random values
28    // -------------
29    /// A random integer between a lower bound (included) and an upper bound (excluded).
30    Rand(Box<(NaturalExpr<V>, NaturalExpr<V>)>),
31    // --------------------
32    // Arithmetic operators
33    // --------------------
34    /// Arithmetic n-ary sum.
35    Sum(Vec<NaturalExpr<V>>),
36    /// Arithmetic n-ary multiplication.
37    Product(Vec<NaturalExpr<V>>),
38    /// Mod operation
39    Rem(Box<(NaturalExpr<V>, NaturalExpr<V>)>),
40    /// Div operation
41    Div(Box<(NaturalExpr<V>, NaturalExpr<V>)>),
42    /// Abs
43    Abs(Box<IntegerExpr<V>>),
44    // -----
45    // Flow
46    // -----
47    /// If-Then-Else construct, where If must be a boolean expression,
48    /// Then and Else must have the same type,
49    /// and this is also the type of the whole expression.
50    Ite(Box<(BooleanExpr<V>, NaturalExpr<V>, NaturalExpr<V>)>),
51}
52
53impl<V> NaturalExpr<V>
54where
55    V: Clone,
56{
57    /// Returns `true` if the expression is constant, i.e., it contains no variables, and `false` otherwise.
58    pub fn is_constant(&self) -> bool {
59        match self {
60            NaturalExpr::Const(_) => true,
61            NaturalExpr::Var(_) | NaturalExpr::Rand(_) => false,
62            NaturalExpr::Sum(natural_exprs) | NaturalExpr::Product(natural_exprs) => {
63                natural_exprs.iter().all(NaturalExpr::is_constant)
64            }
65            NaturalExpr::Rem(args) | NaturalExpr::Div(args) => {
66                let (lhs, rhs) = args.as_ref();
67                lhs.is_constant() && rhs.is_constant()
68            }
69            NaturalExpr::Abs(integer_expr) => integer_expr.is_constant(),
70            NaturalExpr::Ite(args) => {
71                let (ite, lhs, rhs) = args.as_ref();
72                ite.is_constant() && lhs.is_constant() && rhs.is_constant()
73            }
74        }
75    }
76
77    /// Returns the [`Natural`] value computed from the expression,
78    /// given the variable evaluation.
79    /// It panics if the evaluation is not possible, including:
80    ///
81    /// - If a variable is not included in the evaluation;
82    /// - If a variable included in the evaluation is not of [`Natural`] type;
83    /// - Division by 0;
84    /// - Overflow.
85    pub fn eval<R: Rng>(&self, vars: &dyn Fn(&V) -> Val, rng: &mut R) -> Natural {
86        match self {
87            NaturalExpr::Const(nat) => *nat,
88            NaturalExpr::Var(var) => {
89                if let Val::Natural(nat) = vars(var) {
90                    nat
91                } else {
92                    panic!("type mismatch: expected natural variable")
93                }
94            }
95            NaturalExpr::Rand(bounds) => {
96                let (lower_bound_expr, upper_bound_expr) = bounds.as_ref();
97                let lower_bound = lower_bound_expr.eval(vars, rng);
98                let upper_bound = upper_bound_expr.eval(vars, rng);
99                rng.random_range(lower_bound..upper_bound)
100            }
101            NaturalExpr::Sum(natural_exprs) => natural_exprs
102                .iter()
103                .fold(0, |acc, expr| acc.strict_add(expr.eval(vars, rng))),
104            NaturalExpr::Product(natural_exprs) => natural_exprs
105                .iter()
106                .fold(1, |acc, expr| acc.strict_mul(expr.eval(vars, rng))),
107            NaturalExpr::Rem(args) => {
108                let (lhs_expr, rhs_expr) = args.as_ref();
109                let lhs = lhs_expr.eval(vars, rng);
110                let rhs = rhs_expr.eval(vars, rng);
111                lhs.strict_rem_euclid(rhs)
112            }
113            NaturalExpr::Div(_) => todo!(),
114            NaturalExpr::Abs(integer_expr) => integer_expr.eval(vars, rng).unsigned_abs(),
115            NaturalExpr::Ite(args) => {
116                let (ite, lhs, rhs) = args.as_ref();
117                if ite.eval(vars, rng) {
118                    lhs.eval(vars, rng)
119                } else {
120                    rhs.eval(vars, rng)
121                }
122            }
123        }
124    }
125
126    pub(crate) fn map<W: Clone>(self, map: &dyn Fn(V) -> W) -> NaturalExpr<W> {
127        match self {
128            NaturalExpr::Const(n) => NaturalExpr::Const(n),
129            NaturalExpr::Var(var) => NaturalExpr::Var(map(var)),
130            NaturalExpr::Rand(bounds) => {
131                let (lower_bound, upper_bound) = *bounds;
132                NaturalExpr::Rand(Box::new((lower_bound.map(map), upper_bound.map(map))))
133            }
134            NaturalExpr::Sum(natural_exprs) => NaturalExpr::Sum(
135                natural_exprs
136                    .into_iter()
137                    .map(|expr| expr.map(map))
138                    .collect(),
139            ),
140            NaturalExpr::Product(natural_exprs) => NaturalExpr::Product(
141                natural_exprs
142                    .into_iter()
143                    .map(|expr| expr.map(map))
144                    .collect(),
145            ),
146            NaturalExpr::Rem(args) => {
147                let (lhs, rhs) = *args;
148                NaturalExpr::Rem(Box::new((lhs.map(map), rhs.map(map))))
149            }
150            NaturalExpr::Div(args) => {
151                let (lhs, rhs) = *args;
152                NaturalExpr::Div(Box::new((lhs.map(map), rhs.map(map))))
153            }
154            NaturalExpr::Abs(integer_expr) => NaturalExpr::Abs(Box::new(integer_expr.map(map))),
155            NaturalExpr::Ite(args) => {
156                let (r#if, then, r#else) = *args;
157                NaturalExpr::Ite(Box::new((r#if.map(map), then.map(map), r#else.map(map))))
158            }
159        }
160    }
161
162    pub(crate) fn context(&self, vars: &dyn Fn(V) -> Option<Type>) -> Result<(), TypeError> {
163        match self {
164            NaturalExpr::Const(_) => Ok(()),
165            NaturalExpr::Var(v) => matches!(vars(v.clone()), Some(Type::Natural))
166                .then_some(())
167                .ok_or(TypeError::TypeMismatch),
168            NaturalExpr::Rand(exprs) | NaturalExpr::Div(exprs) | NaturalExpr::Rem(exprs) => {
169                exprs.0.context(vars).and_then(|()| exprs.1.context(vars))
170            }
171            NaturalExpr::Sum(integer_exprs) | NaturalExpr::Product(integer_exprs) => {
172                integer_exprs.iter().try_for_each(|expr| expr.context(vars))
173            }
174            NaturalExpr::Ite(exprs) => exprs
175                .0
176                .context(vars)
177                .and_then(|()| exprs.1.context(vars))
178                .and_then(|()| exprs.2.context(vars)),
179            NaturalExpr::Abs(integer_expr) => integer_expr.context(vars),
180        }
181    }
182}
183
184impl<V: Clone> From<Natural> for NaturalExpr<V> {
185    fn from(value: Natural) -> Self {
186        NaturalExpr::Const(value)
187    }
188}
189
190impl<V> TryFrom<Expression<V>> for NaturalExpr<V>
191where
192    V: Clone,
193{
194    type Error = TypeError;
195
196    fn try_from(value: Expression<V>) -> Result<Self, Self::Error> {
197        if let Expression::Natural(nat_expr) = value {
198            Ok(nat_expr)
199        } else {
200            Err(TypeError::TypeMismatch)
201        }
202    }
203}
204
205impl<V> Add for NaturalExpr<V>
206where
207    V: Clone,
208{
209    type Output = Self;
210
211    fn add(mut self, mut rhs: Self) -> Self::Output {
212        if let NaturalExpr::Sum(ref mut exprs) = self {
213            if let NaturalExpr::Sum(rhs_exprs) = rhs {
214                exprs.extend(rhs_exprs);
215            } else {
216                exprs.push(rhs);
217            }
218            self
219        } else if let NaturalExpr::Sum(ref mut rhs_exprs) = rhs {
220            rhs_exprs.push(self);
221            rhs
222        } else {
223            NaturalExpr::Sum(vec![self, rhs])
224        }
225    }
226}
227
228impl<V> Mul for NaturalExpr<V>
229where
230    V: Clone,
231{
232    type Output = Self;
233
234    fn mul(mut self, mut rhs: Self) -> Self::Output {
235        if let NaturalExpr::Product(ref mut exprs) = self {
236            if let NaturalExpr::Product(rhs_exprs) = rhs {
237                exprs.extend(rhs_exprs);
238            } else {
239                exprs.push(rhs);
240            }
241            self
242        } else if let NaturalExpr::Product(ref mut rhs_exprs) = rhs {
243            rhs_exprs.push(self);
244            rhs
245        } else {
246            NaturalExpr::Product(vec![self, rhs])
247        }
248    }
249}
250
251impl<V> Div for NaturalExpr<V>
252where
253    V: Clone,
254{
255    type Output = Self;
256
257    fn div(self, rhs: Self) -> Self::Output {
258        NaturalExpr::Div(Box::new((self, rhs)))
259    }
260}
261
262impl<V> Rem for NaturalExpr<V>
263where
264    V: Clone,
265{
266    type Output = Self;
267
268    fn rem(self, rhs: Self) -> Self::Output {
269        NaturalExpr::Rem(Box::new((self, rhs)))
270    }
271}