Skip to main content

scan_core/grammar/
integer.rs

1use std::ops::{Add, Div, Mul, Neg};
2
3use rand::{Rng, RngExt};
4
5use crate::{
6    Expression, Type, TypeError, Val,
7    grammar::{BooleanExpr, FloatExpr, NaturalExpr},
8};
9
10/// Integer values.
11pub type Integer = i64;
12
13/// Integer expressions.
14#[derive(Debug, Clone)]
15pub enum IntegerExpr<V>
16where
17    V: Clone,
18{
19    // -------------------
20    // General expressions
21    // -------------------
22    /// A constant value.
23    Const(Integer),
24    /// A typed variable.
25    Var(V),
26    /// Conversion from Natural
27    Nat(NaturalExpr<V>),
28    // -------------
29    // Random values
30    // -------------
31    /// A random integer between a lower bound (included) and an upper bound (excluded).
32    Rand(Box<(IntegerExpr<V>, IntegerExpr<V>)>),
33    // --------------------
34    // Arithmetic operators
35    // --------------------
36    /// Opposite of a numerical expression.
37    Opposite(Box<IntegerExpr<V>>),
38    /// Arithmetic n-ary sum.
39    Sum(Vec<IntegerExpr<V>>),
40    /// Arithmetic n-ary multiplication.
41    Product(Vec<IntegerExpr<V>>),
42    /// Div operation
43    Div(Box<(IntegerExpr<V>, IntegerExpr<V>)>),
44    /// Rem operation
45    Rem(Box<(IntegerExpr<V>, IntegerExpr<V>)>),
46    /// Floor
47    Floor(Box<FloatExpr<V>>),
48    // -----
49    // Flow
50    // -----
51    /// If-Then-Else construct, where If must be a boolean expression,
52    /// Then and Else must have the same type,
53    /// and this is also the type of the whole expression.
54    Ite(Box<(BooleanExpr<V>, IntegerExpr<V>, IntegerExpr<V>)>),
55}
56
57impl<V> IntegerExpr<V>
58where
59    V: Clone,
60{
61    /// Returns `true` if the expression is constant, i.e., it contains no variables, and `false` otherwise.
62    pub fn is_constant(&self) -> bool {
63        match self {
64            IntegerExpr::Const(_) => true,
65            IntegerExpr::Var(_) | IntegerExpr::Rand(_) => false,
66            IntegerExpr::Nat(natural_expr) => natural_expr.is_constant(),
67            IntegerExpr::Opposite(integer_expr) => integer_expr.is_constant(),
68            IntegerExpr::Sum(integer_exprs) | IntegerExpr::Product(integer_exprs) => {
69                integer_exprs.iter().all(IntegerExpr::is_constant)
70            }
71            IntegerExpr::Div(args) | IntegerExpr::Rem(args) => {
72                let (lhs, rhs) = args.as_ref();
73                lhs.is_constant() && rhs.is_constant()
74            }
75            IntegerExpr::Floor(float_expr) => float_expr.is_constant(),
76            IntegerExpr::Ite(args) => {
77                let (ite, lhs, rhs) = args.as_ref();
78                ite.is_constant() && lhs.is_constant() && rhs.is_constant()
79            }
80        }
81    }
82
83    /// Returns the [`Integer`] value computed from the expression,
84    /// given the variable evaluation.
85    /// It panics if the evaluation is not possible, including:
86    ///
87    /// - If a variable is not included in the evaluation;
88    /// - If a variable included in the evaluation is not of [`Integer`] type;
89    /// - Division by 0;
90    /// - Overflow.
91    pub fn eval<R: Rng>(&self, vars: &dyn Fn(&V) -> Val, rng: &mut R) -> Integer {
92        match self {
93            IntegerExpr::Const(int) => *int,
94            IntegerExpr::Var(var) => {
95                if let Val::Integer(int) = vars(var) {
96                    int
97                } else {
98                    panic!("type mismatch: expected natural variable")
99                }
100            }
101            IntegerExpr::Nat(natural_expr) => natural_expr.eval(vars, rng) as Integer,
102            IntegerExpr::Rand(bounds) => {
103                let (lower_bound_expr, upper_bound_expr) = bounds.as_ref();
104                let lower_bound = lower_bound_expr.eval(vars, rng);
105                let upper_bound = upper_bound_expr.eval(vars, rng);
106                rng.random_range(lower_bound..upper_bound)
107            }
108            IntegerExpr::Opposite(integer_expr) => integer_expr.eval(vars, rng).strict_neg(),
109            IntegerExpr::Sum(integer_exprs) => integer_exprs
110                .iter()
111                .fold(0, |acc, expr| acc.strict_add(expr.eval(vars, rng))),
112            IntegerExpr::Product(integer_exprs) => integer_exprs
113                .iter()
114                .fold(1, |acc, expr| acc.strict_mul(expr.eval(vars, rng))),
115            IntegerExpr::Div(args) => {
116                let (lhs_expr, rhs_expr) = args.as_ref();
117                let lhs = lhs_expr.eval(vars, rng);
118                let rhs = rhs_expr.eval(vars, rng);
119                lhs.strict_div(rhs)
120            }
121            IntegerExpr::Rem(args) => {
122                let (lhs_expr, rhs_expr) = args.as_ref();
123                let lhs = lhs_expr.eval(vars, rng);
124                let rhs = rhs_expr.eval(vars, rng);
125                lhs.strict_rem_euclid(rhs)
126            }
127            // NOTE WARN: is float-to-int floor operation sound?
128            IntegerExpr::Floor(float_expr) => float_expr.eval(vars, rng).floor() as Integer,
129            IntegerExpr::Ite(args) => {
130                let (ite, lhs, rhs) = args.as_ref();
131                if ite.eval(vars, rng) {
132                    lhs.eval(vars, rng)
133                } else {
134                    rhs.eval(vars, rng)
135                }
136            }
137        }
138    }
139
140    pub(crate) fn map<W: Clone>(self, map: &dyn Fn(V) -> W) -> IntegerExpr<W> {
141        match self {
142            IntegerExpr::Const(i) => IntegerExpr::Const(i),
143            IntegerExpr::Var(var) => IntegerExpr::Var(map(var)),
144            IntegerExpr::Nat(nat_expr) => IntegerExpr::Nat(nat_expr.map(map)),
145            IntegerExpr::Rand(bounds) => {
146                let (lower_bound, upper_bound) = *bounds;
147                IntegerExpr::Rand(Box::new((lower_bound.map(map), upper_bound.map(map))))
148            }
149            IntegerExpr::Opposite(integer_expr) => {
150                IntegerExpr::Opposite(Box::new(integer_expr.map(map)))
151            }
152            IntegerExpr::Sum(integer_exprs) => IntegerExpr::Sum(
153                integer_exprs
154                    .into_iter()
155                    .map(|expr| expr.map(map))
156                    .collect(),
157            ),
158            IntegerExpr::Product(integer_exprs) => IntegerExpr::Product(
159                integer_exprs
160                    .into_iter()
161                    .map(|expr| expr.map(map))
162                    .collect(),
163            ),
164            IntegerExpr::Div(args) => {
165                let (lhs, rhs) = *args;
166                IntegerExpr::Div(Box::new((lhs.map(map), rhs.map(map))))
167            }
168            IntegerExpr::Rem(args) => {
169                let (lhs, rhs) = *args;
170                IntegerExpr::Rem(Box::new((lhs.map(map), rhs.map(map))))
171            }
172            IntegerExpr::Floor(float_expr) => IntegerExpr::Floor(Box::new(float_expr.map(map))),
173            IntegerExpr::Ite(args) => {
174                let (r#if, then, r#else) = *args;
175                IntegerExpr::Ite(Box::new((r#if.map(map), then.map(map), r#else.map(map))))
176            }
177        }
178    }
179
180    pub(crate) fn context(&self, vars: &dyn Fn(V) -> Option<Type>) -> Result<(), TypeError> {
181        match self {
182            IntegerExpr::Const(_) => Ok(()),
183            IntegerExpr::Var(v) => matches!(vars(v.clone()), Some(Type::Integer))
184                .then_some(())
185                .ok_or(TypeError::TypeMismatch),
186            IntegerExpr::Nat(natural_expr) => natural_expr.context(vars),
187            IntegerExpr::Rand(exprs) | IntegerExpr::Div(exprs) | IntegerExpr::Rem(exprs) => {
188                exprs.0.context(vars).and_then(|()| exprs.1.context(vars))
189            }
190            IntegerExpr::Opposite(integer_expr) => integer_expr.context(vars),
191            IntegerExpr::Sum(integer_exprs) | IntegerExpr::Product(integer_exprs) => {
192                integer_exprs.iter().try_for_each(|expr| expr.context(vars))
193            }
194            IntegerExpr::Floor(float_expr) => float_expr.context(vars),
195            IntegerExpr::Ite(exprs) => exprs
196                .0
197                .context(vars)
198                .and_then(|()| exprs.1.context(vars))
199                .and_then(|()| exprs.2.context(vars)),
200        }
201    }
202}
203
204impl<V> From<Integer> for IntegerExpr<V>
205where
206    V: Clone,
207{
208    fn from(value: Integer) -> Self {
209        Self::Const(value)
210    }
211}
212
213impl<V> TryFrom<Expression<V>> for IntegerExpr<V>
214where
215    V: Clone,
216{
217    type Error = TypeError;
218
219    fn try_from(value: Expression<V>) -> Result<Self, Self::Error> {
220        match value {
221            Expression::Boolean(_) | Expression::Float(_) => Err(TypeError::TypeMismatch),
222            Expression::Natural(natural_expr) => Ok(IntegerExpr::Nat(natural_expr)),
223            Expression::Integer(integer_expr) => Ok(integer_expr),
224        }
225    }
226}
227
228impl<V> From<NaturalExpr<V>> for IntegerExpr<V>
229where
230    V: Clone,
231{
232    fn from(value: NaturalExpr<V>) -> Self {
233        Self::Nat(value)
234    }
235}
236
237impl<V> Add for IntegerExpr<V>
238where
239    V: Clone,
240{
241    type Output = Self;
242
243    fn add(mut self, mut rhs: Self) -> Self::Output {
244        if let IntegerExpr::Sum(ref mut exprs) = self {
245            if let IntegerExpr::Sum(rhs_exprs) = rhs {
246                exprs.extend(rhs_exprs);
247            } else {
248                exprs.push(rhs);
249            }
250            self
251        } else if let IntegerExpr::Sum(ref mut rhs_exprs) = rhs {
252            rhs_exprs.push(self);
253            rhs
254        } else {
255            IntegerExpr::Sum(vec![self, rhs])
256        }
257    }
258}
259
260impl<V> Mul for IntegerExpr<V>
261where
262    V: Clone,
263{
264    type Output = Self;
265
266    fn mul(mut self, mut rhs: Self) -> Self::Output {
267        if let IntegerExpr::Product(ref mut exprs) = self {
268            if let IntegerExpr::Product(rhs_exprs) = rhs {
269                exprs.extend(rhs_exprs);
270            } else {
271                exprs.push(rhs);
272            }
273            self
274        } else if let IntegerExpr::Product(ref mut rhs_exprs) = rhs {
275            rhs_exprs.push(self);
276            rhs
277        } else {
278            IntegerExpr::Product(vec![self, rhs])
279        }
280    }
281}
282
283impl<V> Div for IntegerExpr<V>
284where
285    V: Clone,
286{
287    type Output = Self;
288
289    fn div(self, rhs: Self) -> Self::Output {
290        IntegerExpr::Div(Box::new((self, rhs)))
291    }
292}
293
294impl<V> Neg for IntegerExpr<V>
295where
296    V: Clone,
297{
298    type Output = Self;
299
300    fn neg(self) -> Self::Output {
301        if let IntegerExpr::Opposite(expr) = self {
302            *expr
303        } else {
304            IntegerExpr::Opposite(Box::new(self))
305        }
306    }
307}