1use std::ops::{Add, Div, Mul, Neg};
2
3use rand::{Rng, RngExt};
4
5use crate::{
6 Expression, Type, TypeError, Val,
7 grammar::{BooleanExpr, IntegerExpr, NaturalExpr},
8};
9
10pub type Float = f64;
12
13#[derive(Debug, Clone)]
15pub enum FloatExpr<V>
16where
17 V: Clone,
18{
19 Const(Float),
24 Var(V),
26 Nat(NaturalExpr<V>),
28 Int(IntegerExpr<V>),
30 Rand(Box<(FloatExpr<V>, FloatExpr<V>)>),
35 Opposite(Box<FloatExpr<V>>),
40 Sum(Vec<FloatExpr<V>>),
42 Product(Vec<FloatExpr<V>>),
44 Div(Box<(FloatExpr<V>, FloatExpr<V>)>),
46 Ite(Box<(BooleanExpr<V>, FloatExpr<V>, FloatExpr<V>)>),
53}
54
55impl<V> FloatExpr<V>
56where
57 V: Clone,
58{
59 pub fn is_constant(&self) -> bool {
61 match self {
62 FloatExpr::Const(_) => true,
63 FloatExpr::Var(_) | FloatExpr::Rand(_) => false,
64 FloatExpr::Nat(natural_expr) => natural_expr.is_constant(),
65 FloatExpr::Int(integer_expr) => integer_expr.is_constant(),
66
67 FloatExpr::Opposite(float_expr) => float_expr.is_constant(),
68 FloatExpr::Sum(float_exprs) | FloatExpr::Product(float_exprs) => {
69 float_exprs.iter().all(FloatExpr::is_constant)
70 }
71 FloatExpr::Div(args) => {
72 let (lhs, rhs) = args.as_ref();
73 lhs.is_constant() && rhs.is_constant()
74 }
75 FloatExpr::Ite(args) => {
76 let (ite, lhs, rhs) = args.as_ref();
77 ite.is_constant() && lhs.is_constant() && rhs.is_constant()
78 }
79 }
80 }
81
82 pub fn eval<R: Rng>(&self, vars: &dyn Fn(&V) -> Val, rng: &mut R) -> Float {
91 match self {
92 FloatExpr::Const(float) => *float,
93 FloatExpr::Var(var) => {
94 if let Val::Float(float) = vars(var) {
95 float
96 } else {
97 panic!("type mismatch: expected float variable")
98 }
99 }
100 FloatExpr::Nat(natural_expr) => natural_expr.eval(vars, rng) as f64,
102 FloatExpr::Int(integer_expr) => integer_expr.eval(vars, rng) as f64,
104 FloatExpr::Rand(bounds) => {
105 let (lower_bound_expr, upper_bound_expr) = bounds.as_ref();
106 let lower_bound = lower_bound_expr.eval(vars, rng);
107 let upper_bound = upper_bound_expr.eval(vars, rng);
108 rng.random_range(lower_bound..upper_bound)
109 }
110 FloatExpr::Opposite(float_expr) => -float_expr.eval(vars, rng),
111 FloatExpr::Sum(float_exprs) => {
112 float_exprs.iter().map(|expr| expr.eval(vars, rng)).sum()
113 }
114 FloatExpr::Product(float_exprs) => float_exprs
115 .iter()
116 .map(|expr| expr.eval(vars, rng))
117 .product(),
118 FloatExpr::Div(args) => {
119 let (lhs_expr, rhs_expr) = args.as_ref();
120 let lhs = lhs_expr.eval(vars, rng);
121 let rhs = rhs_expr.eval(vars, rng);
122 lhs / rhs
123 }
124 FloatExpr::Ite(args) => {
125 let (ite, lhs, rhs) = args.as_ref();
126 if ite.eval(vars, rng) {
127 lhs.eval(vars, rng)
128 } else {
129 rhs.eval(vars, rng)
130 }
131 }
132 }
133 }
134
135 pub(crate) fn map<W: Clone>(self, map: &dyn Fn(V) -> W) -> FloatExpr<W> {
136 match self {
137 FloatExpr::Const(float) => FloatExpr::Const(float),
138 FloatExpr::Var(var) => FloatExpr::Var(map(var)),
139 FloatExpr::Nat(natural_expr) => FloatExpr::Nat(natural_expr.map(map)),
140 FloatExpr::Int(integer_expr) => FloatExpr::Int(integer_expr.map(map)),
141 FloatExpr::Rand(bounds) => {
142 let (lower_bound, upper_bound) = *bounds;
143 FloatExpr::Rand(Box::new((lower_bound.map(map), upper_bound.map(map))))
144 }
145 FloatExpr::Opposite(float_expr) => FloatExpr::Opposite(Box::new(float_expr.map(map))),
146 FloatExpr::Sum(float_exprs) => {
147 FloatExpr::Sum(float_exprs.into_iter().map(|expr| expr.map(map)).collect())
148 }
149 FloatExpr::Product(float_exprs) => {
150 FloatExpr::Product(float_exprs.into_iter().map(|expr| expr.map(map)).collect())
151 }
152 FloatExpr::Div(args) => {
153 let (lhs, rhs) = *args;
154 FloatExpr::Div(Box::new((lhs.map(map), rhs.map(map))))
155 }
156 FloatExpr::Ite(args) => {
157 let (r#if, then, r#else) = *args;
158 FloatExpr::Ite(Box::new((r#if.map(map), then.map(map), r#else.map(map))))
159 }
160 }
161 }
162
163 pub(crate) fn context(&self, vars: &dyn Fn(V) -> Option<Type>) -> Result<(), TypeError> {
164 match self {
165 FloatExpr::Const(_) => Ok(()),
166 FloatExpr::Var(v) => matches!(vars(v.clone()), Some(Type::Float))
167 .then_some(())
168 .ok_or(TypeError::TypeMismatch),
169 FloatExpr::Nat(natural_expr) => natural_expr.context(vars),
170 FloatExpr::Int(integer_expr) => integer_expr.context(vars),
171 FloatExpr::Rand(exprs) | FloatExpr::Div(exprs) => {
172 exprs.0.context(vars).and_then(|()| exprs.1.context(vars))
173 }
174 FloatExpr::Opposite(integer_expr) => integer_expr.context(vars),
175 FloatExpr::Sum(integer_exprs) | FloatExpr::Product(integer_exprs) => {
176 integer_exprs.iter().try_for_each(|expr| expr.context(vars))
177 }
178 FloatExpr::Ite(exprs) => exprs
179 .0
180 .context(vars)
181 .and_then(|()| exprs.1.context(vars))
182 .and_then(|()| exprs.2.context(vars)),
183 }
184 }
185}
186
187impl<V> From<Float> for FloatExpr<V>
188where
189 V: Clone,
190{
191 fn from(value: Float) -> Self {
192 Self::Const(value)
193 }
194}
195
196impl<V> TryFrom<Expression<V>> for FloatExpr<V>
197where
198 V: Clone,
199{
200 type Error = TypeError;
201
202 fn try_from(value: Expression<V>) -> Result<Self, Self::Error> {
203 match value {
204 Expression::Boolean(_) => Err(TypeError::TypeMismatch),
205 Expression::Natural(natural_expr) => Ok(FloatExpr::Nat(natural_expr)),
206 Expression::Integer(integer_expr) => Ok(FloatExpr::Int(integer_expr)),
207 Expression::Float(float_expr) => Ok(float_expr),
208 }
209 }
210}
211
212impl<V> From<NaturalExpr<V>> for FloatExpr<V>
213where
214 V: Clone,
215{
216 fn from(value: NaturalExpr<V>) -> Self {
217 Self::Nat(value)
218 }
219}
220
221impl<V> From<IntegerExpr<V>> for FloatExpr<V>
222where
223 V: Clone,
224{
225 fn from(value: IntegerExpr<V>) -> Self {
226 if let IntegerExpr::Nat(nat_expr) = value {
227 Self::Nat(nat_expr)
228 } else {
229 Self::Int(value)
230 }
231 }
232}
233
234impl<V> Add for FloatExpr<V>
235where
236 V: Clone,
237{
238 type Output = Self;
239
240 fn add(mut self, mut rhs: Self) -> Self::Output {
241 if let FloatExpr::Sum(ref mut exprs) = self {
242 if let FloatExpr::Sum(rhs_exprs) = rhs {
243 exprs.extend(rhs_exprs);
244 } else {
245 exprs.push(rhs);
246 }
247 self
248 } else if let FloatExpr::Sum(ref mut rhs_exprs) = rhs {
249 rhs_exprs.push(self);
250 rhs
251 } else {
252 FloatExpr::Sum(vec![self, rhs])
253 }
254 }
255}
256
257impl<V> Mul for FloatExpr<V>
258where
259 V: Clone,
260{
261 type Output = Self;
262
263 fn mul(mut self, mut rhs: Self) -> Self::Output {
264 if let FloatExpr::Product(ref mut exprs) = self {
265 if let FloatExpr::Product(rhs_exprs) = rhs {
266 exprs.extend(rhs_exprs);
267 } else {
268 exprs.push(rhs);
269 }
270 self
271 } else if let FloatExpr::Product(ref mut rhs_exprs) = rhs {
272 rhs_exprs.push(self);
273 rhs
274 } else {
275 FloatExpr::Product(vec![self, rhs])
276 }
277 }
278}
279
280impl<V> Neg for FloatExpr<V>
281where
282 V: Clone,
283{
284 type Output = Self;
285
286 fn neg(self) -> Self::Output {
287 if let FloatExpr::Opposite(expr) = self {
288 *expr
289 } else {
290 FloatExpr::Opposite(Box::new(self))
291 }
292 }
293}
294
295impl<V> Div for FloatExpr<V>
296where
297 V: Clone,
298{
299 type Output = Self;
300
301 fn div(self, rhs: Self) -> Self::Output {
302 FloatExpr::Div(Box::new((self, rhs)))
303 }
304}