Skip to main content

tract_data/dim/
tree.rs

1use crate::dim::Assertion;
2use crate::internal::*;
3
4use super::{DimLike, sym::*};
5use itertools::Itertools;
6use num_integer::Integer;
7use num_traits::{AsPrimitive, PrimInt, Zero};
8use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10use std::fmt::Debug;
11use std::ops::Neg;
12use std::{fmt, ops};
13
14#[derive(Debug)]
15pub enum TooEarly {
16    UndeterminedSymbol(String),
17    Other(String),
18}
19
20impl std::fmt::Display for TooEarly {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        match self {
23            TooEarly::UndeterminedSymbol(s) => write!(f, "Undetermined symbol in expression: {s}"),
24            TooEarly::Other(s) => write!(f, "{s}"),
25        }
26    }
27}
28
29impl std::error::Error for TooEarly {}
30
31macro_rules! b( ($e:expr) => { Box::new($e) } );
32
33#[derive(Clone, PartialEq, Eq, Hash, Debug)]
34pub enum TDim {
35    Val(i64),
36    Sym(Symbol),
37    Add(Vec<TDim>),
38    Mul(Vec<TDim>),
39    MulInt(i64, Box<TDim>),
40    Div(Box<TDim>, u64),
41    Broadcast(Vec<TDim>),
42    Min(Vec<TDim>),
43    Max(Vec<TDim>),
44}
45
46use TDim::*;
47
48fn tdim_lexi_order(a: &TDim, b: &TDim) -> Ordering {
49    match (a, b) {
50        (Sym(a), Sym(b)) => a.cmp(b),
51        (Val(a), Val(b)) => a.cmp(b),
52        (Add(a), Add(b))
53        | (Mul(a), Mul(b))
54        | (Broadcast(a), Broadcast(b))
55        | (Min(a), Min(b))
56        | (Max(a), Max(b)) => a.len().cmp(&b.len()).then(
57            a.iter()
58                .zip(b.iter())
59                .fold(Ordering::Equal, |acc, (a, b)| acc.then_with(|| tdim_lexi_order(a, b))),
60        ),
61        (MulInt(p, d), MulInt(q, e)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)),
62        (Div(d, p), Div(e, q)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)),
63        (Sym(_), _) => Ordering::Less,
64        (_, Sym(_)) => Ordering::Greater,
65        (Val(_), _) => Ordering::Less,
66        (_, Val(_)) => Ordering::Greater,
67        (Add(_), _) => Ordering::Less,
68        (_, Add(_)) => Ordering::Greater,
69        (Mul(_), _) => Ordering::Less,
70        (_, Mul(_)) => Ordering::Greater,
71        (MulInt(_, _), _) => Ordering::Less,
72        (_, MulInt(_, _)) => Ordering::Greater,
73        (Broadcast(_), _) => Ordering::Less,
74        (_, Broadcast(_)) => Ordering::Greater,
75        (Min(_), _) => Ordering::Less,
76        (_, Min(_)) => Ordering::Greater,
77        (Max(_), _) => Ordering::Less,
78        (_, Max(_)) => Ordering::Greater,
79    }
80}
81
82impl fmt::Display for TDim {
83    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
84        match &self {
85            Sym(sym) => write!(fmt, "{sym}"),
86            Val(it) => write!(fmt, "{it}"),
87            Add(it) => write!(fmt, "{}", it.iter().map(|x| format!("{x}")).join("+")),
88            Mul(it) => write!(fmt, "{}", it.iter().map(|x| format!("({x})")).join("*")),
89            Broadcast(it) => write!(fmt, "{}", it.iter().map(|x| format!("({x})")).join("#")),
90            Min(it) => write!(fmt, "min({})", it.iter().map(|x| format!("{x}")).join(",")),
91            Max(it) => write!(fmt, "max({})", it.iter().map(|x| format!("{x}")).join(",")),
92            MulInt(a, b) => write!(fmt, "{a}*{b}"),
93            Div(a, b) => write!(fmt, "({a})/{b}"),
94        }
95    }
96}
97
98impl TDim {
99    #[inline]
100    pub fn is_one(&self) -> bool {
101        matches!(self, Val(1))
102    }
103
104    #[inline]
105    pub fn to_i64(&self) -> TractResult<i64> {
106        if let Val(v) = self {
107            Ok(*v)
108        } else {
109            Err(TooEarly::UndeterminedSymbol(self.to_string()))?
110        }
111    }
112
113    #[inline]
114    pub fn as_i64(&self) -> Option<i64> {
115        if let Val(v) = self { Some(*v) } else { None }
116    }
117
118    pub fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64> {
119        match self {
120            Sym(sym) => {
121                let Some(v) = values.get(sym) else {
122                    Err(TooEarly::UndeterminedSymbol(self.to_string()))?
123                };
124                Ok(v)
125            }
126            Val(v) => Ok(*v),
127            Add(terms) => {
128                terms.iter().try_fold(0, |acc, it| it.eval_to_i64(values).map(|x| acc + x))
129            }
130            Mul(terms) => {
131                terms.iter().try_fold(1, |acc, it| it.eval_to_i64(values).map(|x| acc * x))
132            }
133            Min(terms) => terms
134                .iter()
135                .try_fold(i64::MAX, |acc, it| it.eval_to_i64(values).map(|x| acc.min(x))),
136            Max(terms) => terms
137                .iter()
138                .try_fold(i64::MIN, |acc, it| it.eval_to_i64(values).map(|x| acc.max(x))),
139            Broadcast(terms) => terms.iter().try_fold(1i64, |acc, it| {
140                it.eval_to_i64(values)
141                    .and_then(|x| ((acc as usize).broadcast(x as usize)).map(|x| x as i64))
142            }),
143            Div(a, q) => Ok(a.eval_to_i64(values)? / *q as i64),
144            MulInt(p, a) => Ok(a.eval_to_i64(values)? * *p),
145        }
146    }
147
148    pub fn eval(&self, values: &SymbolValues) -> TDim {
149        match self {
150            Sym(sym) => values.get(sym).map(Val).unwrap_or_else(|| Sym(sym.clone())),
151            Val(v) => Val(*v),
152            Add(terms) => terms.iter().fold(Val(0), |acc, it| -> TDim { acc + it.eval(values) }),
153            Mul(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim { acc * it.eval(values) }),
154            Min(terms) => {
155                terms.iter().fold(Val(i64::MAX), |acc, it| -> TDim { acc.mini(it.eval(values)) })
156            }
157            Max(terms) => {
158                terms.iter().fold(Val(i64::MIN), |acc, it| -> TDim { acc.maxi(it.eval(values)) })
159            }
160            Broadcast(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim {
161                acc.broadcast(it.eval(values)).unwrap_or_else(|_| self.clone())
162            }),
163            Div(a, q) => a.eval(values) / *q as i64,
164            MulInt(p, a) => a.eval(values) * *p,
165        }
166    }
167
168    pub fn eval_with_scenario(&self, scenario: &str) -> TDim {
169        if let Val(v) = self {
170            return Val(*v);
171        }
172        let scope = self.find_scope().unwrap();
173        let scope = scope.0;
174        let locked = scope.lock();
175        let scope = locked.borrow();
176        self.clone().simplify_rec(&scope, Some(scenario))
177    }
178
179    pub fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self> {
180        match self {
181            Sym(sym) => Ok(if sym == from { to.clone() } else { self.clone() }),
182            Val(v) => Ok(Val(*v)),
183            Add(terms) => terms.iter().try_fold(Val(0), |acc, it| -> TractResult<TDim> {
184                Ok(acc + it.substitute(from, to)?)
185            }),
186            Mul(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult<TDim> {
187                Ok(acc * it.substitute(from, to)?)
188            }),
189            Broadcast(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult<TDim> {
190                acc.broadcast(it.substitute(from, to)?)
191            }),
192            Min(terms) => terms.iter().try_fold(Val(i64::MAX), |acc, it| -> TractResult<TDim> {
193                Ok(acc.mini(it.substitute(from, to)?))
194            }),
195            Max(terms) => terms.iter().try_fold(Val(i64::MIN), |acc, it| -> TractResult<TDim> {
196                Ok(acc.maxi(it.substitute(from, to)?))
197            }),
198            Div(a, q) => Ok(a.substitute(from, to)? / *q as i64),
199            MulInt(p, a) => Ok(a.substitute(from, to)? * *p),
200        }
201    }
202
203    pub fn reduce(self) -> TDim {
204        self.simplify()
205            .wiggle()
206            .into_iter()
207            .sorted_by(tdim_lexi_order)
208            .unique()
209            .map(|e| e.simplify())
210            .min_by_key(|e| e.cost())
211            .unwrap()
212    }
213
214    fn cost(&self) -> usize {
215        use self::TDim::*;
216        match self {
217            Sym(_) | Val(_) => 1,
218            Add(terms) => 2 * terms.iter().map(TDim::cost).sum::<usize>(),
219            Mul(terms) => 3 * terms.iter().map(TDim::cost).sum::<usize>(),
220            Broadcast(terms) => 4 * terms.iter().map(TDim::cost).sum::<usize>(),
221            Min(terms) | Max(terms) => 5 * terms.iter().map(TDim::cost).sum::<usize>(),
222            Div(a, _) => 3 * a.cost(),
223            MulInt(_, a) => 2 * a.cost(),
224        }
225    }
226
227    fn wiggle(&self) -> Vec<TDim> {
228        use self::TDim::*;
229        match self {
230            Sym(_) | Val(_) | Mul(_) | Broadcast(_) | Min(_) | Max(_) => vec![self.clone()],
231            Add(terms) => {
232                let mut forms = vec![];
233                let sub_exprs = terms.iter().map(|e| e.wiggle()).multi_cartesian_product();
234
235                fn first_div_term(terms: &[TDim]) -> Option<(usize, &TDim, u64)> {
236                    terms.iter().enumerate().find_map(|(index, t)| match t {
237                        Div(numerator, quotient) => Some((index, &**numerator, *quotient)),
238                        _ => None,
239                    })
240                }
241
242                fn generate_new_numerator(
243                    div_index: usize,
244                    numerator: &TDim,
245                    quotient: u64,
246                    expr: &[TDim],
247                ) -> Vec<TDim> {
248                    expr.iter()
249                        .enumerate()
250                        .map(|(index, term)| {
251                            if index == div_index {
252                                numerator.clone()
253                            } else {
254                                MulInt(quotient as i64, Box::new(term.clone()))
255                            }
256                        })
257                        .collect()
258                }
259
260                for expr in sub_exprs {
261                    if let Some((div_index, numerator, quotient)) = first_div_term(&expr) {
262                        let new_numerator =
263                            generate_new_numerator(div_index, numerator, quotient, &expr);
264                        forms.push(Div(Box::new(Add(new_numerator)), quotient))
265                    }
266
267                    forms.push(Add(expr));
268                }
269                forms
270            }
271            MulInt(p, a) => a.wiggle().into_iter().map(|a| MulInt(*p, b!(a))).collect(),
272            Div(a, q) => {
273                let mut forms = vec![];
274                for num in a.wiggle() {
275                    if let Add(terms) = &num {
276                        let (integer, non_integer): (Vec<_>, Vec<_>) =
277                            terms.iter().cloned().partition(|a| a.gcd() % q == 0);
278                        let mut new_terms = integer.iter().map(|i| i.div(*q)).collect::<Vec<_>>();
279                        if non_integer.len() > 0 {
280                            new_terms.push(Div(b!(Add(non_integer)), *q));
281                        }
282                        forms.push(Add(new_terms))
283                    }
284                    forms.push(Div(b!(num), *q))
285                }
286                forms
287            }
288        }
289    }
290
291    fn find_any_sym(tdim: &TDim) -> Option<&Symbol> {
292        match tdim {
293            Val(_) => None,
294            Sym(s) => Some(s),
295            Add(terms) | Mul(terms) | Min(terms) | Max(terms) | Broadcast(terms) => {
296                terms.iter().find_map(Self::find_any_sym)
297            }
298            MulInt(_, t) | Div(t, _) => Self::find_any_sym(t),
299        }
300    }
301
302    pub fn find_scope(&self) -> Option<SymbolScope> {
303        Self::find_any_sym(self).and_then(|s| s.scope().clone())
304    }
305
306    pub fn simplify(self) -> TDim {
307        use self::TDim::*;
308        if let Ok(v) = self.eval_to_i64(&SymbolValues::default()) {
309            return Val(v);
310        }
311        let Some(scope) = self.find_scope() else {
312            return self;
313        };
314        let scope = scope.0;
315        let locked = scope.lock();
316        let scope = locked.borrow();
317        let it = self.simplify_rec(&scope, None);
318        let mut current: Option<TDim> = None;
319        for scenario in scope.scenarios() {
320            let v = it.clone().simplify_rec(&scope, Some(scenario));
321            if current.is_some_and(|c| c != v) {
322                return it;
323            } else {
324                current = Some(v);
325            }
326        }
327        current.unwrap_or(it)
328    }
329
330    fn simplify_rec(self, scope: &SymbolScopeData, scenario: Option<&str>) -> TDim {
331        match self {
332            Add(mut terms) => {
333                #[allow(clippy::mutable_key_type)]
334                let mut simplified_terms: HashMap<TDim, i64> = HashMap::new();
335                // factorize common sub-expr
336                while let Some(term) = terms.pop() {
337                    let simplified = term.simplify_rec(scope, scenario);
338                    match simplified {
339                        Val(0) => {} // ignore
340                        Add(members) => {
341                            terms.extend(members);
342                            continue;
343                        }
344                        Val(value) => *simplified_terms.entry(Val(1)).or_insert(0) += value,
345                        MulInt(value, factor) => {
346                            *simplified_terms.entry((*factor).clone()).or_insert(0) += value;
347                        }
348                        n => *simplified_terms.entry(n).or_insert(0) += 1,
349                    };
350                }
351
352                pub fn evaluate_count(term: TDim, count: i64) -> Option<TDim> {
353                    match count {
354                        0 => None,
355                        _ if term == TDim::Val(1) => Some(TDim::Val(count)),
356                        1 => Some(term),
357                        _ => Some(TDim::MulInt(count, Box::new(term))),
358                    }
359                }
360
361                let mut members: Vec<TDim> = simplified_terms
362                    .into_iter()
363                    .filter_map(|(term, count)| evaluate_count(term, count))
364                    .collect();
365                members.sort_by(tdim_lexi_order);
366
367                match members.len() {
368                    0 => TDim::Val(0),
369                    1 => members.into_iter().next().unwrap(),
370                    _ => TDim::Add(members),
371                }
372            }
373            Mul(terms) => {
374                // in case a term is a multiplication itself, flatten it
375                // e.g., (a*b)*c => a*b*c
376                let mut flattened_terms = vec![];
377                for t in terms {
378                    if let Mul(inner_terms) = t.clone().reduce() {
379                        flattened_terms.extend(inner_terms);
380                    } else {
381                        flattened_terms.push(t);
382                    }
383                }
384                let mut terms = flattened_terms;
385
386                let mut gcd = Mul(terms.clone()).gcd() as i64;
387                if gcd == 0 {
388                    return Val(0);
389                }
390                terms = if gcd != 1 {
391                    terms
392                        .into_iter()
393                        .map(|t| {
394                            let gcd = t.gcd();
395                            (t / gcd).simplify_rec(scope, scenario)
396                        })
397                        .collect()
398                } else {
399                    terms
400                };
401                if terms.iter().filter(|t| t == &&Val(-1)).count() % 2 == 1 {
402                    gcd = -gcd;
403                }
404                terms.retain(|t| !t.is_one() && t != &Val(-1));
405                terms.sort_by(tdim_lexi_order);
406
407                match (gcd, terms.len()) {
408                    (_, 0) => Val(gcd), // Case #1: If 0 variables, return product
409                    (0, _) => Val(0),   // Case #2: Result is 0 if coef is 0 (actually
410                    // unreachable as we check at the beginning)
411                    (1, 1) => terms.remove(0), // Case #3: Product is 1, so return the only term
412                    (1, _) => Mul(terms), // Case #4: Product is 1, so return the non-integer terms
413                    (_, 1) => MulInt(gcd, Box::new(terms.remove(0))), // Case #5: Single variable, convert to 1 MulInt
414                    _ => MulInt(gcd, Box::new(Mul(terms))), // Case #6: Multiple variables, convert to MulInt
415                }
416            }
417            MulInt(coef, expr) => {
418                match *expr {
419                    MulInt(c2, inner) => {
420                        return MulInt(coef * c2, inner).simplify_rec(scope, scenario);
421                    }
422                    Val(v) => return Val(coef * v),
423                    _ => {}
424                }
425
426                let simplified = expr.simplify_rec(scope, scenario);
427                match (coef, simplified) {
428                    (0, _) => Val(0), // Case #1: If coef is 0, return 0
429                    (1, s) => s,      // Case #2: If coef is 1, return the simplified expression
430                    (_, Add(terms)) => Add(terms
431                        .into_iter()
432                        .map(|term| MulInt(coef, Box::new(term)).simplify_rec(scope, scenario))
433                        .collect()), // Case #3: If expression is an addition, distribute the coef
434                    (c, Val(v)) => Val(c * v), // Case #4: If expression is a value, combine coefs
435                    (c, MulInt(v, inner)) => MulInt(c * v, inner), // Case #5: If expression is a MulInt, combine coefs
436                    (_, s) => MulInt(coef, Box::new(s)), // Case #6: Otherwise, return the original
437                }
438            }
439            Div(a, q) => {
440                if q == 1 {
441                    return a.simplify_rec(scope, scenario);
442                } else if let Div(a, q2) = *a {
443                    return Div(a, q * q2).simplify_rec(scope, scenario);
444                }
445                let a = a.simplify_rec(scope, scenario);
446                if let Val(a) = a {
447                    Val(a / q as i64)
448                } else if let MulInt(-1, a) = a {
449                    MulInt(-1, b!(Div(a, q)))
450                } else if let Add(mut terms) = a {
451                    if terms
452                        .iter()
453                        .any(|t| if let MulInt(-1, s) = t { matches!(&**s, Sym(_)) } else { false })
454                    {
455                        MulInt(
456                            -1,
457                            b!(Div(
458                                b!(Add(terms.into_iter().map(|t| MulInt(-1, b!(t))).collect())
459                                    .simplify_rec(scope, scenario)),
460                                q
461                            )),
462                        )
463                    } else if let Some(v) =
464                        terms.iter().find_map(|t| if let Val(v) = t { Some(*v) } else { None })
465                    {
466                        let offset = if v >= q as i64 {
467                            Some(v / q as i64)
468                        } else if v < 0 {
469                            Some(-Integer::div_ceil(&-v, &(q as i64)))
470                        } else {
471                            None
472                        };
473                        if let Some(val) = offset {
474                            terms.push(Val(-val * q as i64));
475                            Add(vec![
476                                Val(val),
477                                Div(b!(Add(terms).simplify_rec(scope, scenario)), q),
478                            ])
479                        } else {
480                            Div(b!(Add(terms)), q)
481                        }
482                    } else {
483                        Div(b!(Add(terms)), q)
484                    }
485                } else if let MulInt(p, a) = a {
486                    if p == q as i64 {
487                        a.simplify()
488                    } else {
489                        let gcd = p.abs().gcd(&(q as i64));
490                        if gcd == p {
491                            Div(a, q / gcd as u64)
492                        } else if gcd == q as i64 {
493                            MulInt(p / gcd, a)
494                        } else if gcd > 1 {
495                            Div(b!(MulInt(p / gcd, a)), q / gcd as u64)
496                                .simplify_rec(scope, scenario)
497                        } else {
498                            Div(b!(MulInt(p, a)), q)
499                        }
500                    }
501                } else {
502                    Div(b!(a), q)
503                }
504            }
505            Broadcast(terms) => {
506                let mut terms: Vec<TDim> = terms
507                    .iter()
508                    .map(|s| s.clone().simplify_rec(scope, scenario))
509                    .flat_map(|t| if let Broadcast(t) = t { t } else { vec![t] })
510                    .filter(|t| !t.is_one())
511                    .sorted_by(tdim_lexi_order)
512                    .dedup()
513                    .collect_vec();
514                // a#min(a,b) if a>0 && b>0 => a
515                match &*terms {
516                    [] => Val(1),
517                    [_] => terms.remove(0),
518                    [a, Min(m)] | [Min(m), a]
519                        if m.contains(a) && m.iter().all(|t| scope.prove_strict_positive(t)) =>
520                    {
521                        a.clone()
522                    }
523                    _ => Broadcast(terms),
524                }
525            }
526
527            Min(terms) => {
528                let mut flatten: Vec<TDim> = terms
529                    .into_iter()
530                    .map(|t| t.simplify_rec(scope, scenario))
531                    .flat_map(|t| if let Min(t) = t { t } else { vec![t] })
532                    .sorted_by(tdim_lexi_order)
533                    .dedup()
534                    .collect();
535                #[allow(clippy::mutable_key_type)]
536                let mut redundant = HashSet::<TDim>::default();
537                for pair in flatten.iter().permutations(2) {
538                    let (a, b) = (pair[0], pair[1]);
539                    if redundant.contains(a) || redundant.contains(b) {
540                        continue;
541                    }
542                    let diff = a.clone() - b;
543                    if diff.as_i64().is_some_and(|i| i >= 0) || scope.prove_positive_or_zero(&diff)
544                    {
545                        redundant.insert(a.clone());
546                    }
547                }
548                flatten.retain(|t| !redundant.contains(t));
549                if flatten.len() == 0 {
550                    i64::MAX.to_dim()
551                } else if flatten.len() == 1 {
552                    flatten.into_iter().next().unwrap()
553                } else {
554                    Min(flatten)
555                }
556            }
557            Max(terms) => {
558                let mut flatten: Vec<TDim> = terms
559                    .into_iter()
560                    .map(|t| t.simplify_rec(scope, scenario))
561                    .flat_map(|t| if let Max(t) = t { t } else { vec![t] })
562                    .sorted_by(tdim_lexi_order)
563                    .dedup()
564                    .collect();
565                #[allow(clippy::mutable_key_type)]
566                let mut redundant = HashSet::<TDim>::default();
567                for pair in flatten.iter().permutations(2) {
568                    let (a, b) = (pair[0], pair[1]);
569                    if redundant.contains(a) || redundant.contains(b) {
570                        continue;
571                    }
572                    let diff = a.clone() - b;
573                    if diff.as_i64().is_some_and(|i| i >= 0) || scope.prove_positive_or_zero(&diff)
574                    {
575                        redundant.insert(b.clone());
576                    }
577                }
578                flatten.retain(|t| !redundant.contains(t));
579                if flatten.len() == 0 {
580                    i64::MIN.to_dim()
581                } else if flatten.len() == 1 {
582                    flatten.into_iter().next().unwrap()
583                } else {
584                    Max(flatten)
585                }
586            }
587            Sym(s) => scope
588                .assertions(scenario)
589                .find_map(|a| match a {
590                    Assertion::Eq(Sym(sym), v) if sym == &s => Some(v.clone()),
591                    _ => None,
592                })
593                .unwrap_or(Sym(s)),
594            Val(_) => self,
595        }
596    }
597
598    pub(super) fn inclusive_bound(&self, scope: &SymbolScopeData, upper: bool) -> Option<i64> {
599        use self::TDim::*;
600        match self {
601            Val(n) => Some(*n),
602            Sym(_) => {
603                if upper {
604                    scope
605                        .all_assertions()
606                        .iter()
607                        .filter_map(|assert| match &assert {
608                            Assertion::LT(left, right)
609                                if left == self && right.as_i64().is_some() =>
610                            {
611                                Some(right.as_i64().unwrap() - 1)
612                            }
613                            Assertion::LTE(left, right)
614                                if left == self && right.as_i64().is_some() =>
615                            {
616                                Some(right.as_i64().unwrap())
617                            }
618                            _ => None,
619                        })
620                        .min()
621                } else {
622                    scope
623                        .all_assertions()
624                        .iter()
625                        .filter_map(|assert| match &assert {
626                            Assertion::GT(left, right)
627                                if left == self && right.as_i64().is_some() =>
628                            {
629                                Some(right.as_i64().unwrap() + 1)
630                            }
631                            Assertion::GTE(left, right)
632                                if left == self && right.as_i64().is_some() =>
633                            {
634                                Some(right.as_i64().unwrap())
635                            }
636                            _ => None,
637                        })
638                        .max()
639                }
640            }
641            Add(terms) => {
642                let mut bound = 0;
643                for t in terms {
644                    if let Some(b) = t.inclusive_bound(scope, upper) {
645                        bound += b;
646                    } else {
647                        return None;
648                    }
649                }
650                Some(bound)
651            }
652            MulInt(p, a) => match p.cmp(&0) {
653                Ordering::Equal => Some(0),
654                Ordering::Greater => a.inclusive_bound(scope, upper).map(|x| x * p),
655                Ordering::Less => a.inclusive_bound(scope, !upper).map(|x| x * p),
656            },
657            Mul(_) => None,
658            Min(terms) if !upper => {
659                terms.iter().filter_map(|t| t.inclusive_bound(scope, false)).min()
660            }
661            Max(terms) if upper => {
662                terms.iter().filter_map(|t| t.inclusive_bound(scope, true)).max()
663            }
664            Div(a, q) => a.inclusive_bound(scope, upper).map(|x| x / (*q as i64)),
665            Broadcast(terms) => {
666                if upper {
667                    Max(terms.clone()).inclusive_bound(scope, true)
668                } else {
669                    Min(terms.clone()).inclusive_bound(scope, false)
670                }
671            }
672            _ => None,
673        }
674    }
675
676    pub fn low_inclusive_bound(&self) -> Option<i64> {
677        if let TDim::Val(v) = self {
678            return Some(*v);
679        }
680        let scope = self.find_scope()?;
681        let data = scope.0.lock();
682        let data = data.borrow();
683        self.inclusive_bound(&data, false)
684    }
685
686    pub fn high_inclusive_bound(&self) -> Option<i64> {
687        if let TDim::Val(v) = self {
688            return Some(*v);
689        }
690        let scope = self.find_scope()?;
691        let data = scope.0.lock();
692        let data = data.borrow();
693        self.inclusive_bound(&data, true)
694    }
695
696    pub fn prove_positive_or_zero(&self) -> bool {
697        if let TDim::Val(v) = self {
698            return *v >= 0;
699        }
700        let Some(scope) = self.find_scope() else { return false };
701        let data = scope.0.lock();
702        let data = data.borrow();
703        data.prove_positive_or_zero(self)
704    }
705
706    pub fn prove_strict_positive(&self) -> bool {
707        if let TDim::Val(v) = self {
708            return *v > 0;
709        }
710        (self.clone() - 1).prove_positive_or_zero()
711    }
712
713    pub fn prove_negative_or_zero(&self) -> bool {
714        if let TDim::Val(v) = self {
715            return *v <= 0;
716        }
717        self.clone().neg().prove_positive_or_zero()
718    }
719
720    pub fn prove_strict_negative(&self) -> bool {
721        if let TDim::Val(v) = self {
722            return *v < 0;
723        }
724        self.clone().neg().prove_strict_positive()
725    }
726
727    pub fn gcd(&self) -> u64 {
728        use self::TDim::*;
729        match self {
730            Val(v) => v.unsigned_abs(),
731            Sym(_) => 1,
732            Add(terms) => {
733                let (head, tail) = terms.split_first().unwrap();
734                tail.iter().fold(head.gcd(), |a, b| a.gcd(&b.gcd()))
735            }
736            MulInt(p, a) => a.gcd() * p.unsigned_abs(),
737            Mul(terms) => terms.iter().map(|t| t.gcd()).product(),
738            Min(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
739            Max(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
740            Div(a, q) => {
741                if a.gcd() % *q == 0 {
742                    a.gcd() / *q
743                } else {
744                    1
745                }
746            }
747            Broadcast(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap_or(1),
748        }
749    }
750
751    fn div(&self, d: u64) -> TDim {
752        use self::TDim::*;
753        if d == 1 {
754            return self.clone();
755        }
756        match self {
757            Val(v) => Val(v / d as i64),
758            Sym(_) => panic!(),
759            Add(terms) => Add(terms.iter().map(|t| t.div(d)).collect()),
760            Min(terms) => Min(terms.iter().map(|t| t.div(d)).collect()),
761            Max(terms) => Max(terms.iter().map(|t| t.div(d)).collect()),
762            Broadcast(terms) => Broadcast(terms.iter().map(|t| t.div(d)).collect()),
763            Mul(_) => Div(Box::new(self.clone()), d),
764            MulInt(p, a) => {
765                if *p == d as i64 {
766                    (**a).clone()
767                } else {
768                    let gcd = p.unsigned_abs().gcd(&d);
769                    MulInt(p / gcd as i64, b!(a.div(d / gcd)))
770                }
771            }
772            Div(a, q) => Div(a.clone(), q * d),
773        }
774    }
775
776    pub fn div_ceil(self, rhs: u64) -> TDim {
777        TDim::Div(Box::new(Add(vec![self, Val(rhs as i64 - 1)])), rhs).reduce()
778    }
779
780    pub(super) fn guess_slope(&self, sym: &Symbol) -> (i64, u64) {
781        fn slope_rec(d: &TDim, sym: &Symbol) -> (i64, i64) {
782            match d {
783                Val(_) => (0, 1),
784                Sym(s) => ((sym == s) as i64, 1),
785                Add(terms) => terms
786                    .iter()
787                    .map(|d| slope_rec(d, sym))
788                    .fold((0, 1), |a, b| ((a.0 * b.1 + a.1 * b.0), (b.1 * a.1))),
789                Mul(terms) => terms
790                    .iter()
791                    .map(|d| slope_rec(d, sym))
792                    .fold((1, 1), |a, b| ((a.0 * b.0), (b.1 * a.1))),
793                MulInt(p, a) => {
794                    let (n, d) = slope_rec(a, sym);
795                    (p * n, d)
796                }
797                Div(a, q) => {
798                    let (n, d) = slope_rec(a, sym);
799                    (n, d * *q as i64)
800                }
801                Broadcast(terms) => slope_rec(&terms[0], sym),
802                Min(terms) => slope_rec(&terms[0], sym),
803                Max(terms) => slope_rec(&terms[0], sym),
804            }
805        }
806        let (p, q) = slope_rec(self, sym);
807        reduce_ratio(p, q)
808    }
809
810    #[allow(clippy::mutable_key_type)]
811    pub fn symbols(&self) -> std::collections::HashSet<Symbol> {
812        match self {
813            Val(_) => maplit::hashset!(),
814            Sym(s) => maplit::hashset!(s.clone()),
815            Add(terms) | Mul(terms) | Broadcast(terms) | Min(terms) | Max(terms) => {
816                terms.iter().fold(maplit::hashset!(), |mut set, v| {
817                    set.extend(v.symbols());
818                    set
819                })
820            }
821            MulInt(_, a) => a.symbols(),
822            Div(a, _) => a.symbols(),
823        }
824    }
825
826    pub fn compatible_with(&self, other: &TDim) -> bool {
827        if let Ok(x) = (self.clone() - other).to_i64() {
828            return x == 0;
829        }
830        true // maybe ? :)
831    }
832}
833
834pub(super) fn reduce_ratio(mut p: i64, mut q: i64) -> (i64, u64) {
835    let gcd = p.abs().gcd(&q.abs());
836    if gcd > 1 {
837        p /= gcd;
838        q /= gcd;
839    }
840    if q < 0 { (-p, (-q) as u64) } else { (p, q as u64) }
841}
842
843impl Zero for TDim {
844    fn zero() -> Self {
845        Val(0)
846    }
847    fn is_zero(&self) -> bool {
848        matches!(self, Val(0))
849    }
850}
851
852impl Default for TDim {
853    fn default() -> TDim {
854        Val(0)
855    }
856}
857
858impl num_traits::Bounded for TDim {
859    fn min_value() -> Self {
860        TDim::Val(i64::MIN)
861    }
862
863    fn max_value() -> Self {
864        TDim::Val(i64::MAX)
865    }
866}
867
868impl num_traits::One for TDim {
869    fn one() -> Self {
870        TDim::Val(1)
871    }
872}
873
874impl ::std::iter::Sum for TDim {
875    fn sum<I: Iterator<Item = TDim>>(iter: I) -> TDim {
876        iter.fold(0.into(), |a, b| a + b)
877    }
878}
879
880impl<'a> ::std::iter::Sum<&'a TDim> for TDim {
881    fn sum<I: Iterator<Item = &'a TDim>>(iter: I) -> TDim {
882        iter.fold(0.into(), |a, b| a + b)
883    }
884}
885
886impl std::iter::Product for TDim {
887    fn product<I: Iterator<Item = TDim>>(iter: I) -> Self {
888        iter.fold(TDim::Val(1), |a, b| a * b)
889    }
890}
891
892impl<'a> ::std::iter::Product<&'a TDim> for TDim {
893    fn product<I: Iterator<Item = &'a TDim>>(iter: I) -> TDim {
894        iter.fold(1.into(), |a, b| a * b)
895    }
896}
897
898macro_rules! from_i {
899    ($i: ty) => {
900        impl From<$i> for TDim {
901            fn from(v: $i) -> TDim {
902                TDim::Val(v as _)
903            }
904        }
905        impl<'a> From<&'a $i> for TDim {
906            fn from(v: &'a $i) -> TDim {
907                TDim::Val(*v as _)
908            }
909        }
910    };
911}
912
913from_i!(i32);
914from_i!(i64);
915from_i!(u64);
916from_i!(isize);
917from_i!(usize);
918
919impl From<Symbol> for TDim {
920    fn from(it: Symbol) -> Self {
921        TDim::Sym(it)
922    }
923}
924
925impl<'a> From<&'a Symbol> for TDim {
926    fn from(it: &'a Symbol) -> Self {
927        TDim::Sym(it.clone())
928    }
929}
930
931impl ops::Neg for TDim {
932    type Output = Self;
933    fn neg(self) -> Self {
934        if let Val(v) = self { Val(-v) } else { TDim::MulInt(-1, Box::new(self)).reduce() }
935    }
936}
937
938impl<'a> ops::AddAssign<&'a TDim> for TDim {
939    fn add_assign(&mut self, rhs: &'a TDim) {
940        if rhs.is_zero() {
941        } else if self.is_zero() {
942            *self = rhs.clone();
943        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
944            *s += o;
945        } else {
946            *self = TDim::Add(vec![std::mem::take(self), rhs.clone()]).reduce()
947        }
948    }
949}
950
951impl<I> ops::AddAssign<I> for TDim
952where
953    I: Into<TDim>,
954{
955    fn add_assign(&mut self, rhs: I) {
956        let rhs = rhs.into();
957        if rhs.is_zero() {
958        } else if self.is_zero() {
959            *self = rhs;
960        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
961            *s += o;
962        } else {
963            *self = TDim::Add(vec![std::mem::take(self), rhs]).reduce()
964        }
965    }
966}
967
968impl<I> ops::Add<I> for TDim
969where
970    I: Into<TDim>,
971{
972    type Output = Self;
973    fn add(mut self, rhs: I) -> Self {
974        self += rhs;
975        self
976    }
977}
978
979impl<'a> ops::Add<&'a TDim> for TDim {
980    type Output = Self;
981    fn add(mut self, rhs: &'a TDim) -> Self {
982        self += rhs;
983        self
984    }
985}
986
987#[allow(clippy::suspicious_op_assign_impl)]
988impl<'a> ops::SubAssign<&'a TDim> for TDim {
989    fn sub_assign(&mut self, rhs: &'a TDim) {
990        if rhs.is_zero() {
991        } else if self.is_zero() {
992            *self = rhs.clone().neg();
993        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
994            *s -= o;
995        } else {
996            *self = TDim::Add(vec![std::mem::take(self), rhs.clone().neg()]).reduce()
997        }
998    }
999}
1000
1001impl<I> ops::SubAssign<I> for TDim
1002where
1003    I: Into<TDim>,
1004{
1005    fn sub_assign(&mut self, rhs: I) {
1006        let rhs = rhs.into();
1007        if rhs.is_zero() {
1008        } else if self.is_zero() {
1009            *self = rhs.neg();
1010        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1011            *s -= o;
1012        } else {
1013            *self = TDim::Add(vec![std::mem::take(self), rhs.neg()]).reduce()
1014        }
1015    }
1016}
1017
1018impl<I> ops::Sub<I> for TDim
1019where
1020    I: Into<TDim>,
1021{
1022    type Output = Self;
1023    fn sub(mut self, rhs: I) -> Self {
1024        self -= rhs;
1025        self
1026    }
1027}
1028
1029impl<'a> ops::Sub<&'a TDim> for TDim {
1030    type Output = Self;
1031    fn sub(mut self, rhs: &'a TDim) -> Self {
1032        self -= rhs;
1033        self
1034    }
1035}
1036
1037impl<I: Into<TDim>> ops::MulAssign<I> for TDim {
1038    fn mul_assign(&mut self, rhs: I) {
1039        let rhs = rhs.into();
1040        if self.is_one() {
1041            *self = rhs
1042        } else if rhs.is_one() {
1043        } else {
1044            *self = TDim::Mul(vec![rhs, std::mem::take(self)]).reduce()
1045        }
1046    }
1047}
1048
1049impl<'a> ops::MulAssign<&'a TDim> for TDim {
1050    fn mul_assign(&mut self, rhs: &'a TDim) {
1051        if self.is_one() {
1052            *self = rhs.clone()
1053        } else if rhs.is_one() {
1054        } else {
1055            *self = TDim::Mul(vec![std::mem::take(self), rhs.clone()]).reduce()
1056        }
1057    }
1058}
1059
1060impl<I: Into<TDim>> ops::Mul<I> for TDim {
1061    type Output = Self;
1062    fn mul(mut self, rhs: I) -> Self {
1063        self *= rhs.into();
1064        self
1065    }
1066}
1067
1068impl<'a> ops::Mul<&'a TDim> for TDim {
1069    type Output = Self;
1070    fn mul(mut self, rhs: &'a TDim) -> Self {
1071        self *= rhs;
1072        self
1073    }
1074}
1075
1076impl<I: AsPrimitive<u64> + PrimInt> ops::DivAssign<I> for TDim {
1077    fn div_assign(&mut self, rhs: I) {
1078        *self = TDim::Div(Box::new(std::mem::take(self)), rhs.as_()).reduce()
1079    }
1080}
1081
1082impl<I: AsPrimitive<u64> + PrimInt> ops::Div<I> for TDim {
1083    type Output = Self;
1084    fn div(mut self, rhs: I) -> Self {
1085        self /= rhs.as_();
1086        self
1087    }
1088}
1089
1090impl<I: AsPrimitive<u64> + PrimInt> ops::RemAssign<I> for TDim {
1091    fn rem_assign(&mut self, rhs: I) {
1092        *self += -(self.clone() / rhs.as_() * rhs.as_());
1093    }
1094}
1095
1096impl<I: AsPrimitive<u64> + PrimInt> ops::Rem<I> for TDim {
1097    type Output = Self;
1098    fn rem(mut self, rhs: I) -> Self {
1099        self %= rhs;
1100        self
1101    }
1102}
1103
1104#[cfg(test)]
1105mod tests {
1106    use super::*;
1107
1108    macro_rules! b( ($e:expr) => { Box::new($e) } );
1109
1110    lazy_static::lazy_static! {
1111        static ref table: SymbolScope = SymbolScope::default();
1112        static ref A: Symbol = table.sym("a");
1113        static ref B: Symbol = table.sym("b");
1114        static ref C: Symbol = table.sym("c");
1115        static ref D: Symbol = table.sym("d");
1116        static ref E: Symbol = table.sym("e");
1117    }
1118
1119    fn neg(a: &TDim) -> TDim {
1120        mul(-1, a)
1121    }
1122
1123    fn add(a: &TDim, b: &TDim) -> TDim {
1124        TDim::Add(vec![a.clone(), b.clone()])
1125    }
1126
1127    fn mul(a: i64, b: &TDim) -> TDim {
1128        TDim::MulInt(a, b![b.clone()])
1129    }
1130
1131    fn div(a: &TDim, b: u64) -> TDim {
1132        TDim::Div(b!(a.clone()), b)
1133    }
1134
1135    #[test]
1136    fn reduce_add() {
1137        assert_eq!(add(&A.to_dim(), &neg(&A.to_dim())).reduce(), Val(0))
1138    }
1139
1140    #[test]
1141    fn reduce_neg_mul() {
1142        assert_eq!(neg(&mul(2, &A.to_dim())).reduce(), mul(-2, &A.to_dim()))
1143    }
1144
1145    #[test]
1146    fn reduce_cplx_ex_2() {
1147        assert_eq!(
1148            add(
1149                &add(&Val(-4), &mul(-2, &div(&A.to_dim(), 4))),
1150                &mul(-2, &mul(-1, &div(&A.to_dim(), 4)))
1151            )
1152            .reduce(),
1153            Val(-4)
1154        )
1155    }
1156
1157    #[test]
1158    fn reduce_cplx_ex_3() {
1159        assert_eq!(div(&MulInt(1, b!(MulInt(4, b!(A.to_dim())))), 4).reduce(), A.to_dim())
1160    }
1161
1162    #[test]
1163    fn reduce_cplx_ex_4() {
1164        // (S+1)/2 + (1-S)/2 == 1
1165        assert_eq!(
1166            add(&div(&add(&A.to_dim(), &Val(1)), 2), &div(&add(&neg(&A.to_dim()), &Val(1)), 2))
1167                .reduce(),
1168            1.into()
1169        );
1170    }
1171
1172    #[test]
1173    fn reduce_mul_mul_1() {
1174        assert_eq!(mul(3, &mul(2, &A.to_dim())).reduce(), mul(6, &A.to_dim()))
1175    }
1176
1177    #[test]
1178    fn reduce_mul_mul_2() {
1179        assert_eq!(mul(-2, &mul(-1, &A.to_dim())).reduce(), mul(2, &A.to_dim()))
1180    }
1181
1182    #[test]
1183    fn reduce_mul_div_1() {
1184        assert_eq!(mul(2, &div(&mul(-1, &A.to_dim()), 3)).reduce(), mul(-2, &div(&A.to_dim(), 3)))
1185    }
1186
1187    #[test]
1188    fn const_and_add() {
1189        let e: TDim = 2i64.into();
1190        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), 2);
1191        let e: TDim = TDim::from(2) + 3;
1192        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), 5);
1193        let e: TDim = TDim::from(2) - 3;
1194        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), -1);
1195        let e: TDim = -TDim::from(2);
1196        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), -2);
1197    }
1198
1199    #[test]
1200    fn substitution() {
1201        let a: TDim = A.to_dim();
1202        assert_eq!(a.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 2);
1203        let e = a + 3;
1204        assert_eq!(e.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 5);
1205    }
1206
1207    #[test]
1208    fn reduce_adds() {
1209        let e: TDim = TDim::from(2) + 1;
1210        assert_eq!(e, TDim::from(3));
1211        let e: TDim = TDim::from(3) + 2;
1212        assert_eq!(e, TDim::from(5));
1213        let e: TDim = TDim::from(3) + 0;
1214        assert_eq!(e, TDim::from(3));
1215        let e: TDim = TDim::from(3) + 2 + 1;
1216        assert_eq!(e, TDim::from(6));
1217    }
1218
1219    #[test]
1220    fn reduce_muls() {
1221        let e: TDim = Val(1) * A.to_dim();
1222        assert_eq!(e, A.to_dim());
1223        let e: TDim = A.to_dim() * &B.to_dim() * 1;
1224        assert_eq!(e, A.to_dim() * &B.to_dim());
1225    }
1226
1227    #[test]
1228    fn reduce_divs() {
1229        let e: TDim = TDim::from(2) / 1;
1230        assert_eq!(e, TDim::from(2));
1231        let e: TDim = TDim::from(3) / 2;
1232        assert_eq!(e, TDim::from(1));
1233        let e: TDim = TDim::from(3) % 2;
1234        assert_eq!(e, TDim::from(1));
1235        let e: TDim = TDim::from(5) / 2;
1236        assert_eq!(e, TDim::from(2));
1237        let e: TDim = TDim::from(5) % 2;
1238        assert_eq!(e, TDim::from(1));
1239    }
1240
1241    #[test]
1242    fn reduce_div_bug_0() {
1243        let e1: TDim = (A.to_dim() + 23) / 2 - 1;
1244        let e2: TDim = (A.to_dim() + 21) / 2;
1245        assert_eq!(e1, e2);
1246    }
1247
1248    #[test]
1249    fn reduce_div_bug_1() {
1250        let e1: TDim = (A.to_dim() + -1) / 2;
1251        let e2: TDim = (A.to_dim() + 1) / 2 - 1;
1252        assert_eq!(e1, e2);
1253    }
1254
1255    #[test]
1256    fn reduce_div_bug_2() {
1257        let e1: TDim = ((A.to_dim() + 1) / 2 + 1) / 2;
1258        let e2: TDim = (A.to_dim() + 3) / 4;
1259        assert_eq!(e1, e2);
1260    }
1261
1262    #[test]
1263    fn reduce_div_bug_3() {
1264        let e1: TDim = (A.to_dim() / 2) * -4;
1265        let e2: TDim = (A.to_dim() / 2) * -4 / 1;
1266        assert_eq!(e1, e2);
1267    }
1268
1269    #[test]
1270    fn reduce_mul_div() {
1271        let e: TDim = A.to_dim() * 2 / 2;
1272        assert_eq!(e, A.to_dim());
1273    }
1274
1275    #[test]
1276    fn reduce_div_mul() {
1277        let e: TDim = A.to_dim() / 2 * 2;
1278        assert_ne!(e, A.to_dim());
1279    }
1280
1281    #[test]
1282    fn reduce_add_div() {
1283        let e: TDim = A.to_dim() / 2 + 1;
1284        assert_eq!(e, ((A.to_dim() + 2) / 2));
1285    }
1286
1287    #[test]
1288    fn reduce_neg_mul_() {
1289        let e: TDim = TDim::from(1) - A.to_dim() * 2;
1290        assert_eq!(e, TDim::from(1) + A.to_dim() * -2);
1291    }
1292
1293    #[test]
1294    fn reduce_add_rem_1() {
1295        assert_eq!(((A.to_dim() + 4) % 2), (A.to_dim() % 2));
1296    }
1297
1298    #[test]
1299    fn reduce_add_rem_2() {
1300        assert_eq!(((A.to_dim() - 4) % 2), (A.to_dim() % 2));
1301    }
1302
1303    #[test]
1304    fn reduce_rem_div() {
1305        let e: TDim = A.to_dim() % 2 / 2;
1306        assert_eq!(e, TDim::from(0));
1307    }
1308
1309    #[test]
1310    fn conv2d_ex_1() {
1311        let e = (TDim::from(1) - 1 + 1).div_ceil(1);
1312        assert_eq!(e, TDim::from(1));
1313    }
1314
1315    #[test]
1316    fn conv2d_ex_2() {
1317        let e = (A.to_dim() - 3 + 1).div_ceil(1);
1318        assert_eq!(e, A.to_dim() + -2);
1319    }
1320
1321    #[test]
1322    fn extract_int_gcd_from_muls() {
1323        let term = (A.to_dim() + 1) / 4;
1324        let mul = (term.clone() * 24 - 24) * (term.clone() * 2 - 2);
1325        let target = (term.clone() - 1) * (term.clone() - 1) * 48;
1326        assert_eq!(mul, target);
1327    }
1328
1329    #[test]
1330    fn equality_of_muls() {
1331        let term = (A.to_dim() + 1) / 4;
1332        let mul1 = (term.clone() * 2 - 3) * (term.clone() - 1);
1333        let mul2 = (term.clone() - 1) * (term.clone() * 2 - 3);
1334        assert_eq!(mul1, mul2);
1335    }
1336
1337    #[test]
1338    fn factorize_complex_expr_times_int() {
1339        let term = (A.to_dim() + 1) / 4;
1340        let e = term.clone() * 2 - &term - 1;
1341        assert_eq!(e, term - 1);
1342    }
1343
1344    #[test]
1345    fn broadcast_over_min() {
1346        // assuming a>0, b>0 then a#min(a,b) can be replaced by a
1347        // proof:
1348        //    if b == 1 => min(a,b)=1 => a#1=a => ok
1349        //    if a <= b => min(a,b)=a => ok
1350        //    if 1 < B < A => expression was invalid, we're generalizing over the non-domain and ignoring the constraint
1351        for a in 1..5 {
1352            for b in 1..5 {
1353                if b > 1 && a > b {
1354                    assert!(a.broadcast(a.min(b)).is_err());
1355                } else {
1356                    assert_eq!(a.broadcast(a.min(b)).unwrap(), a);
1357                }
1358            }
1359        }
1360    }
1361
1362    #[test]
1363    fn min_ints_1() {
1364        assert_eq!(2.to_dim().mini(1.to_dim()), 1.to_dim());
1365    }
1366
1367    #[test]
1368    fn min_ints_2() {
1369        assert_eq!(1.to_dim().mini(2.to_dim()), 1.to_dim());
1370    }
1371
1372    #[test]
1373    fn min_same() {
1374        assert_eq!(A.to_dim().mini(A.to_dim()), A.to_dim());
1375    }
1376
1377    #[test]
1378    fn min_noop() {
1379        assert_eq!(A.to_dim().mini(1.to_dim()), A.to_dim().mini(1.to_dim()));
1380    }
1381
1382    #[test]
1383    fn min_diff_1() {
1384        assert_eq!((A.to_dim() + 1).mini(A.to_dim() + 2), A.to_dim() + 1);
1385    }
1386
1387    #[test]
1388    fn slope_0() {
1389        assert_eq!(12.to_dim().guess_slope(&A), (0, 1));
1390    }
1391
1392    #[test]
1393    fn slope_1() {
1394        assert_eq!(A.to_dim().guess_slope(&A), (1, 1));
1395    }
1396
1397    #[test]
1398    fn slope_2() {
1399        assert_eq!((A.to_dim() * 2).guess_slope(&A), (2, 1));
1400    }
1401
1402    #[test]
1403    fn slope_3() {
1404        assert_eq!((A.to_dim() * 2 + A.to_dim() / 2).guess_slope(&A), (5, 2));
1405    }
1406
1407    #[test]
1408    fn slope_4() {
1409        assert_eq!((A.to_dim()).guess_slope(&B), (0, 1));
1410    }
1411
1412    #[test]
1413    fn slope_5() {
1414        assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1));
1415        assert_eq!((A.to_dim() + 1).guess_slope(&B), (0, 1));
1416    }
1417
1418    #[test]
1419    fn slope_6() {
1420        assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1));
1421        assert_eq!((A.to_dim() + B.to_dim()).guess_slope(&B), (1, 1));
1422    }
1423
1424    #[test]
1425    fn min_0() -> TractResult<()> {
1426        let symbols = SymbolScope::default();
1427        assert_eq!(
1428            symbols.parse_tdim("min(S+3, S+2)").unwrap().simplify(),
1429            symbols.parse_tdim("S+2").unwrap(),
1430        );
1431        Ok(())
1432    }
1433
1434    #[test]
1435    fn commutative_mul_parens() -> TractResult<()> {
1436        let symbols = SymbolScope::default();
1437        assert_eq!(
1438            symbols.parse_tdim("A*(B*C)").unwrap().simplify(),
1439            symbols.parse_tdim("(B*A)*C").unwrap().simplify(),
1440        );
1441        Ok(())
1442    }
1443
1444    #[test]
1445    fn commutative_in_nemo_parakeet_model() -> TractResult<()> {
1446        let symbols = SymbolScope::default();
1447        assert_eq!(
1448            symbols
1449                .parse_tdim("8*(1+-1*max(0,5000+-1*(S+7)/8)+max(0,4999+(S+7)/8))*((B)*((S+7)/8))")
1450                .unwrap()
1451                .simplify(),
1452            symbols
1453                .parse_tdim("8*((B)*(1+-1*max(0,5000+-1*(S+7)/8)+max(0,4999+(S+7)/8)))*((S+7)/8)")
1454                .unwrap()
1455                .simplify(),
1456        );
1457        Ok(())
1458    }
1459
1460    #[test]
1461    fn commutative_mul_parens_deep() -> TractResult<()> {
1462        let symbols = SymbolScope::default();
1463        let deep_tdim = Mul(vec![
1464            Mul(vec![Mul(vec![Mul(vec![A.to_dim(), B.to_dim()]), C.to_dim()]), D.to_dim()]),
1465            E.to_dim(),
1466        ])
1467        .simplify();
1468        assert_eq!(deep_tdim, symbols.parse_tdim("a*b*c*d*e").unwrap().simplify());
1469        Ok(())
1470    }
1471}