1use crate::dim::Assertion;
2use crate::internal::*;
3
4use super::{DimLike, sym::*};
5use itertools::Itertools;
6use num_integer::Integer;
7use num_traits::{AsPrimitive, PrimInt, Zero};
8use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10use std::fmt::Debug;
11use std::ops::Neg;
12use std::{fmt, ops};
13
14#[derive(Debug)]
15pub enum TooEarly {
16 UndeterminedSymbol(String),
17 Other(String),
18}
19
20impl std::fmt::Display for TooEarly {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 match self {
23 TooEarly::UndeterminedSymbol(s) => write!(f, "Undetermined symbol in expression: {s}"),
24 TooEarly::Other(s) => write!(f, "{s}"),
25 }
26 }
27}
28
29impl std::error::Error for TooEarly {}
30
31macro_rules! b( ($e:expr) => { Box::new($e) } );
32
33#[allow(clippy::derived_hash_with_manual_eq)]
39#[derive(Clone, Eq, Hash, Debug)]
40pub enum TDim {
41 Val(i64),
42 Sym(Symbol),
43 Add(Vec<TDim>),
44 Mul(Vec<TDim>),
45 MulInt(i64, Box<TDim>),
46 Div(Box<TDim>, u64),
47 Broadcast(Vec<TDim>),
48 Min(Vec<TDim>),
49 Max(Vec<TDim>),
50 Ge(Box<TDim>, Box<TDim>),
52 Eq(Box<TDim>, Box<TDim>),
54}
55
56use TDim::*;
57
58fn eq_structural(a: &TDim, b: &TDim) -> bool {
63 match (a, b) {
64 (Val(x), Val(y)) => x == y,
65 (Sym(x), Sym(y)) => x == y,
66 (Add(x), Add(y))
67 | (Mul(x), Mul(y))
68 | (Broadcast(x), Broadcast(y))
69 | (Min(x), Min(y))
70 | (Max(x), Max(y)) => {
71 x.len() == y.len() && x.iter().zip(y).all(|(a, b)| eq_structural(a, b))
72 }
73 (MulInt(p, x), MulInt(q, y)) => p == q && eq_structural(x, y),
74 (Div(x, p), Div(y, q)) => p == q && eq_structural(x, y),
75 (Ge(a, b), Ge(c, d)) | (Eq(a, b), Eq(c, d)) => eq_structural(a, c) && eq_structural(b, d),
76 _ => false,
77 }
78}
79
80std::thread_local! {
86 static EQ_GUARD: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
87}
88
89impl PartialEq for TDim {
90 fn eq(&self, other: &Self) -> bool {
91 if eq_structural(self, other) {
93 return true;
94 }
95 if EQ_GUARD.with(|g| g.get()) {
98 return false;
99 }
100 if matches!(self, Val(_) | Sym(_)) || matches!(other, Val(_) | Sym(_)) {
109 return false;
110 }
111 EQ_GUARD.with(|g| g.set(true));
118 let diff = (self.clone() - other.clone()).simplify();
119 EQ_GUARD.with(|g| g.set(false));
120 matches!(diff, Val(0))
121 }
122}
123
124fn tdim_lexi_order(a: &TDim, b: &TDim) -> Ordering {
125 match (a, b) {
126 (Sym(a), Sym(b)) => a.cmp(b),
127 (Val(a), Val(b)) => a.cmp(b),
128 (Add(a), Add(b))
129 | (Mul(a), Mul(b))
130 | (Broadcast(a), Broadcast(b))
131 | (Min(a), Min(b))
132 | (Max(a), Max(b)) => a.len().cmp(&b.len()).then(
133 a.iter()
134 .zip(b.iter())
135 .fold(Ordering::Equal, |acc, (a, b)| acc.then_with(|| tdim_lexi_order(a, b))),
136 ),
137 (MulInt(p, d), MulInt(q, e)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)),
138 (Div(d, p), Div(e, q)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)),
139 (Sym(_), _) => Ordering::Less,
140 (_, Sym(_)) => Ordering::Greater,
141 (Val(_), _) => Ordering::Less,
142 (_, Val(_)) => Ordering::Greater,
143 (Add(_), _) => Ordering::Less,
144 (_, Add(_)) => Ordering::Greater,
145 (Mul(_), _) => Ordering::Less,
146 (_, Mul(_)) => Ordering::Greater,
147 (MulInt(_, _), _) => Ordering::Less,
148 (_, MulInt(_, _)) => Ordering::Greater,
149 (Broadcast(_), _) => Ordering::Less,
150 (_, Broadcast(_)) => Ordering::Greater,
151 (Min(_), _) => Ordering::Less,
152 (_, Min(_)) => Ordering::Greater,
153 (Max(_), _) => Ordering::Less,
154 (_, Max(_)) => Ordering::Greater,
155 (Ge(a1, b1), Ge(a2, b2)) | (Eq(a1, b1), Eq(a2, b2)) => {
156 tdim_lexi_order(a1, a2).then_with(|| tdim_lexi_order(b1, b2))
157 }
158 (Ge(_, _) | Eq(_, _), _) => Ordering::Less,
159 (_, Ge(_, _) | Eq(_, _)) => Ordering::Greater,
160 }
161}
162
163fn try_divide_multiple_plus_remainder(
179 terms: &[TDim],
180 q: u64,
181 scope: &SymbolScopeData,
182 extra: &[Assertion],
183) -> Option<TDim> {
184 let mut quotients: Vec<TDim> = vec![];
185 let mut const_rem: i64 = 0;
186 let mut any_extracted = false;
187 for term in terms {
188 match term {
189 MulInt(c, x) if *c != 0 && c.rem_euclid(q as i64) == 0 => {
190 if !scope.prove_positive_or_zero_with_extra(x, extra) {
191 return None;
192 }
193 let new_coeff = c / (q as i64);
194 quotients.push(if new_coeff == 1 {
195 (**x).clone()
196 } else if new_coeff == -1 {
197 MulInt(-1, x.clone())
198 } else {
199 MulInt(new_coeff, x.clone())
200 });
201 any_extracted = true;
202 }
203 Val(v) => const_rem += v,
204 _ => return None,
205 }
206 }
207 if !any_extracted {
208 return None;
209 }
210 if !(0..q as i64).contains(&const_rem) {
211 return None;
212 }
213 Some(if quotients.len() == 1 { quotients.remove(0) } else { Add(quotients) })
214}
215
216impl fmt::Display for TDim {
217 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
218 match &self {
219 Sym(sym) => write!(fmt, "{sym}"),
220 Val(it) => write!(fmt, "{it}"),
221 Add(it) => write!(fmt, "{}", it.iter().map(|x| format!("{x}")).join("+")),
222 Mul(it) => write!(fmt, "{}", it.iter().map(|x| format!("({x})")).join("*")),
223 Broadcast(it) => {
224 write!(fmt, "broadcast({})", it.iter().map(|x| format!("({x})")).join(", "))
225 }
226 Min(it) => write!(fmt, "min({})", it.iter().map(|x| format!("{x}")).join(",")),
227 Max(it) => write!(fmt, "max({})", it.iter().map(|x| format!("{x}")).join(",")),
228 MulInt(a, b) => write!(fmt, "{a}*{b}"),
229 Div(a, b) => write!(fmt, "({a})/{b}"),
230 Ge(a, b) => write!(fmt, "({a}>={b})"),
231 Eq(a, b) => write!(fmt, "({a}=={b})"),
232 }
233 }
234}
235
236impl TDim {
237 #[inline]
238 pub fn is_one(&self) -> bool {
239 matches!(self, Val(1))
240 }
241
242 #[inline]
243 pub fn to_i64(&self) -> TractResult<i64> {
244 if let Val(v) = self {
245 Ok(*v)
246 } else {
247 Err(TooEarly::UndeterminedSymbol(self.to_string()))?
248 }
249 }
250
251 #[inline]
252 pub fn as_i64(&self) -> Option<i64> {
253 if let Val(v) = self { Some(*v) } else { None }
254 }
255
256 pub fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64> {
257 match self {
258 Sym(sym) => {
259 let Some(v) = values.get(sym) else {
260 Err(TooEarly::UndeterminedSymbol(self.to_string()))?
261 };
262 Ok(v)
263 }
264 Val(v) => Ok(*v),
265 Add(terms) => terms.iter().try_fold(0i64, |acc, it| {
266 let x = it.eval_to_i64(values)?;
267 acc.checked_add(x)
268 .with_context(|| format!("Overflow in TDim addition ({acc} + {x})"))
269 }),
270 Mul(terms) => terms.iter().try_fold(1i64, |acc, it| {
271 let x = it.eval_to_i64(values)?;
272 acc.checked_mul(x)
273 .with_context(|| format!("Overflow in TDim multiplication ({acc} * {x})"))
274 }),
275 Min(terms) => terms
276 .iter()
277 .try_fold(i64::MAX, |acc, it| it.eval_to_i64(values).map(|x| acc.min(x))),
278 Max(terms) => terms
279 .iter()
280 .try_fold(i64::MIN, |acc, it| it.eval_to_i64(values).map(|x| acc.max(x))),
281 Broadcast(terms) => terms.iter().try_fold(1i64, |acc, it| {
282 it.eval_to_i64(values)
283 .and_then(|x| ((acc as usize).broadcast(x as usize)).map(|x| x as i64))
284 }),
285 Div(a, q) => Ok(a.eval_to_i64(values)? / *q as i64),
286 MulInt(p, a) => {
287 let x = a.eval_to_i64(values)?;
288 x.checked_mul(*p)
289 .with_context(|| format!("Overflow in TDim multiplication ({x} * {p})"))
290 }
291 Ge(a, b) => Ok(if a.eval_to_i64(values)? >= b.eval_to_i64(values)? { 1 } else { 0 }),
292 Eq(a, b) => Ok(if a.eval_to_i64(values)? == b.eval_to_i64(values)? { 1 } else { 0 }),
293 }
294 }
295
296 pub fn eval(&self, values: &SymbolValues) -> TDim {
297 match self {
298 Sym(sym) => values.get(sym).map(Val).unwrap_or_else(|| Sym(sym.clone())),
299 Val(v) => Val(*v),
300 Add(terms) => terms.iter().fold(Val(0), |acc, it| -> TDim { acc + it.eval(values) }),
301 Mul(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim { acc * it.eval(values) }),
302 Min(terms) => {
303 terms.iter().fold(Val(i64::MAX), |acc, it| -> TDim { acc.mini(it.eval(values)) })
304 }
305 Max(terms) => {
306 terms.iter().fold(Val(i64::MIN), |acc, it| -> TDim { acc.maxi(it.eval(values)) })
307 }
308 Broadcast(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim {
309 acc.broadcast(it.eval(values)).unwrap_or_else(|_| self.clone())
310 }),
311 Div(a, q) => a.eval(values) / *q as i64,
312 MulInt(p, a) => a.eval(values) * *p,
313 Ge(a, b) => {
314 let a2 = a.eval(values);
315 let b2 = b.eval(values);
316 if let (Val(av), Val(bv)) = (&a2, &b2) {
317 Val(if av >= bv { 1 } else { 0 })
318 } else {
319 Ge(b!(a2), b!(b2))
320 }
321 }
322 Eq(a, b) => {
323 let a2 = a.eval(values);
324 let b2 = b.eval(values);
325 if let (Val(av), Val(bv)) = (&a2, &b2) {
326 Val(if av == bv { 1 } else { 0 })
327 } else {
328 Eq(b!(a2), b!(b2))
329 }
330 }
331 }
332 }
333
334 pub fn eval_with_scenario(&self, scenario: &str) -> TDim {
335 if let Val(v) = self {
336 return Val(*v);
337 }
338 let scope = self.find_scope().unwrap();
339 let scope = scope.0;
340 let locked = scope.lock();
341 let scope = locked.borrow();
342 self.clone().simplify_rec(&scope, Some(scenario), &[])
343 }
344
345 pub fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self> {
346 self.substitute_all(&std::collections::HashMap::from([(from.clone(), to.clone())]))
347 }
348
349 pub fn substitute_all(
350 &self,
351 map: &std::collections::HashMap<Symbol, Self>,
352 ) -> TractResult<Self> {
353 match self {
354 Sym(sym) => Ok(map.get(sym).cloned().unwrap_or_else(|| self.clone())),
355 Val(v) => Ok(Val(*v)),
356 Add(terms) => terms.iter().try_fold(Val(0), |acc, it| -> TractResult<TDim> {
357 Ok(acc + it.substitute_all(map)?)
358 }),
359 Mul(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult<TDim> {
360 Ok(acc * it.substitute_all(map)?)
361 }),
362 Broadcast(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult<TDim> {
363 acc.broadcast(it.substitute_all(map)?)
364 }),
365 Min(terms) => terms.iter().try_fold(Val(i64::MAX), |acc, it| -> TractResult<TDim> {
366 Ok(acc.mini(it.substitute_all(map)?))
367 }),
368 Max(terms) => terms.iter().try_fold(Val(i64::MIN), |acc, it| -> TractResult<TDim> {
369 Ok(acc.maxi(it.substitute_all(map)?))
370 }),
371 Div(a, q) => Ok(a.substitute_all(map)? / *q as i64),
372 MulInt(p, a) => Ok(a.substitute_all(map)? * *p),
373 Ge(a, b) => Ok(Ge(b!(a.substitute_all(map)?), b!(b.substitute_all(map)?))),
374 Eq(a, b) => Ok(Eq(b!(a.substitute_all(map)?), b!(b.substitute_all(map)?))),
375 }
376 }
377
378 pub fn reduce(self) -> TDim {
379 self.simplify()
380 .wiggle()
381 .into_iter()
382 .sorted_by(tdim_lexi_order)
383 .unique()
384 .map(|e| e.simplify())
385 .min_by_key(|e| e.cost())
386 .unwrap()
387 }
388
389 fn cost(&self) -> usize {
390 use self::TDim::*;
391 match self {
392 Sym(_) | Val(_) => 1,
393 Add(terms) => 2 * terms.iter().map(TDim::cost).sum::<usize>(),
394 Mul(terms) => 3 * terms.iter().map(TDim::cost).sum::<usize>(),
395 Broadcast(terms) => 4 * terms.iter().map(TDim::cost).sum::<usize>(),
396 Min(terms) | Max(terms) => 5 * terms.iter().map(TDim::cost).sum::<usize>(),
397 Div(a, _) => 3 * a.cost(),
398 MulInt(_, a) => 2 * a.cost(),
399 Ge(a, b) | Eq(a, b) => 5 * (a.cost() + b.cost()),
400 }
401 }
402
403 fn wiggle(&self) -> Vec<TDim> {
404 use self::TDim::*;
405 match self {
406 Sym(_) | Val(_) | Mul(_) | Broadcast(_) | Min(_) | Max(_) | Ge(_, _) | Eq(_, _) => {
407 vec![self.clone()]
408 }
409 Add(terms) => {
410 let mut forms = vec![];
411 let sub_exprs = terms.iter().map(|e| e.wiggle()).multi_cartesian_product();
412
413 fn first_div_term(terms: &[TDim]) -> Option<(usize, &TDim, u64)> {
414 terms.iter().enumerate().find_map(|(index, t)| match t {
415 Div(numerator, quotient) => Some((index, &**numerator, *quotient)),
416 _ => None,
417 })
418 }
419
420 fn generate_new_numerator(
421 div_index: usize,
422 numerator: &TDim,
423 quotient: u64,
424 expr: &[TDim],
425 ) -> Vec<TDim> {
426 expr.iter()
427 .enumerate()
428 .map(|(index, term)| {
429 if index == div_index {
430 numerator.clone()
431 } else {
432 MulInt(quotient as i64, Box::new(term.clone()))
433 }
434 })
435 .collect()
436 }
437
438 for expr in sub_exprs {
439 if let Some((div_index, numerator, quotient)) = first_div_term(&expr) {
440 let new_numerator =
441 generate_new_numerator(div_index, numerator, quotient, &expr);
442 forms.push(Div(Box::new(Add(new_numerator)), quotient))
443 }
444
445 forms.push(Add(expr));
446 }
447 forms
448 }
449 MulInt(p, a) => a.wiggle().into_iter().map(|a| MulInt(*p, b!(a))).collect(),
450 Div(a, q) => {
451 let mut forms = vec![];
452 for num in a.wiggle() {
453 if let Add(terms) = &num {
454 let (integer, non_integer): (Vec<_>, Vec<_>) =
455 terms.iter().cloned().partition(|a| a.gcd() % q == 0);
456 if !non_integer.iter().any(|t| matches!(t, Val(_))) {
467 let mut new_terms =
468 integer.iter().map(|i| i.div(*q)).collect::<Vec<_>>();
469 if non_integer.len() > 0 {
470 new_terms.push(Div(b!(Add(non_integer)), *q));
471 }
472 forms.push(Add(new_terms))
473 }
474 }
475 forms.push(Div(b!(num), *q))
476 }
477 forms
478 }
479 }
480 }
481
482 fn find_any_sym(tdim: &TDim) -> Option<&Symbol> {
483 match tdim {
484 Val(_) => None,
485 Sym(s) => Some(s),
486 Add(terms) | Mul(terms) | Min(terms) | Max(terms) | Broadcast(terms) => {
487 terms.iter().find_map(Self::find_any_sym)
488 }
489 MulInt(_, t) | Div(t, _) => Self::find_any_sym(t),
490 Ge(a, b) | Eq(a, b) => Self::find_any_sym(a).or_else(|| Self::find_any_sym(b)),
491 }
492 }
493
494 pub fn find_scope(&self) -> Option<SymbolScope> {
495 Self::find_any_sym(self).and_then(|s| s.scope().clone())
496 }
497
498 pub fn expand_polynomial(self) -> TDim {
507 use self::TDim::*;
508 match self {
509 Mul(terms) => {
510 let terms: Vec<TDim> = terms.into_iter().map(Self::expand_polynomial).collect();
511 if let Some(add_idx) = terms.iter().position(|t| matches!(t, Add(_))) {
512 let Add(add_terms) = terms[add_idx].clone() else { unreachable!() };
513 let others: Vec<TDim> = terms
514 .iter()
515 .enumerate()
516 .filter(|(i, _)| *i != add_idx)
517 .map(|(_, t)| t.clone())
518 .collect();
519 Add(add_terms
520 .into_iter()
521 .map(|t| {
522 let mut product = others.clone();
523 product.push(t);
524 Mul(product).expand_polynomial()
525 })
526 .collect())
527 .simplify()
528 } else {
529 Mul(terms).simplify()
530 }
531 }
532 MulInt(c, inner) => MulInt(c, Box::new(inner.expand_polynomial())).simplify(),
533 Add(terms) => Add(terms.into_iter().map(Self::expand_polynomial).collect()).simplify(),
534 Div(a, q) => Div(Box::new(a.expand_polynomial()), q).simplify(),
535 Min(terms) => Min(terms.into_iter().map(Self::expand_polynomial).collect()).simplify(),
536 Max(terms) => Max(terms.into_iter().map(Self::expand_polynomial).collect()).simplify(),
537 Broadcast(terms) => {
538 Broadcast(terms.into_iter().map(Self::expand_polynomial).collect()).simplify()
539 }
540 Ge(a, b) => {
541 Ge(Box::new(a.expand_polynomial()), Box::new(b.expand_polynomial())).simplify()
542 }
543 Eq(a, b) => {
544 Eq(Box::new(a.expand_polynomial()), Box::new(b.expand_polynomial())).simplify()
545 }
546 it @ (Sym(_) | Val(_)) => it,
547 }
548 }
549
550 pub fn simplify(self) -> TDim {
551 use self::TDim::*;
552 if let Ok(v) = self.eval_to_i64(&SymbolValues::default()) {
553 return Val(v);
554 }
555 let Some(scope) = self.find_scope() else {
556 return self;
557 };
558 let scope = scope.0;
559 let locked = scope.lock();
560 let scope = locked.borrow();
561 let it = self.simplify_rec(&scope, None, &[]);
562 let mut current: Option<TDim> = None;
563 for scenario in scope.scenarios() {
564 let v = it.clone().simplify_rec(&scope, Some(scenario), &[]);
565 if current.is_some_and(|c| c != v) {
566 return it;
567 } else {
568 current = Some(v);
569 }
570 }
571 current.unwrap_or(it)
572 }
573
574 pub fn simplify_with_extra_assertions(self, extra: &[Assertion]) -> TDim {
575 use self::TDim::*;
576 if extra.is_empty() {
577 return self.simplify();
578 }
579 if let Ok(v) = self.eval_to_i64(&SymbolValues::default()) {
580 return Val(v);
581 }
582 let Some(scope) = self.find_scope() else {
583 return self;
584 };
585 let scope = scope.0;
586 let locked = scope.lock();
587 let scope = locked.borrow();
588 let it = self.simplify_rec(&scope, None, extra);
589 let mut current: Option<TDim> = None;
590 for scenario in scope.scenarios() {
591 let v = it.clone().simplify_rec(&scope, Some(scenario), extra);
592 if current.is_some_and(|c| c != v) {
593 return it;
594 } else {
595 current = Some(v);
596 }
597 }
598 current.unwrap_or(it)
599 }
600
601 fn simplify_rec(
602 self,
603 scope: &SymbolScopeData,
604 scenario: Option<&str>,
605 extra: &[Assertion],
606 ) -> TDim {
607 match self {
608 Add(mut terms) => {
609 #[allow(clippy::mutable_key_type)]
610 let mut simplified_terms: HashMap<TDim, i64> = HashMap::new();
611 while let Some(term) = terms.pop() {
613 let simplified = term.simplify_rec(scope, scenario, extra);
614 match simplified {
615 Val(0) => {} Add(members) => {
617 terms.extend(members);
618 continue;
619 }
620 Val(value) => *simplified_terms.entry(Val(1)).or_insert(0) += value,
621 MulInt(value, factor) => {
622 *simplified_terms.entry((*factor).clone()).or_insert(0) += value;
623 }
624 n => *simplified_terms.entry(n).or_insert(0) += 1,
625 };
626 }
627
628 pub fn evaluate_count(term: TDim, count: i64) -> Option<TDim> {
629 match count {
630 0 => None,
631 _ if term == TDim::Val(1) => Some(TDim::Val(count)),
632 1 => Some(term),
633 _ => Some(TDim::MulInt(count, Box::new(term))),
634 }
635 }
636
637 let has_non_const =
652 simplified_terms.iter().any(|(k, &c)| c != 0 && !matches!(k, Val(_)));
653 let coef_gcd = if has_non_const {
654 simplified_terms
655 .values()
656 .filter(|&&c| c != 0)
657 .map(|c| c.unsigned_abs() as i64)
658 .reduce(|a, b| a.gcd(&b))
659 .unwrap_or(0)
660 } else {
661 0
662 };
663 let outer_factor = if coef_gcd > 1 {
664 for v in simplified_terms.values_mut() {
665 *v /= coef_gcd;
666 }
667 Some(coef_gcd)
668 } else {
669 None
670 };
671
672 let mut members: Vec<TDim> = simplified_terms
673 .into_iter()
674 .filter_map(|(term, count)| evaluate_count(term, count))
675 .collect();
676 members.sort_by(tdim_lexi_order);
677
678 let inner = match members.len() {
679 0 => TDim::Val(0),
680 1 => members.into_iter().next().unwrap(),
681 _ => TDim::Add(members),
682 };
683 match outer_factor {
684 None => inner,
685 Some(_) if matches!(inner, TDim::Val(0)) => TDim::Val(0),
686 Some(g) => TDim::MulInt(g, Box::new(inner)),
687 }
688 }
689 Mul(terms) => {
690 {
703 let add_indices: Vec<usize> = terms
704 .iter()
705 .enumerate()
706 .filter(|(_, t)| matches!(t, Add(_)))
707 .map(|(i, _)| i)
708 .collect();
709 if add_indices.len() == 1 {
710 let add_idx = add_indices[0];
711 let Add(add_terms) = &terms[add_idx] else { unreachable!() };
712 let other_factors: Vec<TDim> = terms
713 .iter()
714 .enumerate()
715 .filter(|(i, _)| *i != add_idx)
716 .map(|(_, t)| t.clone())
717 .collect();
718 let distributed: Vec<TDim> = add_terms
719 .iter()
720 .map(|at| {
721 let mut product = other_factors.clone();
722 product.push(at.clone());
723 Mul(product)
724 })
725 .collect();
726 return Add(distributed).simplify_rec(scope, scenario, extra);
727 }
728 }
729
730 let mut flattened_terms = vec![];
733 for t in terms {
734 match t.clone().reduce() {
735 Mul(inner_terms) => flattened_terms.extend(inner_terms),
736 MulInt(k, inner) => {
737 flattened_terms.push(Val(k));
738 flattened_terms.push(*inner);
739 }
740 other => flattened_terms.push(other),
741 }
742 }
743 let mut terms = flattened_terms;
744
745 let mut gcd = Mul(terms.clone()).gcd() as i64;
746 if gcd == 0 {
747 return Val(0);
748 }
749 terms = if gcd != 1 {
750 terms
751 .into_iter()
752 .map(|t| {
753 let gcd = t.gcd();
754 (t / gcd).simplify_rec(scope, scenario, extra)
755 })
756 .collect()
757 } else {
758 terms
759 };
760 if terms.iter().filter(|t| t == &&Val(-1)).count() % 2 == 1 {
761 gcd = -gcd;
762 }
763 terms.retain(|t| !t.is_one() && t != &Val(-1));
764 terms.sort_by(tdim_lexi_order);
765
766 match (gcd, terms.len()) {
767 (_, 0) => Val(gcd), (0, _) => Val(0), (1, 1) => terms.remove(0), (1, _) => Mul(terms), (_, 1) => MulInt(gcd, Box::new(terms.remove(0))), _ => MulInt(gcd, Box::new(Mul(terms))), }
775 }
776 MulInt(coef, expr) => {
777 match *expr {
778 MulInt(c2, inner) => {
779 if let Some(c) = coef.checked_mul(c2) {
780 return MulInt(c, inner).simplify_rec(scope, scenario, extra);
781 } else {
782 return MulInt(coef, Box::new(MulInt(c2, inner)));
783 }
784 }
785 Val(v) => {
786 return coef
787 .checked_mul(v)
788 .map(Val)
789 .unwrap_or_else(|| MulInt(coef, Box::new(Val(v))));
790 }
791 _ => {}
792 }
793
794 let simplified = expr.simplify_rec(scope, scenario, extra);
795 match (coef, simplified) {
796 (0, _) => Val(0), (1, s) => s, (_, Add(terms)) => Add(terms
799 .into_iter()
800 .map(|term| {
801 MulInt(coef, Box::new(term)).simplify_rec(scope, scenario, extra)
802 })
803 .collect()), (c, Val(v)) => {
805 c.checked_mul(v).map(Val).unwrap_or_else(|| MulInt(c, Box::new(Val(v))))
806 } (c, MulInt(v, inner)) => {
808 if let Some(cv) = c.checked_mul(v) {
809 MulInt(cv, inner) } else {
811 MulInt(c, Box::new(MulInt(v, inner)))
812 }
813 }
814 (_, s) => MulInt(coef, Box::new(s)), }
816 }
817 Div(a, q) => {
818 if q == 1 {
819 return a.simplify_rec(scope, scenario, extra);
820 } else if let Div(a, q2) = *a {
821 return Div(a, q * q2).simplify_rec(scope, scenario, extra);
822 }
823 let a = a.simplify_rec(scope, scenario, extra);
824 if let Val(a) = a {
825 Val(a / q as i64)
826 } else if let MulInt(-1, a) = a {
827 MulInt(-1, b!(Div(a, q)))
828 } else if let Add(mut terms) = a {
829 if terms
830 .iter()
831 .any(|t| if let MulInt(-1, s) = t { matches!(&**s, Sym(_)) } else { false })
832 {
833 MulInt(
834 -1,
835 b!(Div(
836 b!(Add(terms.into_iter().map(|t| MulInt(-1, b!(t))).collect())
837 .simplify_rec(scope, scenario, extra)),
838 q
839 )),
840 )
841 } else if let Some(val) = terms
842 .iter()
843 .find_map(|t| if let Val(v) = t { Some(*v) } else { None })
844 .and_then(|v| {
845 if v >= q as i64 {
846 Some(v / q as i64)
847 } else if v < 0 {
848 Some(-Integer::div_ceil(&-v, &(q as i64)))
849 } else {
850 None
851 }
852 })
853 {
854 terms.push(Val(-val * q as i64));
855 let inner = Div(b!(Add(terms).simplify_rec(scope, scenario, extra)), q)
859 .simplify_rec(scope, scenario, extra);
860 Add(vec![Val(val), inner])
861 } else if let Some(simplified) =
862 try_divide_multiple_plus_remainder(&terms, q, scope, extra)
863 {
864 simplified.simplify_rec(scope, scenario, extra)
873 } else if let Some(found_idx) = terms.iter().position(|term| {
874 matches!(term, MulInt(p, inner)
878 if *p == -(q as i64)
879 && matches!(inner.as_ref(), Div(_, q2) if *q2 == q))
880 }) {
881 let MulInt(_, inner) = &terms[found_idx] else { unreachable!() };
882 let Div(y, _) = inner.as_ref() else { unreachable!() };
883 let remaining: Vec<TDim> = terms
884 .iter()
885 .enumerate()
886 .filter(|&(i, _)| i != found_idx)
887 .map(|(_, t)| t.clone())
888 .collect();
889 let remaining_sum = match remaining.len() {
890 0 => Val(0),
891 1 => remaining.into_iter().next().unwrap(),
892 _ => Add(remaining),
893 };
894 if eq_structural(&remaining_sum, y) {
895 Val(0)
896 } else {
897 Div(b!(Add(terms)), q)
898 }
899 } else {
900 Div(b!(Add(terms)), q)
901 }
902 } else if let MulInt(p, a) = a {
903 if p == q as i64 {
904 a.simplify()
905 } else {
906 let gcd = p.abs().gcd(&(q as i64));
907 if gcd == p {
908 Div(a, q / gcd as u64)
909 } else if gcd == q as i64 {
910 MulInt(p / gcd, a)
911 } else if gcd > 1 {
912 Div(b!(MulInt(p / gcd, a)), q / gcd as u64)
913 .simplify_rec(scope, scenario, extra)
914 } else {
915 Div(b!(MulInt(p, a)), q)
916 }
917 }
918 } else {
919 Div(b!(a), q)
920 }
921 }
922 Broadcast(terms) => {
923 let mut terms: Vec<TDim> = terms
924 .iter()
925 .map(|s| s.clone().simplify_rec(scope, scenario, extra))
926 .flat_map(|t| if let Broadcast(t) = t { t } else { vec![t] })
927 .filter(|t| !t.is_one())
928 .sorted_by(tdim_lexi_order)
929 .dedup()
930 .collect_vec();
931 match &*terms {
933 [] => Val(1),
934 [_] => terms.remove(0),
935 [a, Min(m)] | [Min(m), a]
936 if m.contains(a)
937 && m.iter()
938 .all(|t| scope.prove_strict_positive_with_extra(t, extra)) =>
939 {
940 a.clone()
941 }
942 _ => Broadcast(terms),
943 }
944 }
945
946 Min(terms) => {
947 let mut flatten: Vec<TDim> = terms
948 .into_iter()
949 .map(|t| t.simplify_rec(scope, scenario, extra))
950 .flat_map(|t| if let Min(t) = t { t } else { vec![t] })
951 .filter(|t| t != &Val(i64::MAX))
952 .sorted_by(tdim_lexi_order)
953 .dedup()
954 .collect();
955 #[allow(clippy::mutable_key_type)]
956 let mut redundant = HashSet::<TDim>::default();
957 for pair in flatten.iter().permutations(2) {
958 let (a, b) = (pair[0], pair[1]);
959 if redundant.contains(a) || redundant.contains(b) {
960 continue;
961 }
962 let diff = a.clone() - b;
963 if diff.as_i64().is_some_and(|i| i >= 0)
964 || scope.prove_positive_or_zero_with_extra(&diff, extra)
965 {
966 redundant.insert(a.clone());
967 }
968 }
969 flatten.retain(|t| !redundant.contains(t));
970 if flatten.len() == 0 {
971 i64::MAX.to_dim()
972 } else if flatten.len() == 1 {
973 flatten.into_iter().next().unwrap()
974 } else {
975 Min(flatten)
976 }
977 }
978 Max(terms) => {
979 let mut flatten: Vec<TDim> = terms
980 .into_iter()
981 .map(|t| t.simplify_rec(scope, scenario, extra))
982 .flat_map(|t| if let Max(t) = t { t } else { vec![t] })
983 .filter(|t| t != &Val(i64::MIN))
984 .sorted_by(tdim_lexi_order)
985 .dedup()
986 .collect();
987 #[allow(clippy::mutable_key_type)]
988 let mut redundant = HashSet::<TDim>::default();
989 for pair in flatten.iter().permutations(2) {
990 let (a, b) = (pair[0], pair[1]);
991 if redundant.contains(a) || redundant.contains(b) {
992 continue;
993 }
994 let diff = a.clone() - b;
995 if diff.as_i64().is_some_and(|i| i >= 0)
996 || scope.prove_positive_or_zero_with_extra(&diff, extra)
997 {
998 redundant.insert(b.clone());
999 }
1000 }
1001 flatten.retain(|t| !redundant.contains(t));
1002 if flatten.len() == 0 {
1003 i64::MIN.to_dim()
1004 } else if flatten.len() == 1 {
1005 flatten.into_iter().next().unwrap()
1006 } else {
1007 Max(flatten)
1008 }
1009 }
1010 Sym(s) => scope
1011 .assertions(scenario)
1012 .find_map(|a| match a {
1013 Assertion::Eq(Sym(sym), v) if sym == &s => Some(v.clone()),
1014 _ => None,
1015 })
1016 .unwrap_or(Sym(s)),
1017 Val(_) => self,
1018 Ge(a, b) => {
1019 let a = a.simplify_rec(scope, scenario, extra);
1020 let b = b.simplify_rec(scope, scenario, extra);
1021 match (&a, &b) {
1022 (Val(av), Val(bv)) => Val(if av >= bv { 1 } else { 0 }),
1023 _ => {
1024 let diff = a.clone() - b.clone();
1025 if scope.prove_positive_or_zero_with_extra(&diff, extra) {
1026 Val(1)
1027 } else if scope
1028 .prove_strict_positive_with_extra(&(b.clone() - a.clone()), extra)
1029 {
1030 Val(0)
1031 } else {
1032 Ge(b!(a), b!(b))
1033 }
1034 }
1035 }
1036 }
1037 Eq(a, b) => {
1038 let a = a.simplify_rec(scope, scenario, extra);
1039 let b = b.simplify_rec(scope, scenario, extra);
1040 match (&a, &b) {
1041 (Val(av), Val(bv)) => Val(if av == bv { 1 } else { 0 }),
1042 _ => {
1043 let diff = a.clone() - b.clone();
1044 if scope.prove_strict_positive_with_extra(&diff, extra)
1045 || scope
1046 .prove_strict_positive_with_extra(&(b.clone() - a.clone()), extra)
1047 {
1048 Val(0)
1049 } else {
1050 let boolean_case = match (&a, &b) {
1055 (Val(0), e) | (e, Val(0)) => Some((e, false)),
1056 (Val(1), e) | (e, Val(1)) => Some((e, true)),
1057 _ => None,
1058 };
1059 if let Some((expr, equals_one)) = boolean_case {
1060 if scope.prove_positive_or_zero_with_extra(expr, extra)
1061 && scope.prove_positive_or_zero_with_extra(
1062 &(Val(1) - expr.clone()),
1063 extra,
1064 )
1065 {
1066 return if equals_one {
1067 expr.clone()
1068 } else {
1069 (Val(1) - expr.clone()).simplify_rec(scope, scenario, extra)
1070 };
1071 }
1072 }
1073 Eq(b!(a), b!(b))
1074 }
1075 }
1076 }
1077 }
1078 }
1079 }
1080
1081 pub(super) fn inclusive_bound(&self, scope: &SymbolScopeData, upper: bool) -> Option<i64> {
1082 use self::TDim::*;
1083 match self {
1084 Val(n) => Some(*n),
1085 Sym(_) => {
1086 if upper {
1087 scope
1088 .all_assertions()
1089 .iter()
1090 .filter_map(|assert| match &assert {
1091 Assertion::LT(left, right)
1092 if left == self && right.as_i64().is_some() =>
1093 {
1094 Some(right.as_i64().unwrap() - 1)
1095 }
1096 Assertion::LTE(left, right)
1097 if left == self && right.as_i64().is_some() =>
1098 {
1099 Some(right.as_i64().unwrap())
1100 }
1101 _ => None,
1102 })
1103 .min()
1104 } else {
1105 scope
1106 .all_assertions()
1107 .iter()
1108 .filter_map(|assert| match &assert {
1109 Assertion::GT(left, right)
1110 if left == self && right.as_i64().is_some() =>
1111 {
1112 Some(right.as_i64().unwrap() + 1)
1113 }
1114 Assertion::GTE(left, right)
1115 if left == self && right.as_i64().is_some() =>
1116 {
1117 Some(right.as_i64().unwrap())
1118 }
1119 _ => None,
1120 })
1121 .max()
1122 }
1123 }
1124 Add(terms) => {
1125 let mut bound: i64 = 0;
1126 for t in terms {
1127 if let Some(b) = t.inclusive_bound(scope, upper) {
1128 bound = bound.checked_add(b)?;
1129 } else {
1130 return None;
1131 }
1132 }
1133 Some(bound)
1134 }
1135 MulInt(p, a) => match p.cmp(&0) {
1136 Ordering::Equal => Some(0),
1137 Ordering::Greater => {
1138 a.inclusive_bound(scope, upper).and_then(|x| x.checked_mul(*p))
1139 }
1140 Ordering::Less => a.inclusive_bound(scope, !upper).and_then(|x| x.checked_mul(*p)),
1141 },
1142 Mul(terms) => {
1143 let mut lo: i64 = 1;
1145 let mut hi: i64 = 1;
1146 for t in terms {
1147 let t_lo = t.inclusive_bound(scope, false)?;
1148 let t_hi = t.inclusive_bound(scope, true)?;
1149 if t_lo < 0 {
1150 return None;
1151 }
1152 lo = lo.checked_mul(t_lo)?;
1153 hi = hi.checked_mul(t_hi)?;
1154 }
1155 Some(if upper { hi } else { lo })
1156 }
1157 Min(terms) if !upper => {
1158 let bounds: Option<Vec<i64>> =
1161 terms.iter().map(|t| t.inclusive_bound(scope, false)).collect();
1162 bounds.map(|b| b.into_iter().min().unwrap_or(i64::MAX))
1163 }
1164 Max(terms) if upper => {
1165 let bounds: Option<Vec<i64>> =
1168 terms.iter().map(|t| t.inclusive_bound(scope, true)).collect();
1169 bounds.map(|b| b.into_iter().max().unwrap_or(i64::MIN))
1170 }
1171 Div(a, q) => a.inclusive_bound(scope, upper).map(|x| x / (*q as i64)),
1172 Broadcast(terms) => {
1173 if upper {
1174 Max(terms.clone()).inclusive_bound(scope, true)
1175 } else {
1176 Min(terms.clone()).inclusive_bound(scope, false)
1177 }
1178 }
1179 Ge(_, _) | Eq(_, _) => {
1180 if upper {
1181 Some(1)
1182 } else {
1183 Some(0)
1184 }
1185 }
1186 _ => None,
1187 }
1188 }
1189
1190 pub fn low_inclusive_bound(&self) -> Option<i64> {
1191 if let TDim::Val(v) = self {
1192 return Some(*v);
1193 }
1194 let scope = self.find_scope()?;
1195 let data = scope.0.lock();
1196 let data = data.borrow();
1197 self.inclusive_bound(&data, false)
1198 }
1199
1200 pub fn high_inclusive_bound(&self) -> Option<i64> {
1201 if let TDim::Val(v) = self {
1202 return Some(*v);
1203 }
1204 let scope = self.find_scope()?;
1205 let data = scope.0.lock();
1206 let data = data.borrow();
1207 self.inclusive_bound(&data, true)
1208 }
1209
1210 pub fn prove_positive_or_zero(&self) -> bool {
1211 if let TDim::Val(v) = self {
1212 return *v >= 0;
1213 }
1214 let Some(scope) = self.find_scope() else { return false };
1215 let data = scope.0.lock();
1216 let data = data.borrow();
1217 data.prove_positive_or_zero(self)
1218 }
1219
1220 pub fn prove_strict_positive(&self) -> bool {
1221 if let TDim::Val(v) = self {
1222 return *v > 0;
1223 }
1224 (self.clone() - 1).prove_positive_or_zero()
1225 }
1226
1227 pub fn prove_negative_or_zero(&self) -> bool {
1228 if let TDim::Val(v) = self {
1229 return *v <= 0;
1230 }
1231 self.clone().neg().prove_positive_or_zero()
1232 }
1233
1234 pub fn prove_strict_negative(&self) -> bool {
1235 if let TDim::Val(v) = self {
1236 return *v < 0;
1237 }
1238 self.clone().neg().prove_strict_positive()
1239 }
1240
1241 pub fn lcm(&self, other: &TDim) -> Option<TDim> {
1249 match (self.as_i64(), other.as_i64()) {
1250 (Some(a), Some(b)) if a > 0 && b > 0 => {
1251 let g = (a as u64).gcd(&(b as u64));
1252 let l = (a as u64 / g).saturating_mul(b as u64);
1253 if l > i64::MAX as u64 { None } else { Some(TDim::Val(l as i64)) }
1254 }
1255 (Some(0), _) | (_, Some(0)) => Some(TDim::Val(0)),
1256 _ => None,
1257 }
1258 }
1259
1260 pub fn gcd(&self) -> u64 {
1261 use self::TDim::*;
1262 match self {
1263 Val(v) => v.unsigned_abs(),
1264 Sym(_) => 1,
1265 Add(terms) => {
1266 let (head, tail) = terms.split_first().unwrap();
1267 tail.iter().fold(head.gcd(), |a, b| a.gcd(&b.gcd()))
1268 }
1269 MulInt(p, a) => a.gcd().saturating_mul(p.unsigned_abs()),
1270 Mul(terms) => terms.iter().map(|t| t.gcd()).fold(1u64, |a, b| a.saturating_mul(b)),
1271 Min(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
1272 Max(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
1273 Div(a, q) => {
1274 if a.gcd() % *q == 0 {
1275 a.gcd() / *q
1276 } else {
1277 1
1278 }
1279 }
1280 Broadcast(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap_or(1),
1281 Ge(_, _) | Eq(_, _) => 1,
1282 }
1283 }
1284
1285 fn div(&self, d: u64) -> TDim {
1286 use self::TDim::*;
1287 if d == 1 {
1288 return self.clone();
1289 }
1290 match self {
1291 Val(v) => Val(v / d as i64),
1292 Sym(_) => panic!(),
1293 Add(terms) => Add(terms.iter().map(|t| t.div(d)).collect()),
1294 Min(terms) => Min(terms.iter().map(|t| t.div(d)).collect()),
1295 Max(terms) => Max(terms.iter().map(|t| t.div(d)).collect()),
1296 Broadcast(terms) => Broadcast(terms.iter().map(|t| t.div(d)).collect()),
1297 Mul(_) => Div(Box::new(self.clone()), d),
1298 MulInt(p, a) => {
1299 if *p == d as i64 {
1300 (**a).clone()
1301 } else {
1302 let gcd = p.unsigned_abs().gcd(&d);
1303 MulInt(p / gcd as i64, b!(a.div(d / gcd)))
1304 }
1305 }
1306 Div(a, q) => Div(a.clone(), q * d),
1307 Ge(_, _) | Eq(_, _) => Div(Box::new(self.clone()), d),
1308 }
1309 }
1310
1311 pub fn div_ceil(self, rhs: u64) -> TDim {
1312 TDim::Div(Box::new(Add(vec![self, Val(rhs as i64 - 1)])), rhs).reduce()
1313 }
1314
1315 pub fn guess_slope(&self, sym: &Symbol) -> (i64, u64) {
1316 fn slope_rec(d: &TDim, sym: &Symbol) -> (i64, i64) {
1317 match d {
1318 Val(_) => (0, 1),
1319 Sym(s) => ((sym == s) as i64, 1),
1320 Add(terms) => terms
1321 .iter()
1322 .map(|d| slope_rec(d, sym))
1323 .fold((0, 1), |a, b| ((a.0 * b.1 + a.1 * b.0), (b.1 * a.1))),
1324 Mul(terms) => terms
1325 .iter()
1326 .map(|d| slope_rec(d, sym))
1327 .fold((1, 1), |a, b| ((a.0 * b.0), (b.1 * a.1))),
1328 MulInt(p, a) => {
1329 let (n, d) = slope_rec(a, sym);
1330 (p * n, d)
1331 }
1332 Div(a, q) => {
1333 let (n, d) = slope_rec(a, sym);
1334 (n, d * *q as i64)
1335 }
1336 Broadcast(terms) => slope_rec(&terms[0], sym),
1337 Min(terms) => slope_rec(&terms[0], sym),
1338 Max(terms) => slope_rec(&terms[0], sym),
1339 Ge(_, _) | Eq(_, _) => (0, 1),
1340 }
1341 }
1342 let (p, q) = slope_rec(self, sym);
1343 reduce_ratio(p, q)
1344 }
1345
1346 #[allow(clippy::mutable_key_type)]
1347 pub fn symbols(&self) -> std::collections::HashSet<Symbol> {
1348 match self {
1349 Val(_) => maplit::hashset!(),
1350 Sym(s) => maplit::hashset!(s.clone()),
1351 Add(terms) | Mul(terms) | Broadcast(terms) | Min(terms) | Max(terms) => {
1352 terms.iter().fold(maplit::hashset!(), |mut set, v| {
1353 set.extend(v.symbols());
1354 set
1355 })
1356 }
1357 MulInt(_, a) => a.symbols(),
1358 Div(a, _) => a.symbols(),
1359 Ge(a, b) | Eq(a, b) => {
1360 let mut set = a.symbols();
1361 set.extend(b.symbols());
1362 set
1363 }
1364 }
1365 }
1366
1367 pub fn compatible_with(&self, other: &TDim) -> bool {
1368 if let Ok(x) = (self.clone() - other).to_i64() {
1369 return x == 0;
1370 }
1371 true }
1373}
1374
1375pub(super) fn reduce_ratio(mut p: i64, mut q: i64) -> (i64, u64) {
1376 let gcd = p.abs().gcd(&q.abs());
1377 if gcd > 1 {
1378 p /= gcd;
1379 q /= gcd;
1380 }
1381 if q < 0 { (-p, (-q) as u64) } else { (p, q as u64) }
1382}
1383
1384impl Zero for TDim {
1385 fn zero() -> Self {
1386 Val(0)
1387 }
1388 fn is_zero(&self) -> bool {
1389 matches!(self, Val(0))
1390 }
1391}
1392
1393impl Default for TDim {
1394 fn default() -> TDim {
1395 Val(0)
1396 }
1397}
1398
1399impl num_traits::Bounded for TDim {
1400 fn min_value() -> Self {
1401 TDim::Val(i64::MIN)
1402 }
1403
1404 fn max_value() -> Self {
1405 TDim::Val(i64::MAX)
1406 }
1407}
1408
1409impl num_traits::One for TDim {
1410 fn one() -> Self {
1411 TDim::Val(1)
1412 }
1413}
1414
1415impl ::std::iter::Sum for TDim {
1416 fn sum<I: Iterator<Item = TDim>>(iter: I) -> TDim {
1417 iter.fold(0.into(), |a, b| a + b)
1418 }
1419}
1420
1421impl<'a> ::std::iter::Sum<&'a TDim> for TDim {
1422 fn sum<I: Iterator<Item = &'a TDim>>(iter: I) -> TDim {
1423 iter.fold(0.into(), |a, b| a + b)
1424 }
1425}
1426
1427impl std::iter::Product for TDim {
1428 fn product<I: Iterator<Item = TDim>>(iter: I) -> Self {
1429 iter.fold(TDim::Val(1), |a, b| a * b)
1430 }
1431}
1432
1433impl<'a> ::std::iter::Product<&'a TDim> for TDim {
1434 fn product<I: Iterator<Item = &'a TDim>>(iter: I) -> TDim {
1435 iter.fold(1.into(), |a, b| a * b)
1436 }
1437}
1438
1439macro_rules! from_i {
1440 ($i: ty) => {
1441 impl From<$i> for TDim {
1442 fn from(v: $i) -> TDim {
1443 TDim::Val(v as _)
1444 }
1445 }
1446 impl<'a> From<&'a $i> for TDim {
1447 fn from(v: &'a $i) -> TDim {
1448 TDim::Val(*v as _)
1449 }
1450 }
1451 };
1452}
1453
1454from_i!(i32);
1455from_i!(i64);
1456from_i!(u64);
1457from_i!(isize);
1458from_i!(usize);
1459
1460impl From<Symbol> for TDim {
1461 fn from(it: Symbol) -> Self {
1462 TDim::Sym(it)
1463 }
1464}
1465
1466impl<'a> From<&'a Symbol> for TDim {
1467 fn from(it: &'a Symbol) -> Self {
1468 TDim::Sym(it.clone())
1469 }
1470}
1471
1472impl ops::Neg for TDim {
1473 type Output = Self;
1474 fn neg(self) -> Self {
1475 if let Val(v) = self { Val(-v) } else { TDim::MulInt(-1, Box::new(self)).reduce() }
1476 }
1477}
1478
1479impl<'a> ops::AddAssign<&'a TDim> for TDim {
1480 fn add_assign(&mut self, rhs: &'a TDim) {
1481 if rhs.is_zero() {
1482 } else if self.is_zero() {
1483 *self = rhs.clone();
1484 } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1485 *s += o;
1486 } else {
1487 *self = TDim::Add(vec![std::mem::take(self), rhs.clone()]).reduce()
1488 }
1489 }
1490}
1491
1492impl<I> ops::AddAssign<I> for TDim
1493where
1494 I: Into<TDim>,
1495{
1496 fn add_assign(&mut self, rhs: I) {
1497 let rhs = rhs.into();
1498 if rhs.is_zero() {
1499 } else if self.is_zero() {
1500 *self = rhs;
1501 } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1502 *s += o;
1503 } else {
1504 *self = TDim::Add(vec![std::mem::take(self), rhs]).reduce()
1505 }
1506 }
1507}
1508
1509impl<I> ops::Add<I> for TDim
1510where
1511 I: Into<TDim>,
1512{
1513 type Output = Self;
1514 fn add(mut self, rhs: I) -> Self {
1515 self += rhs;
1516 self
1517 }
1518}
1519
1520impl<'a> ops::Add<&'a TDim> for TDim {
1521 type Output = Self;
1522 fn add(mut self, rhs: &'a TDim) -> Self {
1523 self += rhs;
1524 self
1525 }
1526}
1527
1528#[allow(clippy::suspicious_op_assign_impl)]
1529impl<'a> ops::SubAssign<&'a TDim> for TDim {
1530 fn sub_assign(&mut self, rhs: &'a TDim) {
1531 if rhs.is_zero() {
1532 } else if self.is_zero() {
1533 *self = rhs.clone().neg();
1534 } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1535 *s -= o;
1536 } else {
1537 *self = TDim::Add(vec![std::mem::take(self), rhs.clone().neg()]).reduce()
1538 }
1539 }
1540}
1541
1542impl<I> ops::SubAssign<I> for TDim
1543where
1544 I: Into<TDim>,
1545{
1546 fn sub_assign(&mut self, rhs: I) {
1547 let rhs = rhs.into();
1548 if rhs.is_zero() {
1549 } else if self.is_zero() {
1550 *self = rhs.neg();
1551 } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1552 *s -= o;
1553 } else {
1554 *self = TDim::Add(vec![std::mem::take(self), rhs.neg()]).reduce()
1555 }
1556 }
1557}
1558
1559impl<I> ops::Sub<I> for TDim
1560where
1561 I: Into<TDim>,
1562{
1563 type Output = Self;
1564 fn sub(mut self, rhs: I) -> Self {
1565 self -= rhs;
1566 self
1567 }
1568}
1569
1570impl<'a> ops::Sub<&'a TDim> for TDim {
1571 type Output = Self;
1572 fn sub(mut self, rhs: &'a TDim) -> Self {
1573 self -= rhs;
1574 self
1575 }
1576}
1577
1578impl<I: Into<TDim>> ops::MulAssign<I> for TDim {
1579 fn mul_assign(&mut self, rhs: I) {
1580 let rhs = rhs.into();
1581 if self.is_one() {
1582 *self = rhs
1583 } else if rhs.is_one() {
1584 } else {
1585 *self = TDim::Mul(vec![rhs, std::mem::take(self)]).reduce()
1586 }
1587 }
1588}
1589
1590impl<'a> ops::MulAssign<&'a TDim> for TDim {
1591 fn mul_assign(&mut self, rhs: &'a TDim) {
1592 if self.is_one() {
1593 *self = rhs.clone()
1594 } else if rhs.is_one() {
1595 } else {
1596 *self = TDim::Mul(vec![std::mem::take(self), rhs.clone()]).reduce()
1597 }
1598 }
1599}
1600
1601impl<I: Into<TDim>> ops::Mul<I> for TDim {
1602 type Output = Self;
1603 fn mul(mut self, rhs: I) -> Self {
1604 self *= rhs.into();
1605 self
1606 }
1607}
1608
1609impl<'a> ops::Mul<&'a TDim> for TDim {
1610 type Output = Self;
1611 fn mul(mut self, rhs: &'a TDim) -> Self {
1612 self *= rhs;
1613 self
1614 }
1615}
1616
1617impl<I: AsPrimitive<u64> + PrimInt> ops::DivAssign<I> for TDim {
1618 fn div_assign(&mut self, rhs: I) {
1619 *self = TDim::Div(Box::new(std::mem::take(self)), rhs.as_()).reduce()
1620 }
1621}
1622
1623impl<I: AsPrimitive<u64> + PrimInt> ops::Div<I> for TDim {
1624 type Output = Self;
1625 fn div(mut self, rhs: I) -> Self {
1626 self /= rhs.as_();
1627 self
1628 }
1629}
1630
1631impl<I: AsPrimitive<u64> + PrimInt> ops::RemAssign<I> for TDim {
1632 fn rem_assign(&mut self, rhs: I) {
1633 *self += -(self.clone() / rhs.as_() * rhs.as_());
1634 }
1635}
1636
1637impl<I: AsPrimitive<u64> + PrimInt> ops::Rem<I> for TDim {
1638 type Output = Self;
1639 fn rem(mut self, rhs: I) -> Self {
1640 self %= rhs;
1641 self
1642 }
1643}
1644
1645#[cfg(test)]
1646mod tests {
1647 use super::*;
1648
1649 macro_rules! b( ($e:expr) => { Box::new($e) } );
1650
1651 lazy_static::lazy_static! {
1652 static ref table: SymbolScope = SymbolScope::default();
1653 static ref A: Symbol = table.sym("a");
1654 static ref B: Symbol = table.sym("b");
1655 static ref C: Symbol = table.sym("c");
1656 static ref D: Symbol = table.sym("d");
1657 static ref E: Symbol = table.sym("e");
1658 }
1659
1660 fn neg(a: &TDim) -> TDim {
1661 mul(-1, a)
1662 }
1663
1664 fn add(a: &TDim, b: &TDim) -> TDim {
1665 TDim::Add(vec![a.clone(), b.clone()])
1666 }
1667
1668 fn mul(a: i64, b: &TDim) -> TDim {
1669 TDim::MulInt(a, b![b.clone()])
1670 }
1671
1672 fn div(a: &TDim, b: u64) -> TDim {
1673 TDim::Div(b!(a.clone()), b)
1674 }
1675
1676 #[test]
1677 fn reduce_add() {
1678 assert_eq!(add(&A.to_dim(), &neg(&A.to_dim())).reduce(), Val(0))
1679 }
1680
1681 #[test]
1682 fn lcm_basic() {
1683 assert_eq!(Val(16).lcm(&Val(32)), Some(Val(32)));
1684 assert_eq!(Val(32).lcm(&Val(16)), Some(Val(32)));
1685 assert_eq!(Val(6).lcm(&Val(8)), Some(Val(24)));
1686 assert_eq!(Val(7).lcm(&Val(7)), Some(Val(7)));
1687 assert_eq!(Val(16).lcm(&A.to_dim()), None);
1689 }
1690
1691 #[test]
1692 fn reduce_neg_mul() {
1693 assert_eq!(neg(&mul(2, &A.to_dim())).reduce(), mul(-2, &A.to_dim()))
1694 }
1695
1696 #[test]
1697 fn reduce_cplx_ex_2() {
1698 assert_eq!(
1699 add(
1700 &add(&Val(-4), &mul(-2, &div(&A.to_dim(), 4))),
1701 &mul(-2, &mul(-1, &div(&A.to_dim(), 4)))
1702 )
1703 .reduce(),
1704 Val(-4)
1705 )
1706 }
1707
1708 #[test]
1709 fn reduce_cplx_ex_3() {
1710 assert_eq!(div(&MulInt(1, b!(MulInt(4, b!(A.to_dim())))), 4).reduce(), A.to_dim())
1711 }
1712
1713 #[test]
1714 fn reduce_cplx_ex_4() {
1715 assert_eq!(
1717 add(&div(&add(&A.to_dim(), &Val(1)), 2), &div(&add(&neg(&A.to_dim()), &Val(1)), 2))
1718 .reduce(),
1719 1.into()
1720 );
1721 }
1722
1723 #[test]
1724 fn reduce_mul_mul_1() {
1725 assert_eq!(mul(3, &mul(2, &A.to_dim())).reduce(), mul(6, &A.to_dim()))
1726 }
1727
1728 #[test]
1729 fn reduce_mul_mul_2() {
1730 assert_eq!(mul(-2, &mul(-1, &A.to_dim())).reduce(), mul(2, &A.to_dim()))
1731 }
1732
1733 #[test]
1734 fn reduce_mul_div_1() {
1735 assert_eq!(mul(2, &div(&mul(-1, &A.to_dim()), 3)).reduce(), mul(-2, &div(&A.to_dim(), 3)))
1736 }
1737
1738 #[test]
1739 fn const_and_add() {
1740 let e: TDim = 2i64.into();
1741 assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), 2);
1742 let e: TDim = TDim::from(2) + 3;
1743 assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), 5);
1744 let e: TDim = TDim::from(2) - 3;
1745 assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), -1);
1746 let e: TDim = -TDim::from(2);
1747 assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), -2);
1748 }
1749
1750 #[test]
1751 fn substitution() {
1752 let a: TDim = A.to_dim();
1753 assert_eq!(a.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 2);
1754 let e = a + 3;
1755 assert_eq!(e.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 5);
1756 }
1757
1758 #[test]
1759 fn reduce_adds() {
1760 let e: TDim = TDim::from(2) + 1;
1761 assert_eq!(e, TDim::from(3));
1762 let e: TDim = TDim::from(3) + 2;
1763 assert_eq!(e, TDim::from(5));
1764 let e: TDim = TDim::from(3) + 0;
1765 assert_eq!(e, TDim::from(3));
1766 let e: TDim = TDim::from(3) + 2 + 1;
1767 assert_eq!(e, TDim::from(6));
1768 }
1769
1770 #[test]
1771 fn reduce_muls() {
1772 let e: TDim = Val(1) * A.to_dim();
1773 assert_eq!(e, A.to_dim());
1774 let e: TDim = A.to_dim() * &B.to_dim() * 1;
1775 assert_eq!(e, A.to_dim() * &B.to_dim());
1776 }
1777
1778 #[test]
1779 fn reduce_divs() {
1780 let e: TDim = TDim::from(2) / 1;
1781 assert_eq!(e, TDim::from(2));
1782 let e: TDim = TDim::from(3) / 2;
1783 assert_eq!(e, TDim::from(1));
1784 let e: TDim = TDim::from(3) % 2;
1785 assert_eq!(e, TDim::from(1));
1786 let e: TDim = TDim::from(5) / 2;
1787 assert_eq!(e, TDim::from(2));
1788 let e: TDim = TDim::from(5) % 2;
1789 assert_eq!(e, TDim::from(1));
1790 }
1791
1792 #[test]
1793 fn reduce_div_bug_0() {
1794 let e1: TDim = (A.to_dim() + 23) / 2 - 1;
1795 let e2: TDim = (A.to_dim() + 21) / 2;
1796 assert_eq!(e1, e2);
1797 }
1798
1799 #[test]
1800 fn reduce_div_bug_1() {
1801 let e1: TDim = (A.to_dim() + -1) / 2;
1802 let e2: TDim = (A.to_dim() + 1) / 2 - 1;
1803 assert_eq!(e1, e2);
1804 }
1805
1806 #[test]
1807 fn reduce_div_bug_2() {
1808 let e1: TDim = ((A.to_dim() + 1) / 2 + 1) / 2;
1809 let e2: TDim = (A.to_dim() + 3) / 4;
1810 assert_eq!(e1, e2);
1811 }
1812
1813 #[test]
1814 fn divide_multiple_plus_remainder() {
1815 let scope = SymbolScope::default().with_assertion("S>=0").unwrap();
1817 let s = scope.sym("S");
1818
1819 let e: TDim = (s.to_dim() * 2 + 1) / 2;
1821 assert_eq!(e.simplify(), s.to_dim());
1822
1823 let e: TDim = (s.to_dim() * 2 + 1) / 2 - 1;
1825 assert_eq!(e.simplify(), s.to_dim() - 1);
1826
1827 let e: TDim = (s.to_dim() * 2 - 1) / 2;
1829 assert_eq!(e.simplify(), s.to_dim() - 1);
1830
1831 let e: TDim = (s.to_dim() * 4 + 3) / 2;
1833 assert_eq!(e.simplify(), s.to_dim() * 2 + 1);
1834 }
1835
1836 #[test]
1837 fn divide_multiple_plus_remainder_no_assertion() {
1838 let scope = SymbolScope::default();
1845 let s = scope.sym("S");
1846 let e: TDim = (s.to_dim() * 2 + 1) / 2;
1847 assert_ne!(e.simplify(), s.to_dim());
1848 }
1849
1850 #[test]
1851 fn modulo_div_is_zero() {
1852 let scope = SymbolScope::default();
1855 let s = scope.sym("S");
1856 let e: TDim = (s.to_dim() - s.to_dim() / 2 * 2) / 2;
1858 assert_eq!(e.simplify(), TDim::Val(0));
1859 let a = s.to_dim() + 1;
1862 let e2: TDim = (a.clone() - a.clone() / 2 * 2) / 2;
1863 assert_eq!(e2.simplify(), TDim::Val(0));
1864 }
1865
1866 #[test]
1867 fn reduce_div_bug_3() {
1868 let e1: TDim = (A.to_dim() / 2) * -4;
1869 let e2: TDim = (A.to_dim() / 2) * -4 / 1;
1870 assert_eq!(e1, e2);
1871 }
1872
1873 #[test]
1874 fn reduce_mul_div() {
1875 let e: TDim = A.to_dim() * 2 / 2;
1876 assert_eq!(e, A.to_dim());
1877 }
1878
1879 #[test]
1880 fn expand_polynomial_two_add_factors() {
1881 let a = A.to_dim();
1886 let b = B.to_dim();
1887 let lhs = (a.clone() + a.clone() * &b * 2) * (TDim::from(1) + &b);
1888 let rhs = a.clone() * (TDim::from(1) + &b) * (TDim::from(1) + b.clone() * 2);
1889 assert_eq!(lhs.expand_polynomial(), rhs.expand_polynomial());
1890 }
1891
1892 #[test]
1893 fn reduce_div_mul() {
1894 let e: TDim = A.to_dim() / 2 * 2;
1895 assert_ne!(e, A.to_dim());
1896 }
1897
1898 #[test]
1899 fn reduce_add_div() {
1900 let e: TDim = A.to_dim() / 2 + 1;
1901 assert_eq!(e, ((A.to_dim() + 2) / 2));
1902 }
1903
1904 #[test]
1905 fn reduce_neg_mul_() {
1906 let e: TDim = TDim::from(1) - A.to_dim() * 2;
1907 assert_eq!(e, TDim::from(1) + A.to_dim() * -2);
1908 }
1909
1910 #[test]
1911 fn reduce_add_rem_1() {
1912 assert_eq!(((A.to_dim() + 4) % 2), (A.to_dim() % 2));
1913 }
1914
1915 #[test]
1916 fn reduce_add_rem_2() {
1917 assert_eq!(((A.to_dim() - 4) % 2), (A.to_dim() % 2));
1918 }
1919
1920 #[test]
1921 fn reduce_rem_div() {
1922 let e: TDim = A.to_dim() % 2 / 2;
1923 assert_eq!(e, TDim::from(0));
1924 }
1925
1926 #[test]
1927 fn conv2d_ex_1() {
1928 let e = (TDim::from(1) - 1 + 1).div_ceil(1);
1929 assert_eq!(e, TDim::from(1));
1930 }
1931
1932 #[test]
1933 fn conv2d_ex_2() {
1934 let e = (A.to_dim() - 3 + 1).div_ceil(1);
1935 assert_eq!(e, A.to_dim() + -2);
1936 }
1937
1938 #[test]
1939 fn extract_int_gcd_from_muls() {
1940 let term = (A.to_dim() + 1) / 4;
1941 let mul = (term.clone() * 24 - 24) * (term.clone() * 2 - 2);
1942 let target = (term.clone() - 1) * (term.clone() - 1) * 48;
1943 assert_eq!(mul, target);
1944 }
1945
1946 #[test]
1947 fn equality_of_muls() {
1948 let term = (A.to_dim() + 1) / 4;
1949 let mul1 = (term.clone() * 2 - 3) * (term.clone() - 1);
1950 let mul2 = (term.clone() - 1) * (term.clone() * 2 - 3);
1951 assert_eq!(mul1, mul2);
1952 }
1953
1954 #[test]
1955 fn factorize_complex_expr_times_int() {
1956 let term = (A.to_dim() + 1) / 4;
1957 let e = term.clone() * 2 - &term - 1;
1958 assert_eq!(e, term - 1);
1959 }
1960
1961 #[test]
1962 fn broadcast_over_min() {
1963 for a in 1..5 {
1969 for b in 1..5 {
1970 if b > 1 && a > b {
1971 assert!(a.broadcast(a.min(b)).is_err());
1972 } else {
1973 assert_eq!(a.broadcast(a.min(b)).unwrap(), a);
1974 }
1975 }
1976 }
1977 }
1978
1979 #[test]
1980 fn min_ints_1() {
1981 assert_eq!(2.to_dim().mini(1.to_dim()), 1.to_dim());
1982 }
1983
1984 #[test]
1985 fn min_ints_2() {
1986 assert_eq!(1.to_dim().mini(2.to_dim()), 1.to_dim());
1987 }
1988
1989 #[test]
1990 fn min_same() {
1991 assert_eq!(A.to_dim().mini(A.to_dim()), A.to_dim());
1992 }
1993
1994 #[test]
1995 fn min_noop() {
1996 assert_eq!(A.to_dim().mini(1.to_dim()), A.to_dim().mini(1.to_dim()));
1997 }
1998
1999 #[test]
2000 fn min_diff_1() {
2001 assert_eq!((A.to_dim() + 1).mini(A.to_dim() + 2), A.to_dim() + 1);
2002 }
2003
2004 #[test]
2005 fn slope_0() {
2006 assert_eq!(12.to_dim().guess_slope(&A), (0, 1));
2007 }
2008
2009 #[test]
2010 fn slope_1() {
2011 assert_eq!(A.to_dim().guess_slope(&A), (1, 1));
2012 }
2013
2014 #[test]
2015 fn slope_2() {
2016 assert_eq!((A.to_dim() * 2).guess_slope(&A), (2, 1));
2017 }
2018
2019 #[test]
2020 fn slope_3() {
2021 assert_eq!((A.to_dim() * 2 + A.to_dim() / 2).guess_slope(&A), (5, 2));
2022 }
2023
2024 #[test]
2025 fn slope_4() {
2026 assert_eq!((A.to_dim()).guess_slope(&B), (0, 1));
2027 }
2028
2029 #[test]
2030 fn slope_5() {
2031 assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1));
2032 assert_eq!((A.to_dim() + 1).guess_slope(&B), (0, 1));
2033 }
2034
2035 #[test]
2036 fn slope_6() {
2037 assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1));
2038 assert_eq!((A.to_dim() + B.to_dim()).guess_slope(&B), (1, 1));
2039 }
2040
2041 #[test]
2042 fn min_0() -> TractResult<()> {
2043 let symbols = SymbolScope::default();
2044 assert_eq!(
2045 symbols.parse_tdim("min(S+3, S+2)").unwrap().simplify(),
2046 symbols.parse_tdim("S+2").unwrap(),
2047 );
2048 Ok(())
2049 }
2050
2051 #[test]
2052 fn commutative_mul_parens() -> TractResult<()> {
2053 let symbols = SymbolScope::default();
2054 assert_eq!(
2055 symbols.parse_tdim("A*(B*C)").unwrap().simplify(),
2056 symbols.parse_tdim("(B*A)*C").unwrap().simplify(),
2057 );
2058 Ok(())
2059 }
2060
2061 #[test]
2062 fn commutative_in_nemo_parakeet_model() -> TractResult<()> {
2063 let symbols = SymbolScope::default();
2064 assert_eq!(
2065 symbols
2066 .parse_tdim("8*(1+-1*max(0,5000+-1*(S+7)/8)+max(0,4999+(S+7)/8))*((B)*((S+7)/8))")
2067 .unwrap()
2068 .simplify(),
2069 symbols
2070 .parse_tdim("8*((B)*(1+-1*max(0,5000+-1*(S+7)/8)+max(0,4999+(S+7)/8)))*((S+7)/8)")
2071 .unwrap()
2072 .simplify(),
2073 );
2074 Ok(())
2075 }
2076
2077 #[test]
2078 fn commutative_mul_parens_deep() -> TractResult<()> {
2079 let symbols = SymbolScope::default();
2080 let deep_tdim = Mul(vec![
2081 Mul(vec![Mul(vec![Mul(vec![A.to_dim(), B.to_dim()]), C.to_dim()]), D.to_dim()]),
2082 E.to_dim(),
2083 ])
2084 .simplify();
2085 assert_eq!(deep_tdim, symbols.parse_tdim("a*b*c*d*e").unwrap().simplify());
2086 Ok(())
2087 }
2088
2089 #[test]
2092 fn ge_concrete_true() {
2093 assert_eq!(Ge(b!(Val(5)), b!(Val(3))).reduce(), Val(1));
2094 }
2095
2096 #[test]
2097 fn ge_concrete_false() {
2098 assert_eq!(Ge(b!(Val(2)), b!(Val(3))).reduce(), Val(0));
2099 }
2100
2101 #[test]
2102 fn lt_concrete_true() {
2103 assert_eq!(Ge(b!(Val(3)), b!(Val(3))).reduce(), Val(1));
2105 }
2106
2107 #[test]
2108 fn lt_concrete_false() {
2109 assert_eq!(Ge(b!(Val(3)), b!(Val(6))).reduce(), Val(0));
2111 }
2112
2113 #[test]
2114 fn eq_concrete_true() {
2115 assert_eq!(Eq(b!(Val(3)), b!(Val(3))).reduce(), Val(1));
2116 }
2117
2118 #[test]
2119 fn eq_concrete_false() {
2120 assert_eq!(Eq(b!(Val(3)), b!(Val(4))).reduce(), Val(0));
2121 }
2122
2123 #[test]
2124 fn not_val_0() {
2125 assert_eq!((Val(1) - Val(0)).reduce(), Val(1));
2127 }
2128
2129 #[test]
2130 fn not_val_1() {
2131 assert_eq!((Val(1) - Val(1)).reduce(), Val(0));
2133 }
2134
2135 #[test]
2136 fn not_lt_becomes_ge() {
2137 let s = SymbolScope::default();
2139 let t = s.sym("T");
2140 let x1 = s.sym("x1");
2141 let expr = Val(1) - Ge(b!(Sym(t.clone())), b!(Sym(x1.clone()) + Val(1)));
2143 let at_boundary = expr.substitute(&x1, &Sym(t.clone())).unwrap().simplify();
2144 assert_eq!(at_boundary, Val(1));
2145 }
2146
2147 #[test]
2148 fn eq_with_assertion_proves_false() {
2149 let s = SymbolScope::default();
2151 s.add_assertion("T >= 1").unwrap();
2152 let t = s.sym("T");
2153 let expr = Eq(b!(Sym(t)), b!(Val(0)));
2154 assert_eq!(expr.simplify(), Val(0));
2155 }
2156
2157 #[test]
2158 fn ge_coord_at_extremes() {
2159 let s = SymbolScope::default();
2161 s.add_assertion("T >= 1").unwrap();
2162 let t = s.sym("T");
2163 let x1 = s.sym("x1");
2164 let expr = Ge(b!(Sym(x1.clone())), b!(Sym(t.clone())));
2165 let at_max = expr.substitute(&x1, &(Sym(t.clone()) - Val(1))).unwrap().simplify();
2168 assert_eq!(at_max, Val(0));
2169 }
2170
2171 #[test]
2172 fn eval_to_i64_new_variants() {
2173 use super::super::sym::SymbolValues;
2174 let sv = SymbolValues::default();
2175 assert_eq!(Ge(b!(Val(5)), b!(Val(3))).eval_to_i64(&sv).unwrap(), 1);
2176 assert_eq!(Ge(b!(Val(3)), b!(Val(5))).eval_to_i64(&sv).unwrap(), 0);
2177 assert_eq!(Eq(b!(Val(3)), b!(Val(3))).eval_to_i64(&sv).unwrap(), 1);
2178 assert_eq!(Eq(b!(Val(3)), b!(Val(4))).eval_to_i64(&sv).unwrap(), 0);
2179 }
2180
2181 #[test]
2182 fn eq_boolean_simplifies() {
2183 let s = SymbolScope::default();
2184 s.add_assertion("cw >= 0").unwrap();
2185 s.add_assertion("cw <= 1").unwrap();
2186 let cw = s.sym("cw");
2187 assert_eq!(Eq(b!(Val(1) - Sym(cw.clone())), b!(Val(0))).simplify(), Sym(cw.clone()));
2189 assert_eq!(Eq(b!(Sym(cw.clone())), b!(Val(0))).simplify(), Val(1) - Sym(cw.clone()));
2191 assert_eq!(Eq(b!(Sym(cw.clone())), b!(Val(1))).simplify(), Sym(cw.clone()));
2193 assert_eq!(Eq(b!(Val(1) - Sym(cw.clone())), b!(Val(1))).simplify(), Val(1) - Sym(cw));
2195 }
2196
2197 #[test]
2198 fn eq_boolean_mul_of_ge() {
2199 let s = SymbolScope::default();
2202 let x = s.sym("x");
2203 let product =
2204 Mul(vec![Ge(b!(Val(2)), b!(Sym(x.clone()))), Ge(b!(Sym(x.clone())), b!(Val(0)))]);
2205 let eq = Eq(b!(product.clone()), b!(Val(0)));
2206 assert_eq!(eq.simplify(), Val(1) - product);
2207 }
2208
2209 #[test]
2210 fn min_1_max_0_sym() {
2211 let s = SymbolScope::default();
2213 let x = s.sym("X");
2214 let expr = Min(vec![Val(1), Max(vec![Val(0), Sym(x)])]);
2215 let simplified = expr.simplify();
2216 eprintln!("simplified: {simplified}");
2217 assert!(format!("{simplified}").contains("min"), "Min dropped: {simplified}");
2218 }
2219
2220 #[test]
2221 fn min_preserved_in_subtraction_parts() {
2222 let s = SymbolScope::default();
2224 let t = s.sym("T");
2225 let p = s.sym("P");
2226 let ss = s.sym("S");
2227
2228 let cum_after =
2229 Max(vec![Val(0), (Sym(t.clone()) + Val(1)) * Sym(p.clone()) - Sym(ss.clone())]);
2230 let min_after = Min(vec![Val(1), cum_after.clone()]);
2231 let simplified = min_after.simplify();
2232 eprintln!("min_after simplified: {simplified}");
2233 assert!(format!("{simplified}").contains("min"), "Min wrapper was dropped: {simplified}");
2235 }
2236
2237 #[test]
2238 fn min_preserved_in_subtraction() {
2239 let s = SymbolScope::default();
2242 let t = s.sym("T");
2243 let p = s.sym("P");
2244 let ss = s.sym("S");
2245
2246 let cum_after =
2247 Max(vec![Val(0), (Sym(t.clone()) + Val(1)) * Sym(p.clone()) - Sym(ss.clone())]);
2248 let cum_before = Max(vec![Val(0), Sym(t.clone()) * Sym(p.clone()) - Sym(ss.clone())]);
2249
2250 let ap = Min(vec![Val(1), cum_after.clone()]) - Min(vec![Val(1), cum_before.clone()]);
2251 let simplified = ap.simplify();
2252
2253 use super::super::sym::SymbolValues;
2255 let sv = SymbolValues::default().with(&t, 1).with(&p, 4).with(&ss, 3);
2256 assert_eq!(simplified.eval_to_i64(&sv).unwrap(), 0, "simplified: {simplified}");
2257
2258 let sv = SymbolValues::default().with(&t, 0).with(&p, 4).with(&ss, 3);
2260 assert_eq!(simplified.eval_to_i64(&sv).unwrap(), 1, "simplified: {simplified}");
2261
2262 let sv = SymbolValues::default().with(&t, 0).with(&p, 1).with(&ss, 1);
2264 assert_eq!(simplified.eval_to_i64(&sv).unwrap(), 0, "simplified: {simplified}");
2265 }
2266
2267 #[test]
2268 fn mul_neg_b_by_8() {
2269 let s = SymbolScope::default();
2270 let b = Sym(s.sym("B"));
2271 let a = Mul(vec![Val(8), MulInt(-1, Box::new(b.clone()))]);
2273 let c = MulInt(-8, Box::new(b.clone()));
2274 let a_s = a.simplify();
2275 let c_s = c.simplify();
2276 assert_eq!(a_s, c_s, "8*(-1*B) should simplify the same as -8*B");
2277 }
2278
2279 #[test]
2283 fn reduce_div_by_common_factor_with_divisor() {
2284 let lhs = (A.to_dim() * 14 + 6) / 8;
2285 let rhs = (A.to_dim() * 7 + 3) / 4;
2286 assert_eq!(lhs, rhs);
2287 }
2288
2289 #[test]
2292 fn reduce_div_when_factor_equals_divisor() {
2293 let lhs = (A.to_dim() * 4 + 8) / 4;
2294 let rhs = A.to_dim() + 2;
2295 assert_eq!(lhs, rhs);
2296 }
2297
2298 #[test]
2301 fn no_reduce_when_terms_coprime_with_divisor() {
2302 let e = (A.to_dim() * 7 + 3) / 4;
2303 match &e {
2306 Div(_, q) => assert_eq!(*q, 4),
2307 other => panic!("expected Div(_, 4), got {other:?}"),
2308 }
2309 }
2310
2311 #[test]
2315 fn no_reduce_when_sym_has_implicit_unit_coefficient() {
2316 let e = (A.to_dim() + 4) / 2;
2318 let sv2 = SymbolValues::default().with(&A, 2);
2321 let sv4 = SymbolValues::default().with(&A, 4);
2322 assert_eq!(e.eval_to_i64(&sv2).unwrap(), 3);
2323 assert_eq!(e.eval_to_i64(&sv4).unwrap(), 4);
2324 }
2325}