1use crate::dim::Assertion;
2use crate::internal::*;
3
4use super::{DimLike, sym::*};
5use itertools::Itertools;
6use num_integer::Integer;
7use num_traits::{AsPrimitive, PrimInt, Zero};
8use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10use std::fmt::Debug;
11use std::ops::Neg;
12use std::{fmt, ops};
13
14#[derive(Debug)]
15pub enum TooEarly {
16 UndeterminedSymbol(String),
17 Other(String),
18}
19
20impl std::fmt::Display for TooEarly {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 match self {
23 TooEarly::UndeterminedSymbol(s) => write!(f, "Undetermined symbol in expression: {s}"),
24 TooEarly::Other(s) => write!(f, "{s}"),
25 }
26 }
27}
28
29impl std::error::Error for TooEarly {}
30
31macro_rules! b( ($e:expr) => { Box::new($e) } );
32
33#[derive(Clone, PartialEq, Eq, Hash, Debug)]
34pub enum TDim {
35 Val(i64),
36 Sym(Symbol),
37 Add(Vec<TDim>),
38 Mul(Vec<TDim>),
39 MulInt(i64, Box<TDim>),
40 Div(Box<TDim>, u64),
41 Broadcast(Vec<TDim>),
42 Min(Vec<TDim>),
43 Max(Vec<TDim>),
44 Ge(Box<TDim>, Box<TDim>),
46 Eq(Box<TDim>, Box<TDim>),
48}
49
50use TDim::*;
51
52fn tdim_lexi_order(a: &TDim, b: &TDim) -> Ordering {
53 match (a, b) {
54 (Sym(a), Sym(b)) => a.cmp(b),
55 (Val(a), Val(b)) => a.cmp(b),
56 (Add(a), Add(b))
57 | (Mul(a), Mul(b))
58 | (Broadcast(a), Broadcast(b))
59 | (Min(a), Min(b))
60 | (Max(a), Max(b)) => a.len().cmp(&b.len()).then(
61 a.iter()
62 .zip(b.iter())
63 .fold(Ordering::Equal, |acc, (a, b)| acc.then_with(|| tdim_lexi_order(a, b))),
64 ),
65 (MulInt(p, d), MulInt(q, e)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)),
66 (Div(d, p), Div(e, q)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)),
67 (Sym(_), _) => Ordering::Less,
68 (_, Sym(_)) => Ordering::Greater,
69 (Val(_), _) => Ordering::Less,
70 (_, Val(_)) => Ordering::Greater,
71 (Add(_), _) => Ordering::Less,
72 (_, Add(_)) => Ordering::Greater,
73 (Mul(_), _) => Ordering::Less,
74 (_, Mul(_)) => Ordering::Greater,
75 (MulInt(_, _), _) => Ordering::Less,
76 (_, MulInt(_, _)) => Ordering::Greater,
77 (Broadcast(_), _) => Ordering::Less,
78 (_, Broadcast(_)) => Ordering::Greater,
79 (Min(_), _) => Ordering::Less,
80 (_, Min(_)) => Ordering::Greater,
81 (Max(_), _) => Ordering::Less,
82 (_, Max(_)) => Ordering::Greater,
83 (Ge(a1, b1), Ge(a2, b2)) | (Eq(a1, b1), Eq(a2, b2)) => {
84 tdim_lexi_order(a1, a2).then_with(|| tdim_lexi_order(b1, b2))
85 }
86 (Ge(_, _) | Eq(_, _), _) => Ordering::Less,
87 (_, Ge(_, _) | Eq(_, _)) => Ordering::Greater,
88 }
89}
90
91impl fmt::Display for TDim {
92 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
93 match &self {
94 Sym(sym) => write!(fmt, "{sym}"),
95 Val(it) => write!(fmt, "{it}"),
96 Add(it) => write!(fmt, "{}", it.iter().map(|x| format!("{x}")).join("+")),
97 Mul(it) => write!(fmt, "{}", it.iter().map(|x| format!("({x})")).join("*")),
98 Broadcast(it) => write!(fmt, "{}", it.iter().map(|x| format!("({x})")).join("#")),
99 Min(it) => write!(fmt, "min({})", it.iter().map(|x| format!("{x}")).join(",")),
100 Max(it) => write!(fmt, "max({})", it.iter().map(|x| format!("{x}")).join(",")),
101 MulInt(a, b) => write!(fmt, "{a}*{b}"),
102 Div(a, b) => write!(fmt, "({a})/{b}"),
103 Ge(a, b) => write!(fmt, "({a}>={b})"),
104 Eq(a, b) => write!(fmt, "({a}=={b})"),
105 }
106 }
107}
108
109impl TDim {
110 #[inline]
111 pub fn is_one(&self) -> bool {
112 matches!(self, Val(1))
113 }
114
115 #[inline]
116 pub fn to_i64(&self) -> TractResult<i64> {
117 if let Val(v) = self {
118 Ok(*v)
119 } else {
120 Err(TooEarly::UndeterminedSymbol(self.to_string()))?
121 }
122 }
123
124 #[inline]
125 pub fn as_i64(&self) -> Option<i64> {
126 if let Val(v) = self { Some(*v) } else { None }
127 }
128
129 pub fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64> {
130 match self {
131 Sym(sym) => {
132 let Some(v) = values.get(sym) else {
133 Err(TooEarly::UndeterminedSymbol(self.to_string()))?
134 };
135 Ok(v)
136 }
137 Val(v) => Ok(*v),
138 Add(terms) => terms.iter().try_fold(0i64, |acc, it| {
139 let x = it.eval_to_i64(values)?;
140 acc.checked_add(x)
141 .with_context(|| format!("Overflow in TDim addition ({acc} + {x})"))
142 }),
143 Mul(terms) => terms.iter().try_fold(1i64, |acc, it| {
144 let x = it.eval_to_i64(values)?;
145 acc.checked_mul(x)
146 .with_context(|| format!("Overflow in TDim multiplication ({acc} * {x})"))
147 }),
148 Min(terms) => terms
149 .iter()
150 .try_fold(i64::MAX, |acc, it| it.eval_to_i64(values).map(|x| acc.min(x))),
151 Max(terms) => terms
152 .iter()
153 .try_fold(i64::MIN, |acc, it| it.eval_to_i64(values).map(|x| acc.max(x))),
154 Broadcast(terms) => terms.iter().try_fold(1i64, |acc, it| {
155 it.eval_to_i64(values)
156 .and_then(|x| ((acc as usize).broadcast(x as usize)).map(|x| x as i64))
157 }),
158 Div(a, q) => Ok(a.eval_to_i64(values)? / *q as i64),
159 MulInt(p, a) => {
160 let x = a.eval_to_i64(values)?;
161 x.checked_mul(*p)
162 .with_context(|| format!("Overflow in TDim multiplication ({x} * {p})"))
163 }
164 Ge(a, b) => Ok(if a.eval_to_i64(values)? >= b.eval_to_i64(values)? { 1 } else { 0 }),
165 Eq(a, b) => Ok(if a.eval_to_i64(values)? == b.eval_to_i64(values)? { 1 } else { 0 }),
166 }
167 }
168
169 pub fn eval(&self, values: &SymbolValues) -> TDim {
170 match self {
171 Sym(sym) => values.get(sym).map(Val).unwrap_or_else(|| Sym(sym.clone())),
172 Val(v) => Val(*v),
173 Add(terms) => terms.iter().fold(Val(0), |acc, it| -> TDim { acc + it.eval(values) }),
174 Mul(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim { acc * it.eval(values) }),
175 Min(terms) => {
176 terms.iter().fold(Val(i64::MAX), |acc, it| -> TDim { acc.mini(it.eval(values)) })
177 }
178 Max(terms) => {
179 terms.iter().fold(Val(i64::MIN), |acc, it| -> TDim { acc.maxi(it.eval(values)) })
180 }
181 Broadcast(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim {
182 acc.broadcast(it.eval(values)).unwrap_or_else(|_| self.clone())
183 }),
184 Div(a, q) => a.eval(values) / *q as i64,
185 MulInt(p, a) => a.eval(values) * *p,
186 Ge(a, b) => {
187 let a2 = a.eval(values);
188 let b2 = b.eval(values);
189 if let (Val(av), Val(bv)) = (&a2, &b2) {
190 Val(if av >= bv { 1 } else { 0 })
191 } else {
192 Ge(b!(a2), b!(b2))
193 }
194 }
195 Eq(a, b) => {
196 let a2 = a.eval(values);
197 let b2 = b.eval(values);
198 if let (Val(av), Val(bv)) = (&a2, &b2) {
199 Val(if av == bv { 1 } else { 0 })
200 } else {
201 Eq(b!(a2), b!(b2))
202 }
203 }
204 }
205 }
206
207 pub fn eval_with_scenario(&self, scenario: &str) -> TDim {
208 if let Val(v) = self {
209 return Val(*v);
210 }
211 let scope = self.find_scope().unwrap();
212 let scope = scope.0;
213 let locked = scope.lock();
214 let scope = locked.borrow();
215 self.clone().simplify_rec(&scope, Some(scenario), &[])
216 }
217
218 pub fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self> {
219 self.substitute_all(&std::collections::HashMap::from([(from.clone(), to.clone())]))
220 }
221
222 pub fn substitute_all(
223 &self,
224 map: &std::collections::HashMap<Symbol, Self>,
225 ) -> TractResult<Self> {
226 match self {
227 Sym(sym) => Ok(map.get(sym).cloned().unwrap_or_else(|| self.clone())),
228 Val(v) => Ok(Val(*v)),
229 Add(terms) => terms.iter().try_fold(Val(0), |acc, it| -> TractResult<TDim> {
230 Ok(acc + it.substitute_all(map)?)
231 }),
232 Mul(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult<TDim> {
233 Ok(acc * it.substitute_all(map)?)
234 }),
235 Broadcast(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult<TDim> {
236 acc.broadcast(it.substitute_all(map)?)
237 }),
238 Min(terms) => terms.iter().try_fold(Val(i64::MAX), |acc, it| -> TractResult<TDim> {
239 Ok(acc.mini(it.substitute_all(map)?))
240 }),
241 Max(terms) => terms.iter().try_fold(Val(i64::MIN), |acc, it| -> TractResult<TDim> {
242 Ok(acc.maxi(it.substitute_all(map)?))
243 }),
244 Div(a, q) => Ok(a.substitute_all(map)? / *q as i64),
245 MulInt(p, a) => Ok(a.substitute_all(map)? * *p),
246 Ge(a, b) => Ok(Ge(b!(a.substitute_all(map)?), b!(b.substitute_all(map)?))),
247 Eq(a, b) => Ok(Eq(b!(a.substitute_all(map)?), b!(b.substitute_all(map)?))),
248 }
249 }
250
251 pub fn reduce(self) -> TDim {
252 self.simplify()
253 .wiggle()
254 .into_iter()
255 .sorted_by(tdim_lexi_order)
256 .unique()
257 .map(|e| e.simplify())
258 .min_by_key(|e| e.cost())
259 .unwrap()
260 }
261
262 fn cost(&self) -> usize {
263 use self::TDim::*;
264 match self {
265 Sym(_) | Val(_) => 1,
266 Add(terms) => 2 * terms.iter().map(TDim::cost).sum::<usize>(),
267 Mul(terms) => 3 * terms.iter().map(TDim::cost).sum::<usize>(),
268 Broadcast(terms) => 4 * terms.iter().map(TDim::cost).sum::<usize>(),
269 Min(terms) | Max(terms) => 5 * terms.iter().map(TDim::cost).sum::<usize>(),
270 Div(a, _) => 3 * a.cost(),
271 MulInt(_, a) => 2 * a.cost(),
272 Ge(a, b) | Eq(a, b) => 5 * (a.cost() + b.cost()),
273 }
274 }
275
276 fn wiggle(&self) -> Vec<TDim> {
277 use self::TDim::*;
278 match self {
279 Sym(_) | Val(_) | Mul(_) | Broadcast(_) | Min(_) | Max(_) | Ge(_, _) | Eq(_, _) => {
280 vec![self.clone()]
281 }
282 Add(terms) => {
283 let mut forms = vec![];
284 let sub_exprs = terms.iter().map(|e| e.wiggle()).multi_cartesian_product();
285
286 fn first_div_term(terms: &[TDim]) -> Option<(usize, &TDim, u64)> {
287 terms.iter().enumerate().find_map(|(index, t)| match t {
288 Div(numerator, quotient) => Some((index, &**numerator, *quotient)),
289 _ => None,
290 })
291 }
292
293 fn generate_new_numerator(
294 div_index: usize,
295 numerator: &TDim,
296 quotient: u64,
297 expr: &[TDim],
298 ) -> Vec<TDim> {
299 expr.iter()
300 .enumerate()
301 .map(|(index, term)| {
302 if index == div_index {
303 numerator.clone()
304 } else {
305 MulInt(quotient as i64, Box::new(term.clone()))
306 }
307 })
308 .collect()
309 }
310
311 for expr in sub_exprs {
312 if let Some((div_index, numerator, quotient)) = first_div_term(&expr) {
313 let new_numerator =
314 generate_new_numerator(div_index, numerator, quotient, &expr);
315 forms.push(Div(Box::new(Add(new_numerator)), quotient))
316 }
317
318 forms.push(Add(expr));
319 }
320 forms
321 }
322 MulInt(p, a) => a.wiggle().into_iter().map(|a| MulInt(*p, b!(a))).collect(),
323 Div(a, q) => {
324 let mut forms = vec![];
325 for num in a.wiggle() {
326 if let Add(terms) = &num {
327 let (integer, non_integer): (Vec<_>, Vec<_>) =
328 terms.iter().cloned().partition(|a| a.gcd() % q == 0);
329 let mut new_terms = integer.iter().map(|i| i.div(*q)).collect::<Vec<_>>();
330 if non_integer.len() > 0 {
331 new_terms.push(Div(b!(Add(non_integer)), *q));
332 }
333 forms.push(Add(new_terms))
334 }
335 forms.push(Div(b!(num), *q))
336 }
337 forms
338 }
339 }
340 }
341
342 fn find_any_sym(tdim: &TDim) -> Option<&Symbol> {
343 match tdim {
344 Val(_) => None,
345 Sym(s) => Some(s),
346 Add(terms) | Mul(terms) | Min(terms) | Max(terms) | Broadcast(terms) => {
347 terms.iter().find_map(Self::find_any_sym)
348 }
349 MulInt(_, t) | Div(t, _) => Self::find_any_sym(t),
350 Ge(a, b) | Eq(a, b) => Self::find_any_sym(a).or_else(|| Self::find_any_sym(b)),
351 }
352 }
353
354 pub fn find_scope(&self) -> Option<SymbolScope> {
355 Self::find_any_sym(self).and_then(|s| s.scope().clone())
356 }
357
358 pub fn simplify(self) -> TDim {
359 use self::TDim::*;
360 if let Ok(v) = self.eval_to_i64(&SymbolValues::default()) {
361 return Val(v);
362 }
363 let Some(scope) = self.find_scope() else {
364 return self;
365 };
366 let scope = scope.0;
367 let locked = scope.lock();
368 let scope = locked.borrow();
369 let it = self.simplify_rec(&scope, None, &[]);
370 let mut current: Option<TDim> = None;
371 for scenario in scope.scenarios() {
372 let v = it.clone().simplify_rec(&scope, Some(scenario), &[]);
373 if current.is_some_and(|c| c != v) {
374 return it;
375 } else {
376 current = Some(v);
377 }
378 }
379 current.unwrap_or(it)
380 }
381
382 pub fn simplify_with_extra_assertions(self, extra: &[Assertion]) -> TDim {
383 use self::TDim::*;
384 if extra.is_empty() {
385 return self.simplify();
386 }
387 if let Ok(v) = self.eval_to_i64(&SymbolValues::default()) {
388 return Val(v);
389 }
390 let Some(scope) = self.find_scope() else {
391 return self;
392 };
393 let scope = scope.0;
394 let locked = scope.lock();
395 let scope = locked.borrow();
396 let it = self.simplify_rec(&scope, None, extra);
397 let mut current: Option<TDim> = None;
398 for scenario in scope.scenarios() {
399 let v = it.clone().simplify_rec(&scope, Some(scenario), extra);
400 if current.is_some_and(|c| c != v) {
401 return it;
402 } else {
403 current = Some(v);
404 }
405 }
406 current.unwrap_or(it)
407 }
408
409 fn simplify_rec(
410 self,
411 scope: &SymbolScopeData,
412 scenario: Option<&str>,
413 extra: &[Assertion],
414 ) -> TDim {
415 match self {
416 Add(mut terms) => {
417 #[allow(clippy::mutable_key_type)]
418 let mut simplified_terms: HashMap<TDim, i64> = HashMap::new();
419 while let Some(term) = terms.pop() {
421 let simplified = term.simplify_rec(scope, scenario, extra);
422 match simplified {
423 Val(0) => {} Add(members) => {
425 terms.extend(members);
426 continue;
427 }
428 Val(value) => *simplified_terms.entry(Val(1)).or_insert(0) += value,
429 MulInt(value, factor) => {
430 *simplified_terms.entry((*factor).clone()).or_insert(0) += value;
431 }
432 n => *simplified_terms.entry(n).or_insert(0) += 1,
433 };
434 }
435
436 pub fn evaluate_count(term: TDim, count: i64) -> Option<TDim> {
437 match count {
438 0 => None,
439 _ if term == TDim::Val(1) => Some(TDim::Val(count)),
440 1 => Some(term),
441 _ => Some(TDim::MulInt(count, Box::new(term))),
442 }
443 }
444
445 let mut members: Vec<TDim> = simplified_terms
446 .into_iter()
447 .filter_map(|(term, count)| evaluate_count(term, count))
448 .collect();
449 members.sort_by(tdim_lexi_order);
450
451 match members.len() {
452 0 => TDim::Val(0),
453 1 => members.into_iter().next().unwrap(),
454 _ => TDim::Add(members),
455 }
456 }
457 Mul(terms) => {
458 {
463 let add_indices: Vec<usize> = terms
464 .iter()
465 .enumerate()
466 .filter(|(_, t)| matches!(t, Add(_)))
467 .map(|(i, _)| i)
468 .collect();
469 if add_indices.len() == 1 {
470 let add_idx = add_indices[0];
471 let Add(add_terms) = &terms[add_idx] else { unreachable!() };
472 let other_factors: Vec<TDim> = terms
473 .iter()
474 .enumerate()
475 .filter(|(i, _)| *i != add_idx)
476 .map(|(_, t)| t.clone())
477 .collect();
478 let distributed: Vec<TDim> = add_terms
479 .iter()
480 .map(|at| {
481 let mut product = other_factors.clone();
482 product.push(at.clone());
483 Mul(product)
484 })
485 .collect();
486 return Add(distributed).simplify_rec(scope, scenario, extra);
487 }
488 }
489
490 let mut flattened_terms = vec![];
493 for t in terms {
494 match t.clone().reduce() {
495 Mul(inner_terms) => flattened_terms.extend(inner_terms),
496 MulInt(k, inner) => {
497 flattened_terms.push(Val(k));
498 flattened_terms.push(*inner);
499 }
500 other => flattened_terms.push(other),
501 }
502 }
503 let mut terms = flattened_terms;
504
505 let mut gcd = Mul(terms.clone()).gcd() as i64;
506 if gcd == 0 {
507 return Val(0);
508 }
509 terms = if gcd != 1 {
510 terms
511 .into_iter()
512 .map(|t| {
513 let gcd = t.gcd();
514 (t / gcd).simplify_rec(scope, scenario, extra)
515 })
516 .collect()
517 } else {
518 terms
519 };
520 if terms.iter().filter(|t| t == &&Val(-1)).count() % 2 == 1 {
521 gcd = -gcd;
522 }
523 terms.retain(|t| !t.is_one() && t != &Val(-1));
524 terms.sort_by(tdim_lexi_order);
525
526 match (gcd, terms.len()) {
527 (_, 0) => Val(gcd), (0, _) => Val(0), (1, 1) => terms.remove(0), (1, _) => Mul(terms), (_, 1) => MulInt(gcd, Box::new(terms.remove(0))), _ => MulInt(gcd, Box::new(Mul(terms))), }
535 }
536 MulInt(coef, expr) => {
537 match *expr {
538 MulInt(c2, inner) => {
539 if let Some(c) = coef.checked_mul(c2) {
540 return MulInt(c, inner).simplify_rec(scope, scenario, extra);
541 } else {
542 return MulInt(coef, Box::new(MulInt(c2, inner)));
543 }
544 }
545 Val(v) => {
546 return coef
547 .checked_mul(v)
548 .map(Val)
549 .unwrap_or_else(|| MulInt(coef, Box::new(Val(v))));
550 }
551 _ => {}
552 }
553
554 let simplified = expr.simplify_rec(scope, scenario, extra);
555 match (coef, simplified) {
556 (0, _) => Val(0), (1, s) => s, (_, Add(terms)) => Add(terms
559 .into_iter()
560 .map(|term| {
561 MulInt(coef, Box::new(term)).simplify_rec(scope, scenario, extra)
562 })
563 .collect()), (c, Val(v)) => {
565 c.checked_mul(v).map(Val).unwrap_or_else(|| MulInt(c, Box::new(Val(v))))
566 } (c, MulInt(v, inner)) => {
568 if let Some(cv) = c.checked_mul(v) {
569 MulInt(cv, inner) } else {
571 MulInt(c, Box::new(MulInt(v, inner)))
572 }
573 }
574 (_, s) => MulInt(coef, Box::new(s)), }
576 }
577 Div(a, q) => {
578 if q == 1 {
579 return a.simplify_rec(scope, scenario, extra);
580 } else if let Div(a, q2) = *a {
581 return Div(a, q * q2).simplify_rec(scope, scenario, extra);
582 }
583 let a = a.simplify_rec(scope, scenario, extra);
584 if let Val(a) = a {
585 Val(a / q as i64)
586 } else if let MulInt(-1, a) = a {
587 MulInt(-1, b!(Div(a, q)))
588 } else if let Add(mut terms) = a {
589 if terms
590 .iter()
591 .any(|t| if let MulInt(-1, s) = t { matches!(&**s, Sym(_)) } else { false })
592 {
593 MulInt(
594 -1,
595 b!(Div(
596 b!(Add(terms.into_iter().map(|t| MulInt(-1, b!(t))).collect())
597 .simplify_rec(scope, scenario, extra)),
598 q
599 )),
600 )
601 } else if let Some(v) =
602 terms.iter().find_map(|t| if let Val(v) = t { Some(*v) } else { None })
603 {
604 let offset = if v >= q as i64 {
605 Some(v / q as i64)
606 } else if v < 0 {
607 Some(-Integer::div_ceil(&-v, &(q as i64)))
608 } else {
609 None
610 };
611 if let Some(val) = offset {
612 terms.push(Val(-val * q as i64));
613 Add(vec![
614 Val(val),
615 Div(b!(Add(terms).simplify_rec(scope, scenario, extra)), q),
616 ])
617 } else {
618 Div(b!(Add(terms)), q)
619 }
620 } else {
621 Div(b!(Add(terms)), q)
622 }
623 } else if let MulInt(p, a) = a {
624 if p == q as i64 {
625 a.simplify()
626 } else {
627 let gcd = p.abs().gcd(&(q as i64));
628 if gcd == p {
629 Div(a, q / gcd as u64)
630 } else if gcd == q as i64 {
631 MulInt(p / gcd, a)
632 } else if gcd > 1 {
633 Div(b!(MulInt(p / gcd, a)), q / gcd as u64)
634 .simplify_rec(scope, scenario, extra)
635 } else {
636 Div(b!(MulInt(p, a)), q)
637 }
638 }
639 } else {
640 Div(b!(a), q)
641 }
642 }
643 Broadcast(terms) => {
644 let mut terms: Vec<TDim> = terms
645 .iter()
646 .map(|s| s.clone().simplify_rec(scope, scenario, extra))
647 .flat_map(|t| if let Broadcast(t) = t { t } else { vec![t] })
648 .filter(|t| !t.is_one())
649 .sorted_by(tdim_lexi_order)
650 .dedup()
651 .collect_vec();
652 match &*terms {
654 [] => Val(1),
655 [_] => terms.remove(0),
656 [a, Min(m)] | [Min(m), a]
657 if m.contains(a)
658 && m.iter()
659 .all(|t| scope.prove_strict_positive_with_extra(t, extra)) =>
660 {
661 a.clone()
662 }
663 _ => Broadcast(terms),
664 }
665 }
666
667 Min(terms) => {
668 let mut flatten: Vec<TDim> = terms
669 .into_iter()
670 .map(|t| t.simplify_rec(scope, scenario, extra))
671 .flat_map(|t| if let Min(t) = t { t } else { vec![t] })
672 .filter(|t| t != &Val(i64::MAX))
673 .sorted_by(tdim_lexi_order)
674 .dedup()
675 .collect();
676 #[allow(clippy::mutable_key_type)]
677 let mut redundant = HashSet::<TDim>::default();
678 for pair in flatten.iter().permutations(2) {
679 let (a, b) = (pair[0], pair[1]);
680 if redundant.contains(a) || redundant.contains(b) {
681 continue;
682 }
683 let diff = a.clone() - b;
684 if diff.as_i64().is_some_and(|i| i >= 0)
685 || scope.prove_positive_or_zero_with_extra(&diff, extra)
686 {
687 redundant.insert(a.clone());
688 }
689 }
690 flatten.retain(|t| !redundant.contains(t));
691 if flatten.len() == 0 {
692 i64::MAX.to_dim()
693 } else if flatten.len() == 1 {
694 flatten.into_iter().next().unwrap()
695 } else {
696 Min(flatten)
697 }
698 }
699 Max(terms) => {
700 let mut flatten: Vec<TDim> = terms
701 .into_iter()
702 .map(|t| t.simplify_rec(scope, scenario, extra))
703 .flat_map(|t| if let Max(t) = t { t } else { vec![t] })
704 .filter(|t| t != &Val(i64::MIN))
705 .sorted_by(tdim_lexi_order)
706 .dedup()
707 .collect();
708 #[allow(clippy::mutable_key_type)]
709 let mut redundant = HashSet::<TDim>::default();
710 for pair in flatten.iter().permutations(2) {
711 let (a, b) = (pair[0], pair[1]);
712 if redundant.contains(a) || redundant.contains(b) {
713 continue;
714 }
715 let diff = a.clone() - b;
716 if diff.as_i64().is_some_and(|i| i >= 0)
717 || scope.prove_positive_or_zero_with_extra(&diff, extra)
718 {
719 redundant.insert(b.clone());
720 }
721 }
722 flatten.retain(|t| !redundant.contains(t));
723 if flatten.len() == 0 {
724 i64::MIN.to_dim()
725 } else if flatten.len() == 1 {
726 flatten.into_iter().next().unwrap()
727 } else {
728 Max(flatten)
729 }
730 }
731 Sym(s) => scope
732 .assertions(scenario)
733 .find_map(|a| match a {
734 Assertion::Eq(Sym(sym), v) if sym == &s => Some(v.clone()),
735 _ => None,
736 })
737 .unwrap_or(Sym(s)),
738 Val(_) => self,
739 Ge(a, b) => {
740 let a = a.simplify_rec(scope, scenario, extra);
741 let b = b.simplify_rec(scope, scenario, extra);
742 match (&a, &b) {
743 (Val(av), Val(bv)) => Val(if av >= bv { 1 } else { 0 }),
744 _ => {
745 let diff = a.clone() - b.clone();
746 if scope.prove_positive_or_zero_with_extra(&diff, extra) {
747 Val(1)
748 } else if scope
749 .prove_strict_positive_with_extra(&(b.clone() - a.clone()), extra)
750 {
751 Val(0)
752 } else {
753 Ge(b!(a), b!(b))
754 }
755 }
756 }
757 }
758 Eq(a, b) => {
759 let a = a.simplify_rec(scope, scenario, extra);
760 let b = b.simplify_rec(scope, scenario, extra);
761 match (&a, &b) {
762 (Val(av), Val(bv)) => Val(if av == bv { 1 } else { 0 }),
763 _ => {
764 let diff = a.clone() - b.clone();
765 if scope.prove_strict_positive_with_extra(&diff, extra)
766 || scope
767 .prove_strict_positive_with_extra(&(b.clone() - a.clone()), extra)
768 {
769 Val(0)
770 } else {
771 let boolean_case = match (&a, &b) {
776 (Val(0), e) | (e, Val(0)) => Some((e, false)),
777 (Val(1), e) | (e, Val(1)) => Some((e, true)),
778 _ => None,
779 };
780 if let Some((expr, equals_one)) = boolean_case {
781 if scope.prove_positive_or_zero_with_extra(expr, extra)
782 && scope.prove_positive_or_zero_with_extra(
783 &(Val(1) - expr.clone()),
784 extra,
785 )
786 {
787 return if equals_one {
788 expr.clone()
789 } else {
790 (Val(1) - expr.clone()).simplify_rec(scope, scenario, extra)
791 };
792 }
793 }
794 Eq(b!(a), b!(b))
795 }
796 }
797 }
798 }
799 }
800 }
801
802 pub(super) fn inclusive_bound(&self, scope: &SymbolScopeData, upper: bool) -> Option<i64> {
803 use self::TDim::*;
804 match self {
805 Val(n) => Some(*n),
806 Sym(_) => {
807 if upper {
808 scope
809 .all_assertions()
810 .iter()
811 .filter_map(|assert| match &assert {
812 Assertion::LT(left, right)
813 if left == self && right.as_i64().is_some() =>
814 {
815 Some(right.as_i64().unwrap() - 1)
816 }
817 Assertion::LTE(left, right)
818 if left == self && right.as_i64().is_some() =>
819 {
820 Some(right.as_i64().unwrap())
821 }
822 _ => None,
823 })
824 .min()
825 } else {
826 scope
827 .all_assertions()
828 .iter()
829 .filter_map(|assert| match &assert {
830 Assertion::GT(left, right)
831 if left == self && right.as_i64().is_some() =>
832 {
833 Some(right.as_i64().unwrap() + 1)
834 }
835 Assertion::GTE(left, right)
836 if left == self && right.as_i64().is_some() =>
837 {
838 Some(right.as_i64().unwrap())
839 }
840 _ => None,
841 })
842 .max()
843 }
844 }
845 Add(terms) => {
846 let mut bound: i64 = 0;
847 for t in terms {
848 if let Some(b) = t.inclusive_bound(scope, upper) {
849 bound = bound.checked_add(b)?;
850 } else {
851 return None;
852 }
853 }
854 Some(bound)
855 }
856 MulInt(p, a) => match p.cmp(&0) {
857 Ordering::Equal => Some(0),
858 Ordering::Greater => {
859 a.inclusive_bound(scope, upper).and_then(|x| x.checked_mul(*p))
860 }
861 Ordering::Less => a.inclusive_bound(scope, !upper).and_then(|x| x.checked_mul(*p)),
862 },
863 Mul(terms) => {
864 let mut lo: i64 = 1;
866 let mut hi: i64 = 1;
867 for t in terms {
868 let t_lo = t.inclusive_bound(scope, false)?;
869 let t_hi = t.inclusive_bound(scope, true)?;
870 if t_lo < 0 {
871 return None;
872 }
873 lo = lo.checked_mul(t_lo)?;
874 hi = hi.checked_mul(t_hi)?;
875 }
876 Some(if upper { hi } else { lo })
877 }
878 Min(terms) if !upper => {
879 let bounds: Option<Vec<i64>> =
882 terms.iter().map(|t| t.inclusive_bound(scope, false)).collect();
883 bounds.map(|b| b.into_iter().min().unwrap_or(i64::MAX))
884 }
885 Max(terms) if upper => {
886 let bounds: Option<Vec<i64>> =
889 terms.iter().map(|t| t.inclusive_bound(scope, true)).collect();
890 bounds.map(|b| b.into_iter().max().unwrap_or(i64::MIN))
891 }
892 Div(a, q) => a.inclusive_bound(scope, upper).map(|x| x / (*q as i64)),
893 Broadcast(terms) => {
894 if upper {
895 Max(terms.clone()).inclusive_bound(scope, true)
896 } else {
897 Min(terms.clone()).inclusive_bound(scope, false)
898 }
899 }
900 Ge(_, _) | Eq(_, _) => {
901 if upper {
902 Some(1)
903 } else {
904 Some(0)
905 }
906 }
907 _ => None,
908 }
909 }
910
911 pub fn low_inclusive_bound(&self) -> Option<i64> {
912 if let TDim::Val(v) = self {
913 return Some(*v);
914 }
915 let scope = self.find_scope()?;
916 let data = scope.0.lock();
917 let data = data.borrow();
918 self.inclusive_bound(&data, false)
919 }
920
921 pub fn high_inclusive_bound(&self) -> Option<i64> {
922 if let TDim::Val(v) = self {
923 return Some(*v);
924 }
925 let scope = self.find_scope()?;
926 let data = scope.0.lock();
927 let data = data.borrow();
928 self.inclusive_bound(&data, true)
929 }
930
931 pub fn prove_positive_or_zero(&self) -> bool {
932 if let TDim::Val(v) = self {
933 return *v >= 0;
934 }
935 let Some(scope) = self.find_scope() else { return false };
936 let data = scope.0.lock();
937 let data = data.borrow();
938 data.prove_positive_or_zero(self)
939 }
940
941 pub fn prove_strict_positive(&self) -> bool {
942 if let TDim::Val(v) = self {
943 return *v > 0;
944 }
945 (self.clone() - 1).prove_positive_or_zero()
946 }
947
948 pub fn prove_negative_or_zero(&self) -> bool {
949 if let TDim::Val(v) = self {
950 return *v <= 0;
951 }
952 self.clone().neg().prove_positive_or_zero()
953 }
954
955 pub fn prove_strict_negative(&self) -> bool {
956 if let TDim::Val(v) = self {
957 return *v < 0;
958 }
959 self.clone().neg().prove_strict_positive()
960 }
961
962 pub fn gcd(&self) -> u64 {
963 use self::TDim::*;
964 match self {
965 Val(v) => v.unsigned_abs(),
966 Sym(_) => 1,
967 Add(terms) => {
968 let (head, tail) = terms.split_first().unwrap();
969 tail.iter().fold(head.gcd(), |a, b| a.gcd(&b.gcd()))
970 }
971 MulInt(p, a) => a.gcd().saturating_mul(p.unsigned_abs()),
972 Mul(terms) => terms.iter().map(|t| t.gcd()).fold(1u64, |a, b| a.saturating_mul(b)),
973 Min(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
974 Max(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
975 Div(a, q) => {
976 if a.gcd() % *q == 0 {
977 a.gcd() / *q
978 } else {
979 1
980 }
981 }
982 Broadcast(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap_or(1),
983 Ge(_, _) | Eq(_, _) => 1,
984 }
985 }
986
987 fn div(&self, d: u64) -> TDim {
988 use self::TDim::*;
989 if d == 1 {
990 return self.clone();
991 }
992 match self {
993 Val(v) => Val(v / d as i64),
994 Sym(_) => panic!(),
995 Add(terms) => Add(terms.iter().map(|t| t.div(d)).collect()),
996 Min(terms) => Min(terms.iter().map(|t| t.div(d)).collect()),
997 Max(terms) => Max(terms.iter().map(|t| t.div(d)).collect()),
998 Broadcast(terms) => Broadcast(terms.iter().map(|t| t.div(d)).collect()),
999 Mul(_) => Div(Box::new(self.clone()), d),
1000 MulInt(p, a) => {
1001 if *p == d as i64 {
1002 (**a).clone()
1003 } else {
1004 let gcd = p.unsigned_abs().gcd(&d);
1005 MulInt(p / gcd as i64, b!(a.div(d / gcd)))
1006 }
1007 }
1008 Div(a, q) => Div(a.clone(), q * d),
1009 Ge(_, _) | Eq(_, _) => Div(Box::new(self.clone()), d),
1010 }
1011 }
1012
1013 pub fn div_ceil(self, rhs: u64) -> TDim {
1014 TDim::Div(Box::new(Add(vec![self, Val(rhs as i64 - 1)])), rhs).reduce()
1015 }
1016
1017 pub(super) fn guess_slope(&self, sym: &Symbol) -> (i64, u64) {
1018 fn slope_rec(d: &TDim, sym: &Symbol) -> (i64, i64) {
1019 match d {
1020 Val(_) => (0, 1),
1021 Sym(s) => ((sym == s) as i64, 1),
1022 Add(terms) => terms
1023 .iter()
1024 .map(|d| slope_rec(d, sym))
1025 .fold((0, 1), |a, b| ((a.0 * b.1 + a.1 * b.0), (b.1 * a.1))),
1026 Mul(terms) => terms
1027 .iter()
1028 .map(|d| slope_rec(d, sym))
1029 .fold((1, 1), |a, b| ((a.0 * b.0), (b.1 * a.1))),
1030 MulInt(p, a) => {
1031 let (n, d) = slope_rec(a, sym);
1032 (p * n, d)
1033 }
1034 Div(a, q) => {
1035 let (n, d) = slope_rec(a, sym);
1036 (n, d * *q as i64)
1037 }
1038 Broadcast(terms) => slope_rec(&terms[0], sym),
1039 Min(terms) => slope_rec(&terms[0], sym),
1040 Max(terms) => slope_rec(&terms[0], sym),
1041 Ge(_, _) | Eq(_, _) => (0, 1),
1042 }
1043 }
1044 let (p, q) = slope_rec(self, sym);
1045 reduce_ratio(p, q)
1046 }
1047
1048 #[allow(clippy::mutable_key_type)]
1049 pub fn symbols(&self) -> std::collections::HashSet<Symbol> {
1050 match self {
1051 Val(_) => maplit::hashset!(),
1052 Sym(s) => maplit::hashset!(s.clone()),
1053 Add(terms) | Mul(terms) | Broadcast(terms) | Min(terms) | Max(terms) => {
1054 terms.iter().fold(maplit::hashset!(), |mut set, v| {
1055 set.extend(v.symbols());
1056 set
1057 })
1058 }
1059 MulInt(_, a) => a.symbols(),
1060 Div(a, _) => a.symbols(),
1061 Ge(a, b) | Eq(a, b) => {
1062 let mut set = a.symbols();
1063 set.extend(b.symbols());
1064 set
1065 }
1066 }
1067 }
1068
1069 pub fn compatible_with(&self, other: &TDim) -> bool {
1070 if let Ok(x) = (self.clone() - other).to_i64() {
1071 return x == 0;
1072 }
1073 true }
1075}
1076
1077pub(super) fn reduce_ratio(mut p: i64, mut q: i64) -> (i64, u64) {
1078 let gcd = p.abs().gcd(&q.abs());
1079 if gcd > 1 {
1080 p /= gcd;
1081 q /= gcd;
1082 }
1083 if q < 0 { (-p, (-q) as u64) } else { (p, q as u64) }
1084}
1085
1086impl Zero for TDim {
1087 fn zero() -> Self {
1088 Val(0)
1089 }
1090 fn is_zero(&self) -> bool {
1091 matches!(self, Val(0))
1092 }
1093}
1094
1095impl Default for TDim {
1096 fn default() -> TDim {
1097 Val(0)
1098 }
1099}
1100
1101impl num_traits::Bounded for TDim {
1102 fn min_value() -> Self {
1103 TDim::Val(i64::MIN)
1104 }
1105
1106 fn max_value() -> Self {
1107 TDim::Val(i64::MAX)
1108 }
1109}
1110
1111impl num_traits::One for TDim {
1112 fn one() -> Self {
1113 TDim::Val(1)
1114 }
1115}
1116
1117impl ::std::iter::Sum for TDim {
1118 fn sum<I: Iterator<Item = TDim>>(iter: I) -> TDim {
1119 iter.fold(0.into(), |a, b| a + b)
1120 }
1121}
1122
1123impl<'a> ::std::iter::Sum<&'a TDim> for TDim {
1124 fn sum<I: Iterator<Item = &'a TDim>>(iter: I) -> TDim {
1125 iter.fold(0.into(), |a, b| a + b)
1126 }
1127}
1128
1129impl std::iter::Product for TDim {
1130 fn product<I: Iterator<Item = TDim>>(iter: I) -> Self {
1131 iter.fold(TDim::Val(1), |a, b| a * b)
1132 }
1133}
1134
1135impl<'a> ::std::iter::Product<&'a TDim> for TDim {
1136 fn product<I: Iterator<Item = &'a TDim>>(iter: I) -> TDim {
1137 iter.fold(1.into(), |a, b| a * b)
1138 }
1139}
1140
1141macro_rules! from_i {
1142 ($i: ty) => {
1143 impl From<$i> for TDim {
1144 fn from(v: $i) -> TDim {
1145 TDim::Val(v as _)
1146 }
1147 }
1148 impl<'a> From<&'a $i> for TDim {
1149 fn from(v: &'a $i) -> TDim {
1150 TDim::Val(*v as _)
1151 }
1152 }
1153 };
1154}
1155
1156from_i!(i32);
1157from_i!(i64);
1158from_i!(u64);
1159from_i!(isize);
1160from_i!(usize);
1161
1162impl From<Symbol> for TDim {
1163 fn from(it: Symbol) -> Self {
1164 TDim::Sym(it)
1165 }
1166}
1167
1168impl<'a> From<&'a Symbol> for TDim {
1169 fn from(it: &'a Symbol) -> Self {
1170 TDim::Sym(it.clone())
1171 }
1172}
1173
1174impl ops::Neg for TDim {
1175 type Output = Self;
1176 fn neg(self) -> Self {
1177 if let Val(v) = self { Val(-v) } else { TDim::MulInt(-1, Box::new(self)).reduce() }
1178 }
1179}
1180
1181impl<'a> ops::AddAssign<&'a TDim> for TDim {
1182 fn add_assign(&mut self, rhs: &'a TDim) {
1183 if rhs.is_zero() {
1184 } else if self.is_zero() {
1185 *self = rhs.clone();
1186 } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1187 *s += o;
1188 } else {
1189 *self = TDim::Add(vec![std::mem::take(self), rhs.clone()]).reduce()
1190 }
1191 }
1192}
1193
1194impl<I> ops::AddAssign<I> for TDim
1195where
1196 I: Into<TDim>,
1197{
1198 fn add_assign(&mut self, rhs: I) {
1199 let rhs = rhs.into();
1200 if rhs.is_zero() {
1201 } else if self.is_zero() {
1202 *self = rhs;
1203 } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1204 *s += o;
1205 } else {
1206 *self = TDim::Add(vec![std::mem::take(self), rhs]).reduce()
1207 }
1208 }
1209}
1210
1211impl<I> ops::Add<I> for TDim
1212where
1213 I: Into<TDim>,
1214{
1215 type Output = Self;
1216 fn add(mut self, rhs: I) -> Self {
1217 self += rhs;
1218 self
1219 }
1220}
1221
1222impl<'a> ops::Add<&'a TDim> for TDim {
1223 type Output = Self;
1224 fn add(mut self, rhs: &'a TDim) -> Self {
1225 self += rhs;
1226 self
1227 }
1228}
1229
1230#[allow(clippy::suspicious_op_assign_impl)]
1231impl<'a> ops::SubAssign<&'a TDim> for TDim {
1232 fn sub_assign(&mut self, rhs: &'a TDim) {
1233 if rhs.is_zero() {
1234 } else if self.is_zero() {
1235 *self = rhs.clone().neg();
1236 } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1237 *s -= o;
1238 } else {
1239 *self = TDim::Add(vec![std::mem::take(self), rhs.clone().neg()]).reduce()
1240 }
1241 }
1242}
1243
1244impl<I> ops::SubAssign<I> for TDim
1245where
1246 I: Into<TDim>,
1247{
1248 fn sub_assign(&mut self, rhs: I) {
1249 let rhs = rhs.into();
1250 if rhs.is_zero() {
1251 } else if self.is_zero() {
1252 *self = rhs.neg();
1253 } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1254 *s -= o;
1255 } else {
1256 *self = TDim::Add(vec![std::mem::take(self), rhs.neg()]).reduce()
1257 }
1258 }
1259}
1260
1261impl<I> ops::Sub<I> for TDim
1262where
1263 I: Into<TDim>,
1264{
1265 type Output = Self;
1266 fn sub(mut self, rhs: I) -> Self {
1267 self -= rhs;
1268 self
1269 }
1270}
1271
1272impl<'a> ops::Sub<&'a TDim> for TDim {
1273 type Output = Self;
1274 fn sub(mut self, rhs: &'a TDim) -> Self {
1275 self -= rhs;
1276 self
1277 }
1278}
1279
1280impl<I: Into<TDim>> ops::MulAssign<I> for TDim {
1281 fn mul_assign(&mut self, rhs: I) {
1282 let rhs = rhs.into();
1283 if self.is_one() {
1284 *self = rhs
1285 } else if rhs.is_one() {
1286 } else {
1287 *self = TDim::Mul(vec![rhs, std::mem::take(self)]).reduce()
1288 }
1289 }
1290}
1291
1292impl<'a> ops::MulAssign<&'a TDim> for TDim {
1293 fn mul_assign(&mut self, rhs: &'a TDim) {
1294 if self.is_one() {
1295 *self = rhs.clone()
1296 } else if rhs.is_one() {
1297 } else {
1298 *self = TDim::Mul(vec![std::mem::take(self), rhs.clone()]).reduce()
1299 }
1300 }
1301}
1302
1303impl<I: Into<TDim>> ops::Mul<I> for TDim {
1304 type Output = Self;
1305 fn mul(mut self, rhs: I) -> Self {
1306 self *= rhs.into();
1307 self
1308 }
1309}
1310
1311impl<'a> ops::Mul<&'a TDim> for TDim {
1312 type Output = Self;
1313 fn mul(mut self, rhs: &'a TDim) -> Self {
1314 self *= rhs;
1315 self
1316 }
1317}
1318
1319impl<I: AsPrimitive<u64> + PrimInt> ops::DivAssign<I> for TDim {
1320 fn div_assign(&mut self, rhs: I) {
1321 *self = TDim::Div(Box::new(std::mem::take(self)), rhs.as_()).reduce()
1322 }
1323}
1324
1325impl<I: AsPrimitive<u64> + PrimInt> ops::Div<I> for TDim {
1326 type Output = Self;
1327 fn div(mut self, rhs: I) -> Self {
1328 self /= rhs.as_();
1329 self
1330 }
1331}
1332
1333impl<I: AsPrimitive<u64> + PrimInt> ops::RemAssign<I> for TDim {
1334 fn rem_assign(&mut self, rhs: I) {
1335 *self += -(self.clone() / rhs.as_() * rhs.as_());
1336 }
1337}
1338
1339impl<I: AsPrimitive<u64> + PrimInt> ops::Rem<I> for TDim {
1340 type Output = Self;
1341 fn rem(mut self, rhs: I) -> Self {
1342 self %= rhs;
1343 self
1344 }
1345}
1346
1347#[cfg(test)]
1348mod tests {
1349 use super::*;
1350
1351 macro_rules! b( ($e:expr) => { Box::new($e) } );
1352
1353 lazy_static::lazy_static! {
1354 static ref table: SymbolScope = SymbolScope::default();
1355 static ref A: Symbol = table.sym("a");
1356 static ref B: Symbol = table.sym("b");
1357 static ref C: Symbol = table.sym("c");
1358 static ref D: Symbol = table.sym("d");
1359 static ref E: Symbol = table.sym("e");
1360 }
1361
1362 fn neg(a: &TDim) -> TDim {
1363 mul(-1, a)
1364 }
1365
1366 fn add(a: &TDim, b: &TDim) -> TDim {
1367 TDim::Add(vec![a.clone(), b.clone()])
1368 }
1369
1370 fn mul(a: i64, b: &TDim) -> TDim {
1371 TDim::MulInt(a, b![b.clone()])
1372 }
1373
1374 fn div(a: &TDim, b: u64) -> TDim {
1375 TDim::Div(b!(a.clone()), b)
1376 }
1377
1378 #[test]
1379 fn reduce_add() {
1380 assert_eq!(add(&A.to_dim(), &neg(&A.to_dim())).reduce(), Val(0))
1381 }
1382
1383 #[test]
1384 fn reduce_neg_mul() {
1385 assert_eq!(neg(&mul(2, &A.to_dim())).reduce(), mul(-2, &A.to_dim()))
1386 }
1387
1388 #[test]
1389 fn reduce_cplx_ex_2() {
1390 assert_eq!(
1391 add(
1392 &add(&Val(-4), &mul(-2, &div(&A.to_dim(), 4))),
1393 &mul(-2, &mul(-1, &div(&A.to_dim(), 4)))
1394 )
1395 .reduce(),
1396 Val(-4)
1397 )
1398 }
1399
1400 #[test]
1401 fn reduce_cplx_ex_3() {
1402 assert_eq!(div(&MulInt(1, b!(MulInt(4, b!(A.to_dim())))), 4).reduce(), A.to_dim())
1403 }
1404
1405 #[test]
1406 fn reduce_cplx_ex_4() {
1407 assert_eq!(
1409 add(&div(&add(&A.to_dim(), &Val(1)), 2), &div(&add(&neg(&A.to_dim()), &Val(1)), 2))
1410 .reduce(),
1411 1.into()
1412 );
1413 }
1414
1415 #[test]
1416 fn reduce_mul_mul_1() {
1417 assert_eq!(mul(3, &mul(2, &A.to_dim())).reduce(), mul(6, &A.to_dim()))
1418 }
1419
1420 #[test]
1421 fn reduce_mul_mul_2() {
1422 assert_eq!(mul(-2, &mul(-1, &A.to_dim())).reduce(), mul(2, &A.to_dim()))
1423 }
1424
1425 #[test]
1426 fn reduce_mul_div_1() {
1427 assert_eq!(mul(2, &div(&mul(-1, &A.to_dim()), 3)).reduce(), mul(-2, &div(&A.to_dim(), 3)))
1428 }
1429
1430 #[test]
1431 fn const_and_add() {
1432 let e: TDim = 2i64.into();
1433 assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), 2);
1434 let e: TDim = TDim::from(2) + 3;
1435 assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), 5);
1436 let e: TDim = TDim::from(2) - 3;
1437 assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), -1);
1438 let e: TDim = -TDim::from(2);
1439 assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), -2);
1440 }
1441
1442 #[test]
1443 fn substitution() {
1444 let a: TDim = A.to_dim();
1445 assert_eq!(a.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 2);
1446 let e = a + 3;
1447 assert_eq!(e.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 5);
1448 }
1449
1450 #[test]
1451 fn reduce_adds() {
1452 let e: TDim = TDim::from(2) + 1;
1453 assert_eq!(e, TDim::from(3));
1454 let e: TDim = TDim::from(3) + 2;
1455 assert_eq!(e, TDim::from(5));
1456 let e: TDim = TDim::from(3) + 0;
1457 assert_eq!(e, TDim::from(3));
1458 let e: TDim = TDim::from(3) + 2 + 1;
1459 assert_eq!(e, TDim::from(6));
1460 }
1461
1462 #[test]
1463 fn reduce_muls() {
1464 let e: TDim = Val(1) * A.to_dim();
1465 assert_eq!(e, A.to_dim());
1466 let e: TDim = A.to_dim() * &B.to_dim() * 1;
1467 assert_eq!(e, A.to_dim() * &B.to_dim());
1468 }
1469
1470 #[test]
1471 fn reduce_divs() {
1472 let e: TDim = TDim::from(2) / 1;
1473 assert_eq!(e, TDim::from(2));
1474 let e: TDim = TDim::from(3) / 2;
1475 assert_eq!(e, TDim::from(1));
1476 let e: TDim = TDim::from(3) % 2;
1477 assert_eq!(e, TDim::from(1));
1478 let e: TDim = TDim::from(5) / 2;
1479 assert_eq!(e, TDim::from(2));
1480 let e: TDim = TDim::from(5) % 2;
1481 assert_eq!(e, TDim::from(1));
1482 }
1483
1484 #[test]
1485 fn reduce_div_bug_0() {
1486 let e1: TDim = (A.to_dim() + 23) / 2 - 1;
1487 let e2: TDim = (A.to_dim() + 21) / 2;
1488 assert_eq!(e1, e2);
1489 }
1490
1491 #[test]
1492 fn reduce_div_bug_1() {
1493 let e1: TDim = (A.to_dim() + -1) / 2;
1494 let e2: TDim = (A.to_dim() + 1) / 2 - 1;
1495 assert_eq!(e1, e2);
1496 }
1497
1498 #[test]
1499 fn reduce_div_bug_2() {
1500 let e1: TDim = ((A.to_dim() + 1) / 2 + 1) / 2;
1501 let e2: TDim = (A.to_dim() + 3) / 4;
1502 assert_eq!(e1, e2);
1503 }
1504
1505 #[test]
1506 fn reduce_div_bug_3() {
1507 let e1: TDim = (A.to_dim() / 2) * -4;
1508 let e2: TDim = (A.to_dim() / 2) * -4 / 1;
1509 assert_eq!(e1, e2);
1510 }
1511
1512 #[test]
1513 fn reduce_mul_div() {
1514 let e: TDim = A.to_dim() * 2 / 2;
1515 assert_eq!(e, A.to_dim());
1516 }
1517
1518 #[test]
1519 fn reduce_div_mul() {
1520 let e: TDim = A.to_dim() / 2 * 2;
1521 assert_ne!(e, A.to_dim());
1522 }
1523
1524 #[test]
1525 fn reduce_add_div() {
1526 let e: TDim = A.to_dim() / 2 + 1;
1527 assert_eq!(e, ((A.to_dim() + 2) / 2));
1528 }
1529
1530 #[test]
1531 fn reduce_neg_mul_() {
1532 let e: TDim = TDim::from(1) - A.to_dim() * 2;
1533 assert_eq!(e, TDim::from(1) + A.to_dim() * -2);
1534 }
1535
1536 #[test]
1537 fn reduce_add_rem_1() {
1538 assert_eq!(((A.to_dim() + 4) % 2), (A.to_dim() % 2));
1539 }
1540
1541 #[test]
1542 fn reduce_add_rem_2() {
1543 assert_eq!(((A.to_dim() - 4) % 2), (A.to_dim() % 2));
1544 }
1545
1546 #[test]
1547 fn reduce_rem_div() {
1548 let e: TDim = A.to_dim() % 2 / 2;
1549 assert_eq!(e, TDim::from(0));
1550 }
1551
1552 #[test]
1553 fn conv2d_ex_1() {
1554 let e = (TDim::from(1) - 1 + 1).div_ceil(1);
1555 assert_eq!(e, TDim::from(1));
1556 }
1557
1558 #[test]
1559 fn conv2d_ex_2() {
1560 let e = (A.to_dim() - 3 + 1).div_ceil(1);
1561 assert_eq!(e, A.to_dim() + -2);
1562 }
1563
1564 #[test]
1565 fn extract_int_gcd_from_muls() {
1566 let term = (A.to_dim() + 1) / 4;
1567 let mul = (term.clone() * 24 - 24) * (term.clone() * 2 - 2);
1568 let target = (term.clone() - 1) * (term.clone() - 1) * 48;
1569 assert_eq!(mul, target);
1570 }
1571
1572 #[test]
1573 fn equality_of_muls() {
1574 let term = (A.to_dim() + 1) / 4;
1575 let mul1 = (term.clone() * 2 - 3) * (term.clone() - 1);
1576 let mul2 = (term.clone() - 1) * (term.clone() * 2 - 3);
1577 assert_eq!(mul1, mul2);
1578 }
1579
1580 #[test]
1581 fn factorize_complex_expr_times_int() {
1582 let term = (A.to_dim() + 1) / 4;
1583 let e = term.clone() * 2 - &term - 1;
1584 assert_eq!(e, term - 1);
1585 }
1586
1587 #[test]
1588 fn broadcast_over_min() {
1589 for a in 1..5 {
1595 for b in 1..5 {
1596 if b > 1 && a > b {
1597 assert!(a.broadcast(a.min(b)).is_err());
1598 } else {
1599 assert_eq!(a.broadcast(a.min(b)).unwrap(), a);
1600 }
1601 }
1602 }
1603 }
1604
1605 #[test]
1606 fn min_ints_1() {
1607 assert_eq!(2.to_dim().mini(1.to_dim()), 1.to_dim());
1608 }
1609
1610 #[test]
1611 fn min_ints_2() {
1612 assert_eq!(1.to_dim().mini(2.to_dim()), 1.to_dim());
1613 }
1614
1615 #[test]
1616 fn min_same() {
1617 assert_eq!(A.to_dim().mini(A.to_dim()), A.to_dim());
1618 }
1619
1620 #[test]
1621 fn min_noop() {
1622 assert_eq!(A.to_dim().mini(1.to_dim()), A.to_dim().mini(1.to_dim()));
1623 }
1624
1625 #[test]
1626 fn min_diff_1() {
1627 assert_eq!((A.to_dim() + 1).mini(A.to_dim() + 2), A.to_dim() + 1);
1628 }
1629
1630 #[test]
1631 fn slope_0() {
1632 assert_eq!(12.to_dim().guess_slope(&A), (0, 1));
1633 }
1634
1635 #[test]
1636 fn slope_1() {
1637 assert_eq!(A.to_dim().guess_slope(&A), (1, 1));
1638 }
1639
1640 #[test]
1641 fn slope_2() {
1642 assert_eq!((A.to_dim() * 2).guess_slope(&A), (2, 1));
1643 }
1644
1645 #[test]
1646 fn slope_3() {
1647 assert_eq!((A.to_dim() * 2 + A.to_dim() / 2).guess_slope(&A), (5, 2));
1648 }
1649
1650 #[test]
1651 fn slope_4() {
1652 assert_eq!((A.to_dim()).guess_slope(&B), (0, 1));
1653 }
1654
1655 #[test]
1656 fn slope_5() {
1657 assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1));
1658 assert_eq!((A.to_dim() + 1).guess_slope(&B), (0, 1));
1659 }
1660
1661 #[test]
1662 fn slope_6() {
1663 assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1));
1664 assert_eq!((A.to_dim() + B.to_dim()).guess_slope(&B), (1, 1));
1665 }
1666
1667 #[test]
1668 fn min_0() -> TractResult<()> {
1669 let symbols = SymbolScope::default();
1670 assert_eq!(
1671 symbols.parse_tdim("min(S+3, S+2)").unwrap().simplify(),
1672 symbols.parse_tdim("S+2").unwrap(),
1673 );
1674 Ok(())
1675 }
1676
1677 #[test]
1678 fn commutative_mul_parens() -> TractResult<()> {
1679 let symbols = SymbolScope::default();
1680 assert_eq!(
1681 symbols.parse_tdim("A*(B*C)").unwrap().simplify(),
1682 symbols.parse_tdim("(B*A)*C").unwrap().simplify(),
1683 );
1684 Ok(())
1685 }
1686
1687 #[test]
1688 fn commutative_in_nemo_parakeet_model() -> TractResult<()> {
1689 let symbols = SymbolScope::default();
1690 assert_eq!(
1691 symbols
1692 .parse_tdim("8*(1+-1*max(0,5000+-1*(S+7)/8)+max(0,4999+(S+7)/8))*((B)*((S+7)/8))")
1693 .unwrap()
1694 .simplify(),
1695 symbols
1696 .parse_tdim("8*((B)*(1+-1*max(0,5000+-1*(S+7)/8)+max(0,4999+(S+7)/8)))*((S+7)/8)")
1697 .unwrap()
1698 .simplify(),
1699 );
1700 Ok(())
1701 }
1702
1703 #[test]
1704 fn commutative_mul_parens_deep() -> TractResult<()> {
1705 let symbols = SymbolScope::default();
1706 let deep_tdim = Mul(vec![
1707 Mul(vec![Mul(vec![Mul(vec![A.to_dim(), B.to_dim()]), C.to_dim()]), D.to_dim()]),
1708 E.to_dim(),
1709 ])
1710 .simplify();
1711 assert_eq!(deep_tdim, symbols.parse_tdim("a*b*c*d*e").unwrap().simplify());
1712 Ok(())
1713 }
1714
1715 #[test]
1718 fn ge_concrete_true() {
1719 assert_eq!(Ge(b!(Val(5)), b!(Val(3))).reduce(), Val(1));
1720 }
1721
1722 #[test]
1723 fn ge_concrete_false() {
1724 assert_eq!(Ge(b!(Val(2)), b!(Val(3))).reduce(), Val(0));
1725 }
1726
1727 #[test]
1728 fn lt_concrete_true() {
1729 assert_eq!(Ge(b!(Val(3)), b!(Val(3))).reduce(), Val(1));
1731 }
1732
1733 #[test]
1734 fn lt_concrete_false() {
1735 assert_eq!(Ge(b!(Val(3)), b!(Val(6))).reduce(), Val(0));
1737 }
1738
1739 #[test]
1740 fn eq_concrete_true() {
1741 assert_eq!(Eq(b!(Val(3)), b!(Val(3))).reduce(), Val(1));
1742 }
1743
1744 #[test]
1745 fn eq_concrete_false() {
1746 assert_eq!(Eq(b!(Val(3)), b!(Val(4))).reduce(), Val(0));
1747 }
1748
1749 #[test]
1750 fn not_val_0() {
1751 assert_eq!((Val(1) - Val(0)).reduce(), Val(1));
1753 }
1754
1755 #[test]
1756 fn not_val_1() {
1757 assert_eq!((Val(1) - Val(1)).reduce(), Val(0));
1759 }
1760
1761 #[test]
1762 fn not_lt_becomes_ge() {
1763 let s = SymbolScope::default();
1765 let t = s.sym("T");
1766 let x1 = s.sym("x1");
1767 let expr = Val(1) - Ge(b!(Sym(t.clone())), b!(Sym(x1.clone()) + Val(1)));
1769 let at_boundary = expr.substitute(&x1, &Sym(t.clone())).unwrap().simplify();
1770 assert_eq!(at_boundary, Val(1));
1771 }
1772
1773 #[test]
1774 fn eq_with_assertion_proves_false() {
1775 let s = SymbolScope::default();
1777 s.add_assertion("T >= 1").unwrap();
1778 let t = s.sym("T");
1779 let expr = Eq(b!(Sym(t)), b!(Val(0)));
1780 assert_eq!(expr.simplify(), Val(0));
1781 }
1782
1783 #[test]
1784 fn ge_coord_at_extremes() {
1785 let s = SymbolScope::default();
1787 s.add_assertion("T >= 1").unwrap();
1788 let t = s.sym("T");
1789 let x1 = s.sym("x1");
1790 let expr = Ge(b!(Sym(x1.clone())), b!(Sym(t.clone())));
1791 let at_max = expr.substitute(&x1, &(Sym(t.clone()) - Val(1))).unwrap().simplify();
1794 assert_eq!(at_max, Val(0));
1795 }
1796
1797 #[test]
1798 fn eval_to_i64_new_variants() {
1799 use super::super::sym::SymbolValues;
1800 let sv = SymbolValues::default();
1801 assert_eq!(Ge(b!(Val(5)), b!(Val(3))).eval_to_i64(&sv).unwrap(), 1);
1802 assert_eq!(Ge(b!(Val(3)), b!(Val(5))).eval_to_i64(&sv).unwrap(), 0);
1803 assert_eq!(Eq(b!(Val(3)), b!(Val(3))).eval_to_i64(&sv).unwrap(), 1);
1804 assert_eq!(Eq(b!(Val(3)), b!(Val(4))).eval_to_i64(&sv).unwrap(), 0);
1805 }
1806
1807 #[test]
1808 fn eq_boolean_simplifies() {
1809 let s = SymbolScope::default();
1810 s.add_assertion("cw >= 0").unwrap();
1811 s.add_assertion("cw <= 1").unwrap();
1812 let cw = s.sym("cw");
1813 assert_eq!(Eq(b!(Val(1) - Sym(cw.clone())), b!(Val(0))).simplify(), Sym(cw.clone()));
1815 assert_eq!(Eq(b!(Sym(cw.clone())), b!(Val(0))).simplify(), Val(1) - Sym(cw.clone()));
1817 assert_eq!(Eq(b!(Sym(cw.clone())), b!(Val(1))).simplify(), Sym(cw.clone()));
1819 assert_eq!(Eq(b!(Val(1) - Sym(cw.clone())), b!(Val(1))).simplify(), Val(1) - Sym(cw));
1821 }
1822
1823 #[test]
1824 fn eq_boolean_mul_of_ge() {
1825 let s = SymbolScope::default();
1828 let x = s.sym("x");
1829 let product =
1830 Mul(vec![Ge(b!(Val(2)), b!(Sym(x.clone()))), Ge(b!(Sym(x.clone())), b!(Val(0)))]);
1831 let eq = Eq(b!(product.clone()), b!(Val(0)));
1832 assert_eq!(eq.simplify(), Val(1) - product);
1833 }
1834
1835 #[test]
1836 fn min_1_max_0_sym() {
1837 let s = SymbolScope::default();
1839 let x = s.sym("X");
1840 let expr = Min(vec![Val(1), Max(vec![Val(0), Sym(x)])]);
1841 let simplified = expr.simplify();
1842 eprintln!("simplified: {simplified}");
1843 assert!(format!("{simplified}").contains("min"), "Min dropped: {simplified}");
1844 }
1845
1846 #[test]
1847 fn min_preserved_in_subtraction_parts() {
1848 let s = SymbolScope::default();
1850 let t = s.sym("T");
1851 let p = s.sym("P");
1852 let ss = s.sym("S");
1853
1854 let cum_after =
1855 Max(vec![Val(0), (Sym(t.clone()) + Val(1)) * Sym(p.clone()) - Sym(ss.clone())]);
1856 let min_after = Min(vec![Val(1), cum_after.clone()]);
1857 let simplified = min_after.simplify();
1858 eprintln!("min_after simplified: {simplified}");
1859 assert!(format!("{simplified}").contains("min"), "Min wrapper was dropped: {simplified}");
1861 }
1862
1863 #[test]
1864 fn min_preserved_in_subtraction() {
1865 let s = SymbolScope::default();
1868 let t = s.sym("T");
1869 let p = s.sym("P");
1870 let ss = s.sym("S");
1871
1872 let cum_after =
1873 Max(vec![Val(0), (Sym(t.clone()) + Val(1)) * Sym(p.clone()) - Sym(ss.clone())]);
1874 let cum_before = Max(vec![Val(0), Sym(t.clone()) * Sym(p.clone()) - Sym(ss.clone())]);
1875
1876 let ap = Min(vec![Val(1), cum_after.clone()]) - Min(vec![Val(1), cum_before.clone()]);
1877 let simplified = ap.simplify();
1878
1879 use super::super::sym::SymbolValues;
1881 let sv = SymbolValues::default().with(&t, 1).with(&p, 4).with(&ss, 3);
1882 assert_eq!(simplified.eval_to_i64(&sv).unwrap(), 0, "simplified: {simplified}");
1883
1884 let sv = SymbolValues::default().with(&t, 0).with(&p, 4).with(&ss, 3);
1886 assert_eq!(simplified.eval_to_i64(&sv).unwrap(), 1, "simplified: {simplified}");
1887
1888 let sv = SymbolValues::default().with(&t, 0).with(&p, 1).with(&ss, 1);
1890 assert_eq!(simplified.eval_to_i64(&sv).unwrap(), 0, "simplified: {simplified}");
1891 }
1892
1893 #[test]
1894 fn mul_neg_b_by_8() {
1895 let s = SymbolScope::default();
1896 let b = Sym(s.sym("B"));
1897 let a = Mul(vec![Val(8), MulInt(-1, Box::new(b.clone()))]);
1899 let c = MulInt(-8, Box::new(b.clone()));
1900 let a_s = a.simplify();
1901 let c_s = c.simplify();
1902 assert_eq!(a_s, c_s, "8*(-1*B) should simplify the same as -8*B");
1903 }
1904}