1use crate::dim::Assertion;
2use crate::internal::*;
3
4use super::{DimLike, sym::*};
5use itertools::Itertools;
6use num_integer::Integer;
7use num_traits::{AsPrimitive, PrimInt, Zero};
8use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10use std::fmt::Debug;
11use std::ops::Neg;
12use std::{fmt, ops};
13
14#[derive(Debug)]
15pub enum TooEarly {
16 UndeterminedSymbol(String),
17 Other(String),
18}
19
20impl std::fmt::Display for TooEarly {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 match self {
23 TooEarly::UndeterminedSymbol(s) => write!(f, "Undetermined symbol in expression: {s}"),
24 TooEarly::Other(s) => write!(f, "{s}"),
25 }
26 }
27}
28
29impl std::error::Error for TooEarly {}
30
31macro_rules! b( ($e:expr) => { Box::new($e) } );
32
33#[allow(clippy::derived_hash_with_manual_eq)]
39#[derive(Clone, Eq, Hash, Debug)]
40pub enum TDim {
41 Val(i64),
42 Sym(Symbol),
43 Add(Vec<TDim>),
44 Mul(Vec<TDim>),
45 MulInt(i64, Box<TDim>),
46 Div(Box<TDim>, u64),
47 Broadcast(Vec<TDim>),
48 Min(Vec<TDim>),
49 Max(Vec<TDim>),
50 Ge(Box<TDim>, Box<TDim>),
52 Eq(Box<TDim>, Box<TDim>),
54}
55
56use TDim::*;
57
58fn eq_structural(a: &TDim, b: &TDim) -> bool {
63 match (a, b) {
64 (Val(x), Val(y)) => x == y,
65 (Sym(x), Sym(y)) => x == y,
66 (Add(x), Add(y))
67 | (Mul(x), Mul(y))
68 | (Broadcast(x), Broadcast(y))
69 | (Min(x), Min(y))
70 | (Max(x), Max(y)) => {
71 x.len() == y.len() && x.iter().zip(y).all(|(a, b)| eq_structural(a, b))
72 }
73 (MulInt(p, x), MulInt(q, y)) => p == q && eq_structural(x, y),
74 (Div(x, p), Div(y, q)) => p == q && eq_structural(x, y),
75 (Ge(a, b), Ge(c, d)) | (Eq(a, b), Eq(c, d)) => eq_structural(a, c) && eq_structural(b, d),
76 _ => false,
77 }
78}
79
80std::thread_local! {
86 static EQ_GUARD: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
87}
88
89impl PartialEq for TDim {
90 fn eq(&self, other: &Self) -> bool {
91 if eq_structural(self, other) {
93 return true;
94 }
95 if EQ_GUARD.with(|g| g.get()) {
98 return false;
99 }
100 if matches!(self, Val(_) | Sym(_)) || matches!(other, Val(_) | Sym(_)) {
109 return false;
110 }
111 EQ_GUARD.with(|g| g.set(true));
118 let diff = (self.clone() - other.clone()).simplify();
119 EQ_GUARD.with(|g| g.set(false));
120 matches!(diff, Val(0))
121 }
122}
123
124fn tdim_lexi_order(a: &TDim, b: &TDim) -> Ordering {
125 match (a, b) {
126 (Sym(a), Sym(b)) => a.cmp(b),
127 (Val(a), Val(b)) => a.cmp(b),
128 (Add(a), Add(b))
129 | (Mul(a), Mul(b))
130 | (Broadcast(a), Broadcast(b))
131 | (Min(a), Min(b))
132 | (Max(a), Max(b)) => a.len().cmp(&b.len()).then(
133 a.iter()
134 .zip(b.iter())
135 .fold(Ordering::Equal, |acc, (a, b)| acc.then_with(|| tdim_lexi_order(a, b))),
136 ),
137 (MulInt(p, d), MulInt(q, e)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)),
138 (Div(d, p), Div(e, q)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)),
139 (Sym(_), _) => Ordering::Less,
140 (_, Sym(_)) => Ordering::Greater,
141 (Val(_), _) => Ordering::Less,
142 (_, Val(_)) => Ordering::Greater,
143 (Add(_), _) => Ordering::Less,
144 (_, Add(_)) => Ordering::Greater,
145 (Mul(_), _) => Ordering::Less,
146 (_, Mul(_)) => Ordering::Greater,
147 (MulInt(_, _), _) => Ordering::Less,
148 (_, MulInt(_, _)) => Ordering::Greater,
149 (Broadcast(_), _) => Ordering::Less,
150 (_, Broadcast(_)) => Ordering::Greater,
151 (Min(_), _) => Ordering::Less,
152 (_, Min(_)) => Ordering::Greater,
153 (Max(_), _) => Ordering::Less,
154 (_, Max(_)) => Ordering::Greater,
155 (Ge(a1, b1), Ge(a2, b2)) | (Eq(a1, b1), Eq(a2, b2)) => {
156 tdim_lexi_order(a1, a2).then_with(|| tdim_lexi_order(b1, b2))
157 }
158 (Ge(_, _) | Eq(_, _), _) => Ordering::Less,
159 (_, Ge(_, _) | Eq(_, _)) => Ordering::Greater,
160 }
161}
162
163fn try_divide_multiple_plus_remainder(
179 terms: &[TDim],
180 q: u64,
181 scope: &SymbolScopeData,
182 extra: &[Assertion],
183) -> Option<TDim> {
184 let mut quotients: Vec<TDim> = vec![];
185 let mut const_rem: i64 = 0;
186 let mut any_extracted = false;
187 for term in terms {
188 match term {
189 MulInt(c, x) if *c != 0 && c.rem_euclid(q as i64) == 0 => {
190 if !scope.prove_positive_or_zero_with_extra(x, extra) {
191 return None;
192 }
193 let new_coeff = c / (q as i64);
194 quotients.push(if new_coeff == 1 {
195 (**x).clone()
196 } else if new_coeff == -1 {
197 MulInt(-1, x.clone())
198 } else {
199 MulInt(new_coeff, x.clone())
200 });
201 any_extracted = true;
202 }
203 Val(v) => const_rem += v,
204 _ => return None,
205 }
206 }
207 if !any_extracted {
208 return None;
209 }
210 if !(0..q as i64).contains(&const_rem) {
211 return None;
212 }
213 Some(if quotients.len() == 1 { quotients.remove(0) } else { Add(quotients) })
214}
215
216impl fmt::Display for TDim {
217 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
218 match &self {
219 Sym(sym) => write!(fmt, "{sym}"),
220 Val(it) => write!(fmt, "{it}"),
221 Add(it) => write!(fmt, "{}", it.iter().map(|x| format!("{x}")).join("+")),
222 Mul(it) => write!(fmt, "{}", it.iter().map(|x| format!("({x})")).join("*")),
223 Broadcast(it) => {
224 write!(fmt, "broadcast({})", it.iter().map(|x| format!("({x})")).join(", "))
225 }
226 Min(it) => write!(fmt, "min({})", it.iter().map(|x| format!("{x}")).join(",")),
227 Max(it) => write!(fmt, "max({})", it.iter().map(|x| format!("{x}")).join(",")),
228 MulInt(a, b) => write!(fmt, "{a}*{b}"),
229 Div(a, b) => write!(fmt, "({a})/{b}"),
230 Ge(a, b) => write!(fmt, "({a}>={b})"),
231 Eq(a, b) => write!(fmt, "({a}=={b})"),
232 }
233 }
234}
235
236impl TDim {
237 #[inline]
238 pub fn is_one(&self) -> bool {
239 matches!(self, Val(1))
240 }
241
242 #[inline]
243 pub fn to_i64(&self) -> TractResult<i64> {
244 if let Val(v) = self {
245 Ok(*v)
246 } else {
247 Err(TooEarly::UndeterminedSymbol(self.to_string()))?
248 }
249 }
250
251 #[inline]
252 pub fn as_i64(&self) -> Option<i64> {
253 if let Val(v) = self { Some(*v) } else { None }
254 }
255
256 pub fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64> {
257 match self {
258 Sym(sym) => {
259 let Some(v) = values.get(sym) else {
260 Err(TooEarly::UndeterminedSymbol(self.to_string()))?
261 };
262 Ok(v)
263 }
264 Val(v) => Ok(*v),
265 Add(terms) => terms.iter().try_fold(0i64, |acc, it| {
266 let x = it.eval_to_i64(values)?;
267 acc.checked_add(x)
268 .with_context(|| format!("Overflow in TDim addition ({acc} + {x})"))
269 }),
270 Mul(terms) => terms.iter().try_fold(1i64, |acc, it| {
271 let x = it.eval_to_i64(values)?;
272 acc.checked_mul(x)
273 .with_context(|| format!("Overflow in TDim multiplication ({acc} * {x})"))
274 }),
275 Min(terms) => terms
276 .iter()
277 .try_fold(i64::MAX, |acc, it| it.eval_to_i64(values).map(|x| acc.min(x))),
278 Max(terms) => terms
279 .iter()
280 .try_fold(i64::MIN, |acc, it| it.eval_to_i64(values).map(|x| acc.max(x))),
281 Broadcast(terms) => terms.iter().try_fold(1i64, |acc, it| {
282 it.eval_to_i64(values)
283 .and_then(|x| ((acc as usize).broadcast(x as usize)).map(|x| x as i64))
284 }),
285 Div(a, q) => Ok(a.eval_to_i64(values)? / *q as i64),
286 MulInt(p, a) => {
287 let x = a.eval_to_i64(values)?;
288 x.checked_mul(*p)
289 .with_context(|| format!("Overflow in TDim multiplication ({x} * {p})"))
290 }
291 Ge(a, b) => Ok(if a.eval_to_i64(values)? >= b.eval_to_i64(values)? { 1 } else { 0 }),
292 Eq(a, b) => Ok(if a.eval_to_i64(values)? == b.eval_to_i64(values)? { 1 } else { 0 }),
293 }
294 }
295
296 pub fn eval(&self, values: &SymbolValues) -> TDim {
297 match self {
298 Sym(sym) => values.get(sym).map(Val).unwrap_or_else(|| Sym(sym.clone())),
299 Val(v) => Val(*v),
300 Add(terms) => terms.iter().fold(Val(0), |acc, it| -> TDim { acc + it.eval(values) }),
301 Mul(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim { acc * it.eval(values) }),
302 Min(terms) => {
303 terms.iter().fold(Val(i64::MAX), |acc, it| -> TDim { acc.mini(it.eval(values)) })
304 }
305 Max(terms) => {
306 terms.iter().fold(Val(i64::MIN), |acc, it| -> TDim { acc.maxi(it.eval(values)) })
307 }
308 Broadcast(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim {
309 acc.broadcast(it.eval(values)).unwrap_or_else(|_| self.clone())
310 }),
311 Div(a, q) => a.eval(values) / *q as i64,
312 MulInt(p, a) => a.eval(values) * *p,
313 Ge(a, b) => {
314 let a2 = a.eval(values);
315 let b2 = b.eval(values);
316 if let (Val(av), Val(bv)) = (&a2, &b2) {
317 Val(if av >= bv { 1 } else { 0 })
318 } else {
319 Ge(b!(a2), b!(b2))
320 }
321 }
322 Eq(a, b) => {
323 let a2 = a.eval(values);
324 let b2 = b.eval(values);
325 if let (Val(av), Val(bv)) = (&a2, &b2) {
326 Val(if av == bv { 1 } else { 0 })
327 } else {
328 Eq(b!(a2), b!(b2))
329 }
330 }
331 }
332 }
333
334 pub fn eval_with_scenario(&self, scenario: &str) -> TDim {
335 if let Val(v) = self {
336 return Val(*v);
337 }
338 let scope = self.find_scope().unwrap();
339 let scope = scope.0;
340 let locked = scope.lock();
341 let scope = locked.borrow();
342 self.clone().simplify_rec(&scope, Some(scenario), &[])
343 }
344
345 pub fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self> {
346 self.substitute_all(&std::collections::HashMap::from([(from.clone(), to.clone())]))
347 }
348
349 pub fn substitute_all(
350 &self,
351 map: &std::collections::HashMap<Symbol, Self>,
352 ) -> TractResult<Self> {
353 match self {
354 Sym(sym) => Ok(map.get(sym).cloned().unwrap_or_else(|| self.clone())),
355 Val(v) => Ok(Val(*v)),
356 Add(terms) => terms.iter().try_fold(Val(0), |acc, it| -> TractResult<TDim> {
357 Ok(acc + it.substitute_all(map)?)
358 }),
359 Mul(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult<TDim> {
360 Ok(acc * it.substitute_all(map)?)
361 }),
362 Broadcast(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult<TDim> {
363 acc.broadcast(it.substitute_all(map)?)
364 }),
365 Min(terms) => terms.iter().try_fold(Val(i64::MAX), |acc, it| -> TractResult<TDim> {
366 Ok(acc.mini(it.substitute_all(map)?))
367 }),
368 Max(terms) => terms.iter().try_fold(Val(i64::MIN), |acc, it| -> TractResult<TDim> {
369 Ok(acc.maxi(it.substitute_all(map)?))
370 }),
371 Div(a, q) => Ok(a.substitute_all(map)? / *q as i64),
372 MulInt(p, a) => Ok(a.substitute_all(map)? * *p),
373 Ge(a, b) => Ok(Ge(b!(a.substitute_all(map)?), b!(b.substitute_all(map)?))),
374 Eq(a, b) => Ok(Eq(b!(a.substitute_all(map)?), b!(b.substitute_all(map)?))),
375 }
376 }
377
378 pub fn reduce(self) -> TDim {
379 self.simplify()
380 .wiggle()
381 .into_iter()
382 .sorted_by(tdim_lexi_order)
383 .unique()
384 .map(|e| e.simplify())
385 .min_by_key(|e| e.cost())
386 .unwrap()
387 }
388
389 fn cost(&self) -> usize {
390 use self::TDim::*;
391 match self {
392 Sym(_) | Val(_) => 1,
393 Add(terms) => 2 * terms.iter().map(TDim::cost).sum::<usize>(),
394 Mul(terms) => 3 * terms.iter().map(TDim::cost).sum::<usize>(),
395 Broadcast(terms) => 4 * terms.iter().map(TDim::cost).sum::<usize>(),
396 Min(terms) | Max(terms) => 5 * terms.iter().map(TDim::cost).sum::<usize>(),
397 Div(a, _) => 3 * a.cost(),
398 MulInt(_, a) => 2 * a.cost(),
399 Ge(a, b) | Eq(a, b) => 5 * (a.cost() + b.cost()),
400 }
401 }
402
403 fn wiggle(&self) -> Vec<TDim> {
404 use self::TDim::*;
405 match self {
406 Sym(_) | Val(_) | Mul(_) | Broadcast(_) | Min(_) | Max(_) | Ge(_, _) | Eq(_, _) => {
407 vec![self.clone()]
408 }
409 Add(terms) => {
410 let mut forms = vec![];
411 let sub_exprs = terms.iter().map(|e| e.wiggle()).multi_cartesian_product();
412
413 fn first_div_term(terms: &[TDim]) -> Option<(usize, &TDim, u64)> {
414 terms.iter().enumerate().find_map(|(index, t)| match t {
415 Div(numerator, quotient) => Some((index, &**numerator, *quotient)),
416 _ => None,
417 })
418 }
419
420 fn generate_new_numerator(
421 div_index: usize,
422 numerator: &TDim,
423 quotient: u64,
424 expr: &[TDim],
425 ) -> Vec<TDim> {
426 expr.iter()
427 .enumerate()
428 .map(|(index, term)| {
429 if index == div_index {
430 numerator.clone()
431 } else {
432 MulInt(quotient as i64, Box::new(term.clone()))
433 }
434 })
435 .collect()
436 }
437
438 for expr in sub_exprs {
439 if let Some((div_index, numerator, quotient)) = first_div_term(&expr) {
440 let new_numerator =
441 generate_new_numerator(div_index, numerator, quotient, &expr);
442 forms.push(Div(Box::new(Add(new_numerator)), quotient))
443 }
444
445 forms.push(Add(expr));
446 }
447 forms
448 }
449 MulInt(p, a) => a.wiggle().into_iter().map(|a| MulInt(*p, b!(a))).collect(),
450 Div(a, q) => {
451 let mut forms = vec![];
452 for num in a.wiggle() {
453 if let Add(terms) = &num {
454 let (integer, non_integer): (Vec<_>, Vec<_>) =
455 terms.iter().cloned().partition(|a| a.gcd() % q == 0);
456 if !non_integer.iter().any(|t| matches!(t, Val(_))) {
467 let mut new_terms =
468 integer.iter().map(|i| i.div(*q)).collect::<Vec<_>>();
469 if non_integer.len() > 0 {
470 new_terms.push(Div(b!(Add(non_integer)), *q));
471 }
472 forms.push(Add(new_terms))
473 }
474 }
475 forms.push(Div(b!(num), *q))
476 }
477 forms
478 }
479 }
480 }
481
482 fn find_any_sym(tdim: &TDim) -> Option<&Symbol> {
483 match tdim {
484 Val(_) => None,
485 Sym(s) => Some(s),
486 Add(terms) | Mul(terms) | Min(terms) | Max(terms) | Broadcast(terms) => {
487 terms.iter().find_map(Self::find_any_sym)
488 }
489 MulInt(_, t) | Div(t, _) => Self::find_any_sym(t),
490 Ge(a, b) | Eq(a, b) => Self::find_any_sym(a).or_else(|| Self::find_any_sym(b)),
491 }
492 }
493
494 pub fn find_scope(&self) -> Option<SymbolScope> {
495 Self::find_any_sym(self).and_then(|s| s.scope().clone())
496 }
497
498 pub fn expand_polynomial(self) -> TDim {
507 use self::TDim::*;
508 match self {
509 Mul(terms) => {
510 let terms: Vec<TDim> = terms.into_iter().map(Self::expand_polynomial).collect();
511 if let Some(add_idx) = terms.iter().position(|t| matches!(t, Add(_))) {
512 let Add(add_terms) = terms[add_idx].clone() else { unreachable!() };
513 let others: Vec<TDim> = terms
514 .iter()
515 .enumerate()
516 .filter(|(i, _)| *i != add_idx)
517 .map(|(_, t)| t.clone())
518 .collect();
519 Add(add_terms
520 .into_iter()
521 .map(|t| {
522 let mut product = others.clone();
523 product.push(t);
524 Mul(product).expand_polynomial()
525 })
526 .collect())
527 .simplify()
528 } else {
529 Mul(terms).simplify()
530 }
531 }
532 MulInt(c, inner) => MulInt(c, Box::new(inner.expand_polynomial())).simplify(),
533 Add(terms) => Add(terms.into_iter().map(Self::expand_polynomial).collect()).simplify(),
534 Div(a, q) => Div(Box::new(a.expand_polynomial()), q).simplify(),
535 Min(terms) => Min(terms.into_iter().map(Self::expand_polynomial).collect()).simplify(),
536 Max(terms) => Max(terms.into_iter().map(Self::expand_polynomial).collect()).simplify(),
537 Broadcast(terms) => {
538 Broadcast(terms.into_iter().map(Self::expand_polynomial).collect()).simplify()
539 }
540 Ge(a, b) => {
541 Ge(Box::new(a.expand_polynomial()), Box::new(b.expand_polynomial())).simplify()
542 }
543 Eq(a, b) => {
544 Eq(Box::new(a.expand_polynomial()), Box::new(b.expand_polynomial())).simplify()
545 }
546 it @ (Sym(_) | Val(_)) => it,
547 }
548 }
549
550 pub fn simplify(self) -> TDim {
551 use self::TDim::*;
552 if let Ok(v) = self.eval_to_i64(&SymbolValues::default()) {
553 return Val(v);
554 }
555 let Some(scope) = self.find_scope() else {
556 return self;
557 };
558 let scope = scope.0;
559 let locked = scope.lock();
560 let scope = locked.borrow();
561 let it = self.simplify_rec(&scope, None, &[]);
562 let mut current: Option<TDim> = None;
563 for scenario in scope.scenarios() {
564 let v = it.clone().simplify_rec(&scope, Some(scenario), &[]);
565 if current.is_some_and(|c| c != v) {
566 return it;
567 } else {
568 current = Some(v);
569 }
570 }
571 current.unwrap_or(it)
572 }
573
574 pub fn simplify_with_extra_assertions(self, extra: &[Assertion]) -> TDim {
575 use self::TDim::*;
576 if extra.is_empty() {
577 return self.simplify();
578 }
579 if let Ok(v) = self.eval_to_i64(&SymbolValues::default()) {
580 return Val(v);
581 }
582 let Some(scope) = self.find_scope() else {
583 return self;
584 };
585 let scope = scope.0;
586 let locked = scope.lock();
587 let scope = locked.borrow();
588 let it = self.simplify_rec(&scope, None, extra);
589 let mut current: Option<TDim> = None;
590 for scenario in scope.scenarios() {
591 let v = it.clone().simplify_rec(&scope, Some(scenario), extra);
592 if current.is_some_and(|c| c != v) {
593 return it;
594 } else {
595 current = Some(v);
596 }
597 }
598 current.unwrap_or(it)
599 }
600
601 fn simplify_rec(
602 self,
603 scope: &SymbolScopeData,
604 scenario: Option<&str>,
605 extra: &[Assertion],
606 ) -> TDim {
607 match self {
608 Add(mut terms) => {
609 #[allow(clippy::mutable_key_type)]
610 let mut simplified_terms: HashMap<TDim, i64> = HashMap::new();
611 while let Some(term) = terms.pop() {
613 let simplified = term.simplify_rec(scope, scenario, extra);
614 match simplified {
615 Val(0) => {} Add(members) => {
617 terms.extend(members);
618 continue;
619 }
620 Val(value) => *simplified_terms.entry(Val(1)).or_insert(0) += value,
621 MulInt(value, factor) => {
622 *simplified_terms.entry((*factor).clone()).or_insert(0) += value;
623 }
624 n => *simplified_terms.entry(n).or_insert(0) += 1,
625 };
626 }
627
628 pub fn evaluate_count(term: TDim, count: i64) -> Option<TDim> {
629 match count {
630 0 => None,
631 _ if term == TDim::Val(1) => Some(TDim::Val(count)),
632 1 => Some(term),
633 _ => Some(TDim::MulInt(count, Box::new(term))),
634 }
635 }
636
637 let has_non_const =
652 simplified_terms.iter().any(|(k, &c)| c != 0 && !matches!(k, Val(_)));
653 let coef_gcd = if has_non_const {
654 simplified_terms
655 .values()
656 .filter(|&&c| c != 0)
657 .map(|c| c.unsigned_abs() as i64)
658 .reduce(|a, b| a.gcd(&b))
659 .unwrap_or(0)
660 } else {
661 0
662 };
663 let outer_factor = if coef_gcd > 1 {
664 for v in simplified_terms.values_mut() {
665 *v /= coef_gcd;
666 }
667 Some(coef_gcd)
668 } else {
669 None
670 };
671
672 let mut members: Vec<TDim> = simplified_terms
673 .into_iter()
674 .filter_map(|(term, count)| evaluate_count(term, count))
675 .collect();
676 members.sort_by(tdim_lexi_order);
677
678 let inner = match members.len() {
679 0 => TDim::Val(0),
680 1 => members.into_iter().next().unwrap(),
681 _ => TDim::Add(members),
682 };
683 match outer_factor {
684 None => inner,
685 Some(_) if matches!(inner, TDim::Val(0)) => TDim::Val(0),
686 Some(g) => TDim::MulInt(g, Box::new(inner)),
687 }
688 }
689 Mul(terms) => {
690 {
703 let add_indices: Vec<usize> = terms
704 .iter()
705 .enumerate()
706 .filter(|(_, t)| matches!(t, Add(_)))
707 .map(|(i, _)| i)
708 .collect();
709 if add_indices.len() == 1 {
710 let add_idx = add_indices[0];
711 let Add(add_terms) = &terms[add_idx] else { unreachable!() };
712 let other_factors: Vec<TDim> = terms
713 .iter()
714 .enumerate()
715 .filter(|(i, _)| *i != add_idx)
716 .map(|(_, t)| t.clone())
717 .collect();
718 let distributed: Vec<TDim> = add_terms
719 .iter()
720 .map(|at| {
721 let mut product = other_factors.clone();
722 product.push(at.clone());
723 Mul(product)
724 })
725 .collect();
726 return Add(distributed).simplify_rec(scope, scenario, extra);
727 }
728 }
729
730 let mut flattened_terms = vec![];
733 for t in terms {
734 match t.clone().reduce() {
735 Mul(inner_terms) => flattened_terms.extend(inner_terms),
736 MulInt(k, inner) => {
737 flattened_terms.push(Val(k));
738 flattened_terms.push(*inner);
739 }
740 other => flattened_terms.push(other),
741 }
742 }
743 let mut terms = flattened_terms;
744
745 let mut gcd = Mul(terms.clone()).gcd() as i64;
746 if gcd == 0 {
747 return Val(0);
748 }
749 terms = if gcd != 1 {
750 terms
751 .into_iter()
752 .map(|t| {
753 let gcd = t.gcd();
754 (t / gcd).simplify_rec(scope, scenario, extra)
755 })
756 .collect()
757 } else {
758 terms
759 };
760 if terms.iter().filter(|t| t == &&Val(-1)).count() % 2 == 1 {
761 gcd = -gcd;
762 }
763 terms.retain(|t| !t.is_one() && t != &Val(-1));
764 terms.sort_by(tdim_lexi_order);
765
766 match (gcd, terms.len()) {
767 (_, 0) => Val(gcd), (0, _) => Val(0), (1, 1) => terms.remove(0), (1, _) => Mul(terms), (_, 1) => MulInt(gcd, Box::new(terms.remove(0))), _ => MulInt(gcd, Box::new(Mul(terms))), }
775 }
776 MulInt(coef, expr) => {
777 match *expr {
778 MulInt(c2, inner) => {
779 if let Some(c) = coef.checked_mul(c2) {
780 return MulInt(c, inner).simplify_rec(scope, scenario, extra);
781 } else {
782 return MulInt(coef, Box::new(MulInt(c2, inner)));
783 }
784 }
785 Val(v) => {
786 return coef
787 .checked_mul(v)
788 .map(Val)
789 .unwrap_or_else(|| MulInt(coef, Box::new(Val(v))));
790 }
791 _ => {}
792 }
793
794 let simplified = expr.simplify_rec(scope, scenario, extra);
795 match (coef, simplified) {
796 (0, _) => Val(0), (1, s) => s, (_, Add(terms)) => Add(terms
799 .into_iter()
800 .map(|term| {
801 MulInt(coef, Box::new(term)).simplify_rec(scope, scenario, extra)
802 })
803 .collect()), (c, Val(v)) => {
805 c.checked_mul(v).map(Val).unwrap_or_else(|| MulInt(c, Box::new(Val(v))))
806 } (c, MulInt(v, inner)) => {
808 if let Some(cv) = c.checked_mul(v) {
809 MulInt(cv, inner) } else {
811 MulInt(c, Box::new(MulInt(v, inner)))
812 }
813 }
814 (_, s) => MulInt(coef, Box::new(s)), }
816 }
817 Div(a, q) => {
818 if q == 1 {
819 return a.simplify_rec(scope, scenario, extra);
820 } else if let Div(a, q2) = *a {
821 return Div(a, q * q2).simplify_rec(scope, scenario, extra);
822 }
823 let a = a.simplify_rec(scope, scenario, extra);
824 if let Val(a) = a {
825 Val(a / q as i64)
826 } else if let MulInt(-1, a) = a {
827 MulInt(-1, b!(Div(a, q)))
828 } else if let Add(mut terms) = a {
829 if terms
830 .iter()
831 .any(|t| if let MulInt(-1, s) = t { matches!(&**s, Sym(_)) } else { false })
832 {
833 MulInt(
834 -1,
835 b!(Div(
836 b!(Add(terms.into_iter().map(|t| MulInt(-1, b!(t))).collect())
837 .simplify_rec(scope, scenario, extra)),
838 q
839 )),
840 )
841 } else if let Some(val) = terms
842 .iter()
843 .find_map(|t| if let Val(v) = t { Some(*v) } else { None })
844 .and_then(|v| {
845 if v >= q as i64 {
846 Some(v / q as i64)
847 } else if v < 0 {
848 Some(-Integer::div_ceil(&-v, &(q as i64)))
849 } else {
850 None
851 }
852 })
853 {
854 terms.push(Val(-val * q as i64));
855 let inner = Div(b!(Add(terms).simplify_rec(scope, scenario, extra)), q)
859 .simplify_rec(scope, scenario, extra);
860 Add(vec![Val(val), inner])
861 } else if let Some(simplified) =
862 try_divide_multiple_plus_remainder(&terms, q, scope, extra)
863 {
864 simplified.simplify_rec(scope, scenario, extra)
873 } else if let Some(found_idx) = terms.iter().position(|term| {
874 matches!(term, MulInt(p, inner)
878 if *p == -(q as i64)
879 && matches!(inner.as_ref(), Div(_, q2) if *q2 == q))
880 }) {
881 let MulInt(_, inner) = &terms[found_idx] else { unreachable!() };
882 let Div(y, _) = inner.as_ref() else { unreachable!() };
883 let remaining: Vec<TDim> = terms
884 .iter()
885 .enumerate()
886 .filter(|&(i, _)| i != found_idx)
887 .map(|(_, t)| t.clone())
888 .collect();
889 let remaining_sum = match remaining.len() {
890 0 => Val(0),
891 1 => remaining.into_iter().next().unwrap(),
892 _ => Add(remaining),
893 };
894 if eq_structural(&remaining_sum, y) {
895 Val(0)
896 } else {
897 Div(b!(Add(terms)), q)
898 }
899 } else {
900 Div(b!(Add(terms)), q)
901 }
902 } else if let MulInt(p, a) = a {
903 if p == q as i64 {
904 a.simplify()
905 } else {
906 let gcd = p.abs().gcd(&(q as i64));
907 if gcd == p {
908 Div(a, q / gcd as u64)
909 } else if gcd == q as i64 {
910 MulInt(p / gcd, a)
911 } else if gcd > 1 {
912 Div(b!(MulInt(p / gcd, a)), q / gcd as u64)
913 .simplify_rec(scope, scenario, extra)
914 } else {
915 Div(b!(MulInt(p, a)), q)
916 }
917 }
918 } else {
919 Div(b!(a), q)
920 }
921 }
922 Broadcast(terms) => {
923 let mut terms: Vec<TDim> = terms
924 .iter()
925 .map(|s| s.clone().simplify_rec(scope, scenario, extra))
926 .flat_map(|t| if let Broadcast(t) = t { t } else { vec![t] })
927 .filter(|t| !t.is_one())
928 .sorted_by(tdim_lexi_order)
929 .dedup()
930 .collect_vec();
931 match &*terms {
933 [] => Val(1),
934 [_] => terms.remove(0),
935 [a, Min(m)] | [Min(m), a]
936 if m.contains(a)
937 && m.iter()
938 .all(|t| scope.prove_strict_positive_with_extra(t, extra)) =>
939 {
940 a.clone()
941 }
942 _ => Broadcast(terms),
943 }
944 }
945
946 Min(terms) => {
947 let mut flatten: Vec<TDim> = terms
948 .into_iter()
949 .map(|t| t.simplify_rec(scope, scenario, extra))
950 .flat_map(|t| if let Min(t) = t { t } else { vec![t] })
951 .filter(|t| t != &Val(i64::MAX))
952 .sorted_by(tdim_lexi_order)
953 .dedup()
954 .collect();
955 #[allow(clippy::mutable_key_type)]
956 let mut redundant = HashSet::<TDim>::default();
957 for pair in flatten.iter().permutations(2) {
958 let (a, b) = (pair[0], pair[1]);
959 if redundant.contains(a) || redundant.contains(b) {
960 continue;
961 }
962 let diff = a.clone() - b;
963 if diff.as_i64().is_some_and(|i| i >= 0)
964 || scope.prove_positive_or_zero_with_extra(&diff, extra)
965 {
966 redundant.insert(a.clone());
967 }
968 }
969 flatten.retain(|t| !redundant.contains(t));
970 if flatten.len() == 0 {
971 i64::MAX.to_dim()
972 } else if flatten.len() == 1 {
973 flatten.into_iter().next().unwrap()
974 } else {
975 Min(flatten)
976 }
977 }
978 Max(terms) => {
979 let mut flatten: Vec<TDim> = terms
980 .into_iter()
981 .map(|t| t.simplify_rec(scope, scenario, extra))
982 .flat_map(|t| if let Max(t) = t { t } else { vec![t] })
983 .filter(|t| t != &Val(i64::MIN))
984 .sorted_by(tdim_lexi_order)
985 .dedup()
986 .collect();
987 #[allow(clippy::mutable_key_type)]
988 let mut redundant = HashSet::<TDim>::default();
989 for pair in flatten.iter().permutations(2) {
990 let (a, b) = (pair[0], pair[1]);
991 if redundant.contains(a) || redundant.contains(b) {
992 continue;
993 }
994 let diff = a.clone() - b;
995 if diff.as_i64().is_some_and(|i| i >= 0)
996 || scope.prove_positive_or_zero_with_extra(&diff, extra)
997 {
998 redundant.insert(b.clone());
999 }
1000 }
1001 flatten.retain(|t| !redundant.contains(t));
1002 if flatten.len() == 0 {
1003 i64::MIN.to_dim()
1004 } else if flatten.len() == 1 {
1005 flatten.into_iter().next().unwrap()
1006 } else {
1007 Max(flatten)
1008 }
1009 }
1010 Sym(s) => scope
1011 .assertions(scenario)
1012 .find_map(|a| match a {
1013 Assertion::Eq(Sym(sym), v) if sym == &s => Some(v.clone()),
1014 _ => None,
1015 })
1016 .unwrap_or(Sym(s)),
1017 Val(_) => self,
1018 Ge(a, b) => {
1019 let a = a.simplify_rec(scope, scenario, extra);
1020 let b = b.simplify_rec(scope, scenario, extra);
1021 match (&a, &b) {
1022 (Val(av), Val(bv)) => Val(if av >= bv { 1 } else { 0 }),
1023 _ => {
1024 let diff = a.clone() - b.clone();
1025 if scope.prove_positive_or_zero_with_extra(&diff, extra) {
1026 Val(1)
1027 } else if scope
1028 .prove_strict_positive_with_extra(&(b.clone() - a.clone()), extra)
1029 {
1030 Val(0)
1031 } else {
1032 Ge(b!(a), b!(b))
1033 }
1034 }
1035 }
1036 }
1037 Eq(a, b) => {
1038 let a = a.simplify_rec(scope, scenario, extra);
1039 let b = b.simplify_rec(scope, scenario, extra);
1040 match (&a, &b) {
1041 (Val(av), Val(bv)) => Val(if av == bv { 1 } else { 0 }),
1042 _ => {
1043 let diff = a.clone() - b.clone();
1044 if scope.prove_strict_positive_with_extra(&diff, extra)
1045 || scope
1046 .prove_strict_positive_with_extra(&(b.clone() - a.clone()), extra)
1047 {
1048 Val(0)
1049 } else {
1050 let boolean_case = match (&a, &b) {
1055 (Val(0), e) | (e, Val(0)) => Some((e, false)),
1056 (Val(1), e) | (e, Val(1)) => Some((e, true)),
1057 _ => None,
1058 };
1059 if let Some((expr, equals_one)) = boolean_case
1060 && scope.prove_positive_or_zero_with_extra(expr, extra)
1061 && scope.prove_positive_or_zero_with_extra(
1062 &(Val(1) - expr.clone()),
1063 extra,
1064 )
1065 {
1066 return if equals_one {
1067 expr.clone()
1068 } else {
1069 (Val(1) - expr.clone()).simplify_rec(scope, scenario, extra)
1070 };
1071 }
1072 Eq(b!(a), b!(b))
1073 }
1074 }
1075 }
1076 }
1077 }
1078 }
1079
1080 pub(super) fn inclusive_bound(&self, scope: &SymbolScopeData, upper: bool) -> Option<i64> {
1081 use self::TDim::*;
1082 match self {
1083 Val(n) => Some(*n),
1084 Sym(_) => {
1085 if upper {
1086 scope
1087 .all_assertions()
1088 .iter()
1089 .filter_map(|assert| match &assert {
1090 Assertion::LT(left, right)
1091 if left == self && right.as_i64().is_some() =>
1092 {
1093 Some(right.as_i64().unwrap() - 1)
1094 }
1095 Assertion::LTE(left, right)
1096 if left == self && right.as_i64().is_some() =>
1097 {
1098 Some(right.as_i64().unwrap())
1099 }
1100 _ => None,
1101 })
1102 .min()
1103 } else {
1104 scope
1105 .all_assertions()
1106 .iter()
1107 .filter_map(|assert| match &assert {
1108 Assertion::GT(left, right)
1109 if left == self && right.as_i64().is_some() =>
1110 {
1111 Some(right.as_i64().unwrap() + 1)
1112 }
1113 Assertion::GTE(left, right)
1114 if left == self && right.as_i64().is_some() =>
1115 {
1116 Some(right.as_i64().unwrap())
1117 }
1118 _ => None,
1119 })
1120 .max()
1121 }
1122 }
1123 Add(terms) => {
1124 let mut bound: i64 = 0;
1125 for t in terms {
1126 {
1127 let b = t.inclusive_bound(scope, upper)?;
1128 bound = bound.checked_add(b)?;
1129 }
1130 }
1131 Some(bound)
1132 }
1133 MulInt(p, a) => match p.cmp(&0) {
1134 Ordering::Equal => Some(0),
1135 Ordering::Greater => {
1136 a.inclusive_bound(scope, upper).and_then(|x| x.checked_mul(*p))
1137 }
1138 Ordering::Less => a.inclusive_bound(scope, !upper).and_then(|x| x.checked_mul(*p)),
1139 },
1140 Mul(terms) => {
1141 let mut lo: i64 = 1;
1143 let mut hi: i64 = 1;
1144 for t in terms {
1145 let t_lo = t.inclusive_bound(scope, false)?;
1146 let t_hi = t.inclusive_bound(scope, true)?;
1147 if t_lo < 0 {
1148 return None;
1149 }
1150 lo = lo.checked_mul(t_lo)?;
1151 hi = hi.checked_mul(t_hi)?;
1152 }
1153 Some(if upper { hi } else { lo })
1154 }
1155 Min(terms) if !upper => {
1156 let bounds: Option<Vec<i64>> =
1159 terms.iter().map(|t| t.inclusive_bound(scope, false)).collect();
1160 bounds.map(|b| b.into_iter().min().unwrap_or(i64::MAX))
1161 }
1162 Max(terms) if upper => {
1163 let bounds: Option<Vec<i64>> =
1166 terms.iter().map(|t| t.inclusive_bound(scope, true)).collect();
1167 bounds.map(|b| b.into_iter().max().unwrap_or(i64::MIN))
1168 }
1169 Div(a, q) => a.inclusive_bound(scope, upper).map(|x| x / (*q as i64)),
1170 Broadcast(terms) => {
1171 if upper {
1172 Max(terms.clone()).inclusive_bound(scope, true)
1173 } else {
1174 Min(terms.clone()).inclusive_bound(scope, false)
1175 }
1176 }
1177 Ge(_, _) | Eq(_, _) => {
1178 if upper {
1179 Some(1)
1180 } else {
1181 Some(0)
1182 }
1183 }
1184 _ => None,
1185 }
1186 }
1187
1188 pub fn low_inclusive_bound(&self) -> Option<i64> {
1189 if let TDim::Val(v) = self {
1190 return Some(*v);
1191 }
1192 let scope = self.find_scope()?;
1193 let data = scope.0.lock();
1194 let data = data.borrow();
1195 self.inclusive_bound(&data, false)
1196 }
1197
1198 pub fn high_inclusive_bound(&self) -> Option<i64> {
1199 if let TDim::Val(v) = self {
1200 return Some(*v);
1201 }
1202 let scope = self.find_scope()?;
1203 let data = scope.0.lock();
1204 let data = data.borrow();
1205 self.inclusive_bound(&data, true)
1206 }
1207
1208 pub fn prove_positive_or_zero(&self) -> bool {
1209 if let TDim::Val(v) = self {
1210 return *v >= 0;
1211 }
1212 let Some(scope) = self.find_scope() else { return false };
1213 let data = scope.0.lock();
1214 let data = data.borrow();
1215 data.prove_positive_or_zero(self)
1216 }
1217
1218 pub fn prove_strict_positive(&self) -> bool {
1219 if let TDim::Val(v) = self {
1220 return *v > 0;
1221 }
1222 (self.clone() - 1).prove_positive_or_zero()
1223 }
1224
1225 pub fn prove_negative_or_zero(&self) -> bool {
1226 if let TDim::Val(v) = self {
1227 return *v <= 0;
1228 }
1229 self.clone().neg().prove_positive_or_zero()
1230 }
1231
1232 pub fn prove_strict_negative(&self) -> bool {
1233 if let TDim::Val(v) = self {
1234 return *v < 0;
1235 }
1236 self.clone().neg().prove_strict_positive()
1237 }
1238
1239 pub fn lcm(&self, other: &TDim) -> Option<TDim> {
1247 match (self.as_i64(), other.as_i64()) {
1248 (Some(a), Some(b)) if a > 0 && b > 0 => {
1249 let g = (a as u64).gcd(&(b as u64));
1250 let l = (a as u64 / g).saturating_mul(b as u64);
1251 if l > i64::MAX as u64 { None } else { Some(TDim::Val(l as i64)) }
1252 }
1253 (Some(0), _) | (_, Some(0)) => Some(TDim::Val(0)),
1254 _ => None,
1255 }
1256 }
1257
1258 pub fn gcd(&self) -> u64 {
1259 use self::TDim::*;
1260 match self {
1261 Val(v) => v.unsigned_abs(),
1262 Sym(_) => 1,
1263 Add(terms) => {
1264 let (head, tail) = terms.split_first().unwrap();
1265 tail.iter().fold(head.gcd(), |a, b| a.gcd(&b.gcd()))
1266 }
1267 MulInt(p, a) => a.gcd().saturating_mul(p.unsigned_abs()),
1268 Mul(terms) => terms.iter().map(|t| t.gcd()).fold(1u64, |a, b| a.saturating_mul(b)),
1269 Min(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
1270 Max(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
1271 Div(a, q) => {
1272 if a.gcd() % *q == 0 {
1273 a.gcd() / *q
1274 } else {
1275 1
1276 }
1277 }
1278 Broadcast(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap_or(1),
1279 Ge(_, _) | Eq(_, _) => 1,
1280 }
1281 }
1282
1283 fn div(&self, d: u64) -> TDim {
1284 use self::TDim::*;
1285 if d == 1 {
1286 return self.clone();
1287 }
1288 match self {
1289 Val(v) => Val(v / d as i64),
1290 Sym(_) => panic!(),
1291 Add(terms) => Add(terms.iter().map(|t| t.div(d)).collect()),
1292 Min(terms) => Min(terms.iter().map(|t| t.div(d)).collect()),
1293 Max(terms) => Max(terms.iter().map(|t| t.div(d)).collect()),
1294 Broadcast(terms) => Broadcast(terms.iter().map(|t| t.div(d)).collect()),
1295 Mul(_) => Div(Box::new(self.clone()), d),
1296 MulInt(p, a) => {
1297 if *p == d as i64 {
1298 (**a).clone()
1299 } else {
1300 let gcd = p.unsigned_abs().gcd(&d);
1301 MulInt(p / gcd as i64, b!(a.div(d / gcd)))
1302 }
1303 }
1304 Div(a, q) => Div(a.clone(), q * d),
1305 Ge(_, _) | Eq(_, _) => Div(Box::new(self.clone()), d),
1306 }
1307 }
1308
1309 pub fn div_ceil(self, rhs: u64) -> TDim {
1310 TDim::Div(Box::new(Add(vec![self, Val(rhs as i64 - 1)])), rhs).reduce()
1311 }
1312
1313 pub fn guess_slope(&self, sym: &Symbol) -> (i64, u64) {
1314 fn slope_rec(d: &TDim, sym: &Symbol) -> (i64, i64) {
1315 match d {
1316 Val(_) => (0, 1),
1317 Sym(s) => ((sym == s) as i64, 1),
1318 Add(terms) => terms
1319 .iter()
1320 .map(|d| slope_rec(d, sym))
1321 .fold((0, 1), |a, b| ((a.0 * b.1 + a.1 * b.0), (b.1 * a.1))),
1322 Mul(terms) => terms
1323 .iter()
1324 .map(|d| slope_rec(d, sym))
1325 .fold((1, 1), |a, b| ((a.0 * b.0), (b.1 * a.1))),
1326 MulInt(p, a) => {
1327 let (n, d) = slope_rec(a, sym);
1328 (p * n, d)
1329 }
1330 Div(a, q) => {
1331 let (n, d) = slope_rec(a, sym);
1332 (n, d * *q as i64)
1333 }
1334 Broadcast(terms) => slope_rec(&terms[0], sym),
1335 Min(terms) => slope_rec(&terms[0], sym),
1336 Max(terms) => slope_rec(&terms[0], sym),
1337 Ge(_, _) | Eq(_, _) => (0, 1),
1338 }
1339 }
1340 let (p, q) = slope_rec(self, sym);
1341 reduce_ratio(p, q)
1342 }
1343
1344 #[allow(clippy::mutable_key_type)]
1345 pub fn symbols(&self) -> std::collections::HashSet<Symbol> {
1346 match self {
1347 Val(_) => maplit::hashset!(),
1348 Sym(s) => maplit::hashset!(s.clone()),
1349 Add(terms) | Mul(terms) | Broadcast(terms) | Min(terms) | Max(terms) => {
1350 terms.iter().fold(maplit::hashset!(), |mut set, v| {
1351 set.extend(v.symbols());
1352 set
1353 })
1354 }
1355 MulInt(_, a) => a.symbols(),
1356 Div(a, _) => a.symbols(),
1357 Ge(a, b) | Eq(a, b) => {
1358 let mut set = a.symbols();
1359 set.extend(b.symbols());
1360 set
1361 }
1362 }
1363 }
1364
1365 pub fn compatible_with(&self, other: &TDim) -> bool {
1366 if let Ok(x) = (self.clone() - other).to_i64() {
1367 return x == 0;
1368 }
1369 true }
1371}
1372
1373pub(super) fn reduce_ratio(mut p: i64, mut q: i64) -> (i64, u64) {
1374 let gcd = p.abs().gcd(&q.abs());
1375 if gcd > 1 {
1376 p /= gcd;
1377 q /= gcd;
1378 }
1379 if q < 0 { (-p, (-q) as u64) } else { (p, q as u64) }
1380}
1381
1382impl Zero for TDim {
1383 fn zero() -> Self {
1384 Val(0)
1385 }
1386 fn is_zero(&self) -> bool {
1387 matches!(self, Val(0))
1388 }
1389}
1390
1391impl Default for TDim {
1392 fn default() -> TDim {
1393 Val(0)
1394 }
1395}
1396
1397impl num_traits::Bounded for TDim {
1398 fn min_value() -> Self {
1399 TDim::Val(i64::MIN)
1400 }
1401
1402 fn max_value() -> Self {
1403 TDim::Val(i64::MAX)
1404 }
1405}
1406
1407impl num_traits::One for TDim {
1408 fn one() -> Self {
1409 TDim::Val(1)
1410 }
1411}
1412
1413impl ::std::iter::Sum for TDim {
1414 fn sum<I: Iterator<Item = TDim>>(iter: I) -> TDim {
1415 iter.fold(0.into(), |a, b| a + b)
1416 }
1417}
1418
1419impl<'a> ::std::iter::Sum<&'a TDim> for TDim {
1420 fn sum<I: Iterator<Item = &'a TDim>>(iter: I) -> TDim {
1421 iter.fold(0.into(), |a, b| a + b)
1422 }
1423}
1424
1425impl std::iter::Product for TDim {
1426 fn product<I: Iterator<Item = TDim>>(iter: I) -> Self {
1427 iter.fold(TDim::Val(1), |a, b| a * b)
1428 }
1429}
1430
1431impl<'a> ::std::iter::Product<&'a TDim> for TDim {
1432 fn product<I: Iterator<Item = &'a TDim>>(iter: I) -> TDim {
1433 iter.fold(1.into(), |a, b| a * b)
1434 }
1435}
1436
1437macro_rules! from_i {
1438 ($i: ty) => {
1439 impl From<$i> for TDim {
1440 fn from(v: $i) -> TDim {
1441 TDim::Val(v as _)
1442 }
1443 }
1444 impl<'a> From<&'a $i> for TDim {
1445 fn from(v: &'a $i) -> TDim {
1446 TDim::Val(*v as _)
1447 }
1448 }
1449 };
1450}
1451
1452from_i!(i32);
1453from_i!(i64);
1454from_i!(u64);
1455from_i!(isize);
1456from_i!(usize);
1457
1458impl From<Symbol> for TDim {
1459 fn from(it: Symbol) -> Self {
1460 TDim::Sym(it)
1461 }
1462}
1463
1464impl<'a> From<&'a Symbol> for TDim {
1465 fn from(it: &'a Symbol) -> Self {
1466 TDim::Sym(it.clone())
1467 }
1468}
1469
1470impl ops::Neg for TDim {
1471 type Output = Self;
1472 fn neg(self) -> Self {
1473 if let Val(v) = self { Val(-v) } else { TDim::MulInt(-1, Box::new(self)).reduce() }
1474 }
1475}
1476
1477impl<'a> ops::AddAssign<&'a TDim> for TDim {
1478 fn add_assign(&mut self, rhs: &'a TDim) {
1479 if rhs.is_zero() {
1480 } else if self.is_zero() {
1481 *self = rhs.clone();
1482 } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1483 *s += o;
1484 } else {
1485 *self = TDim::Add(vec![std::mem::take(self), rhs.clone()]).reduce()
1486 }
1487 }
1488}
1489
1490impl<I> ops::AddAssign<I> for TDim
1491where
1492 I: Into<TDim>,
1493{
1494 fn add_assign(&mut self, rhs: I) {
1495 let rhs = rhs.into();
1496 if rhs.is_zero() {
1497 } else if self.is_zero() {
1498 *self = rhs;
1499 } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1500 *s += o;
1501 } else {
1502 *self = TDim::Add(vec![std::mem::take(self), rhs]).reduce()
1503 }
1504 }
1505}
1506
1507impl<I> ops::Add<I> for TDim
1508where
1509 I: Into<TDim>,
1510{
1511 type Output = Self;
1512 fn add(mut self, rhs: I) -> Self {
1513 self += rhs;
1514 self
1515 }
1516}
1517
1518impl<'a> ops::Add<&'a TDim> for TDim {
1519 type Output = Self;
1520 fn add(mut self, rhs: &'a TDim) -> Self {
1521 self += rhs;
1522 self
1523 }
1524}
1525
1526#[allow(clippy::suspicious_op_assign_impl)]
1527impl<'a> ops::SubAssign<&'a TDim> for TDim {
1528 fn sub_assign(&mut self, rhs: &'a TDim) {
1529 if rhs.is_zero() {
1530 } else if self.is_zero() {
1531 *self = rhs.clone().neg();
1532 } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1533 *s -= o;
1534 } else {
1535 *self = TDim::Add(vec![std::mem::take(self), rhs.clone().neg()]).reduce()
1536 }
1537 }
1538}
1539
1540impl<I> ops::SubAssign<I> for TDim
1541where
1542 I: Into<TDim>,
1543{
1544 fn sub_assign(&mut self, rhs: I) {
1545 let rhs = rhs.into();
1546 if rhs.is_zero() {
1547 } else if self.is_zero() {
1548 *self = rhs.neg();
1549 } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1550 *s -= o;
1551 } else {
1552 *self = TDim::Add(vec![std::mem::take(self), rhs.neg()]).reduce()
1553 }
1554 }
1555}
1556
1557impl<I> ops::Sub<I> for TDim
1558where
1559 I: Into<TDim>,
1560{
1561 type Output = Self;
1562 fn sub(mut self, rhs: I) -> Self {
1563 self -= rhs;
1564 self
1565 }
1566}
1567
1568impl<'a> ops::Sub<&'a TDim> for TDim {
1569 type Output = Self;
1570 fn sub(mut self, rhs: &'a TDim) -> Self {
1571 self -= rhs;
1572 self
1573 }
1574}
1575
1576impl<I: Into<TDim>> ops::MulAssign<I> for TDim {
1577 fn mul_assign(&mut self, rhs: I) {
1578 let rhs = rhs.into();
1579 if self.is_one() {
1580 *self = rhs
1581 } else if rhs.is_one() {
1582 } else {
1583 *self = TDim::Mul(vec![rhs, std::mem::take(self)]).reduce()
1584 }
1585 }
1586}
1587
1588impl<'a> ops::MulAssign<&'a TDim> for TDim {
1589 fn mul_assign(&mut self, rhs: &'a TDim) {
1590 if self.is_one() {
1591 *self = rhs.clone()
1592 } else if rhs.is_one() {
1593 } else {
1594 *self = TDim::Mul(vec![std::mem::take(self), rhs.clone()]).reduce()
1595 }
1596 }
1597}
1598
1599impl<I: Into<TDim>> ops::Mul<I> for TDim {
1600 type Output = Self;
1601 fn mul(mut self, rhs: I) -> Self {
1602 self *= rhs.into();
1603 self
1604 }
1605}
1606
1607impl<'a> ops::Mul<&'a TDim> for TDim {
1608 type Output = Self;
1609 fn mul(mut self, rhs: &'a TDim) -> Self {
1610 self *= rhs;
1611 self
1612 }
1613}
1614
1615impl<I: AsPrimitive<u64> + PrimInt> ops::DivAssign<I> for TDim {
1616 fn div_assign(&mut self, rhs: I) {
1617 *self = TDim::Div(Box::new(std::mem::take(self)), rhs.as_()).reduce()
1618 }
1619}
1620
1621impl<I: AsPrimitive<u64> + PrimInt> ops::Div<I> for TDim {
1622 type Output = Self;
1623 fn div(mut self, rhs: I) -> Self {
1624 self /= rhs.as_();
1625 self
1626 }
1627}
1628
1629impl<I: AsPrimitive<u64> + PrimInt> ops::RemAssign<I> for TDim {
1630 fn rem_assign(&mut self, rhs: I) {
1631 *self += -(self.clone() / rhs.as_() * rhs.as_());
1632 }
1633}
1634
1635impl<I: AsPrimitive<u64> + PrimInt> ops::Rem<I> for TDim {
1636 type Output = Self;
1637 fn rem(mut self, rhs: I) -> Self {
1638 self %= rhs;
1639 self
1640 }
1641}
1642
1643#[cfg(test)]
1644mod tests {
1645 use super::*;
1646
1647 macro_rules! b( ($e:expr) => { Box::new($e) } );
1648
1649 lazy_static::lazy_static! {
1650 static ref table: SymbolScope = SymbolScope::default();
1651 static ref A: Symbol = table.sym("a");
1652 static ref B: Symbol = table.sym("b");
1653 static ref C: Symbol = table.sym("c");
1654 static ref D: Symbol = table.sym("d");
1655 static ref E: Symbol = table.sym("e");
1656 }
1657
1658 fn neg(a: &TDim) -> TDim {
1659 mul(-1, a)
1660 }
1661
1662 fn add(a: &TDim, b: &TDim) -> TDim {
1663 TDim::Add(vec![a.clone(), b.clone()])
1664 }
1665
1666 fn mul(a: i64, b: &TDim) -> TDim {
1667 TDim::MulInt(a, b![b.clone()])
1668 }
1669
1670 fn div(a: &TDim, b: u64) -> TDim {
1671 TDim::Div(b!(a.clone()), b)
1672 }
1673
1674 #[test]
1675 fn reduce_add() {
1676 assert_eq!(add(&A.to_dim(), &neg(&A.to_dim())).reduce(), Val(0))
1677 }
1678
1679 #[test]
1680 fn lcm_basic() {
1681 assert_eq!(Val(16).lcm(&Val(32)), Some(Val(32)));
1682 assert_eq!(Val(32).lcm(&Val(16)), Some(Val(32)));
1683 assert_eq!(Val(6).lcm(&Val(8)), Some(Val(24)));
1684 assert_eq!(Val(7).lcm(&Val(7)), Some(Val(7)));
1685 assert_eq!(Val(16).lcm(&A.to_dim()), None);
1687 }
1688
1689 #[test]
1690 fn reduce_neg_mul() {
1691 assert_eq!(neg(&mul(2, &A.to_dim())).reduce(), mul(-2, &A.to_dim()))
1692 }
1693
1694 #[test]
1695 fn reduce_cplx_ex_2() {
1696 assert_eq!(
1697 add(
1698 &add(&Val(-4), &mul(-2, &div(&A.to_dim(), 4))),
1699 &mul(-2, &mul(-1, &div(&A.to_dim(), 4)))
1700 )
1701 .reduce(),
1702 Val(-4)
1703 )
1704 }
1705
1706 #[test]
1707 fn reduce_cplx_ex_3() {
1708 assert_eq!(div(&MulInt(1, b!(MulInt(4, b!(A.to_dim())))), 4).reduce(), A.to_dim())
1709 }
1710
1711 #[test]
1712 fn reduce_cplx_ex_4() {
1713 assert_eq!(
1715 add(&div(&add(&A.to_dim(), &Val(1)), 2), &div(&add(&neg(&A.to_dim()), &Val(1)), 2))
1716 .reduce(),
1717 1.into()
1718 );
1719 }
1720
1721 #[test]
1722 fn reduce_mul_mul_1() {
1723 assert_eq!(mul(3, &mul(2, &A.to_dim())).reduce(), mul(6, &A.to_dim()))
1724 }
1725
1726 #[test]
1727 fn reduce_mul_mul_2() {
1728 assert_eq!(mul(-2, &mul(-1, &A.to_dim())).reduce(), mul(2, &A.to_dim()))
1729 }
1730
1731 #[test]
1732 fn reduce_mul_div_1() {
1733 assert_eq!(mul(2, &div(&mul(-1, &A.to_dim()), 3)).reduce(), mul(-2, &div(&A.to_dim(), 3)))
1734 }
1735
1736 #[test]
1737 fn const_and_add() {
1738 let e: TDim = 2i64.into();
1739 assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), 2);
1740 let e: TDim = TDim::from(2) + 3;
1741 assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), 5);
1742 let e: TDim = TDim::from(2) - 3;
1743 assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), -1);
1744 let e: TDim = -TDim::from(2);
1745 assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), -2);
1746 }
1747
1748 #[test]
1749 fn substitution() {
1750 let a: TDim = A.to_dim();
1751 assert_eq!(a.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 2);
1752 let e = a + 3;
1753 assert_eq!(e.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 5);
1754 }
1755
1756 #[test]
1757 fn reduce_adds() {
1758 let e: TDim = TDim::from(2) + 1;
1759 assert_eq!(e, TDim::from(3));
1760 let e: TDim = TDim::from(3) + 2;
1761 assert_eq!(e, TDim::from(5));
1762 let e: TDim = TDim::from(3) + 0;
1763 assert_eq!(e, TDim::from(3));
1764 let e: TDim = TDim::from(3) + 2 + 1;
1765 assert_eq!(e, TDim::from(6));
1766 }
1767
1768 #[test]
1769 fn reduce_muls() {
1770 let e: TDim = Val(1) * A.to_dim();
1771 assert_eq!(e, A.to_dim());
1772 let e: TDim = A.to_dim() * &B.to_dim() * 1;
1773 assert_eq!(e, A.to_dim() * &B.to_dim());
1774 }
1775
1776 #[test]
1777 fn reduce_divs() {
1778 let e: TDim = TDim::from(2) / 1;
1779 assert_eq!(e, TDim::from(2));
1780 let e: TDim = TDim::from(3) / 2;
1781 assert_eq!(e, TDim::from(1));
1782 let e: TDim = TDim::from(3) % 2;
1783 assert_eq!(e, TDim::from(1));
1784 let e: TDim = TDim::from(5) / 2;
1785 assert_eq!(e, TDim::from(2));
1786 let e: TDim = TDim::from(5) % 2;
1787 assert_eq!(e, TDim::from(1));
1788 }
1789
1790 #[test]
1791 fn reduce_div_bug_0() {
1792 let e1: TDim = (A.to_dim() + 23) / 2 - 1;
1793 let e2: TDim = (A.to_dim() + 21) / 2;
1794 assert_eq!(e1, e2);
1795 }
1796
1797 #[test]
1798 fn reduce_div_bug_1() {
1799 let e1: TDim = (A.to_dim() + -1) / 2;
1800 let e2: TDim = (A.to_dim() + 1) / 2 - 1;
1801 assert_eq!(e1, e2);
1802 }
1803
1804 #[test]
1805 fn reduce_div_bug_2() {
1806 let e1: TDim = ((A.to_dim() + 1) / 2 + 1) / 2;
1807 let e2: TDim = (A.to_dim() + 3) / 4;
1808 assert_eq!(e1, e2);
1809 }
1810
1811 #[test]
1812 fn divide_multiple_plus_remainder() {
1813 let scope = SymbolScope::default().with_assertion("S>=0").unwrap();
1815 let s = scope.sym("S");
1816
1817 let e: TDim = (s.to_dim() * 2 + 1) / 2;
1819 assert_eq!(e.simplify(), s.to_dim());
1820
1821 let e: TDim = (s.to_dim() * 2 + 1) / 2 - 1;
1823 assert_eq!(e.simplify(), s.to_dim() - 1);
1824
1825 let e: TDim = (s.to_dim() * 2 - 1) / 2;
1827 assert_eq!(e.simplify(), s.to_dim() - 1);
1828
1829 let e: TDim = (s.to_dim() * 4 + 3) / 2;
1831 assert_eq!(e.simplify(), s.to_dim() * 2 + 1);
1832 }
1833
1834 #[test]
1835 fn divide_multiple_plus_remainder_no_assertion() {
1836 let scope = SymbolScope::default();
1843 let s = scope.sym("S");
1844 let e: TDim = (s.to_dim() * 2 + 1) / 2;
1845 assert_ne!(e.simplify(), s.to_dim());
1846 }
1847
1848 #[test]
1849 fn modulo_div_is_zero() {
1850 let scope = SymbolScope::default();
1853 let s = scope.sym("S");
1854 let e: TDim = (s.to_dim() - s.to_dim() / 2 * 2) / 2;
1856 assert_eq!(e.simplify(), TDim::Val(0));
1857 let a = s.to_dim() + 1;
1860 let e2: TDim = (a.clone() - a.clone() / 2 * 2) / 2;
1861 assert_eq!(e2.simplify(), TDim::Val(0));
1862 }
1863
1864 #[test]
1865 fn reduce_div_bug_3() {
1866 let e1: TDim = (A.to_dim() / 2) * -4;
1867 let e2: TDim = (A.to_dim() / 2) * -4 / 1;
1868 assert_eq!(e1, e2);
1869 }
1870
1871 #[test]
1872 fn reduce_mul_div() {
1873 let e: TDim = A.to_dim() * 2 / 2;
1874 assert_eq!(e, A.to_dim());
1875 }
1876
1877 #[test]
1878 fn expand_polynomial_two_add_factors() {
1879 let a = A.to_dim();
1884 let b = B.to_dim();
1885 let lhs = (a.clone() + a.clone() * &b * 2) * (TDim::from(1) + &b);
1886 let rhs = a.clone() * (TDim::from(1) + &b) * (TDim::from(1) + b.clone() * 2);
1887 assert_eq!(lhs.expand_polynomial(), rhs.expand_polynomial());
1888 }
1889
1890 #[test]
1891 fn reduce_div_mul() {
1892 let e: TDim = A.to_dim() / 2 * 2;
1893 assert_ne!(e, A.to_dim());
1894 }
1895
1896 #[test]
1897 fn reduce_add_div() {
1898 let e: TDim = A.to_dim() / 2 + 1;
1899 assert_eq!(e, ((A.to_dim() + 2) / 2));
1900 }
1901
1902 #[test]
1903 fn reduce_neg_mul_() {
1904 let e: TDim = TDim::from(1) - A.to_dim() * 2;
1905 assert_eq!(e, TDim::from(1) + A.to_dim() * -2);
1906 }
1907
1908 #[test]
1909 fn reduce_add_rem_1() {
1910 assert_eq!(((A.to_dim() + 4) % 2), (A.to_dim() % 2));
1911 }
1912
1913 #[test]
1914 fn reduce_add_rem_2() {
1915 assert_eq!(((A.to_dim() - 4) % 2), (A.to_dim() % 2));
1916 }
1917
1918 #[test]
1919 fn reduce_rem_div() {
1920 let e: TDim = A.to_dim() % 2 / 2;
1921 assert_eq!(e, TDim::from(0));
1922 }
1923
1924 #[test]
1925 fn conv2d_ex_1() {
1926 let e = (TDim::from(1) - 1 + 1).div_ceil(1);
1927 assert_eq!(e, TDim::from(1));
1928 }
1929
1930 #[test]
1931 fn conv2d_ex_2() {
1932 let e = (A.to_dim() - 3 + 1).div_ceil(1);
1933 assert_eq!(e, A.to_dim() + -2);
1934 }
1935
1936 #[test]
1937 fn extract_int_gcd_from_muls() {
1938 let term = (A.to_dim() + 1) / 4;
1939 let mul = (term.clone() * 24 - 24) * (term.clone() * 2 - 2);
1940 let target = (term.clone() - 1) * (term.clone() - 1) * 48;
1941 assert_eq!(mul, target);
1942 }
1943
1944 #[test]
1945 fn equality_of_muls() {
1946 let term = (A.to_dim() + 1) / 4;
1947 let mul1 = (term.clone() * 2 - 3) * (term.clone() - 1);
1948 let mul2 = (term.clone() - 1) * (term.clone() * 2 - 3);
1949 assert_eq!(mul1, mul2);
1950 }
1951
1952 #[test]
1953 fn factorize_complex_expr_times_int() {
1954 let term = (A.to_dim() + 1) / 4;
1955 let e = term.clone() * 2 - &term - 1;
1956 assert_eq!(e, term - 1);
1957 }
1958
1959 #[test]
1960 fn broadcast_over_min() {
1961 for a in 1..5 {
1967 for b in 1..5 {
1968 if b > 1 && a > b {
1969 assert!(a.broadcast(a.min(b)).is_err());
1970 } else {
1971 assert_eq!(a.broadcast(a.min(b)).unwrap(), a);
1972 }
1973 }
1974 }
1975 }
1976
1977 #[test]
1978 fn min_ints_1() {
1979 assert_eq!(2.to_dim().mini(1.to_dim()), 1.to_dim());
1980 }
1981
1982 #[test]
1983 fn min_ints_2() {
1984 assert_eq!(1.to_dim().mini(2.to_dim()), 1.to_dim());
1985 }
1986
1987 #[test]
1988 fn min_same() {
1989 assert_eq!(A.to_dim().mini(A.to_dim()), A.to_dim());
1990 }
1991
1992 #[test]
1993 fn min_noop() {
1994 assert_eq!(A.to_dim().mini(1.to_dim()), A.to_dim().mini(1.to_dim()));
1995 }
1996
1997 #[test]
1998 fn min_diff_1() {
1999 assert_eq!((A.to_dim() + 1).mini(A.to_dim() + 2), A.to_dim() + 1);
2000 }
2001
2002 #[test]
2003 fn slope_0() {
2004 assert_eq!(12.to_dim().guess_slope(&A), (0, 1));
2005 }
2006
2007 #[test]
2008 fn slope_1() {
2009 assert_eq!(A.to_dim().guess_slope(&A), (1, 1));
2010 }
2011
2012 #[test]
2013 fn slope_2() {
2014 assert_eq!((A.to_dim() * 2).guess_slope(&A), (2, 1));
2015 }
2016
2017 #[test]
2018 fn slope_3() {
2019 assert_eq!((A.to_dim() * 2 + A.to_dim() / 2).guess_slope(&A), (5, 2));
2020 }
2021
2022 #[test]
2023 fn slope_4() {
2024 assert_eq!((A.to_dim()).guess_slope(&B), (0, 1));
2025 }
2026
2027 #[test]
2028 fn slope_5() {
2029 assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1));
2030 assert_eq!((A.to_dim() + 1).guess_slope(&B), (0, 1));
2031 }
2032
2033 #[test]
2034 fn slope_6() {
2035 assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1));
2036 assert_eq!((A.to_dim() + B.to_dim()).guess_slope(&B), (1, 1));
2037 }
2038
2039 #[test]
2040 fn min_0() -> TractResult<()> {
2041 let symbols = SymbolScope::default();
2042 assert_eq!(
2043 symbols.parse_tdim("min(S+3, S+2)").unwrap().simplify(),
2044 symbols.parse_tdim("S+2").unwrap(),
2045 );
2046 Ok(())
2047 }
2048
2049 #[test]
2050 fn commutative_mul_parens() -> TractResult<()> {
2051 let symbols = SymbolScope::default();
2052 assert_eq!(
2053 symbols.parse_tdim("A*(B*C)").unwrap().simplify(),
2054 symbols.parse_tdim("(B*A)*C").unwrap().simplify(),
2055 );
2056 Ok(())
2057 }
2058
2059 #[test]
2060 fn commutative_in_nemo_parakeet_model() -> TractResult<()> {
2061 let symbols = SymbolScope::default();
2062 assert_eq!(
2063 symbols
2064 .parse_tdim("8*(1+-1*max(0,5000+-1*(S+7)/8)+max(0,4999+(S+7)/8))*((B)*((S+7)/8))")
2065 .unwrap()
2066 .simplify(),
2067 symbols
2068 .parse_tdim("8*((B)*(1+-1*max(0,5000+-1*(S+7)/8)+max(0,4999+(S+7)/8)))*((S+7)/8)")
2069 .unwrap()
2070 .simplify(),
2071 );
2072 Ok(())
2073 }
2074
2075 #[test]
2076 fn commutative_mul_parens_deep() -> TractResult<()> {
2077 let symbols = SymbolScope::default();
2078 let deep_tdim = Mul(vec![
2079 Mul(vec![Mul(vec![Mul(vec![A.to_dim(), B.to_dim()]), C.to_dim()]), D.to_dim()]),
2080 E.to_dim(),
2081 ])
2082 .simplify();
2083 assert_eq!(deep_tdim, symbols.parse_tdim("a*b*c*d*e").unwrap().simplify());
2084 Ok(())
2085 }
2086
2087 #[test]
2090 fn ge_concrete_true() {
2091 assert_eq!(Ge(b!(Val(5)), b!(Val(3))).reduce(), Val(1));
2092 }
2093
2094 #[test]
2095 fn ge_concrete_false() {
2096 assert_eq!(Ge(b!(Val(2)), b!(Val(3))).reduce(), Val(0));
2097 }
2098
2099 #[test]
2100 fn lt_concrete_true() {
2101 assert_eq!(Ge(b!(Val(3)), b!(Val(3))).reduce(), Val(1));
2103 }
2104
2105 #[test]
2106 fn lt_concrete_false() {
2107 assert_eq!(Ge(b!(Val(3)), b!(Val(6))).reduce(), Val(0));
2109 }
2110
2111 #[test]
2112 fn eq_concrete_true() {
2113 assert_eq!(Eq(b!(Val(3)), b!(Val(3))).reduce(), Val(1));
2114 }
2115
2116 #[test]
2117 fn eq_concrete_false() {
2118 assert_eq!(Eq(b!(Val(3)), b!(Val(4))).reduce(), Val(0));
2119 }
2120
2121 #[test]
2122 fn not_val_0() {
2123 assert_eq!((Val(1) - Val(0)).reduce(), Val(1));
2125 }
2126
2127 #[test]
2128 fn not_val_1() {
2129 assert_eq!((Val(1) - Val(1)).reduce(), Val(0));
2131 }
2132
2133 #[test]
2134 fn not_lt_becomes_ge() {
2135 let s = SymbolScope::default();
2137 let t = s.sym("T");
2138 let x1 = s.sym("x1");
2139 let expr = Val(1) - Ge(b!(Sym(t.clone())), b!(Sym(x1.clone()) + Val(1)));
2141 let at_boundary = expr.substitute(&x1, &Sym(t.clone())).unwrap().simplify();
2142 assert_eq!(at_boundary, Val(1));
2143 }
2144
2145 #[test]
2146 fn eq_with_assertion_proves_false() {
2147 let s = SymbolScope::default();
2149 s.add_assertion("T >= 1").unwrap();
2150 let t = s.sym("T");
2151 let expr = Eq(b!(Sym(t)), b!(Val(0)));
2152 assert_eq!(expr.simplify(), Val(0));
2153 }
2154
2155 #[test]
2156 fn ge_coord_at_extremes() {
2157 let s = SymbolScope::default();
2159 s.add_assertion("T >= 1").unwrap();
2160 let t = s.sym("T");
2161 let x1 = s.sym("x1");
2162 let expr = Ge(b!(Sym(x1.clone())), b!(Sym(t.clone())));
2163 let at_max = expr.substitute(&x1, &(Sym(t.clone()) - Val(1))).unwrap().simplify();
2166 assert_eq!(at_max, Val(0));
2167 }
2168
2169 #[test]
2170 fn eval_to_i64_new_variants() {
2171 use super::super::sym::SymbolValues;
2172 let sv = SymbolValues::default();
2173 assert_eq!(Ge(b!(Val(5)), b!(Val(3))).eval_to_i64(&sv).unwrap(), 1);
2174 assert_eq!(Ge(b!(Val(3)), b!(Val(5))).eval_to_i64(&sv).unwrap(), 0);
2175 assert_eq!(Eq(b!(Val(3)), b!(Val(3))).eval_to_i64(&sv).unwrap(), 1);
2176 assert_eq!(Eq(b!(Val(3)), b!(Val(4))).eval_to_i64(&sv).unwrap(), 0);
2177 }
2178
2179 #[test]
2180 fn eq_boolean_simplifies() {
2181 let s = SymbolScope::default();
2182 s.add_assertion("cw >= 0").unwrap();
2183 s.add_assertion("cw <= 1").unwrap();
2184 let cw = s.sym("cw");
2185 assert_eq!(Eq(b!(Val(1) - Sym(cw.clone())), b!(Val(0))).simplify(), Sym(cw.clone()));
2187 assert_eq!(Eq(b!(Sym(cw.clone())), b!(Val(0))).simplify(), Val(1) - Sym(cw.clone()));
2189 assert_eq!(Eq(b!(Sym(cw.clone())), b!(Val(1))).simplify(), Sym(cw.clone()));
2191 assert_eq!(Eq(b!(Val(1) - Sym(cw.clone())), b!(Val(1))).simplify(), Val(1) - Sym(cw));
2193 }
2194
2195 #[test]
2196 fn eq_boolean_mul_of_ge() {
2197 let s = SymbolScope::default();
2200 let x = s.sym("x");
2201 let product =
2202 Mul(vec![Ge(b!(Val(2)), b!(Sym(x.clone()))), Ge(b!(Sym(x.clone())), b!(Val(0)))]);
2203 let eq = Eq(b!(product.clone()), b!(Val(0)));
2204 assert_eq!(eq.simplify(), Val(1) - product);
2205 }
2206
2207 #[test]
2208 fn min_1_max_0_sym() {
2209 let s = SymbolScope::default();
2211 let x = s.sym("X");
2212 let expr = Min(vec![Val(1), Max(vec![Val(0), Sym(x)])]);
2213 let simplified = expr.simplify();
2214 eprintln!("simplified: {simplified}");
2215 assert!(format!("{simplified}").contains("min"), "Min dropped: {simplified}");
2216 }
2217
2218 #[test]
2219 fn min_preserved_in_subtraction_parts() {
2220 let s = SymbolScope::default();
2222 let t = s.sym("T");
2223 let p = s.sym("P");
2224 let ss = s.sym("S");
2225
2226 let cum_after =
2227 Max(vec![Val(0), (Sym(t.clone()) + Val(1)) * Sym(p.clone()) - Sym(ss.clone())]);
2228 let min_after = Min(vec![Val(1), cum_after.clone()]);
2229 let simplified = min_after.simplify();
2230 eprintln!("min_after simplified: {simplified}");
2231 assert!(format!("{simplified}").contains("min"), "Min wrapper was dropped: {simplified}");
2233 }
2234
2235 #[test]
2236 fn min_preserved_in_subtraction() {
2237 let s = SymbolScope::default();
2240 let t = s.sym("T");
2241 let p = s.sym("P");
2242 let ss = s.sym("S");
2243
2244 let cum_after =
2245 Max(vec![Val(0), (Sym(t.clone()) + Val(1)) * Sym(p.clone()) - Sym(ss.clone())]);
2246 let cum_before = Max(vec![Val(0), Sym(t.clone()) * Sym(p.clone()) - Sym(ss.clone())]);
2247
2248 let ap = Min(vec![Val(1), cum_after.clone()]) - Min(vec![Val(1), cum_before.clone()]);
2249 let simplified = ap.simplify();
2250
2251 use super::super::sym::SymbolValues;
2253 let sv = SymbolValues::default().with(&t, 1).with(&p, 4).with(&ss, 3);
2254 assert_eq!(simplified.eval_to_i64(&sv).unwrap(), 0, "simplified: {simplified}");
2255
2256 let sv = SymbolValues::default().with(&t, 0).with(&p, 4).with(&ss, 3);
2258 assert_eq!(simplified.eval_to_i64(&sv).unwrap(), 1, "simplified: {simplified}");
2259
2260 let sv = SymbolValues::default().with(&t, 0).with(&p, 1).with(&ss, 1);
2262 assert_eq!(simplified.eval_to_i64(&sv).unwrap(), 0, "simplified: {simplified}");
2263 }
2264
2265 #[test]
2266 fn mul_neg_b_by_8() {
2267 let s = SymbolScope::default();
2268 let b = Sym(s.sym("B"));
2269 let a = Mul(vec![Val(8), MulInt(-1, Box::new(b.clone()))]);
2271 let c = MulInt(-8, Box::new(b.clone()));
2272 let a_s = a.simplify();
2273 let c_s = c.simplify();
2274 assert_eq!(a_s, c_s, "8*(-1*B) should simplify the same as -8*B");
2275 }
2276
2277 #[test]
2281 fn reduce_div_by_common_factor_with_divisor() {
2282 let lhs = (A.to_dim() * 14 + 6) / 8;
2283 let rhs = (A.to_dim() * 7 + 3) / 4;
2284 assert_eq!(lhs, rhs);
2285 }
2286
2287 #[test]
2290 fn reduce_div_when_factor_equals_divisor() {
2291 let lhs = (A.to_dim() * 4 + 8) / 4;
2292 let rhs = A.to_dim() + 2;
2293 assert_eq!(lhs, rhs);
2294 }
2295
2296 #[test]
2299 fn no_reduce_when_terms_coprime_with_divisor() {
2300 let e = (A.to_dim() * 7 + 3) / 4;
2301 match &e {
2304 Div(_, q) => assert_eq!(*q, 4),
2305 other => panic!("expected Div(_, 4), got {other:?}"),
2306 }
2307 }
2308
2309 #[test]
2313 fn no_reduce_when_sym_has_implicit_unit_coefficient() {
2314 let e = (A.to_dim() + 4) / 2;
2316 let sv2 = SymbolValues::default().with(&A, 2);
2319 let sv4 = SymbolValues::default().with(&A, 4);
2320 assert_eq!(e.eval_to_i64(&sv2).unwrap(), 3);
2321 assert_eq!(e.eval_to_i64(&sv4).unwrap(), 4);
2322 }
2323}