1use std::ops::{Add, Div, Mul, Rem};
2
3use rand::{Rng, RngExt};
4
5use crate::{
6 Expression, Type, TypeError, Val,
7 grammar::{BooleanExpr, IntegerExpr},
8};
9
10pub type Natural = u64;
12
13#[derive(Debug, Clone)]
15pub enum NaturalExpr<V>
16where
17 V: Clone,
18{
19 Const(Natural),
24 Var(V),
26 Rand(Box<(NaturalExpr<V>, NaturalExpr<V>)>),
31 Sum(Vec<NaturalExpr<V>>),
36 Product(Vec<NaturalExpr<V>>),
38 Rem(Box<(NaturalExpr<V>, NaturalExpr<V>)>),
40 Div(Box<(NaturalExpr<V>, NaturalExpr<V>)>),
42 Abs(Box<IntegerExpr<V>>),
44 Ite(Box<(BooleanExpr<V>, NaturalExpr<V>, NaturalExpr<V>)>),
51}
52
53impl<V> NaturalExpr<V>
54where
55 V: Clone,
56{
57 pub fn is_constant(&self) -> bool {
59 match self {
60 NaturalExpr::Const(_) => true,
61 NaturalExpr::Var(_) | NaturalExpr::Rand(_) => false,
62 NaturalExpr::Sum(natural_exprs) | NaturalExpr::Product(natural_exprs) => {
63 natural_exprs.iter().all(NaturalExpr::is_constant)
64 }
65 NaturalExpr::Rem(args) | NaturalExpr::Div(args) => {
66 let (lhs, rhs) = args.as_ref();
67 lhs.is_constant() && rhs.is_constant()
68 }
69 NaturalExpr::Abs(integer_expr) => integer_expr.is_constant(),
70 NaturalExpr::Ite(args) => {
71 let (ite, lhs, rhs) = args.as_ref();
72 ite.is_constant() && lhs.is_constant() && rhs.is_constant()
73 }
74 }
75 }
76
77 pub fn eval<R: Rng>(&self, vars: &dyn Fn(&V) -> Val, rng: &mut R) -> Natural {
86 match self {
87 NaturalExpr::Const(nat) => *nat,
88 NaturalExpr::Var(var) => {
89 if let Val::Natural(nat) = vars(var) {
90 nat
91 } else {
92 panic!("type mismatch: expected natural variable")
93 }
94 }
95 NaturalExpr::Rand(bounds) => {
96 let (lower_bound_expr, upper_bound_expr) = bounds.as_ref();
97 let lower_bound = lower_bound_expr.eval(vars, rng);
98 let upper_bound = upper_bound_expr.eval(vars, rng);
99 rng.random_range(lower_bound..upper_bound)
100 }
101 NaturalExpr::Sum(natural_exprs) => natural_exprs
102 .iter()
103 .fold(0, |acc, expr| acc.strict_add(expr.eval(vars, rng))),
104 NaturalExpr::Product(natural_exprs) => natural_exprs
105 .iter()
106 .fold(1, |acc, expr| acc.strict_mul(expr.eval(vars, rng))),
107 NaturalExpr::Rem(args) => {
108 let (lhs_expr, rhs_expr) = args.as_ref();
109 let lhs = lhs_expr.eval(vars, rng);
110 let rhs = rhs_expr.eval(vars, rng);
111 lhs.strict_rem_euclid(rhs)
112 }
113 NaturalExpr::Div(_) => todo!(),
114 NaturalExpr::Abs(integer_expr) => integer_expr.eval(vars, rng).unsigned_abs(),
115 NaturalExpr::Ite(args) => {
116 let (ite, lhs, rhs) = args.as_ref();
117 if ite.eval(vars, rng) {
118 lhs.eval(vars, rng)
119 } else {
120 rhs.eval(vars, rng)
121 }
122 }
123 }
124 }
125
126 pub(crate) fn map<W: Clone>(self, map: &dyn Fn(V) -> W) -> NaturalExpr<W> {
127 match self {
128 NaturalExpr::Const(n) => NaturalExpr::Const(n),
129 NaturalExpr::Var(var) => NaturalExpr::Var(map(var)),
130 NaturalExpr::Rand(bounds) => {
131 let (lower_bound, upper_bound) = *bounds;
132 NaturalExpr::Rand(Box::new((lower_bound.map(map), upper_bound.map(map))))
133 }
134 NaturalExpr::Sum(natural_exprs) => NaturalExpr::Sum(
135 natural_exprs
136 .into_iter()
137 .map(|expr| expr.map(map))
138 .collect(),
139 ),
140 NaturalExpr::Product(natural_exprs) => NaturalExpr::Product(
141 natural_exprs
142 .into_iter()
143 .map(|expr| expr.map(map))
144 .collect(),
145 ),
146 NaturalExpr::Rem(args) => {
147 let (lhs, rhs) = *args;
148 NaturalExpr::Rem(Box::new((lhs.map(map), rhs.map(map))))
149 }
150 NaturalExpr::Div(args) => {
151 let (lhs, rhs) = *args;
152 NaturalExpr::Div(Box::new((lhs.map(map), rhs.map(map))))
153 }
154 NaturalExpr::Abs(integer_expr) => NaturalExpr::Abs(Box::new(integer_expr.map(map))),
155 NaturalExpr::Ite(args) => {
156 let (r#if, then, r#else) = *args;
157 NaturalExpr::Ite(Box::new((r#if.map(map), then.map(map), r#else.map(map))))
158 }
159 }
160 }
161
162 pub(crate) fn context(&self, vars: &dyn Fn(V) -> Option<Type>) -> Result<(), TypeError> {
163 match self {
164 NaturalExpr::Const(_) => Ok(()),
165 NaturalExpr::Var(v) => matches!(vars(v.clone()), Some(Type::Natural))
166 .then_some(())
167 .ok_or(TypeError::TypeMismatch),
168 NaturalExpr::Rand(exprs) | NaturalExpr::Div(exprs) | NaturalExpr::Rem(exprs) => {
169 exprs.0.context(vars).and_then(|()| exprs.1.context(vars))
170 }
171 NaturalExpr::Sum(integer_exprs) | NaturalExpr::Product(integer_exprs) => {
172 integer_exprs.iter().try_for_each(|expr| expr.context(vars))
173 }
174 NaturalExpr::Ite(exprs) => exprs
175 .0
176 .context(vars)
177 .and_then(|()| exprs.1.context(vars))
178 .and_then(|()| exprs.2.context(vars)),
179 NaturalExpr::Abs(integer_expr) => integer_expr.context(vars),
180 }
181 }
182}
183
184impl<V: Clone> From<Natural> for NaturalExpr<V> {
185 fn from(value: Natural) -> Self {
186 NaturalExpr::Const(value)
187 }
188}
189
190impl<V> TryFrom<Expression<V>> for NaturalExpr<V>
191where
192 V: Clone,
193{
194 type Error = TypeError;
195
196 fn try_from(value: Expression<V>) -> Result<Self, Self::Error> {
197 if let Expression::Natural(nat_expr) = value {
198 Ok(nat_expr)
199 } else {
200 Err(TypeError::TypeMismatch)
201 }
202 }
203}
204
205impl<V> Add for NaturalExpr<V>
206where
207 V: Clone,
208{
209 type Output = Self;
210
211 fn add(mut self, mut rhs: Self) -> Self::Output {
212 if let NaturalExpr::Sum(ref mut exprs) = self {
213 if let NaturalExpr::Sum(rhs_exprs) = rhs {
214 exprs.extend(rhs_exprs);
215 } else {
216 exprs.push(rhs);
217 }
218 self
219 } else if let NaturalExpr::Sum(ref mut rhs_exprs) = rhs {
220 rhs_exprs.push(self);
221 rhs
222 } else {
223 NaturalExpr::Sum(vec![self, rhs])
224 }
225 }
226}
227
228impl<V> Mul for NaturalExpr<V>
229where
230 V: Clone,
231{
232 type Output = Self;
233
234 fn mul(mut self, mut rhs: Self) -> Self::Output {
235 if let NaturalExpr::Product(ref mut exprs) = self {
236 if let NaturalExpr::Product(rhs_exprs) = rhs {
237 exprs.extend(rhs_exprs);
238 } else {
239 exprs.push(rhs);
240 }
241 self
242 } else if let NaturalExpr::Product(ref mut rhs_exprs) = rhs {
243 rhs_exprs.push(self);
244 rhs
245 } else {
246 NaturalExpr::Product(vec![self, rhs])
247 }
248 }
249}
250
251impl<V> Div for NaturalExpr<V>
252where
253 V: Clone,
254{
255 type Output = Self;
256
257 fn div(self, rhs: Self) -> Self::Output {
258 NaturalExpr::Div(Box::new((self, rhs)))
259 }
260}
261
262impl<V> Rem for NaturalExpr<V>
263where
264 V: Clone,
265{
266 type Output = Self;
267
268 fn rem(self, rhs: Self) -> Self::Output {
269 NaturalExpr::Rem(Box::new((self, rhs)))
270 }
271}