1use std::ops::{Add, Div, Mul, Neg};
2
3use rand::{Rng, RngExt};
4
5use crate::{
6 Expression, Type, TypeError, Val,
7 grammar::{BooleanExpr, FloatExpr, NaturalExpr},
8};
9
10pub type Integer = i64;
12
13#[derive(Debug, Clone)]
15pub enum IntegerExpr<V>
16where
17 V: Clone,
18{
19 Const(Integer),
24 Var(V),
26 Nat(NaturalExpr<V>),
28 Rand(Box<(IntegerExpr<V>, IntegerExpr<V>)>),
33 Opposite(Box<IntegerExpr<V>>),
38 Sum(Vec<IntegerExpr<V>>),
40 Product(Vec<IntegerExpr<V>>),
42 Div(Box<(IntegerExpr<V>, IntegerExpr<V>)>),
44 Rem(Box<(IntegerExpr<V>, IntegerExpr<V>)>),
46 Floor(Box<FloatExpr<V>>),
48 Ite(Box<(BooleanExpr<V>, IntegerExpr<V>, IntegerExpr<V>)>),
55}
56
57impl<V> IntegerExpr<V>
58where
59 V: Clone,
60{
61 pub fn is_constant(&self) -> bool {
63 match self {
64 IntegerExpr::Const(_) => true,
65 IntegerExpr::Var(_) | IntegerExpr::Rand(_) => false,
66 IntegerExpr::Nat(natural_expr) => natural_expr.is_constant(),
67 IntegerExpr::Opposite(integer_expr) => integer_expr.is_constant(),
68 IntegerExpr::Sum(integer_exprs) | IntegerExpr::Product(integer_exprs) => {
69 integer_exprs.iter().all(IntegerExpr::is_constant)
70 }
71 IntegerExpr::Div(args) | IntegerExpr::Rem(args) => {
72 let (lhs, rhs) = args.as_ref();
73 lhs.is_constant() && rhs.is_constant()
74 }
75 IntegerExpr::Floor(float_expr) => float_expr.is_constant(),
76 IntegerExpr::Ite(args) => {
77 let (ite, lhs, rhs) = args.as_ref();
78 ite.is_constant() && lhs.is_constant() && rhs.is_constant()
79 }
80 }
81 }
82
83 pub fn eval<R: Rng>(&self, vars: &dyn Fn(&V) -> Val, rng: &mut R) -> Integer {
92 match self {
93 IntegerExpr::Const(int) => *int,
94 IntegerExpr::Var(var) => {
95 if let Val::Integer(int) = vars(var) {
96 int
97 } else {
98 panic!("type mismatch: expected natural variable")
99 }
100 }
101 IntegerExpr::Nat(natural_expr) => natural_expr.eval(vars, rng) as Integer,
102 IntegerExpr::Rand(bounds) => {
103 let (lower_bound_expr, upper_bound_expr) = bounds.as_ref();
104 let lower_bound = lower_bound_expr.eval(vars, rng);
105 let upper_bound = upper_bound_expr.eval(vars, rng);
106 rng.random_range(lower_bound..upper_bound)
107 }
108 IntegerExpr::Opposite(integer_expr) => integer_expr.eval(vars, rng).strict_neg(),
109 IntegerExpr::Sum(integer_exprs) => integer_exprs
110 .iter()
111 .fold(0, |acc, expr| acc.strict_add(expr.eval(vars, rng))),
112 IntegerExpr::Product(integer_exprs) => integer_exprs
113 .iter()
114 .fold(1, |acc, expr| acc.strict_mul(expr.eval(vars, rng))),
115 IntegerExpr::Div(args) => {
116 let (lhs_expr, rhs_expr) = args.as_ref();
117 let lhs = lhs_expr.eval(vars, rng);
118 let rhs = rhs_expr.eval(vars, rng);
119 lhs.strict_div(rhs)
120 }
121 IntegerExpr::Rem(args) => {
122 let (lhs_expr, rhs_expr) = args.as_ref();
123 let lhs = lhs_expr.eval(vars, rng);
124 let rhs = rhs_expr.eval(vars, rng);
125 lhs.strict_rem_euclid(rhs)
126 }
127 IntegerExpr::Floor(float_expr) => float_expr.eval(vars, rng).floor() as Integer,
129 IntegerExpr::Ite(args) => {
130 let (ite, lhs, rhs) = args.as_ref();
131 if ite.eval(vars, rng) {
132 lhs.eval(vars, rng)
133 } else {
134 rhs.eval(vars, rng)
135 }
136 }
137 }
138 }
139
140 pub(crate) fn map<W: Clone>(self, map: &dyn Fn(V) -> W) -> IntegerExpr<W> {
141 match self {
142 IntegerExpr::Const(i) => IntegerExpr::Const(i),
143 IntegerExpr::Var(var) => IntegerExpr::Var(map(var)),
144 IntegerExpr::Nat(nat_expr) => IntegerExpr::Nat(nat_expr.map(map)),
145 IntegerExpr::Rand(bounds) => {
146 let (lower_bound, upper_bound) = *bounds;
147 IntegerExpr::Rand(Box::new((lower_bound.map(map), upper_bound.map(map))))
148 }
149 IntegerExpr::Opposite(integer_expr) => {
150 IntegerExpr::Opposite(Box::new(integer_expr.map(map)))
151 }
152 IntegerExpr::Sum(integer_exprs) => IntegerExpr::Sum(
153 integer_exprs
154 .into_iter()
155 .map(|expr| expr.map(map))
156 .collect(),
157 ),
158 IntegerExpr::Product(integer_exprs) => IntegerExpr::Product(
159 integer_exprs
160 .into_iter()
161 .map(|expr| expr.map(map))
162 .collect(),
163 ),
164 IntegerExpr::Div(args) => {
165 let (lhs, rhs) = *args;
166 IntegerExpr::Div(Box::new((lhs.map(map), rhs.map(map))))
167 }
168 IntegerExpr::Rem(args) => {
169 let (lhs, rhs) = *args;
170 IntegerExpr::Rem(Box::new((lhs.map(map), rhs.map(map))))
171 }
172 IntegerExpr::Floor(float_expr) => IntegerExpr::Floor(Box::new(float_expr.map(map))),
173 IntegerExpr::Ite(args) => {
174 let (r#if, then, r#else) = *args;
175 IntegerExpr::Ite(Box::new((r#if.map(map), then.map(map), r#else.map(map))))
176 }
177 }
178 }
179
180 pub(crate) fn context(&self, vars: &dyn Fn(V) -> Option<Type>) -> Result<(), TypeError> {
181 match self {
182 IntegerExpr::Const(_) => Ok(()),
183 IntegerExpr::Var(v) => matches!(vars(v.clone()), Some(Type::Integer))
184 .then_some(())
185 .ok_or(TypeError::TypeMismatch),
186 IntegerExpr::Nat(natural_expr) => natural_expr.context(vars),
187 IntegerExpr::Rand(exprs) | IntegerExpr::Div(exprs) | IntegerExpr::Rem(exprs) => {
188 exprs.0.context(vars).and_then(|()| exprs.1.context(vars))
189 }
190 IntegerExpr::Opposite(integer_expr) => integer_expr.context(vars),
191 IntegerExpr::Sum(integer_exprs) | IntegerExpr::Product(integer_exprs) => {
192 integer_exprs.iter().try_for_each(|expr| expr.context(vars))
193 }
194 IntegerExpr::Floor(float_expr) => float_expr.context(vars),
195 IntegerExpr::Ite(exprs) => exprs
196 .0
197 .context(vars)
198 .and_then(|()| exprs.1.context(vars))
199 .and_then(|()| exprs.2.context(vars)),
200 }
201 }
202}
203
204impl<V> From<Integer> for IntegerExpr<V>
205where
206 V: Clone,
207{
208 fn from(value: Integer) -> Self {
209 Self::Const(value)
210 }
211}
212
213impl<V> TryFrom<Expression<V>> for IntegerExpr<V>
214where
215 V: Clone,
216{
217 type Error = TypeError;
218
219 fn try_from(value: Expression<V>) -> Result<Self, Self::Error> {
220 match value {
221 Expression::Boolean(_) | Expression::Float(_) => Err(TypeError::TypeMismatch),
222 Expression::Natural(natural_expr) => Ok(IntegerExpr::Nat(natural_expr)),
223 Expression::Integer(integer_expr) => Ok(integer_expr),
224 }
225 }
226}
227
228impl<V> From<NaturalExpr<V>> for IntegerExpr<V>
229where
230 V: Clone,
231{
232 fn from(value: NaturalExpr<V>) -> Self {
233 Self::Nat(value)
234 }
235}
236
237impl<V> Add for IntegerExpr<V>
238where
239 V: Clone,
240{
241 type Output = Self;
242
243 fn add(mut self, mut rhs: Self) -> Self::Output {
244 if let IntegerExpr::Sum(ref mut exprs) = self {
245 if let IntegerExpr::Sum(rhs_exprs) = rhs {
246 exprs.extend(rhs_exprs);
247 } else {
248 exprs.push(rhs);
249 }
250 self
251 } else if let IntegerExpr::Sum(ref mut rhs_exprs) = rhs {
252 rhs_exprs.push(self);
253 rhs
254 } else {
255 IntegerExpr::Sum(vec![self, rhs])
256 }
257 }
258}
259
260impl<V> Mul for IntegerExpr<V>
261where
262 V: Clone,
263{
264 type Output = Self;
265
266 fn mul(mut self, mut rhs: Self) -> Self::Output {
267 if let IntegerExpr::Product(ref mut exprs) = self {
268 if let IntegerExpr::Product(rhs_exprs) = rhs {
269 exprs.extend(rhs_exprs);
270 } else {
271 exprs.push(rhs);
272 }
273 self
274 } else if let IntegerExpr::Product(ref mut rhs_exprs) = rhs {
275 rhs_exprs.push(self);
276 rhs
277 } else {
278 IntegerExpr::Product(vec![self, rhs])
279 }
280 }
281}
282
283impl<V> Div for IntegerExpr<V>
284where
285 V: Clone,
286{
287 type Output = Self;
288
289 fn div(self, rhs: Self) -> Self::Output {
290 IntegerExpr::Div(Box::new((self, rhs)))
291 }
292}
293
294impl<V> Neg for IntegerExpr<V>
295where
296 V: Clone,
297{
298 type Output = Self;
299
300 fn neg(self) -> Self::Output {
301 if let IntegerExpr::Opposite(expr) = self {
302 *expr
303 } else {
304 IntegerExpr::Opposite(Box::new(self))
305 }
306 }
307}