Skip to main content

tract_data/dim/
tree.rs

1use crate::dim::Assertion;
2use crate::internal::*;
3
4use super::{DimLike, sym::*};
5use itertools::Itertools;
6use num_integer::Integer;
7use num_traits::{AsPrimitive, PrimInt, Zero};
8use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10use std::fmt::Debug;
11use std::ops::Neg;
12use std::{fmt, ops};
13
14#[derive(Debug)]
15pub enum TooEarly {
16    UndeterminedSymbol(String),
17    Other(String),
18}
19
20impl std::fmt::Display for TooEarly {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        match self {
23            TooEarly::UndeterminedSymbol(s) => write!(f, "Undetermined symbol in expression: {s}"),
24            TooEarly::Other(s) => write!(f, "{s}"),
25        }
26    }
27}
28
29impl std::error::Error for TooEarly {}
30
31macro_rules! b( ($e:expr) => { Box::new($e) } );
32
33#[derive(Clone, PartialEq, Eq, Hash, Debug)]
34pub enum TDim {
35    Val(i64),
36    Sym(Symbol),
37    Add(Vec<TDim>),
38    Mul(Vec<TDim>),
39    MulInt(i64, Box<TDim>),
40    Div(Box<TDim>, u64),
41    Broadcast(Vec<TDim>),
42    Min(Vec<TDim>),
43    Max(Vec<TDim>),
44    /// Comparison: evaluates to 1 (true) or 0 (false). lhs >= rhs
45    Ge(Box<TDim>, Box<TDim>),
46    /// Comparison: evaluates to 1 (true) or 0 (false). lhs == rhs
47    Eq(Box<TDim>, Box<TDim>),
48}
49
50use TDim::*;
51
52fn tdim_lexi_order(a: &TDim, b: &TDim) -> Ordering {
53    match (a, b) {
54        (Sym(a), Sym(b)) => a.cmp(b),
55        (Val(a), Val(b)) => a.cmp(b),
56        (Add(a), Add(b))
57        | (Mul(a), Mul(b))
58        | (Broadcast(a), Broadcast(b))
59        | (Min(a), Min(b))
60        | (Max(a), Max(b)) => a.len().cmp(&b.len()).then(
61            a.iter()
62                .zip(b.iter())
63                .fold(Ordering::Equal, |acc, (a, b)| acc.then_with(|| tdim_lexi_order(a, b))),
64        ),
65        (MulInt(p, d), MulInt(q, e)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)),
66        (Div(d, p), Div(e, q)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)),
67        (Sym(_), _) => Ordering::Less,
68        (_, Sym(_)) => Ordering::Greater,
69        (Val(_), _) => Ordering::Less,
70        (_, Val(_)) => Ordering::Greater,
71        (Add(_), _) => Ordering::Less,
72        (_, Add(_)) => Ordering::Greater,
73        (Mul(_), _) => Ordering::Less,
74        (_, Mul(_)) => Ordering::Greater,
75        (MulInt(_, _), _) => Ordering::Less,
76        (_, MulInt(_, _)) => Ordering::Greater,
77        (Broadcast(_), _) => Ordering::Less,
78        (_, Broadcast(_)) => Ordering::Greater,
79        (Min(_), _) => Ordering::Less,
80        (_, Min(_)) => Ordering::Greater,
81        (Max(_), _) => Ordering::Less,
82        (_, Max(_)) => Ordering::Greater,
83        (Ge(a1, b1), Ge(a2, b2)) | (Eq(a1, b1), Eq(a2, b2)) => {
84            tdim_lexi_order(a1, a2).then_with(|| tdim_lexi_order(b1, b2))
85        }
86        (Ge(_, _) | Eq(_, _), _) => Ordering::Less,
87        (_, Ge(_, _) | Eq(_, _)) => Ordering::Greater,
88    }
89}
90
91impl fmt::Display for TDim {
92    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
93        match &self {
94            Sym(sym) => write!(fmt, "{sym}"),
95            Val(it) => write!(fmt, "{it}"),
96            Add(it) => write!(fmt, "{}", it.iter().map(|x| format!("{x}")).join("+")),
97            Mul(it) => write!(fmt, "{}", it.iter().map(|x| format!("({x})")).join("*")),
98            Broadcast(it) => write!(fmt, "{}", it.iter().map(|x| format!("({x})")).join("#")),
99            Min(it) => write!(fmt, "min({})", it.iter().map(|x| format!("{x}")).join(",")),
100            Max(it) => write!(fmt, "max({})", it.iter().map(|x| format!("{x}")).join(",")),
101            MulInt(a, b) => write!(fmt, "{a}*{b}"),
102            Div(a, b) => write!(fmt, "({a})/{b}"),
103            Ge(a, b) => write!(fmt, "({a}>={b})"),
104            Eq(a, b) => write!(fmt, "({a}=={b})"),
105        }
106    }
107}
108
109impl TDim {
110    #[inline]
111    pub fn is_one(&self) -> bool {
112        matches!(self, Val(1))
113    }
114
115    #[inline]
116    pub fn to_i64(&self) -> TractResult<i64> {
117        if let Val(v) = self {
118            Ok(*v)
119        } else {
120            Err(TooEarly::UndeterminedSymbol(self.to_string()))?
121        }
122    }
123
124    #[inline]
125    pub fn as_i64(&self) -> Option<i64> {
126        if let Val(v) = self { Some(*v) } else { None }
127    }
128
129    pub fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64> {
130        match self {
131            Sym(sym) => {
132                let Some(v) = values.get(sym) else {
133                    Err(TooEarly::UndeterminedSymbol(self.to_string()))?
134                };
135                Ok(v)
136            }
137            Val(v) => Ok(*v),
138            Add(terms) => terms.iter().try_fold(0i64, |acc, it| {
139                let x = it.eval_to_i64(values)?;
140                acc.checked_add(x)
141                    .with_context(|| format!("Overflow in TDim addition ({acc} + {x})"))
142            }),
143            Mul(terms) => terms.iter().try_fold(1i64, |acc, it| {
144                let x = it.eval_to_i64(values)?;
145                acc.checked_mul(x)
146                    .with_context(|| format!("Overflow in TDim multiplication ({acc} * {x})"))
147            }),
148            Min(terms) => terms
149                .iter()
150                .try_fold(i64::MAX, |acc, it| it.eval_to_i64(values).map(|x| acc.min(x))),
151            Max(terms) => terms
152                .iter()
153                .try_fold(i64::MIN, |acc, it| it.eval_to_i64(values).map(|x| acc.max(x))),
154            Broadcast(terms) => terms.iter().try_fold(1i64, |acc, it| {
155                it.eval_to_i64(values)
156                    .and_then(|x| ((acc as usize).broadcast(x as usize)).map(|x| x as i64))
157            }),
158            Div(a, q) => Ok(a.eval_to_i64(values)? / *q as i64),
159            MulInt(p, a) => {
160                let x = a.eval_to_i64(values)?;
161                x.checked_mul(*p)
162                    .with_context(|| format!("Overflow in TDim multiplication ({x} * {p})"))
163            }
164            Ge(a, b) => Ok(if a.eval_to_i64(values)? >= b.eval_to_i64(values)? { 1 } else { 0 }),
165            Eq(a, b) => Ok(if a.eval_to_i64(values)? == b.eval_to_i64(values)? { 1 } else { 0 }),
166        }
167    }
168
169    pub fn eval(&self, values: &SymbolValues) -> TDim {
170        match self {
171            Sym(sym) => values.get(sym).map(Val).unwrap_or_else(|| Sym(sym.clone())),
172            Val(v) => Val(*v),
173            Add(terms) => terms.iter().fold(Val(0), |acc, it| -> TDim { acc + it.eval(values) }),
174            Mul(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim { acc * it.eval(values) }),
175            Min(terms) => {
176                terms.iter().fold(Val(i64::MAX), |acc, it| -> TDim { acc.mini(it.eval(values)) })
177            }
178            Max(terms) => {
179                terms.iter().fold(Val(i64::MIN), |acc, it| -> TDim { acc.maxi(it.eval(values)) })
180            }
181            Broadcast(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim {
182                acc.broadcast(it.eval(values)).unwrap_or_else(|_| self.clone())
183            }),
184            Div(a, q) => a.eval(values) / *q as i64,
185            MulInt(p, a) => a.eval(values) * *p,
186            Ge(a, b) => {
187                let a2 = a.eval(values);
188                let b2 = b.eval(values);
189                if let (Val(av), Val(bv)) = (&a2, &b2) {
190                    Val(if av >= bv { 1 } else { 0 })
191                } else {
192                    Ge(b!(a2), b!(b2))
193                }
194            }
195            Eq(a, b) => {
196                let a2 = a.eval(values);
197                let b2 = b.eval(values);
198                if let (Val(av), Val(bv)) = (&a2, &b2) {
199                    Val(if av == bv { 1 } else { 0 })
200                } else {
201                    Eq(b!(a2), b!(b2))
202                }
203            }
204        }
205    }
206
207    pub fn eval_with_scenario(&self, scenario: &str) -> TDim {
208        if let Val(v) = self {
209            return Val(*v);
210        }
211        let scope = self.find_scope().unwrap();
212        let scope = scope.0;
213        let locked = scope.lock();
214        let scope = locked.borrow();
215        self.clone().simplify_rec(&scope, Some(scenario), &[])
216    }
217
218    pub fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self> {
219        self.substitute_all(&std::collections::HashMap::from([(from.clone(), to.clone())]))
220    }
221
222    pub fn substitute_all(
223        &self,
224        map: &std::collections::HashMap<Symbol, Self>,
225    ) -> TractResult<Self> {
226        match self {
227            Sym(sym) => Ok(map.get(sym).cloned().unwrap_or_else(|| self.clone())),
228            Val(v) => Ok(Val(*v)),
229            Add(terms) => terms.iter().try_fold(Val(0), |acc, it| -> TractResult<TDim> {
230                Ok(acc + it.substitute_all(map)?)
231            }),
232            Mul(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult<TDim> {
233                Ok(acc * it.substitute_all(map)?)
234            }),
235            Broadcast(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult<TDim> {
236                acc.broadcast(it.substitute_all(map)?)
237            }),
238            Min(terms) => terms.iter().try_fold(Val(i64::MAX), |acc, it| -> TractResult<TDim> {
239                Ok(acc.mini(it.substitute_all(map)?))
240            }),
241            Max(terms) => terms.iter().try_fold(Val(i64::MIN), |acc, it| -> TractResult<TDim> {
242                Ok(acc.maxi(it.substitute_all(map)?))
243            }),
244            Div(a, q) => Ok(a.substitute_all(map)? / *q as i64),
245            MulInt(p, a) => Ok(a.substitute_all(map)? * *p),
246            Ge(a, b) => Ok(Ge(b!(a.substitute_all(map)?), b!(b.substitute_all(map)?))),
247            Eq(a, b) => Ok(Eq(b!(a.substitute_all(map)?), b!(b.substitute_all(map)?))),
248        }
249    }
250
251    pub fn reduce(self) -> TDim {
252        self.simplify()
253            .wiggle()
254            .into_iter()
255            .sorted_by(tdim_lexi_order)
256            .unique()
257            .map(|e| e.simplify())
258            .min_by_key(|e| e.cost())
259            .unwrap()
260    }
261
262    fn cost(&self) -> usize {
263        use self::TDim::*;
264        match self {
265            Sym(_) | Val(_) => 1,
266            Add(terms) => 2 * terms.iter().map(TDim::cost).sum::<usize>(),
267            Mul(terms) => 3 * terms.iter().map(TDim::cost).sum::<usize>(),
268            Broadcast(terms) => 4 * terms.iter().map(TDim::cost).sum::<usize>(),
269            Min(terms) | Max(terms) => 5 * terms.iter().map(TDim::cost).sum::<usize>(),
270            Div(a, _) => 3 * a.cost(),
271            MulInt(_, a) => 2 * a.cost(),
272            Ge(a, b) | Eq(a, b) => 5 * (a.cost() + b.cost()),
273        }
274    }
275
276    fn wiggle(&self) -> Vec<TDim> {
277        use self::TDim::*;
278        match self {
279            Sym(_) | Val(_) | Mul(_) | Broadcast(_) | Min(_) | Max(_) | Ge(_, _) | Eq(_, _) => {
280                vec![self.clone()]
281            }
282            Add(terms) => {
283                let mut forms = vec![];
284                let sub_exprs = terms.iter().map(|e| e.wiggle()).multi_cartesian_product();
285
286                fn first_div_term(terms: &[TDim]) -> Option<(usize, &TDim, u64)> {
287                    terms.iter().enumerate().find_map(|(index, t)| match t {
288                        Div(numerator, quotient) => Some((index, &**numerator, *quotient)),
289                        _ => None,
290                    })
291                }
292
293                fn generate_new_numerator(
294                    div_index: usize,
295                    numerator: &TDim,
296                    quotient: u64,
297                    expr: &[TDim],
298                ) -> Vec<TDim> {
299                    expr.iter()
300                        .enumerate()
301                        .map(|(index, term)| {
302                            if index == div_index {
303                                numerator.clone()
304                            } else {
305                                MulInt(quotient as i64, Box::new(term.clone()))
306                            }
307                        })
308                        .collect()
309                }
310
311                for expr in sub_exprs {
312                    if let Some((div_index, numerator, quotient)) = first_div_term(&expr) {
313                        let new_numerator =
314                            generate_new_numerator(div_index, numerator, quotient, &expr);
315                        forms.push(Div(Box::new(Add(new_numerator)), quotient))
316                    }
317
318                    forms.push(Add(expr));
319                }
320                forms
321            }
322            MulInt(p, a) => a.wiggle().into_iter().map(|a| MulInt(*p, b!(a))).collect(),
323            Div(a, q) => {
324                let mut forms = vec![];
325                for num in a.wiggle() {
326                    if let Add(terms) = &num {
327                        let (integer, non_integer): (Vec<_>, Vec<_>) =
328                            terms.iter().cloned().partition(|a| a.gcd() % q == 0);
329                        let mut new_terms = integer.iter().map(|i| i.div(*q)).collect::<Vec<_>>();
330                        if non_integer.len() > 0 {
331                            new_terms.push(Div(b!(Add(non_integer)), *q));
332                        }
333                        forms.push(Add(new_terms))
334                    }
335                    forms.push(Div(b!(num), *q))
336                }
337                forms
338            }
339        }
340    }
341
342    fn find_any_sym(tdim: &TDim) -> Option<&Symbol> {
343        match tdim {
344            Val(_) => None,
345            Sym(s) => Some(s),
346            Add(terms) | Mul(terms) | Min(terms) | Max(terms) | Broadcast(terms) => {
347                terms.iter().find_map(Self::find_any_sym)
348            }
349            MulInt(_, t) | Div(t, _) => Self::find_any_sym(t),
350            Ge(a, b) | Eq(a, b) => Self::find_any_sym(a).or_else(|| Self::find_any_sym(b)),
351        }
352    }
353
354    pub fn find_scope(&self) -> Option<SymbolScope> {
355        Self::find_any_sym(self).and_then(|s| s.scope().clone())
356    }
357
358    pub fn simplify(self) -> TDim {
359        use self::TDim::*;
360        if let Ok(v) = self.eval_to_i64(&SymbolValues::default()) {
361            return Val(v);
362        }
363        let Some(scope) = self.find_scope() else {
364            return self;
365        };
366        let scope = scope.0;
367        let locked = scope.lock();
368        let scope = locked.borrow();
369        let it = self.simplify_rec(&scope, None, &[]);
370        let mut current: Option<TDim> = None;
371        for scenario in scope.scenarios() {
372            let v = it.clone().simplify_rec(&scope, Some(scenario), &[]);
373            if current.is_some_and(|c| c != v) {
374                return it;
375            } else {
376                current = Some(v);
377            }
378        }
379        current.unwrap_or(it)
380    }
381
382    pub fn simplify_with_extra_assertions(self, extra: &[Assertion]) -> TDim {
383        use self::TDim::*;
384        if extra.is_empty() {
385            return self.simplify();
386        }
387        if let Ok(v) = self.eval_to_i64(&SymbolValues::default()) {
388            return Val(v);
389        }
390        let Some(scope) = self.find_scope() else {
391            return self;
392        };
393        let scope = scope.0;
394        let locked = scope.lock();
395        let scope = locked.borrow();
396        let it = self.simplify_rec(&scope, None, extra);
397        let mut current: Option<TDim> = None;
398        for scenario in scope.scenarios() {
399            let v = it.clone().simplify_rec(&scope, Some(scenario), extra);
400            if current.is_some_and(|c| c != v) {
401                return it;
402            } else {
403                current = Some(v);
404            }
405        }
406        current.unwrap_or(it)
407    }
408
409    fn simplify_rec(
410        self,
411        scope: &SymbolScopeData,
412        scenario: Option<&str>,
413        extra: &[Assertion],
414    ) -> TDim {
415        match self {
416            Add(mut terms) => {
417                #[allow(clippy::mutable_key_type)]
418                let mut simplified_terms: HashMap<TDim, i64> = HashMap::new();
419                // factorize common sub-expr
420                while let Some(term) = terms.pop() {
421                    let simplified = term.simplify_rec(scope, scenario, extra);
422                    match simplified {
423                        Val(0) => {} // ignore
424                        Add(members) => {
425                            terms.extend(members);
426                            continue;
427                        }
428                        Val(value) => *simplified_terms.entry(Val(1)).or_insert(0) += value,
429                        MulInt(value, factor) => {
430                            *simplified_terms.entry((*factor).clone()).or_insert(0) += value;
431                        }
432                        n => *simplified_terms.entry(n).or_insert(0) += 1,
433                    };
434                }
435
436                pub fn evaluate_count(term: TDim, count: i64) -> Option<TDim> {
437                    match count {
438                        0 => None,
439                        _ if term == TDim::Val(1) => Some(TDim::Val(count)),
440                        1 => Some(term),
441                        _ => Some(TDim::MulInt(count, Box::new(term))),
442                    }
443                }
444
445                let mut members: Vec<TDim> = simplified_terms
446                    .into_iter()
447                    .filter_map(|(term, count)| evaluate_count(term, count))
448                    .collect();
449                members.sort_by(tdim_lexi_order);
450
451                match members.len() {
452                    0 => TDim::Val(0),
453                    1 => members.into_iter().next().unwrap(),
454                    _ => TDim::Add(members),
455                }
456            }
457            Mul(terms) => {
458                // Distribute over Add: if exactly one factor is an Add,
459                // expand Mul([a, Add([b, c])]) => Add([Mul([a, b]), Mul([a, c])]).
460                // This lets (T+1)*P simplify to T*P + P, which is needed for
461                // cancellation in expressions like (T+1)*P - T*P.
462                {
463                    let add_indices: Vec<usize> = terms
464                        .iter()
465                        .enumerate()
466                        .filter(|(_, t)| matches!(t, Add(_)))
467                        .map(|(i, _)| i)
468                        .collect();
469                    if add_indices.len() == 1 {
470                        let add_idx = add_indices[0];
471                        let Add(add_terms) = &terms[add_idx] else { unreachable!() };
472                        let other_factors: Vec<TDim> = terms
473                            .iter()
474                            .enumerate()
475                            .filter(|(i, _)| *i != add_idx)
476                            .map(|(_, t)| t.clone())
477                            .collect();
478                        let distributed: Vec<TDim> = add_terms
479                            .iter()
480                            .map(|at| {
481                                let mut product = other_factors.clone();
482                                product.push(at.clone());
483                                Mul(product)
484                            })
485                            .collect();
486                        return Add(distributed).simplify_rec(scope, scenario, extra);
487                    }
488                }
489
490                // in case a term is a multiplication itself, flatten it
491                // e.g., (a*b)*c => a*b*c, and MulInt(k, x) => Val(k)*x
492                let mut flattened_terms = vec![];
493                for t in terms {
494                    match t.clone().reduce() {
495                        Mul(inner_terms) => flattened_terms.extend(inner_terms),
496                        MulInt(k, inner) => {
497                            flattened_terms.push(Val(k));
498                            flattened_terms.push(*inner);
499                        }
500                        other => flattened_terms.push(other),
501                    }
502                }
503                let mut terms = flattened_terms;
504
505                let mut gcd = Mul(terms.clone()).gcd() as i64;
506                if gcd == 0 {
507                    return Val(0);
508                }
509                terms = if gcd != 1 {
510                    terms
511                        .into_iter()
512                        .map(|t| {
513                            let gcd = t.gcd();
514                            (t / gcd).simplify_rec(scope, scenario, extra)
515                        })
516                        .collect()
517                } else {
518                    terms
519                };
520                if terms.iter().filter(|t| t == &&Val(-1)).count() % 2 == 1 {
521                    gcd = -gcd;
522                }
523                terms.retain(|t| !t.is_one() && t != &Val(-1));
524                terms.sort_by(tdim_lexi_order);
525
526                match (gcd, terms.len()) {
527                    (_, 0) => Val(gcd), // Case #1: If 0 variables, return product
528                    (0, _) => Val(0),   // Case #2: Result is 0 if coef is 0 (actually
529                    // unreachable as we check at the beginning)
530                    (1, 1) => terms.remove(0), // Case #3: Product is 1, so return the only term
531                    (1, _) => Mul(terms), // Case #4: Product is 1, so return the non-integer terms
532                    (_, 1) => MulInt(gcd, Box::new(terms.remove(0))), // Case #5: Single variable, convert to 1 MulInt
533                    _ => MulInt(gcd, Box::new(Mul(terms))), // Case #6: Multiple variables, convert to MulInt
534                }
535            }
536            MulInt(coef, expr) => {
537                match *expr {
538                    MulInt(c2, inner) => {
539                        if let Some(c) = coef.checked_mul(c2) {
540                            return MulInt(c, inner).simplify_rec(scope, scenario, extra);
541                        } else {
542                            return MulInt(coef, Box::new(MulInt(c2, inner)));
543                        }
544                    }
545                    Val(v) => {
546                        return coef
547                            .checked_mul(v)
548                            .map(Val)
549                            .unwrap_or_else(|| MulInt(coef, Box::new(Val(v))));
550                    }
551                    _ => {}
552                }
553
554                let simplified = expr.simplify_rec(scope, scenario, extra);
555                match (coef, simplified) {
556                    (0, _) => Val(0), // Case #1: If coef is 0, return 0
557                    (1, s) => s,      // Case #2: If coef is 1, return the simplified expression
558                    (_, Add(terms)) => Add(terms
559                        .into_iter()
560                        .map(|term| {
561                            MulInt(coef, Box::new(term)).simplify_rec(scope, scenario, extra)
562                        })
563                        .collect()), // Case #3: If expression is an addition, distribute the coef
564                    (c, Val(v)) => {
565                        c.checked_mul(v).map(Val).unwrap_or_else(|| MulInt(c, Box::new(Val(v))))
566                    } // Case #4: If expression is a value, combine coefs
567                    (c, MulInt(v, inner)) => {
568                        if let Some(cv) = c.checked_mul(v) {
569                            MulInt(cv, inner) // Case #5: If expression is a MulInt, combine coefs
570                        } else {
571                            MulInt(c, Box::new(MulInt(v, inner)))
572                        }
573                    }
574                    (_, s) => MulInt(coef, Box::new(s)), // Case #6: Otherwise, return the original
575                }
576            }
577            Div(a, q) => {
578                if q == 1 {
579                    return a.simplify_rec(scope, scenario, extra);
580                } else if let Div(a, q2) = *a {
581                    return Div(a, q * q2).simplify_rec(scope, scenario, extra);
582                }
583                let a = a.simplify_rec(scope, scenario, extra);
584                if let Val(a) = a {
585                    Val(a / q as i64)
586                } else if let MulInt(-1, a) = a {
587                    MulInt(-1, b!(Div(a, q)))
588                } else if let Add(mut terms) = a {
589                    if terms
590                        .iter()
591                        .any(|t| if let MulInt(-1, s) = t { matches!(&**s, Sym(_)) } else { false })
592                    {
593                        MulInt(
594                            -1,
595                            b!(Div(
596                                b!(Add(terms.into_iter().map(|t| MulInt(-1, b!(t))).collect())
597                                    .simplify_rec(scope, scenario, extra)),
598                                q
599                            )),
600                        )
601                    } else if let Some(v) =
602                        terms.iter().find_map(|t| if let Val(v) = t { Some(*v) } else { None })
603                    {
604                        let offset = if v >= q as i64 {
605                            Some(v / q as i64)
606                        } else if v < 0 {
607                            Some(-Integer::div_ceil(&-v, &(q as i64)))
608                        } else {
609                            None
610                        };
611                        if let Some(val) = offset {
612                            terms.push(Val(-val * q as i64));
613                            Add(vec![
614                                Val(val),
615                                Div(b!(Add(terms).simplify_rec(scope, scenario, extra)), q),
616                            ])
617                        } else {
618                            Div(b!(Add(terms)), q)
619                        }
620                    } else {
621                        Div(b!(Add(terms)), q)
622                    }
623                } else if let MulInt(p, a) = a {
624                    if p == q as i64 {
625                        a.simplify()
626                    } else {
627                        let gcd = p.abs().gcd(&(q as i64));
628                        if gcd == p {
629                            Div(a, q / gcd as u64)
630                        } else if gcd == q as i64 {
631                            MulInt(p / gcd, a)
632                        } else if gcd > 1 {
633                            Div(b!(MulInt(p / gcd, a)), q / gcd as u64)
634                                .simplify_rec(scope, scenario, extra)
635                        } else {
636                            Div(b!(MulInt(p, a)), q)
637                        }
638                    }
639                } else {
640                    Div(b!(a), q)
641                }
642            }
643            Broadcast(terms) => {
644                let mut terms: Vec<TDim> = terms
645                    .iter()
646                    .map(|s| s.clone().simplify_rec(scope, scenario, extra))
647                    .flat_map(|t| if let Broadcast(t) = t { t } else { vec![t] })
648                    .filter(|t| !t.is_one())
649                    .sorted_by(tdim_lexi_order)
650                    .dedup()
651                    .collect_vec();
652                // a#min(a,b) if a>0 && b>0 => a
653                match &*terms {
654                    [] => Val(1),
655                    [_] => terms.remove(0),
656                    [a, Min(m)] | [Min(m), a]
657                        if m.contains(a)
658                            && m.iter()
659                                .all(|t| scope.prove_strict_positive_with_extra(t, extra)) =>
660                    {
661                        a.clone()
662                    }
663                    _ => Broadcast(terms),
664                }
665            }
666
667            Min(terms) => {
668                let mut flatten: Vec<TDim> = terms
669                    .into_iter()
670                    .map(|t| t.simplify_rec(scope, scenario, extra))
671                    .flat_map(|t| if let Min(t) = t { t } else { vec![t] })
672                    .filter(|t| t != &Val(i64::MAX))
673                    .sorted_by(tdim_lexi_order)
674                    .dedup()
675                    .collect();
676                #[allow(clippy::mutable_key_type)]
677                let mut redundant = HashSet::<TDim>::default();
678                for pair in flatten.iter().permutations(2) {
679                    let (a, b) = (pair[0], pair[1]);
680                    if redundant.contains(a) || redundant.contains(b) {
681                        continue;
682                    }
683                    let diff = a.clone() - b;
684                    if diff.as_i64().is_some_and(|i| i >= 0)
685                        || scope.prove_positive_or_zero_with_extra(&diff, extra)
686                    {
687                        redundant.insert(a.clone());
688                    }
689                }
690                flatten.retain(|t| !redundant.contains(t));
691                if flatten.len() == 0 {
692                    i64::MAX.to_dim()
693                } else if flatten.len() == 1 {
694                    flatten.into_iter().next().unwrap()
695                } else {
696                    Min(flatten)
697                }
698            }
699            Max(terms) => {
700                let mut flatten: Vec<TDim> = terms
701                    .into_iter()
702                    .map(|t| t.simplify_rec(scope, scenario, extra))
703                    .flat_map(|t| if let Max(t) = t { t } else { vec![t] })
704                    .filter(|t| t != &Val(i64::MIN))
705                    .sorted_by(tdim_lexi_order)
706                    .dedup()
707                    .collect();
708                #[allow(clippy::mutable_key_type)]
709                let mut redundant = HashSet::<TDim>::default();
710                for pair in flatten.iter().permutations(2) {
711                    let (a, b) = (pair[0], pair[1]);
712                    if redundant.contains(a) || redundant.contains(b) {
713                        continue;
714                    }
715                    let diff = a.clone() - b;
716                    if diff.as_i64().is_some_and(|i| i >= 0)
717                        || scope.prove_positive_or_zero_with_extra(&diff, extra)
718                    {
719                        redundant.insert(b.clone());
720                    }
721                }
722                flatten.retain(|t| !redundant.contains(t));
723                if flatten.len() == 0 {
724                    i64::MIN.to_dim()
725                } else if flatten.len() == 1 {
726                    flatten.into_iter().next().unwrap()
727                } else {
728                    Max(flatten)
729                }
730            }
731            Sym(s) => scope
732                .assertions(scenario)
733                .find_map(|a| match a {
734                    Assertion::Eq(Sym(sym), v) if sym == &s => Some(v.clone()),
735                    _ => None,
736                })
737                .unwrap_or(Sym(s)),
738            Val(_) => self,
739            Ge(a, b) => {
740                let a = a.simplify_rec(scope, scenario, extra);
741                let b = b.simplify_rec(scope, scenario, extra);
742                match (&a, &b) {
743                    (Val(av), Val(bv)) => Val(if av >= bv { 1 } else { 0 }),
744                    _ => {
745                        let diff = a.clone() - b.clone();
746                        if scope.prove_positive_or_zero_with_extra(&diff, extra) {
747                            Val(1)
748                        } else if scope
749                            .prove_strict_positive_with_extra(&(b.clone() - a.clone()), extra)
750                        {
751                            Val(0)
752                        } else {
753                            Ge(b!(a), b!(b))
754                        }
755                    }
756                }
757            }
758            Eq(a, b) => {
759                let a = a.simplify_rec(scope, scenario, extra);
760                let b = b.simplify_rec(scope, scenario, extra);
761                match (&a, &b) {
762                    (Val(av), Val(bv)) => Val(if av == bv { 1 } else { 0 }),
763                    _ => {
764                        let diff = a.clone() - b.clone();
765                        if scope.prove_strict_positive_with_extra(&diff, extra)
766                            || scope
767                                .prove_strict_positive_with_extra(&(b.clone() - a.clone()), extra)
768                        {
769                            Val(0)
770                        } else {
771                            // When one side is 0 or 1 and the other is
772                            // provably in [0,1], reduce to boolean algebra:
773                            //   Eq(expr, 0) → 1 - expr
774                            //   Eq(expr, 1) → expr
775                            let boolean_case = match (&a, &b) {
776                                (Val(0), e) | (e, Val(0)) => Some((e, false)),
777                                (Val(1), e) | (e, Val(1)) => Some((e, true)),
778                                _ => None,
779                            };
780                            if let Some((expr, equals_one)) = boolean_case {
781                                if scope.prove_positive_or_zero_with_extra(expr, extra)
782                                    && scope.prove_positive_or_zero_with_extra(
783                                        &(Val(1) - expr.clone()),
784                                        extra,
785                                    )
786                                {
787                                    return if equals_one {
788                                        expr.clone()
789                                    } else {
790                                        (Val(1) - expr.clone()).simplify_rec(scope, scenario, extra)
791                                    };
792                                }
793                            }
794                            Eq(b!(a), b!(b))
795                        }
796                    }
797                }
798            }
799        }
800    }
801
802    pub(super) fn inclusive_bound(&self, scope: &SymbolScopeData, upper: bool) -> Option<i64> {
803        use self::TDim::*;
804        match self {
805            Val(n) => Some(*n),
806            Sym(_) => {
807                if upper {
808                    scope
809                        .all_assertions()
810                        .iter()
811                        .filter_map(|assert| match &assert {
812                            Assertion::LT(left, right)
813                                if left == self && right.as_i64().is_some() =>
814                            {
815                                Some(right.as_i64().unwrap() - 1)
816                            }
817                            Assertion::LTE(left, right)
818                                if left == self && right.as_i64().is_some() =>
819                            {
820                                Some(right.as_i64().unwrap())
821                            }
822                            _ => None,
823                        })
824                        .min()
825                } else {
826                    scope
827                        .all_assertions()
828                        .iter()
829                        .filter_map(|assert| match &assert {
830                            Assertion::GT(left, right)
831                                if left == self && right.as_i64().is_some() =>
832                            {
833                                Some(right.as_i64().unwrap() + 1)
834                            }
835                            Assertion::GTE(left, right)
836                                if left == self && right.as_i64().is_some() =>
837                            {
838                                Some(right.as_i64().unwrap())
839                            }
840                            _ => None,
841                        })
842                        .max()
843                }
844            }
845            Add(terms) => {
846                let mut bound: i64 = 0;
847                for t in terms {
848                    if let Some(b) = t.inclusive_bound(scope, upper) {
849                        bound = bound.checked_add(b)?;
850                    } else {
851                        return None;
852                    }
853                }
854                Some(bound)
855            }
856            MulInt(p, a) => match p.cmp(&0) {
857                Ordering::Equal => Some(0),
858                Ordering::Greater => {
859                    a.inclusive_bound(scope, upper).and_then(|x| x.checked_mul(*p))
860                }
861                Ordering::Less => a.inclusive_bound(scope, !upper).and_then(|x| x.checked_mul(*p)),
862            },
863            Mul(terms) => {
864                // If all factors have known non-negative bounds, we can bound the product.
865                let mut lo: i64 = 1;
866                let mut hi: i64 = 1;
867                for t in terms {
868                    let t_lo = t.inclusive_bound(scope, false)?;
869                    let t_hi = t.inclusive_bound(scope, true)?;
870                    if t_lo < 0 {
871                        return None;
872                    }
873                    lo = lo.checked_mul(t_lo)?;
874                    hi = hi.checked_mul(t_hi)?;
875                }
876                Some(if upper { hi } else { lo })
877            }
878            Min(terms) if !upper => {
879                // All terms must have known lower bounds; if any is unknown,
880                // the Min lower bound is unknown.
881                let bounds: Option<Vec<i64>> =
882                    terms.iter().map(|t| t.inclusive_bound(scope, false)).collect();
883                bounds.map(|b| b.into_iter().min().unwrap_or(i64::MAX))
884            }
885            Max(terms) if upper => {
886                // All terms must have known upper bounds; if any is unknown,
887                // the Max upper bound is unknown.
888                let bounds: Option<Vec<i64>> =
889                    terms.iter().map(|t| t.inclusive_bound(scope, true)).collect();
890                bounds.map(|b| b.into_iter().max().unwrap_or(i64::MIN))
891            }
892            Div(a, q) => a.inclusive_bound(scope, upper).map(|x| x / (*q as i64)),
893            Broadcast(terms) => {
894                if upper {
895                    Max(terms.clone()).inclusive_bound(scope, true)
896                } else {
897                    Min(terms.clone()).inclusive_bound(scope, false)
898                }
899            }
900            Ge(_, _) | Eq(_, _) => {
901                if upper {
902                    Some(1)
903                } else {
904                    Some(0)
905                }
906            }
907            _ => None,
908        }
909    }
910
911    pub fn low_inclusive_bound(&self) -> Option<i64> {
912        if let TDim::Val(v) = self {
913            return Some(*v);
914        }
915        let scope = self.find_scope()?;
916        let data = scope.0.lock();
917        let data = data.borrow();
918        self.inclusive_bound(&data, false)
919    }
920
921    pub fn high_inclusive_bound(&self) -> Option<i64> {
922        if let TDim::Val(v) = self {
923            return Some(*v);
924        }
925        let scope = self.find_scope()?;
926        let data = scope.0.lock();
927        let data = data.borrow();
928        self.inclusive_bound(&data, true)
929    }
930
931    pub fn prove_positive_or_zero(&self) -> bool {
932        if let TDim::Val(v) = self {
933            return *v >= 0;
934        }
935        let Some(scope) = self.find_scope() else { return false };
936        let data = scope.0.lock();
937        let data = data.borrow();
938        data.prove_positive_or_zero(self)
939    }
940
941    pub fn prove_strict_positive(&self) -> bool {
942        if let TDim::Val(v) = self {
943            return *v > 0;
944        }
945        (self.clone() - 1).prove_positive_or_zero()
946    }
947
948    pub fn prove_negative_or_zero(&self) -> bool {
949        if let TDim::Val(v) = self {
950            return *v <= 0;
951        }
952        self.clone().neg().prove_positive_or_zero()
953    }
954
955    pub fn prove_strict_negative(&self) -> bool {
956        if let TDim::Val(v) = self {
957            return *v < 0;
958        }
959        self.clone().neg().prove_strict_positive()
960    }
961
962    pub fn gcd(&self) -> u64 {
963        use self::TDim::*;
964        match self {
965            Val(v) => v.unsigned_abs(),
966            Sym(_) => 1,
967            Add(terms) => {
968                let (head, tail) = terms.split_first().unwrap();
969                tail.iter().fold(head.gcd(), |a, b| a.gcd(&b.gcd()))
970            }
971            MulInt(p, a) => a.gcd().saturating_mul(p.unsigned_abs()),
972            Mul(terms) => terms.iter().map(|t| t.gcd()).fold(1u64, |a, b| a.saturating_mul(b)),
973            Min(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
974            Max(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
975            Div(a, q) => {
976                if a.gcd() % *q == 0 {
977                    a.gcd() / *q
978                } else {
979                    1
980                }
981            }
982            Broadcast(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap_or(1),
983            Ge(_, _) | Eq(_, _) => 1,
984        }
985    }
986
987    fn div(&self, d: u64) -> TDim {
988        use self::TDim::*;
989        if d == 1 {
990            return self.clone();
991        }
992        match self {
993            Val(v) => Val(v / d as i64),
994            Sym(_) => panic!(),
995            Add(terms) => Add(terms.iter().map(|t| t.div(d)).collect()),
996            Min(terms) => Min(terms.iter().map(|t| t.div(d)).collect()),
997            Max(terms) => Max(terms.iter().map(|t| t.div(d)).collect()),
998            Broadcast(terms) => Broadcast(terms.iter().map(|t| t.div(d)).collect()),
999            Mul(_) => Div(Box::new(self.clone()), d),
1000            MulInt(p, a) => {
1001                if *p == d as i64 {
1002                    (**a).clone()
1003                } else {
1004                    let gcd = p.unsigned_abs().gcd(&d);
1005                    MulInt(p / gcd as i64, b!(a.div(d / gcd)))
1006                }
1007            }
1008            Div(a, q) => Div(a.clone(), q * d),
1009            Ge(_, _) | Eq(_, _) => Div(Box::new(self.clone()), d),
1010        }
1011    }
1012
1013    pub fn div_ceil(self, rhs: u64) -> TDim {
1014        TDim::Div(Box::new(Add(vec![self, Val(rhs as i64 - 1)])), rhs).reduce()
1015    }
1016
1017    pub(super) fn guess_slope(&self, sym: &Symbol) -> (i64, u64) {
1018        fn slope_rec(d: &TDim, sym: &Symbol) -> (i64, i64) {
1019            match d {
1020                Val(_) => (0, 1),
1021                Sym(s) => ((sym == s) as i64, 1),
1022                Add(terms) => terms
1023                    .iter()
1024                    .map(|d| slope_rec(d, sym))
1025                    .fold((0, 1), |a, b| ((a.0 * b.1 + a.1 * b.0), (b.1 * a.1))),
1026                Mul(terms) => terms
1027                    .iter()
1028                    .map(|d| slope_rec(d, sym))
1029                    .fold((1, 1), |a, b| ((a.0 * b.0), (b.1 * a.1))),
1030                MulInt(p, a) => {
1031                    let (n, d) = slope_rec(a, sym);
1032                    (p * n, d)
1033                }
1034                Div(a, q) => {
1035                    let (n, d) = slope_rec(a, sym);
1036                    (n, d * *q as i64)
1037                }
1038                Broadcast(terms) => slope_rec(&terms[0], sym),
1039                Min(terms) => slope_rec(&terms[0], sym),
1040                Max(terms) => slope_rec(&terms[0], sym),
1041                Ge(_, _) | Eq(_, _) => (0, 1),
1042            }
1043        }
1044        let (p, q) = slope_rec(self, sym);
1045        reduce_ratio(p, q)
1046    }
1047
1048    #[allow(clippy::mutable_key_type)]
1049    pub fn symbols(&self) -> std::collections::HashSet<Symbol> {
1050        match self {
1051            Val(_) => maplit::hashset!(),
1052            Sym(s) => maplit::hashset!(s.clone()),
1053            Add(terms) | Mul(terms) | Broadcast(terms) | Min(terms) | Max(terms) => {
1054                terms.iter().fold(maplit::hashset!(), |mut set, v| {
1055                    set.extend(v.symbols());
1056                    set
1057                })
1058            }
1059            MulInt(_, a) => a.symbols(),
1060            Div(a, _) => a.symbols(),
1061            Ge(a, b) | Eq(a, b) => {
1062                let mut set = a.symbols();
1063                set.extend(b.symbols());
1064                set
1065            }
1066        }
1067    }
1068
1069    pub fn compatible_with(&self, other: &TDim) -> bool {
1070        if let Ok(x) = (self.clone() - other).to_i64() {
1071            return x == 0;
1072        }
1073        true // maybe ? :)
1074    }
1075}
1076
1077pub(super) fn reduce_ratio(mut p: i64, mut q: i64) -> (i64, u64) {
1078    let gcd = p.abs().gcd(&q.abs());
1079    if gcd > 1 {
1080        p /= gcd;
1081        q /= gcd;
1082    }
1083    if q < 0 { (-p, (-q) as u64) } else { (p, q as u64) }
1084}
1085
1086impl Zero for TDim {
1087    fn zero() -> Self {
1088        Val(0)
1089    }
1090    fn is_zero(&self) -> bool {
1091        matches!(self, Val(0))
1092    }
1093}
1094
1095impl Default for TDim {
1096    fn default() -> TDim {
1097        Val(0)
1098    }
1099}
1100
1101impl num_traits::Bounded for TDim {
1102    fn min_value() -> Self {
1103        TDim::Val(i64::MIN)
1104    }
1105
1106    fn max_value() -> Self {
1107        TDim::Val(i64::MAX)
1108    }
1109}
1110
1111impl num_traits::One for TDim {
1112    fn one() -> Self {
1113        TDim::Val(1)
1114    }
1115}
1116
1117impl ::std::iter::Sum for TDim {
1118    fn sum<I: Iterator<Item = TDim>>(iter: I) -> TDim {
1119        iter.fold(0.into(), |a, b| a + b)
1120    }
1121}
1122
1123impl<'a> ::std::iter::Sum<&'a TDim> for TDim {
1124    fn sum<I: Iterator<Item = &'a TDim>>(iter: I) -> TDim {
1125        iter.fold(0.into(), |a, b| a + b)
1126    }
1127}
1128
1129impl std::iter::Product for TDim {
1130    fn product<I: Iterator<Item = TDim>>(iter: I) -> Self {
1131        iter.fold(TDim::Val(1), |a, b| a * b)
1132    }
1133}
1134
1135impl<'a> ::std::iter::Product<&'a TDim> for TDim {
1136    fn product<I: Iterator<Item = &'a TDim>>(iter: I) -> TDim {
1137        iter.fold(1.into(), |a, b| a * b)
1138    }
1139}
1140
1141macro_rules! from_i {
1142    ($i: ty) => {
1143        impl From<$i> for TDim {
1144            fn from(v: $i) -> TDim {
1145                TDim::Val(v as _)
1146            }
1147        }
1148        impl<'a> From<&'a $i> for TDim {
1149            fn from(v: &'a $i) -> TDim {
1150                TDim::Val(*v as _)
1151            }
1152        }
1153    };
1154}
1155
1156from_i!(i32);
1157from_i!(i64);
1158from_i!(u64);
1159from_i!(isize);
1160from_i!(usize);
1161
1162impl From<Symbol> for TDim {
1163    fn from(it: Symbol) -> Self {
1164        TDim::Sym(it)
1165    }
1166}
1167
1168impl<'a> From<&'a Symbol> for TDim {
1169    fn from(it: &'a Symbol) -> Self {
1170        TDim::Sym(it.clone())
1171    }
1172}
1173
1174impl ops::Neg for TDim {
1175    type Output = Self;
1176    fn neg(self) -> Self {
1177        if let Val(v) = self { Val(-v) } else { TDim::MulInt(-1, Box::new(self)).reduce() }
1178    }
1179}
1180
1181impl<'a> ops::AddAssign<&'a TDim> for TDim {
1182    fn add_assign(&mut self, rhs: &'a TDim) {
1183        if rhs.is_zero() {
1184        } else if self.is_zero() {
1185            *self = rhs.clone();
1186        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1187            *s += o;
1188        } else {
1189            *self = TDim::Add(vec![std::mem::take(self), rhs.clone()]).reduce()
1190        }
1191    }
1192}
1193
1194impl<I> ops::AddAssign<I> for TDim
1195where
1196    I: Into<TDim>,
1197{
1198    fn add_assign(&mut self, rhs: I) {
1199        let rhs = rhs.into();
1200        if rhs.is_zero() {
1201        } else if self.is_zero() {
1202            *self = rhs;
1203        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1204            *s += o;
1205        } else {
1206            *self = TDim::Add(vec![std::mem::take(self), rhs]).reduce()
1207        }
1208    }
1209}
1210
1211impl<I> ops::Add<I> for TDim
1212where
1213    I: Into<TDim>,
1214{
1215    type Output = Self;
1216    fn add(mut self, rhs: I) -> Self {
1217        self += rhs;
1218        self
1219    }
1220}
1221
1222impl<'a> ops::Add<&'a TDim> for TDim {
1223    type Output = Self;
1224    fn add(mut self, rhs: &'a TDim) -> Self {
1225        self += rhs;
1226        self
1227    }
1228}
1229
1230#[allow(clippy::suspicious_op_assign_impl)]
1231impl<'a> ops::SubAssign<&'a TDim> for TDim {
1232    fn sub_assign(&mut self, rhs: &'a TDim) {
1233        if rhs.is_zero() {
1234        } else if self.is_zero() {
1235            *self = rhs.clone().neg();
1236        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1237            *s -= o;
1238        } else {
1239            *self = TDim::Add(vec![std::mem::take(self), rhs.clone().neg()]).reduce()
1240        }
1241    }
1242}
1243
1244impl<I> ops::SubAssign<I> for TDim
1245where
1246    I: Into<TDim>,
1247{
1248    fn sub_assign(&mut self, rhs: I) {
1249        let rhs = rhs.into();
1250        if rhs.is_zero() {
1251        } else if self.is_zero() {
1252            *self = rhs.neg();
1253        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1254            *s -= o;
1255        } else {
1256            *self = TDim::Add(vec![std::mem::take(self), rhs.neg()]).reduce()
1257        }
1258    }
1259}
1260
1261impl<I> ops::Sub<I> for TDim
1262where
1263    I: Into<TDim>,
1264{
1265    type Output = Self;
1266    fn sub(mut self, rhs: I) -> Self {
1267        self -= rhs;
1268        self
1269    }
1270}
1271
1272impl<'a> ops::Sub<&'a TDim> for TDim {
1273    type Output = Self;
1274    fn sub(mut self, rhs: &'a TDim) -> Self {
1275        self -= rhs;
1276        self
1277    }
1278}
1279
1280impl<I: Into<TDim>> ops::MulAssign<I> for TDim {
1281    fn mul_assign(&mut self, rhs: I) {
1282        let rhs = rhs.into();
1283        if self.is_one() {
1284            *self = rhs
1285        } else if rhs.is_one() {
1286        } else {
1287            *self = TDim::Mul(vec![rhs, std::mem::take(self)]).reduce()
1288        }
1289    }
1290}
1291
1292impl<'a> ops::MulAssign<&'a TDim> for TDim {
1293    fn mul_assign(&mut self, rhs: &'a TDim) {
1294        if self.is_one() {
1295            *self = rhs.clone()
1296        } else if rhs.is_one() {
1297        } else {
1298            *self = TDim::Mul(vec![std::mem::take(self), rhs.clone()]).reduce()
1299        }
1300    }
1301}
1302
1303impl<I: Into<TDim>> ops::Mul<I> for TDim {
1304    type Output = Self;
1305    fn mul(mut self, rhs: I) -> Self {
1306        self *= rhs.into();
1307        self
1308    }
1309}
1310
1311impl<'a> ops::Mul<&'a TDim> for TDim {
1312    type Output = Self;
1313    fn mul(mut self, rhs: &'a TDim) -> Self {
1314        self *= rhs;
1315        self
1316    }
1317}
1318
1319impl<I: AsPrimitive<u64> + PrimInt> ops::DivAssign<I> for TDim {
1320    fn div_assign(&mut self, rhs: I) {
1321        *self = TDim::Div(Box::new(std::mem::take(self)), rhs.as_()).reduce()
1322    }
1323}
1324
1325impl<I: AsPrimitive<u64> + PrimInt> ops::Div<I> for TDim {
1326    type Output = Self;
1327    fn div(mut self, rhs: I) -> Self {
1328        self /= rhs.as_();
1329        self
1330    }
1331}
1332
1333impl<I: AsPrimitive<u64> + PrimInt> ops::RemAssign<I> for TDim {
1334    fn rem_assign(&mut self, rhs: I) {
1335        *self += -(self.clone() / rhs.as_() * rhs.as_());
1336    }
1337}
1338
1339impl<I: AsPrimitive<u64> + PrimInt> ops::Rem<I> for TDim {
1340    type Output = Self;
1341    fn rem(mut self, rhs: I) -> Self {
1342        self %= rhs;
1343        self
1344    }
1345}
1346
1347#[cfg(test)]
1348mod tests {
1349    use super::*;
1350
1351    macro_rules! b( ($e:expr) => { Box::new($e) } );
1352
1353    lazy_static::lazy_static! {
1354        static ref table: SymbolScope = SymbolScope::default();
1355        static ref A: Symbol = table.sym("a");
1356        static ref B: Symbol = table.sym("b");
1357        static ref C: Symbol = table.sym("c");
1358        static ref D: Symbol = table.sym("d");
1359        static ref E: Symbol = table.sym("e");
1360    }
1361
1362    fn neg(a: &TDim) -> TDim {
1363        mul(-1, a)
1364    }
1365
1366    fn add(a: &TDim, b: &TDim) -> TDim {
1367        TDim::Add(vec![a.clone(), b.clone()])
1368    }
1369
1370    fn mul(a: i64, b: &TDim) -> TDim {
1371        TDim::MulInt(a, b![b.clone()])
1372    }
1373
1374    fn div(a: &TDim, b: u64) -> TDim {
1375        TDim::Div(b!(a.clone()), b)
1376    }
1377
1378    #[test]
1379    fn reduce_add() {
1380        assert_eq!(add(&A.to_dim(), &neg(&A.to_dim())).reduce(), Val(0))
1381    }
1382
1383    #[test]
1384    fn reduce_neg_mul() {
1385        assert_eq!(neg(&mul(2, &A.to_dim())).reduce(), mul(-2, &A.to_dim()))
1386    }
1387
1388    #[test]
1389    fn reduce_cplx_ex_2() {
1390        assert_eq!(
1391            add(
1392                &add(&Val(-4), &mul(-2, &div(&A.to_dim(), 4))),
1393                &mul(-2, &mul(-1, &div(&A.to_dim(), 4)))
1394            )
1395            .reduce(),
1396            Val(-4)
1397        )
1398    }
1399
1400    #[test]
1401    fn reduce_cplx_ex_3() {
1402        assert_eq!(div(&MulInt(1, b!(MulInt(4, b!(A.to_dim())))), 4).reduce(), A.to_dim())
1403    }
1404
1405    #[test]
1406    fn reduce_cplx_ex_4() {
1407        // (S+1)/2 + (1-S)/2 == 1
1408        assert_eq!(
1409            add(&div(&add(&A.to_dim(), &Val(1)), 2), &div(&add(&neg(&A.to_dim()), &Val(1)), 2))
1410                .reduce(),
1411            1.into()
1412        );
1413    }
1414
1415    #[test]
1416    fn reduce_mul_mul_1() {
1417        assert_eq!(mul(3, &mul(2, &A.to_dim())).reduce(), mul(6, &A.to_dim()))
1418    }
1419
1420    #[test]
1421    fn reduce_mul_mul_2() {
1422        assert_eq!(mul(-2, &mul(-1, &A.to_dim())).reduce(), mul(2, &A.to_dim()))
1423    }
1424
1425    #[test]
1426    fn reduce_mul_div_1() {
1427        assert_eq!(mul(2, &div(&mul(-1, &A.to_dim()), 3)).reduce(), mul(-2, &div(&A.to_dim(), 3)))
1428    }
1429
1430    #[test]
1431    fn const_and_add() {
1432        let e: TDim = 2i64.into();
1433        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), 2);
1434        let e: TDim = TDim::from(2) + 3;
1435        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), 5);
1436        let e: TDim = TDim::from(2) - 3;
1437        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), -1);
1438        let e: TDim = -TDim::from(2);
1439        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), -2);
1440    }
1441
1442    #[test]
1443    fn substitution() {
1444        let a: TDim = A.to_dim();
1445        assert_eq!(a.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 2);
1446        let e = a + 3;
1447        assert_eq!(e.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 5);
1448    }
1449
1450    #[test]
1451    fn reduce_adds() {
1452        let e: TDim = TDim::from(2) + 1;
1453        assert_eq!(e, TDim::from(3));
1454        let e: TDim = TDim::from(3) + 2;
1455        assert_eq!(e, TDim::from(5));
1456        let e: TDim = TDim::from(3) + 0;
1457        assert_eq!(e, TDim::from(3));
1458        let e: TDim = TDim::from(3) + 2 + 1;
1459        assert_eq!(e, TDim::from(6));
1460    }
1461
1462    #[test]
1463    fn reduce_muls() {
1464        let e: TDim = Val(1) * A.to_dim();
1465        assert_eq!(e, A.to_dim());
1466        let e: TDim = A.to_dim() * &B.to_dim() * 1;
1467        assert_eq!(e, A.to_dim() * &B.to_dim());
1468    }
1469
1470    #[test]
1471    fn reduce_divs() {
1472        let e: TDim = TDim::from(2) / 1;
1473        assert_eq!(e, TDim::from(2));
1474        let e: TDim = TDim::from(3) / 2;
1475        assert_eq!(e, TDim::from(1));
1476        let e: TDim = TDim::from(3) % 2;
1477        assert_eq!(e, TDim::from(1));
1478        let e: TDim = TDim::from(5) / 2;
1479        assert_eq!(e, TDim::from(2));
1480        let e: TDim = TDim::from(5) % 2;
1481        assert_eq!(e, TDim::from(1));
1482    }
1483
1484    #[test]
1485    fn reduce_div_bug_0() {
1486        let e1: TDim = (A.to_dim() + 23) / 2 - 1;
1487        let e2: TDim = (A.to_dim() + 21) / 2;
1488        assert_eq!(e1, e2);
1489    }
1490
1491    #[test]
1492    fn reduce_div_bug_1() {
1493        let e1: TDim = (A.to_dim() + -1) / 2;
1494        let e2: TDim = (A.to_dim() + 1) / 2 - 1;
1495        assert_eq!(e1, e2);
1496    }
1497
1498    #[test]
1499    fn reduce_div_bug_2() {
1500        let e1: TDim = ((A.to_dim() + 1) / 2 + 1) / 2;
1501        let e2: TDim = (A.to_dim() + 3) / 4;
1502        assert_eq!(e1, e2);
1503    }
1504
1505    #[test]
1506    fn reduce_div_bug_3() {
1507        let e1: TDim = (A.to_dim() / 2) * -4;
1508        let e2: TDim = (A.to_dim() / 2) * -4 / 1;
1509        assert_eq!(e1, e2);
1510    }
1511
1512    #[test]
1513    fn reduce_mul_div() {
1514        let e: TDim = A.to_dim() * 2 / 2;
1515        assert_eq!(e, A.to_dim());
1516    }
1517
1518    #[test]
1519    fn reduce_div_mul() {
1520        let e: TDim = A.to_dim() / 2 * 2;
1521        assert_ne!(e, A.to_dim());
1522    }
1523
1524    #[test]
1525    fn reduce_add_div() {
1526        let e: TDim = A.to_dim() / 2 + 1;
1527        assert_eq!(e, ((A.to_dim() + 2) / 2));
1528    }
1529
1530    #[test]
1531    fn reduce_neg_mul_() {
1532        let e: TDim = TDim::from(1) - A.to_dim() * 2;
1533        assert_eq!(e, TDim::from(1) + A.to_dim() * -2);
1534    }
1535
1536    #[test]
1537    fn reduce_add_rem_1() {
1538        assert_eq!(((A.to_dim() + 4) % 2), (A.to_dim() % 2));
1539    }
1540
1541    #[test]
1542    fn reduce_add_rem_2() {
1543        assert_eq!(((A.to_dim() - 4) % 2), (A.to_dim() % 2));
1544    }
1545
1546    #[test]
1547    fn reduce_rem_div() {
1548        let e: TDim = A.to_dim() % 2 / 2;
1549        assert_eq!(e, TDim::from(0));
1550    }
1551
1552    #[test]
1553    fn conv2d_ex_1() {
1554        let e = (TDim::from(1) - 1 + 1).div_ceil(1);
1555        assert_eq!(e, TDim::from(1));
1556    }
1557
1558    #[test]
1559    fn conv2d_ex_2() {
1560        let e = (A.to_dim() - 3 + 1).div_ceil(1);
1561        assert_eq!(e, A.to_dim() + -2);
1562    }
1563
1564    #[test]
1565    fn extract_int_gcd_from_muls() {
1566        let term = (A.to_dim() + 1) / 4;
1567        let mul = (term.clone() * 24 - 24) * (term.clone() * 2 - 2);
1568        let target = (term.clone() - 1) * (term.clone() - 1) * 48;
1569        assert_eq!(mul, target);
1570    }
1571
1572    #[test]
1573    fn equality_of_muls() {
1574        let term = (A.to_dim() + 1) / 4;
1575        let mul1 = (term.clone() * 2 - 3) * (term.clone() - 1);
1576        let mul2 = (term.clone() - 1) * (term.clone() * 2 - 3);
1577        assert_eq!(mul1, mul2);
1578    }
1579
1580    #[test]
1581    fn factorize_complex_expr_times_int() {
1582        let term = (A.to_dim() + 1) / 4;
1583        let e = term.clone() * 2 - &term - 1;
1584        assert_eq!(e, term - 1);
1585    }
1586
1587    #[test]
1588    fn broadcast_over_min() {
1589        // assuming a>0, b>0 then a#min(a,b) can be replaced by a
1590        // proof:
1591        //    if b == 1 => min(a,b)=1 => a#1=a => ok
1592        //    if a <= b => min(a,b)=a => ok
1593        //    if 1 < B < A => expression was invalid, we're generalizing over the non-domain and ignoring the constraint
1594        for a in 1..5 {
1595            for b in 1..5 {
1596                if b > 1 && a > b {
1597                    assert!(a.broadcast(a.min(b)).is_err());
1598                } else {
1599                    assert_eq!(a.broadcast(a.min(b)).unwrap(), a);
1600                }
1601            }
1602        }
1603    }
1604
1605    #[test]
1606    fn min_ints_1() {
1607        assert_eq!(2.to_dim().mini(1.to_dim()), 1.to_dim());
1608    }
1609
1610    #[test]
1611    fn min_ints_2() {
1612        assert_eq!(1.to_dim().mini(2.to_dim()), 1.to_dim());
1613    }
1614
1615    #[test]
1616    fn min_same() {
1617        assert_eq!(A.to_dim().mini(A.to_dim()), A.to_dim());
1618    }
1619
1620    #[test]
1621    fn min_noop() {
1622        assert_eq!(A.to_dim().mini(1.to_dim()), A.to_dim().mini(1.to_dim()));
1623    }
1624
1625    #[test]
1626    fn min_diff_1() {
1627        assert_eq!((A.to_dim() + 1).mini(A.to_dim() + 2), A.to_dim() + 1);
1628    }
1629
1630    #[test]
1631    fn slope_0() {
1632        assert_eq!(12.to_dim().guess_slope(&A), (0, 1));
1633    }
1634
1635    #[test]
1636    fn slope_1() {
1637        assert_eq!(A.to_dim().guess_slope(&A), (1, 1));
1638    }
1639
1640    #[test]
1641    fn slope_2() {
1642        assert_eq!((A.to_dim() * 2).guess_slope(&A), (2, 1));
1643    }
1644
1645    #[test]
1646    fn slope_3() {
1647        assert_eq!((A.to_dim() * 2 + A.to_dim() / 2).guess_slope(&A), (5, 2));
1648    }
1649
1650    #[test]
1651    fn slope_4() {
1652        assert_eq!((A.to_dim()).guess_slope(&B), (0, 1));
1653    }
1654
1655    #[test]
1656    fn slope_5() {
1657        assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1));
1658        assert_eq!((A.to_dim() + 1).guess_slope(&B), (0, 1));
1659    }
1660
1661    #[test]
1662    fn slope_6() {
1663        assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1));
1664        assert_eq!((A.to_dim() + B.to_dim()).guess_slope(&B), (1, 1));
1665    }
1666
1667    #[test]
1668    fn min_0() -> TractResult<()> {
1669        let symbols = SymbolScope::default();
1670        assert_eq!(
1671            symbols.parse_tdim("min(S+3, S+2)").unwrap().simplify(),
1672            symbols.parse_tdim("S+2").unwrap(),
1673        );
1674        Ok(())
1675    }
1676
1677    #[test]
1678    fn commutative_mul_parens() -> TractResult<()> {
1679        let symbols = SymbolScope::default();
1680        assert_eq!(
1681            symbols.parse_tdim("A*(B*C)").unwrap().simplify(),
1682            symbols.parse_tdim("(B*A)*C").unwrap().simplify(),
1683        );
1684        Ok(())
1685    }
1686
1687    #[test]
1688    fn commutative_in_nemo_parakeet_model() -> TractResult<()> {
1689        let symbols = SymbolScope::default();
1690        assert_eq!(
1691            symbols
1692                .parse_tdim("8*(1+-1*max(0,5000+-1*(S+7)/8)+max(0,4999+(S+7)/8))*((B)*((S+7)/8))")
1693                .unwrap()
1694                .simplify(),
1695            symbols
1696                .parse_tdim("8*((B)*(1+-1*max(0,5000+-1*(S+7)/8)+max(0,4999+(S+7)/8)))*((S+7)/8)")
1697                .unwrap()
1698                .simplify(),
1699        );
1700        Ok(())
1701    }
1702
1703    #[test]
1704    fn commutative_mul_parens_deep() -> TractResult<()> {
1705        let symbols = SymbolScope::default();
1706        let deep_tdim = Mul(vec![
1707            Mul(vec![Mul(vec![Mul(vec![A.to_dim(), B.to_dim()]), C.to_dim()]), D.to_dim()]),
1708            E.to_dim(),
1709        ])
1710        .simplify();
1711        assert_eq!(deep_tdim, symbols.parse_tdim("a*b*c*d*e").unwrap().simplify());
1712        Ok(())
1713    }
1714
1715    // ---- Tests for new comparison/not TDim variants ----
1716
1717    #[test]
1718    fn ge_concrete_true() {
1719        assert_eq!(Ge(b!(Val(5)), b!(Val(3))).reduce(), Val(1));
1720    }
1721
1722    #[test]
1723    fn ge_concrete_false() {
1724        assert_eq!(Ge(b!(Val(2)), b!(Val(3))).reduce(), Val(0));
1725    }
1726
1727    #[test]
1728    fn lt_concrete_true() {
1729        // Lt(2,3) normalizes to Ge(3, 2+1) = Ge(3, 3)
1730        assert_eq!(Ge(b!(Val(3)), b!(Val(3))).reduce(), Val(1));
1731    }
1732
1733    #[test]
1734    fn lt_concrete_false() {
1735        // Lt(5,3) normalizes to Ge(3, 5+1) = Ge(3, 6)
1736        assert_eq!(Ge(b!(Val(3)), b!(Val(6))).reduce(), Val(0));
1737    }
1738
1739    #[test]
1740    fn eq_concrete_true() {
1741        assert_eq!(Eq(b!(Val(3)), b!(Val(3))).reduce(), Val(1));
1742    }
1743
1744    #[test]
1745    fn eq_concrete_false() {
1746        assert_eq!(Eq(b!(Val(3)), b!(Val(4))).reduce(), Val(0));
1747    }
1748
1749    #[test]
1750    fn not_val_0() {
1751        // not(0) = 1 - 0 = 1
1752        assert_eq!((Val(1) - Val(0)).reduce(), Val(1));
1753    }
1754
1755    #[test]
1756    fn not_val_1() {
1757        // not(1) = 1 - 1 = 0
1758        assert_eq!((Val(1) - Val(1)).reduce(), Val(0));
1759    }
1760
1761    #[test]
1762    fn not_lt_becomes_ge() {
1763        // not(Lt(x1, T)) = 1 - Ge(T, x1+1); check it evaluates correctly at boundary
1764        let s = SymbolScope::default();
1765        let t = s.sym("T");
1766        let x1 = s.sym("x1");
1767        // at x1 = T (boundary), Ge(T, T+1) = 0, so 1 - 0 = 1 (not-lt is true when x1 >= T)
1768        let expr = Val(1) - Ge(b!(Sym(t.clone())), b!(Sym(x1.clone()) + Val(1)));
1769        let at_boundary = expr.substitute(&x1, &Sym(t.clone())).unwrap().simplify();
1770        assert_eq!(at_boundary, Val(1));
1771    }
1772
1773    #[test]
1774    fn eq_with_assertion_proves_false() {
1775        // Eq(T, 0) should reduce to Val(0) when T >= 1
1776        let s = SymbolScope::default();
1777        s.add_assertion("T >= 1").unwrap();
1778        let t = s.sym("T");
1779        let expr = Eq(b!(Sym(t)), b!(Val(0)));
1780        assert_eq!(expr.simplify(), Val(0));
1781    }
1782
1783    #[test]
1784    fn ge_coord_at_extremes() {
1785        // Ge(x1, T) should not simplify without coordinate substitution
1786        let s = SymbolScope::default();
1787        s.add_assertion("T >= 1").unwrap();
1788        let t = s.sym("T");
1789        let x1 = s.sym("x1");
1790        let expr = Ge(b!(Sym(x1.clone())), b!(Sym(t.clone())));
1791        // simplify() alone can't prove this false (x1 could be > T)
1792        // but with coordinate substitution (x1 = T-1), Ge(T-1, T) = 0
1793        let at_max = expr.substitute(&x1, &(Sym(t.clone()) - Val(1))).unwrap().simplify();
1794        assert_eq!(at_max, Val(0));
1795    }
1796
1797    #[test]
1798    fn eval_to_i64_new_variants() {
1799        use super::super::sym::SymbolValues;
1800        let sv = SymbolValues::default();
1801        assert_eq!(Ge(b!(Val(5)), b!(Val(3))).eval_to_i64(&sv).unwrap(), 1);
1802        assert_eq!(Ge(b!(Val(3)), b!(Val(5))).eval_to_i64(&sv).unwrap(), 0);
1803        assert_eq!(Eq(b!(Val(3)), b!(Val(3))).eval_to_i64(&sv).unwrap(), 1);
1804        assert_eq!(Eq(b!(Val(3)), b!(Val(4))).eval_to_i64(&sv).unwrap(), 0);
1805    }
1806
1807    #[test]
1808    fn eq_boolean_simplifies() {
1809        let s = SymbolScope::default();
1810        s.add_assertion("cw >= 0").unwrap();
1811        s.add_assertion("cw <= 1").unwrap();
1812        let cw = s.sym("cw");
1813        // Eq(1 - cw, 0) → cw
1814        assert_eq!(Eq(b!(Val(1) - Sym(cw.clone())), b!(Val(0))).simplify(), Sym(cw.clone()));
1815        // Eq(cw, 0) → 1 - cw
1816        assert_eq!(Eq(b!(Sym(cw.clone())), b!(Val(0))).simplify(), Val(1) - Sym(cw.clone()));
1817        // Eq(cw, 1) → cw
1818        assert_eq!(Eq(b!(Sym(cw.clone())), b!(Val(1))).simplify(), Sym(cw.clone()));
1819        // Eq(1 - cw, 1) → 1 - cw
1820        assert_eq!(Eq(b!(Val(1) - Sym(cw.clone())), b!(Val(1))).simplify(), Val(1) - Sym(cw));
1821    }
1822
1823    #[test]
1824    fn eq_boolean_mul_of_ge() {
1825        // Product of Ge terms: Ge(a,b) * Ge(c,d) is in [0,1]
1826        // so Eq(product, 0) should simplify to 1 - product
1827        let s = SymbolScope::default();
1828        let x = s.sym("x");
1829        let product =
1830            Mul(vec![Ge(b!(Val(2)), b!(Sym(x.clone()))), Ge(b!(Sym(x.clone())), b!(Val(0)))]);
1831        let eq = Eq(b!(product.clone()), b!(Val(0)));
1832        assert_eq!(eq.simplify(), Val(1) - product);
1833    }
1834
1835    #[test]
1836    fn min_1_max_0_sym() {
1837        // Min(1, Max(0, X)) must not simplify away the Min when X is unconstrained.
1838        let s = SymbolScope::default();
1839        let x = s.sym("X");
1840        let expr = Min(vec![Val(1), Max(vec![Val(0), Sym(x)])]);
1841        let simplified = expr.simplify();
1842        eprintln!("simplified: {simplified}");
1843        assert!(format!("{simplified}").contains("min"), "Min dropped: {simplified}");
1844    }
1845
1846    #[test]
1847    fn min_preserved_in_subtraction_parts() {
1848        // Test that Min([1, X]) simplifies correctly in isolation
1849        let s = SymbolScope::default();
1850        let t = s.sym("T");
1851        let p = s.sym("P");
1852        let ss = s.sym("S");
1853
1854        let cum_after =
1855            Max(vec![Val(0), (Sym(t.clone()) + Val(1)) * Sym(p.clone()) - Sym(ss.clone())]);
1856        let min_after = Min(vec![Val(1), cum_after.clone()]);
1857        let simplified = min_after.simplify();
1858        eprintln!("min_after simplified: {simplified}");
1859        // Must contain "min" — the Min must not be dropped
1860        assert!(format!("{simplified}").contains("min"), "Min wrapper was dropped: {simplified}");
1861    }
1862
1863    #[test]
1864    fn min_preserved_in_subtraction() {
1865        // min(1, X) - min(1, Y) must preserve the min() wrappers.
1866        // This is the pattern used by PulseV2Pad's output_facts for after-padding.
1867        let s = SymbolScope::default();
1868        let t = s.sym("T");
1869        let p = s.sym("P");
1870        let ss = s.sym("S");
1871
1872        let cum_after =
1873            Max(vec![Val(0), (Sym(t.clone()) + Val(1)) * Sym(p.clone()) - Sym(ss.clone())]);
1874        let cum_before = Max(vec![Val(0), Sym(t.clone()) * Sym(p.clone()) - Sym(ss.clone())]);
1875
1876        let ap = Min(vec![Val(1), cum_after.clone()]) - Min(vec![Val(1), cum_before.clone()]);
1877        let simplified = ap.simplify();
1878
1879        // At T=1, P=4, S=3: min(1, max(0, 8-3)) - min(1, max(0, 4-3)) = 1 - 1 = 0
1880        use super::super::sym::SymbolValues;
1881        let sv = SymbolValues::default().with(&t, 1).with(&p, 4).with(&ss, 3);
1882        assert_eq!(simplified.eval_to_i64(&sv).unwrap(), 0, "simplified: {simplified}");
1883
1884        // At T=0, P=4, S=3: min(1, max(0, 4-3)) - min(1, max(0, 0-3)) = 1 - 0 = 1
1885        let sv = SymbolValues::default().with(&t, 0).with(&p, 4).with(&ss, 3);
1886        assert_eq!(simplified.eval_to_i64(&sv).unwrap(), 1, "simplified: {simplified}");
1887
1888        // At T=0, P=1, S=1: min(1, max(0, 1-1)) - min(1, max(0, 0-1)) = 0 - 0 = 0
1889        let sv = SymbolValues::default().with(&t, 0).with(&p, 1).with(&ss, 1);
1890        assert_eq!(simplified.eval_to_i64(&sv).unwrap(), 0, "simplified: {simplified}");
1891    }
1892
1893    #[test]
1894    fn mul_neg_b_by_8() {
1895        let s = SymbolScope::default();
1896        let b = Sym(s.sym("B"));
1897        // 8*(-1*B) should equal -8*B
1898        let a = Mul(vec![Val(8), MulInt(-1, Box::new(b.clone()))]);
1899        let c = MulInt(-8, Box::new(b.clone()));
1900        let a_s = a.simplify();
1901        let c_s = c.simplify();
1902        assert_eq!(a_s, c_s, "8*(-1*B) should simplify the same as -8*B");
1903    }
1904}