1use crate::traits::{One, Zero};
4use lazy_static::__Deref;
5use serde::{Deserialize, Serialize};
6use std::{collections::BTreeSet as Set, fmt};
7use syn::{BinOp, ExprBinary};
8
9use crate::{scope::Scope, traits::*};
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd, Ord, Eq)]
15pub enum Term<T> {
16 Value(T),
17 Scalar(u128),
18 Var(VarValue),
19
20 Add(Box<Self>, Box<Self>),
21 Mul(Box<Self>, Box<Self>),
22}
23
24pub type SimpleTerm = Term<u128>;
25pub type ChromaticTerm = Term<Weight>;
26
27#[derive(Debug, Clone, Serialize, Deserialize, PartialOrd, Ord, Eq)]
28pub struct VarValue(pub String);
30
31impl From<VarValue> for String {
32 fn from(v: VarValue) -> String {
33 v.0
34 }
35}
36
37impl From<String> for VarValue {
38 fn from(s: String) -> Self {
39 Self(s)
40 }
41}
42
43impl From<&str> for VarValue {
44 fn from(s: &str) -> Self {
45 Self(s.into())
46 }
47}
48
49impl std::ops::Deref for VarValue {
50 type Target = String;
51
52 fn deref(&self) -> &Self::Target {
53 &self.0
54 }
55}
56
57impl PartialEq for VarValue {
58 fn eq(&self, other: &Self) -> bool {
59 self.0.replace('_', "") == other.0.replace('_', "")
60 }
61}
62
63#[macro_export]
65macro_rules! scalar {
66 ($a:expr) => {
67 $crate::term::Term::Scalar($a as u128)
68 };
69}
70
71#[macro_export]
73macro_rules! val {
74 ($a:expr) => {
75 $crate::term::SimpleTerm::Value($a as u128)
76 };
77}
78
79#[macro_export]
80macro_rules! cval {
81 ($a:expr) => {
82 $crate::term::Term::Value($a)
83 };
84}
85
86#[macro_export]
88macro_rules! var {
89 ($a:expr) => {
90 $crate::term::SimpleTerm::Var($a.into())
91 };
92}
93
94#[macro_export]
96macro_rules! cvar {
97 ($a:expr) => {
98 $crate::term::Term::Var($a.into())
99 };
100}
101
102#[macro_export]
104macro_rules! add {
105 ($a:expr, $b:expr) => {
106 $crate::term::SimpleTerm::Add($a.into(), $b.into())
107 };
108}
109
110#[macro_export]
112macro_rules! cadd {
113 ($a:expr, $b:expr) => {
114 $crate::term::ChromaticTerm::Add($a.into(), $b.into())
115 };
116}
117
118#[macro_export]
120macro_rules! mul {
121 ($a:expr, $b:expr) => {
122 $crate::term::SimpleTerm::Mul($a.into(), $b.into())
123 };
124}
125
126#[macro_export]
128macro_rules! cmul {
129 ($a:expr, $b:expr) => {
130 $crate::term::ChromaticTerm::Mul($a.into(), $b.into())
131 };
132}
133
134impl SimpleTerm {
135 pub fn eval(&self, ctx: &crate::scope::SimpleScope) -> Result<u128, String> {
137 match self {
138 Self::Value(x) => Ok(*x),
139 Self::Scalar(x) => Ok(*x),
140 Self::Add(x, y) => Ok(x.eval(ctx)? + y.eval(ctx)?),
141 Self::Mul(x, y) => Ok(x.eval(ctx)? * y.eval(ctx)?),
142 Self::Var(x) =>
143 if let Some(var) = ctx.get(x) {
144 var.eval(ctx)
145 } else {
146 Err(format!("Variable '{}' not found", x.deref()))
147 },
148 }
149 }
150
151 pub fn into_chromatic(self, unit: crate::Dimension) -> ChromaticTerm {
152 match self {
153 Self::Value(x) | Self::Scalar(x) =>
154 ChromaticTerm::Value(Self::scalar_into_term(x, unit)),
155 Self::Add(x, y) => ChromaticTerm::Add(
156 Box::new(x.into_chromatic(unit)),
157 Box::new(y.into_chromatic(unit)),
158 ),
159 Self::Mul(x, y) => ChromaticTerm::Mul(
160 Box::new(x.into_chromatic(unit)),
161 Box::new(y.into_chromatic(unit)),
162 ),
163 Self::Var(x) => ChromaticTerm::Var(x),
164 }
165 }
166
167 fn scalar_into_term(s: u128, unit: crate::Dimension) -> Weight {
168 match unit {
169 crate::Dimension::Time => Weight { time: s, proof: 0 },
170 crate::Dimension::Proof => Weight { proof: s, time: 0 },
171 }
172 }
173}
174
175impl<T> Term<T>
176where
177 T: Clone + core::fmt::Display + One + Zero + PartialEq + Eq + ValueFormatter,
178{
179 pub fn is_const_zero(&self) -> bool {
180 match self {
181 Self::Value(x) => x == &T::zero(),
182 Self::Scalar(x) => *x == 0,
183 _ => false,
184 }
185 }
186
187 pub fn is_const_one(&self) -> bool {
188 match self {
189 Self::Value(x) => x == &T::one(),
190 Self::Scalar(x) => *x == 1,
191 _ => false,
192 }
193 }
194
195 pub fn free_vars(&self, scope: &Scope<Term<T>>) -> Set<String> {
200 match self {
201 Self::Var(var) if scope.get(var).is_some() => Set::default(),
202 Self::Var(var) => Set::from([var.clone().into()]),
203 Self::Scalar(_) => Set::default(),
204 Self::Value(_) => Set::default(),
205 Self::Mul(l, r) | Self::Add(l, r) =>
206 l.free_vars(scope).union(&r.free_vars(scope)).cloned().collect(),
207 }
208 }
209
210 pub fn bound_vars(&self, scope: &Scope<Term<T>>) -> Set<String> {
215 match self {
216 Self::Var(var) if scope.get(var).is_some() => Set::from([var.clone().into()]),
217 Self::Var(_var) => Set::default(),
218 Self::Scalar(_) => Set::default(),
219 Self::Value(_) => Set::default(),
220 Self::Mul(l, r) | Self::Add(l, r) =>
221 l.bound_vars(scope).union(&r.bound_vars(scope)).cloned().collect(),
222 }
223 }
224
225 pub fn fmt_equation(&self, scope: &Scope<Term<T>>) -> String {
226 let bounds = self.bound_vars(scope);
227 let frees = self.free_vars(scope);
228
229 let mut equation = Vec::<String>::new();
230 for var in bounds.iter() {
231 let v = scope.get(var).unwrap();
232 equation.push(format!("{}={}", var, v));
233 }
234 for var in frees.iter() {
235 equation.push(var.clone());
236 }
237 equation.join(", ")
238 }
239
240 pub fn into_substituted(self, var: &str, term: &Term<T>) -> Self {
241 let mut s = self;
242 s.substitute(var, term);
243 s
244 }
245
246 pub fn substitute(&mut self, var: &str, term: &Term<T>) {
247 match self {
248 Self::Var(v) if v.0 == var => *self = term.clone(),
249 Self::Var(_) => {},
250 Self::Scalar(_) => {},
251 Self::Value(_) => {},
252 Self::Mul(l, r) | Self::Add(l, r) => {
253 l.substitute(var, term);
254 r.substitute(var, term);
255 },
256 }
257 }
258
259 fn fmt_with_bracket(&self, has_bracket: bool) -> String {
260 self.maybe_fmt_with_bracket(has_bracket).unwrap_or("0".to_string())
261 }
262
263 fn maybe_fmt_with_bracket(&self, has_bracket: bool) -> Option<String> {
264 match self {
265 Self::Mul(l, r) => {
266 if l.is_const_one() {
268 r.maybe_fmt_with_bracket(has_bracket)
269 } else if r.is_const_one() {
270 l.maybe_fmt_with_bracket(has_bracket)
271 } else if r.is_const_zero() || l.is_const_zero() {
272 None
273 } else {
274 match (l.maybe_fmt_with_bracket(false), r.maybe_fmt_with_bracket(false)) {
275 (Some(l), Some(r)) => Some(format!("{} * {}", l, r)),
276 (Some(l), None) => Some(l),
277 (None, Some(r)) => Some(r),
278 (None, None) => None,
279 }
280 }
281 },
282 Self::Add(l, r) => {
283 if l.is_const_zero() && r.is_const_zero() {
285 None
286 } else if l.is_const_zero() {
287 r.maybe_fmt_with_bracket(has_bracket)
288 } else if r.is_const_zero() {
289 l.maybe_fmt_with_bracket(has_bracket)
290 } else if has_bracket {
291 match (l.maybe_fmt_with_bracket(true), r.maybe_fmt_with_bracket(true)) {
292 (Some(l), Some(r)) => Some(format!("{} + {}", l, r)),
293 (Some(l), None) => Some(l),
294 (None, Some(r)) => Some(r),
295 (None, None) => None,
296 }
297 } else {
298 match (l.maybe_fmt_with_bracket(true), r.maybe_fmt_with_bracket(true)) {
299 (Some(l), Some(r)) => Some(format!("({} + {})", l, r)),
300 (Some(l), None) => Some(l),
301 (None, Some(r)) => Some(r),
302 (None, None) => None,
303 }
304 }
305 },
306 Self::Value(val) => Some(val.format_scalar()),
307 Self::Scalar(val) => Some(crate::Dimension::fmt_scalar(*val)),
308 Self::Var(var) => Some(var.clone().into()),
309 }
310 }
311
312 pub fn visit<F, R>(&self, f: &mut F) -> Result<Vec<R>, String>
313 where
314 F: FnMut(&Self) -> Result<R, String>,
315 {
316 let mut res = Vec::<R>::new();
317 res.push(f(self)?);
318
319 match self {
320 v @ Self::Value(_) => Ok(vec![f(v)?]),
321 v @ Self::Scalar(_) => Ok(vec![f(v)?]),
322 v @ Self::Var(_) => Ok(vec![f(v)?]),
323 Self::Add(l, r) | Self::Mul(l, r) => {
324 res.append(&mut l.visit(f)?);
325 res.append(&mut r.visit(f)?);
326 Ok(res)
327 },
328 }
329 }
330
331 pub fn find_largest_factor(&self, var: &str) -> Option<u128> {
333 self.visit::<_, Option<u128>>(&mut |t| {
334 if let Term::<T>::Mul(l, r) = t {
335 if r.as_var() == Some(var) && l.as_scalar().is_some() {
336 return Ok(Some(l.as_scalar().unwrap()))
337 }
338 if l.as_var() == Some(var) && r.as_scalar().is_some() {
339 return Ok(Some(r.as_scalar().unwrap()))
340 }
341 }
342 Ok(None)
343 })
344 .unwrap()
345 .into_iter()
346 .flatten()
347 .max()
348 }
349
350 pub fn as_scalar(&self) -> Option<u128> {
351 match self {
352 Self::Scalar(val) => Some(*val),
353 _ => None,
354 }
355 }
356
357 pub fn as_var(&self) -> Option<&str> {
358 match self {
359 Self::Var(var) => Some(var),
360 _ => None,
361 }
362 }
363}
364
365impl ChromaticTerm {
366 pub fn eval(&self, ctx: &crate::scope::ChromaticScope) -> Result<Weight, String> {
368 match self {
369 Self::Value(x) => Ok(x.clone()),
370 Self::Scalar(_) => unreachable!("Scalars cannot be evaluated; qed"),
371 Self::Add(x, y) => {
372 let (a, b) = x.eval(ctx)?.into();
373 let (m, n) = y.eval(ctx)?.into();
374 Ok((a + m, b + n).into())
375 },
376 Self::Mul(x, y) => match (x.as_ref(), y.as_ref()) {
377 (Self::Scalar(x), y) => {
378 let (a, b) = y.eval(ctx)?.into();
379 Ok((*x * a, *x * b).into())
380 },
381 (x, Self::Scalar(y)) => {
382 let (a, b) = x.eval(ctx)?.into();
383 Ok((*y * a, *y * b).into())
384 },
385 (Self::Var(x), y) => match ctx.get(x) {
386 Some(Self::Scalar(x)) => Ok(y.eval(ctx)?.mul_scalar(x)),
387 Some(_) => Err(format!("Variable '{}' is not a scalar", x.deref())),
388 None => Err(format!("Variable '{}' not found", x.deref())),
389 },
390 (x, Self::Var(y)) => match ctx.get(y) {
391 Some(Self::Scalar(y)) => Ok(x.eval(ctx)?.mul_scalar(y)),
392 Some(_) => Err(format!("Variable '{}' is not a scalar", y.deref())),
393 None => Err(format!("Variable '{}' not found", y.deref())),
394 },
395 _ => unreachable!("Cannot multiply two terms; qed"),
396 },
397 Self::Var(x) =>
398 if let Some(var) = ctx.get(x) {
399 var.eval(ctx)
400 } else {
401 Err(format!("Variable '{}' not found", x.deref()))
402 },
403 }
404 }
405
406 pub fn simplify(&self, unit: crate::Dimension) -> Result<SimpleTerm, String> {
407 self.for_values(|t| match t {
408 Self::Value(Weight { time, .. }) if unit == crate::Dimension::Time =>
409 Ok(SimpleTerm::Value(*time)),
410 Self::Value(Weight { proof, .. }) if unit == crate::Dimension::Proof =>
411 Ok(SimpleTerm::Value(*proof)),
412 Self::Scalar(val) => Ok(SimpleTerm::Scalar(*val)),
413 Self::Var(var) => Ok(SimpleTerm::Var(var.clone())),
414 _ => unreachable!(),
415 })
416 }
417
418 pub fn for_values<F>(&self, f: F) -> Result<SimpleTerm, String>
419 where
420 F: Fn(&Self) -> Result<SimpleTerm, String> + Clone,
421 {
422 match self {
423 v @ Self::Value(_) | v @ Self::Scalar(_) | v @ Self::Var(_) => f(v),
424 Self::Mul(l, r) => Ok(SimpleTerm::Mul(
425 l.for_values::<F>(f.clone())?.into(),
426 r.for_values::<F>(f)?.into(),
427 )),
428 Self::Add(l, r) => Ok(SimpleTerm::Add(
429 l.for_values::<F>(f.clone())?.into(),
430 r.for_values::<F>(f)?.into(),
431 )),
432 }
433 }
434
435 pub fn splice_add(self, other: Self) -> Self {
437 match (self, other) {
438 (Self::Add(t1, p1), Self::Add(t2, p2)) =>
439 Self::Add(Box::new(t1.splice_add(*t2)), Box::new(p1.splice_add(*p2))),
440 (Self::Value(x), Self::Value(y)) => {
441 if x.time == 0 && y.proof == 0 {
443 Self::Value(Weight { time: y.time, proof: x.proof })
444 } else if x.proof == 0 && y.time == 0 {
445 Self::Value(Weight { time: x.time, proof: y.proof })
446 } else {
447 Self::Add(Box::new(Self::Value(x)), Box::new(Self::Value(y)))
448 }
449 },
450 (s, o) => Self::Add(Box::new(s), Box::new(o)),
451 }
452 }
453}
454
455impl<T> fmt::Display for Term<T>
456where
457 T: Clone + core::fmt::Display + One + Zero + PartialEq + Eq + ValueFormatter,
458{
459 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
460 write!(f, "{}", self.fmt_with_bracket(true))
461 }
462}
463
464impl TryInto<SimpleTerm> for &ExprBinary {
465 type Error = String;
466
467 fn try_into(self) -> Result<SimpleTerm, Self::Error> {
468 let left = crate::parse::pallet::parse_scalar_expression(&self.left)?.into();
469 let right = crate::parse::pallet::parse_scalar_expression(&self.right)?.into();
470
471 let term = match self.op {
472 BinOp::Mul(_) => SimpleTerm::Mul(left, right),
473 BinOp::Add(_) => SimpleTerm::Add(left, right),
474 _ => return Err("Unexpected operator".into()),
475 };
476 Ok(term)
477 }
478}