Skip to main content

scan_core/grammar/
float.rs

1use std::ops::{Add, Div, Mul, Neg};
2
3use rand::{Rng, RngExt};
4
5use crate::{
6    Expression, Type, TypeError, Val,
7    grammar::{BooleanExpr, IntegerExpr, NaturalExpr},
8};
9
10/// Floating-point values.
11pub type Float = f64;
12
13/// Floating-point numerical expression.
14#[derive(Debug, Clone)]
15pub enum FloatExpr<V>
16where
17    V: Clone,
18{
19    // -------------------
20    // General expressions
21    // -------------------
22    /// A constant value.
23    Const(Float),
24    /// A typed variable.
25    Var(V),
26    /// Conversion from Natural
27    Nat(NaturalExpr<V>),
28    /// Conversion from Integer
29    Int(IntegerExpr<V>),
30    // -------------
31    // Random values
32    // -------------
33    /// A random integer between a lower bound (included) and an upper bound (excluded).
34    Rand(Box<(FloatExpr<V>, FloatExpr<V>)>),
35    // --------------------
36    // Arithmetic operators
37    // --------------------
38    /// Opposite of a numerical expression.
39    Opposite(Box<FloatExpr<V>>),
40    /// Arithmetic n-ary sum.
41    Sum(Vec<FloatExpr<V>>),
42    /// Arithmetic n-ary multiplication.
43    Product(Vec<FloatExpr<V>>),
44    /// Div operation
45    Div(Box<(FloatExpr<V>, FloatExpr<V>)>),
46    // -----
47    // Flow
48    // -----
49    /// If-Then-Else construct, where If must be a boolean expression,
50    /// Then and Else must have the same type,
51    /// and this is also the type of the whole expression.
52    Ite(Box<(BooleanExpr<V>, FloatExpr<V>, FloatExpr<V>)>),
53}
54
55impl<V> FloatExpr<V>
56where
57    V: Clone,
58{
59    /// Returns `true` if the expression is constant, i.e., it contains no variables, and `false` otherwise.
60    pub fn is_constant(&self) -> bool {
61        match self {
62            FloatExpr::Const(_) => true,
63            FloatExpr::Var(_) | FloatExpr::Rand(_) => false,
64            FloatExpr::Nat(natural_expr) => natural_expr.is_constant(),
65            FloatExpr::Int(integer_expr) => integer_expr.is_constant(),
66
67            FloatExpr::Opposite(float_expr) => float_expr.is_constant(),
68            FloatExpr::Sum(float_exprs) | FloatExpr::Product(float_exprs) => {
69                float_exprs.iter().all(FloatExpr::is_constant)
70            }
71            FloatExpr::Div(args) => {
72                let (lhs, rhs) = args.as_ref();
73                lhs.is_constant() && rhs.is_constant()
74            }
75            FloatExpr::Ite(args) => {
76                let (ite, lhs, rhs) = args.as_ref();
77                ite.is_constant() && lhs.is_constant() && rhs.is_constant()
78            }
79        }
80    }
81
82    /// Returns the [`Float`] value computed from the expression,
83    /// given the variable evaluation.
84    /// It panics if the evaluation is not possible, including:
85    ///
86    /// - If a variable is not included in the evaluation;
87    /// - If a variable included in the evaluation is not of [`Float`] type;
88    /// - Division by 0;
89    /// - Overflow.
90    pub fn eval<R: Rng>(&self, vars: &dyn Fn(&V) -> Val, rng: &mut R) -> Float {
91        match self {
92            FloatExpr::Const(float) => *float,
93            FloatExpr::Var(var) => {
94                if let Val::Float(float) = vars(var) {
95                    float
96                } else {
97                    panic!("type mismatch: expected float variable")
98                }
99            }
100            // NOTE WARN: the u64 as f64 is lossy!
101            FloatExpr::Nat(natural_expr) => natural_expr.eval(vars, rng) as f64,
102            // NOTE WARN: the i64 as f64 is lossy!
103            FloatExpr::Int(integer_expr) => integer_expr.eval(vars, rng) as f64,
104            FloatExpr::Rand(bounds) => {
105                let (lower_bound_expr, upper_bound_expr) = bounds.as_ref();
106                let lower_bound = lower_bound_expr.eval(vars, rng);
107                let upper_bound = upper_bound_expr.eval(vars, rng);
108                rng.random_range(lower_bound..upper_bound)
109            }
110            FloatExpr::Opposite(float_expr) => -float_expr.eval(vars, rng),
111            FloatExpr::Sum(float_exprs) => {
112                float_exprs.iter().map(|expr| expr.eval(vars, rng)).sum()
113            }
114            FloatExpr::Product(float_exprs) => float_exprs
115                .iter()
116                .map(|expr| expr.eval(vars, rng))
117                .product(),
118            FloatExpr::Div(args) => {
119                let (lhs_expr, rhs_expr) = args.as_ref();
120                let lhs = lhs_expr.eval(vars, rng);
121                let rhs = rhs_expr.eval(vars, rng);
122                lhs / rhs
123            }
124            FloatExpr::Ite(args) => {
125                let (ite, lhs, rhs) = args.as_ref();
126                if ite.eval(vars, rng) {
127                    lhs.eval(vars, rng)
128                } else {
129                    rhs.eval(vars, rng)
130                }
131            }
132        }
133    }
134
135    pub(crate) fn map<W: Clone>(self, map: &dyn Fn(V) -> W) -> FloatExpr<W> {
136        match self {
137            FloatExpr::Const(float) => FloatExpr::Const(float),
138            FloatExpr::Var(var) => FloatExpr::Var(map(var)),
139            FloatExpr::Nat(natural_expr) => FloatExpr::Nat(natural_expr.map(map)),
140            FloatExpr::Int(integer_expr) => FloatExpr::Int(integer_expr.map(map)),
141            FloatExpr::Rand(bounds) => {
142                let (lower_bound, upper_bound) = *bounds;
143                FloatExpr::Rand(Box::new((lower_bound.map(map), upper_bound.map(map))))
144            }
145            FloatExpr::Opposite(float_expr) => FloatExpr::Opposite(Box::new(float_expr.map(map))),
146            FloatExpr::Sum(float_exprs) => {
147                FloatExpr::Sum(float_exprs.into_iter().map(|expr| expr.map(map)).collect())
148            }
149            FloatExpr::Product(float_exprs) => {
150                FloatExpr::Product(float_exprs.into_iter().map(|expr| expr.map(map)).collect())
151            }
152            FloatExpr::Div(args) => {
153                let (lhs, rhs) = *args;
154                FloatExpr::Div(Box::new((lhs.map(map), rhs.map(map))))
155            }
156            FloatExpr::Ite(args) => {
157                let (r#if, then, r#else) = *args;
158                FloatExpr::Ite(Box::new((r#if.map(map), then.map(map), r#else.map(map))))
159            }
160        }
161    }
162
163    pub(crate) fn context(&self, vars: &dyn Fn(V) -> Option<Type>) -> Result<(), TypeError> {
164        match self {
165            FloatExpr::Const(_) => Ok(()),
166            FloatExpr::Var(v) => matches!(vars(v.clone()), Some(Type::Float))
167                .then_some(())
168                .ok_or(TypeError::TypeMismatch),
169            FloatExpr::Nat(natural_expr) => natural_expr.context(vars),
170            FloatExpr::Int(integer_expr) => integer_expr.context(vars),
171            FloatExpr::Rand(exprs) | FloatExpr::Div(exprs) => {
172                exprs.0.context(vars).and_then(|()| exprs.1.context(vars))
173            }
174            FloatExpr::Opposite(integer_expr) => integer_expr.context(vars),
175            FloatExpr::Sum(integer_exprs) | FloatExpr::Product(integer_exprs) => {
176                integer_exprs.iter().try_for_each(|expr| expr.context(vars))
177            }
178            FloatExpr::Ite(exprs) => exprs
179                .0
180                .context(vars)
181                .and_then(|()| exprs.1.context(vars))
182                .and_then(|()| exprs.2.context(vars)),
183        }
184    }
185}
186
187impl<V> From<Float> for FloatExpr<V>
188where
189    V: Clone,
190{
191    fn from(value: Float) -> Self {
192        Self::Const(value)
193    }
194}
195
196impl<V> TryFrom<Expression<V>> for FloatExpr<V>
197where
198    V: Clone,
199{
200    type Error = TypeError;
201
202    fn try_from(value: Expression<V>) -> Result<Self, Self::Error> {
203        match value {
204            Expression::Boolean(_) => Err(TypeError::TypeMismatch),
205            Expression::Natural(natural_expr) => Ok(FloatExpr::Nat(natural_expr)),
206            Expression::Integer(integer_expr) => Ok(FloatExpr::Int(integer_expr)),
207            Expression::Float(float_expr) => Ok(float_expr),
208        }
209    }
210}
211
212impl<V> From<NaturalExpr<V>> for FloatExpr<V>
213where
214    V: Clone,
215{
216    fn from(value: NaturalExpr<V>) -> Self {
217        Self::Nat(value)
218    }
219}
220
221impl<V> From<IntegerExpr<V>> for FloatExpr<V>
222where
223    V: Clone,
224{
225    fn from(value: IntegerExpr<V>) -> Self {
226        if let IntegerExpr::Nat(nat_expr) = value {
227            Self::Nat(nat_expr)
228        } else {
229            Self::Int(value)
230        }
231    }
232}
233
234impl<V> Add for FloatExpr<V>
235where
236    V: Clone,
237{
238    type Output = Self;
239
240    fn add(mut self, mut rhs: Self) -> Self::Output {
241        if let FloatExpr::Sum(ref mut exprs) = self {
242            if let FloatExpr::Sum(rhs_exprs) = rhs {
243                exprs.extend(rhs_exprs);
244            } else {
245                exprs.push(rhs);
246            }
247            self
248        } else if let FloatExpr::Sum(ref mut rhs_exprs) = rhs {
249            rhs_exprs.push(self);
250            rhs
251        } else {
252            FloatExpr::Sum(vec![self, rhs])
253        }
254    }
255}
256
257impl<V> Mul for FloatExpr<V>
258where
259    V: Clone,
260{
261    type Output = Self;
262
263    fn mul(mut self, mut rhs: Self) -> Self::Output {
264        if let FloatExpr::Product(ref mut exprs) = self {
265            if let FloatExpr::Product(rhs_exprs) = rhs {
266                exprs.extend(rhs_exprs);
267            } else {
268                exprs.push(rhs);
269            }
270            self
271        } else if let FloatExpr::Product(ref mut rhs_exprs) = rhs {
272            rhs_exprs.push(self);
273            rhs
274        } else {
275            FloatExpr::Product(vec![self, rhs])
276        }
277    }
278}
279
280impl<V> Neg for FloatExpr<V>
281where
282    V: Clone,
283{
284    type Output = Self;
285
286    fn neg(self) -> Self::Output {
287        if let FloatExpr::Opposite(expr) = self {
288            *expr
289        } else {
290            FloatExpr::Opposite(Box::new(self))
291        }
292    }
293}
294
295impl<V> Div for FloatExpr<V>
296where
297    V: Clone,
298{
299    type Output = Self;
300
301    fn div(self, rhs: Self) -> Self::Output {
302        FloatExpr::Div(Box::new((self, rhs)))
303    }
304}