Skip to main content

tract_data/dim/
tree.rs

1use crate::dim::Assertion;
2use crate::internal::*;
3
4use super::{DimLike, sym::*};
5use itertools::Itertools;
6use num_integer::Integer;
7use num_traits::{AsPrimitive, PrimInt, Zero};
8use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10use std::fmt::Debug;
11use std::ops::Neg;
12use std::{fmt, ops};
13
14#[derive(Debug)]
15pub enum TooEarly {
16    UndeterminedSymbol(String),
17    Other(String),
18}
19
20impl std::fmt::Display for TooEarly {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        match self {
23            TooEarly::UndeterminedSymbol(s) => write!(f, "Undetermined symbol in expression: {s}"),
24            TooEarly::Other(s) => write!(f, "{s}"),
25        }
26    }
27}
28
29impl std::error::Error for TooEarly {}
30
31macro_rules! b( ($e:expr) => { Box::new($e) } );
32
33// `Hash` stays structural while `PartialEq` accepts an algebraic second chance:
34// see the `PartialEq` impl below for the rationale (the simplifier's internal
35// `HashMap<TDim, _>` only ever compares within same-canonical-form buckets, so
36// the standard `a == b => hash(a) == hash(b)` contract being violated outside
37// that path is acceptable here).
38#[allow(clippy::derived_hash_with_manual_eq)]
39#[derive(Clone, Eq, Hash, Debug)]
40pub enum TDim {
41    Val(i64),
42    Sym(Symbol),
43    Add(Vec<TDim>),
44    Mul(Vec<TDim>),
45    MulInt(i64, Box<TDim>),
46    Div(Box<TDim>, u64),
47    Broadcast(Vec<TDim>),
48    Min(Vec<TDim>),
49    Max(Vec<TDim>),
50    /// Comparison: evaluates to 1 (true) or 0 (false). lhs >= rhs
51    Ge(Box<TDim>, Box<TDim>),
52    /// Comparison: evaluates to 1 (true) or 0 (false). lhs == rhs
53    Eq(Box<TDim>, Box<TDim>),
54}
55
56use TDim::*;
57
58/// Structural equality on the TDim tree — what `#[derive(PartialEq)]` would
59/// produce.  Used as the fast-path inside `PartialEq` (and by the simplifier's
60/// internal `HashMap<TDim, _>`, which compares within same-hash buckets where
61/// structural equality is the only thing that matters).
62fn eq_structural(a: &TDim, b: &TDim) -> bool {
63    match (a, b) {
64        (Val(x), Val(y)) => x == y,
65        (Sym(x), Sym(y)) => x == y,
66        (Add(x), Add(y))
67        | (Mul(x), Mul(y))
68        | (Broadcast(x), Broadcast(y))
69        | (Min(x), Min(y))
70        | (Max(x), Max(y)) => {
71            x.len() == y.len() && x.iter().zip(y).all(|(a, b)| eq_structural(a, b))
72        }
73        (MulInt(p, x), MulInt(q, y)) => p == q && eq_structural(x, y),
74        (Div(x, p), Div(y, q)) => p == q && eq_structural(x, y),
75        (Ge(a, b), Ge(c, d)) | (Eq(a, b), Eq(c, d)) => eq_structural(a, c) && eq_structural(b, d),
76        _ => false,
77    }
78}
79
80// Thread-local guard: while simplifying the difference inside `eq`, fall back
81// to the structural-only path for any nested `==` calls.  Without this guard
82// the simplifier's internal `HashMap<TDim, i64>` would re-enter `eq` from
83// inside `(self - other).simplify()`, recursing without bound on
84// non-structurally-equal inputs.
85std::thread_local! {
86    static EQ_GUARD: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
87}
88
89impl PartialEq for TDim {
90    fn eq(&self, other: &Self) -> bool {
91        // Fast path: structural tree equality.
92        if eq_structural(self, other) {
93            return true;
94        }
95        // Inside an enclosing simplification triggered by a previous
96        // second-chance call, fall back to structural equality only.
97        if EQ_GUARD.with(|g| g.get()) {
98            return false;
99        }
100        // Skip second-chance when either side is a leaf (`Val` or `Sym`).
101        // For `Val(c)` vs anything non-`Val`: if they were semantically
102        // equal, the simplifier should already have folded the other side
103        // to `Val(c)`; running a diff-and-simplify here just risks
104        // arithmetic overflow on extreme constants (e.g. the simplifier
105        // filters against `Val(i64::MAX)`/`Val(i64::MIN)` sentinels).
106        // For `Sym(x)` leaves, assertion-driven equality belongs in
107        // `simplify`, not in `eq`.
108        if matches!(self, Val(_) | Sym(_)) || matches!(other, Val(_) | Sym(_)) {
109            return false;
110        }
111        // Second chance: prove the difference simplifies to zero.  Two
112        // algebraically equal TDims often arrive at different canonical
113        // forms via different construction paths (e.g. `1 + (7S+3)/4` and
114        // `((S+1)*7)/4` after blockify substitutes T → k·S in encoder
115        // shapes).  Subtracting and simplifying lets the existing
116        // simplifier rules cancel them out.
117        EQ_GUARD.with(|g| g.set(true));
118        let diff = (self.clone() - other.clone()).simplify();
119        EQ_GUARD.with(|g| g.set(false));
120        matches!(diff, Val(0))
121    }
122}
123
124fn tdim_lexi_order(a: &TDim, b: &TDim) -> Ordering {
125    match (a, b) {
126        (Sym(a), Sym(b)) => a.cmp(b),
127        (Val(a), Val(b)) => a.cmp(b),
128        (Add(a), Add(b))
129        | (Mul(a), Mul(b))
130        | (Broadcast(a), Broadcast(b))
131        | (Min(a), Min(b))
132        | (Max(a), Max(b)) => a.len().cmp(&b.len()).then(
133            a.iter()
134                .zip(b.iter())
135                .fold(Ordering::Equal, |acc, (a, b)| acc.then_with(|| tdim_lexi_order(a, b))),
136        ),
137        (MulInt(p, d), MulInt(q, e)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)),
138        (Div(d, p), Div(e, q)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)),
139        (Sym(_), _) => Ordering::Less,
140        (_, Sym(_)) => Ordering::Greater,
141        (Val(_), _) => Ordering::Less,
142        (_, Val(_)) => Ordering::Greater,
143        (Add(_), _) => Ordering::Less,
144        (_, Add(_)) => Ordering::Greater,
145        (Mul(_), _) => Ordering::Less,
146        (_, Mul(_)) => Ordering::Greater,
147        (MulInt(_, _), _) => Ordering::Less,
148        (_, MulInt(_, _)) => Ordering::Greater,
149        (Broadcast(_), _) => Ordering::Less,
150        (_, Broadcast(_)) => Ordering::Greater,
151        (Min(_), _) => Ordering::Less,
152        (_, Min(_)) => Ordering::Greater,
153        (Max(_), _) => Ordering::Less,
154        (_, Max(_)) => Ordering::Greater,
155        (Ge(a1, b1), Ge(a2, b2)) | (Eq(a1, b1), Eq(a2, b2)) => {
156            tdim_lexi_order(a1, a2).then_with(|| tdim_lexi_order(b1, b2))
157        }
158        (Ge(_, _) | Eq(_, _), _) => Ordering::Less,
159        (_, Ge(_, _) | Eq(_, _)) => Ordering::Greater,
160    }
161}
162
163/// `Div(Add(terms), q)` — try to extract every `MulInt(c, X)` where `c % q == 0`
164/// out of the Div, leaving only a constant remainder in `[0, q)`.
165///
166/// Returns `Some(simplified)` when the residual constant is in `[0, q)` and
167/// every extracted symbolic factor `X` is provably non-negative — both
168/// conditions are required for soundness under tract's truncating
169/// division (`Rust i64 /`):
170///
171/// * the constant being in `[0, q)` makes `c/q_trunc = 0`;
172/// * `X ≥ 0` makes the identity `(k·X + c)/k_trunc = X` hold (it fails
173///   at e.g. `X = -1, k = 2, c = 0` because truncation rounds toward zero).
174///
175/// The `Val` arm above already handles constants outside `[0, q)`, so by
176/// the time we get here `terms` contains at most one `Val` and any number
177/// of `MulInt(c, X)` / other shapes.
178fn try_divide_multiple_plus_remainder(
179    terms: &[TDim],
180    q: u64,
181    scope: &SymbolScopeData,
182    extra: &[Assertion],
183) -> Option<TDim> {
184    let mut quotients: Vec<TDim> = vec![];
185    let mut const_rem: i64 = 0;
186    let mut any_extracted = false;
187    for term in terms {
188        match term {
189            MulInt(c, x) if *c != 0 && c.rem_euclid(q as i64) == 0 => {
190                if !scope.prove_positive_or_zero_with_extra(x, extra) {
191                    return None;
192                }
193                let new_coeff = c / (q as i64);
194                quotients.push(if new_coeff == 1 {
195                    (**x).clone()
196                } else if new_coeff == -1 {
197                    MulInt(-1, x.clone())
198                } else {
199                    MulInt(new_coeff, x.clone())
200                });
201                any_extracted = true;
202            }
203            Val(v) => const_rem += v,
204            _ => return None,
205        }
206    }
207    if !any_extracted {
208        return None;
209    }
210    if !(0..q as i64).contains(&const_rem) {
211        return None;
212    }
213    Some(if quotients.len() == 1 { quotients.remove(0) } else { Add(quotients) })
214}
215
216impl fmt::Display for TDim {
217    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
218        match &self {
219            Sym(sym) => write!(fmt, "{sym}"),
220            Val(it) => write!(fmt, "{it}"),
221            Add(it) => write!(fmt, "{}", it.iter().map(|x| format!("{x}")).join("+")),
222            Mul(it) => write!(fmt, "{}", it.iter().map(|x| format!("({x})")).join("*")),
223            Broadcast(it) => {
224                write!(fmt, "broadcast({})", it.iter().map(|x| format!("({x})")).join(", "))
225            }
226            Min(it) => write!(fmt, "min({})", it.iter().map(|x| format!("{x}")).join(",")),
227            Max(it) => write!(fmt, "max({})", it.iter().map(|x| format!("{x}")).join(",")),
228            MulInt(a, b) => write!(fmt, "{a}*{b}"),
229            Div(a, b) => write!(fmt, "({a})/{b}"),
230            Ge(a, b) => write!(fmt, "({a}>={b})"),
231            Eq(a, b) => write!(fmt, "({a}=={b})"),
232        }
233    }
234}
235
236impl TDim {
237    #[inline]
238    pub fn is_one(&self) -> bool {
239        matches!(self, Val(1))
240    }
241
242    #[inline]
243    pub fn to_i64(&self) -> TractResult<i64> {
244        if let Val(v) = self {
245            Ok(*v)
246        } else {
247            Err(TooEarly::UndeterminedSymbol(self.to_string()))?
248        }
249    }
250
251    #[inline]
252    pub fn as_i64(&self) -> Option<i64> {
253        if let Val(v) = self { Some(*v) } else { None }
254    }
255
256    pub fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64> {
257        match self {
258            Sym(sym) => {
259                let Some(v) = values.get(sym) else {
260                    Err(TooEarly::UndeterminedSymbol(self.to_string()))?
261                };
262                Ok(v)
263            }
264            Val(v) => Ok(*v),
265            Add(terms) => terms.iter().try_fold(0i64, |acc, it| {
266                let x = it.eval_to_i64(values)?;
267                acc.checked_add(x)
268                    .with_context(|| format!("Overflow in TDim addition ({acc} + {x})"))
269            }),
270            Mul(terms) => terms.iter().try_fold(1i64, |acc, it| {
271                let x = it.eval_to_i64(values)?;
272                acc.checked_mul(x)
273                    .with_context(|| format!("Overflow in TDim multiplication ({acc} * {x})"))
274            }),
275            Min(terms) => terms
276                .iter()
277                .try_fold(i64::MAX, |acc, it| it.eval_to_i64(values).map(|x| acc.min(x))),
278            Max(terms) => terms
279                .iter()
280                .try_fold(i64::MIN, |acc, it| it.eval_to_i64(values).map(|x| acc.max(x))),
281            Broadcast(terms) => terms.iter().try_fold(1i64, |acc, it| {
282                it.eval_to_i64(values)
283                    .and_then(|x| ((acc as usize).broadcast(x as usize)).map(|x| x as i64))
284            }),
285            Div(a, q) => Ok(a.eval_to_i64(values)? / *q as i64),
286            MulInt(p, a) => {
287                let x = a.eval_to_i64(values)?;
288                x.checked_mul(*p)
289                    .with_context(|| format!("Overflow in TDim multiplication ({x} * {p})"))
290            }
291            Ge(a, b) => Ok(if a.eval_to_i64(values)? >= b.eval_to_i64(values)? { 1 } else { 0 }),
292            Eq(a, b) => Ok(if a.eval_to_i64(values)? == b.eval_to_i64(values)? { 1 } else { 0 }),
293        }
294    }
295
296    pub fn eval(&self, values: &SymbolValues) -> TDim {
297        match self {
298            Sym(sym) => values.get(sym).map(Val).unwrap_or_else(|| Sym(sym.clone())),
299            Val(v) => Val(*v),
300            Add(terms) => terms.iter().fold(Val(0), |acc, it| -> TDim { acc + it.eval(values) }),
301            Mul(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim { acc * it.eval(values) }),
302            Min(terms) => {
303                terms.iter().fold(Val(i64::MAX), |acc, it| -> TDim { acc.mini(it.eval(values)) })
304            }
305            Max(terms) => {
306                terms.iter().fold(Val(i64::MIN), |acc, it| -> TDim { acc.maxi(it.eval(values)) })
307            }
308            Broadcast(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim {
309                acc.broadcast(it.eval(values)).unwrap_or_else(|_| self.clone())
310            }),
311            Div(a, q) => a.eval(values) / *q as i64,
312            MulInt(p, a) => a.eval(values) * *p,
313            Ge(a, b) => {
314                let a2 = a.eval(values);
315                let b2 = b.eval(values);
316                if let (Val(av), Val(bv)) = (&a2, &b2) {
317                    Val(if av >= bv { 1 } else { 0 })
318                } else {
319                    Ge(b!(a2), b!(b2))
320                }
321            }
322            Eq(a, b) => {
323                let a2 = a.eval(values);
324                let b2 = b.eval(values);
325                if let (Val(av), Val(bv)) = (&a2, &b2) {
326                    Val(if av == bv { 1 } else { 0 })
327                } else {
328                    Eq(b!(a2), b!(b2))
329                }
330            }
331        }
332    }
333
334    pub fn eval_with_scenario(&self, scenario: &str) -> TDim {
335        if let Val(v) = self {
336            return Val(*v);
337        }
338        let scope = self.find_scope().unwrap();
339        let scope = scope.0;
340        let locked = scope.lock();
341        let scope = locked.borrow();
342        self.clone().simplify_rec(&scope, Some(scenario), &[])
343    }
344
345    pub fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self> {
346        self.substitute_all(&std::collections::HashMap::from([(from.clone(), to.clone())]))
347    }
348
349    pub fn substitute_all(
350        &self,
351        map: &std::collections::HashMap<Symbol, Self>,
352    ) -> TractResult<Self> {
353        match self {
354            Sym(sym) => Ok(map.get(sym).cloned().unwrap_or_else(|| self.clone())),
355            Val(v) => Ok(Val(*v)),
356            Add(terms) => terms.iter().try_fold(Val(0), |acc, it| -> TractResult<TDim> {
357                Ok(acc + it.substitute_all(map)?)
358            }),
359            Mul(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult<TDim> {
360                Ok(acc * it.substitute_all(map)?)
361            }),
362            Broadcast(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult<TDim> {
363                acc.broadcast(it.substitute_all(map)?)
364            }),
365            Min(terms) => terms.iter().try_fold(Val(i64::MAX), |acc, it| -> TractResult<TDim> {
366                Ok(acc.mini(it.substitute_all(map)?))
367            }),
368            Max(terms) => terms.iter().try_fold(Val(i64::MIN), |acc, it| -> TractResult<TDim> {
369                Ok(acc.maxi(it.substitute_all(map)?))
370            }),
371            Div(a, q) => Ok(a.substitute_all(map)? / *q as i64),
372            MulInt(p, a) => Ok(a.substitute_all(map)? * *p),
373            Ge(a, b) => Ok(Ge(b!(a.substitute_all(map)?), b!(b.substitute_all(map)?))),
374            Eq(a, b) => Ok(Eq(b!(a.substitute_all(map)?), b!(b.substitute_all(map)?))),
375        }
376    }
377
378    pub fn reduce(self) -> TDim {
379        self.simplify()
380            .wiggle()
381            .into_iter()
382            .sorted_by(tdim_lexi_order)
383            .unique()
384            .map(|e| e.simplify())
385            .min_by_key(|e| e.cost())
386            .unwrap()
387    }
388
389    fn cost(&self) -> usize {
390        use self::TDim::*;
391        match self {
392            Sym(_) | Val(_) => 1,
393            Add(terms) => 2 * terms.iter().map(TDim::cost).sum::<usize>(),
394            Mul(terms) => 3 * terms.iter().map(TDim::cost).sum::<usize>(),
395            Broadcast(terms) => 4 * terms.iter().map(TDim::cost).sum::<usize>(),
396            Min(terms) | Max(terms) => 5 * terms.iter().map(TDim::cost).sum::<usize>(),
397            Div(a, _) => 3 * a.cost(),
398            MulInt(_, a) => 2 * a.cost(),
399            Ge(a, b) | Eq(a, b) => 5 * (a.cost() + b.cost()),
400        }
401    }
402
403    fn wiggle(&self) -> Vec<TDim> {
404        use self::TDim::*;
405        match self {
406            Sym(_) | Val(_) | Mul(_) | Broadcast(_) | Min(_) | Max(_) | Ge(_, _) | Eq(_, _) => {
407                vec![self.clone()]
408            }
409            Add(terms) => {
410                let mut forms = vec![];
411                let sub_exprs = terms.iter().map(|e| e.wiggle()).multi_cartesian_product();
412
413                fn first_div_term(terms: &[TDim]) -> Option<(usize, &TDim, u64)> {
414                    terms.iter().enumerate().find_map(|(index, t)| match t {
415                        Div(numerator, quotient) => Some((index, &**numerator, *quotient)),
416                        _ => None,
417                    })
418                }
419
420                fn generate_new_numerator(
421                    div_index: usize,
422                    numerator: &TDim,
423                    quotient: u64,
424                    expr: &[TDim],
425                ) -> Vec<TDim> {
426                    expr.iter()
427                        .enumerate()
428                        .map(|(index, term)| {
429                            if index == div_index {
430                                numerator.clone()
431                            } else {
432                                MulInt(quotient as i64, Box::new(term.clone()))
433                            }
434                        })
435                        .collect()
436                }
437
438                for expr in sub_exprs {
439                    if let Some((div_index, numerator, quotient)) = first_div_term(&expr) {
440                        let new_numerator =
441                            generate_new_numerator(div_index, numerator, quotient, &expr);
442                        forms.push(Div(Box::new(Add(new_numerator)), quotient))
443                    }
444
445                    forms.push(Add(expr));
446                }
447                forms
448            }
449            MulInt(p, a) => a.wiggle().into_iter().map(|a| MulInt(*p, b!(a))).collect(),
450            Div(a, q) => {
451                let mut forms = vec![];
452                for num in a.wiggle() {
453                    if let Add(terms) = &num {
454                        let (integer, non_integer): (Vec<_>, Vec<_>) =
455                            terms.iter().cloned().partition(|a| a.gcd() % q == 0);
456                        // Skip when the non-integer bucket holds a constant:
457                        // under tract's truncating `/`, splitting (k·X+c)/k →
458                        // X + c/k is unsound for negative X (X=-1, k=2, c=1:
459                        // (-1)/2 = 0 ≠ X). The sound version, gated on
460                        // prove_positive_or_zero, lives in simplify_rec::Div
461                        // via try_divide_multiple_plus_remainder. Cases where
462                        // the remainder is purely symbolic (e.g. A%2 → /2
463                        // lowers to (A − 2·(A/2))/2, non_integer=[A]) stay
464                        // here: the emitted Div(non_integer, q) cancels with
465                        // the extracted quotient and reduces to zero.
466                        if !non_integer.iter().any(|t| matches!(t, Val(_))) {
467                            let mut new_terms =
468                                integer.iter().map(|i| i.div(*q)).collect::<Vec<_>>();
469                            if non_integer.len() > 0 {
470                                new_terms.push(Div(b!(Add(non_integer)), *q));
471                            }
472                            forms.push(Add(new_terms))
473                        }
474                    }
475                    forms.push(Div(b!(num), *q))
476                }
477                forms
478            }
479        }
480    }
481
482    fn find_any_sym(tdim: &TDim) -> Option<&Symbol> {
483        match tdim {
484            Val(_) => None,
485            Sym(s) => Some(s),
486            Add(terms) | Mul(terms) | Min(terms) | Max(terms) | Broadcast(terms) => {
487                terms.iter().find_map(Self::find_any_sym)
488            }
489            MulInt(_, t) | Div(t, _) => Self::find_any_sym(t),
490            Ge(a, b) | Eq(a, b) => Self::find_any_sym(a).or_else(|| Self::find_any_sym(b)),
491        }
492    }
493
494    pub fn find_scope(&self) -> Option<SymbolScope> {
495        Self::find_any_sym(self).and_then(|s| s.scope().clone())
496    }
497
498    /// Fully distribute every `Mul` of `Add`s in `self` into a flat sum of
499    /// products, then `simplify`.  Used to compare two algebraically equal
500    /// but differently-factored TDims for equality (e.g. Reshape volume
501    /// checks on graphs where the same dimension is built two ways).
502    ///
503    /// Cost can blow up combinatorially on very deeply factored expressions
504    /// — call this only at boundaries where structural equality is needed,
505    /// not as a general-purpose simplifier.
506    pub fn expand_polynomial(self) -> TDim {
507        use self::TDim::*;
508        match self {
509            Mul(terms) => {
510                let terms: Vec<TDim> = terms.into_iter().map(Self::expand_polynomial).collect();
511                if let Some(add_idx) = terms.iter().position(|t| matches!(t, Add(_))) {
512                    let Add(add_terms) = terms[add_idx].clone() else { unreachable!() };
513                    let others: Vec<TDim> = terms
514                        .iter()
515                        .enumerate()
516                        .filter(|(i, _)| *i != add_idx)
517                        .map(|(_, t)| t.clone())
518                        .collect();
519                    Add(add_terms
520                        .into_iter()
521                        .map(|t| {
522                            let mut product = others.clone();
523                            product.push(t);
524                            Mul(product).expand_polynomial()
525                        })
526                        .collect())
527                    .simplify()
528                } else {
529                    Mul(terms).simplify()
530                }
531            }
532            MulInt(c, inner) => MulInt(c, Box::new(inner.expand_polynomial())).simplify(),
533            Add(terms) => Add(terms.into_iter().map(Self::expand_polynomial).collect()).simplify(),
534            Div(a, q) => Div(Box::new(a.expand_polynomial()), q).simplify(),
535            Min(terms) => Min(terms.into_iter().map(Self::expand_polynomial).collect()).simplify(),
536            Max(terms) => Max(terms.into_iter().map(Self::expand_polynomial).collect()).simplify(),
537            Broadcast(terms) => {
538                Broadcast(terms.into_iter().map(Self::expand_polynomial).collect()).simplify()
539            }
540            Ge(a, b) => {
541                Ge(Box::new(a.expand_polynomial()), Box::new(b.expand_polynomial())).simplify()
542            }
543            Eq(a, b) => {
544                Eq(Box::new(a.expand_polynomial()), Box::new(b.expand_polynomial())).simplify()
545            }
546            it @ (Sym(_) | Val(_)) => it,
547        }
548    }
549
550    pub fn simplify(self) -> TDim {
551        use self::TDim::*;
552        if let Ok(v) = self.eval_to_i64(&SymbolValues::default()) {
553            return Val(v);
554        }
555        let Some(scope) = self.find_scope() else {
556            return self;
557        };
558        let scope = scope.0;
559        let locked = scope.lock();
560        let scope = locked.borrow();
561        let it = self.simplify_rec(&scope, None, &[]);
562        let mut current: Option<TDim> = None;
563        for scenario in scope.scenarios() {
564            let v = it.clone().simplify_rec(&scope, Some(scenario), &[]);
565            if current.is_some_and(|c| c != v) {
566                return it;
567            } else {
568                current = Some(v);
569            }
570        }
571        current.unwrap_or(it)
572    }
573
574    pub fn simplify_with_extra_assertions(self, extra: &[Assertion]) -> TDim {
575        use self::TDim::*;
576        if extra.is_empty() {
577            return self.simplify();
578        }
579        if let Ok(v) = self.eval_to_i64(&SymbolValues::default()) {
580            return Val(v);
581        }
582        let Some(scope) = self.find_scope() else {
583            return self;
584        };
585        let scope = scope.0;
586        let locked = scope.lock();
587        let scope = locked.borrow();
588        let it = self.simplify_rec(&scope, None, extra);
589        let mut current: Option<TDim> = None;
590        for scenario in scope.scenarios() {
591            let v = it.clone().simplify_rec(&scope, Some(scenario), extra);
592            if current.is_some_and(|c| c != v) {
593                return it;
594            } else {
595                current = Some(v);
596            }
597        }
598        current.unwrap_or(it)
599    }
600
601    fn simplify_rec(
602        self,
603        scope: &SymbolScopeData,
604        scenario: Option<&str>,
605        extra: &[Assertion],
606    ) -> TDim {
607        match self {
608            Add(mut terms) => {
609                #[allow(clippy::mutable_key_type)]
610                let mut simplified_terms: HashMap<TDim, i64> = HashMap::new();
611                // factorize common sub-expr
612                while let Some(term) = terms.pop() {
613                    let simplified = term.simplify_rec(scope, scenario, extra);
614                    match simplified {
615                        Val(0) => {} // ignore
616                        Add(members) => {
617                            terms.extend(members);
618                            continue;
619                        }
620                        Val(value) => *simplified_terms.entry(Val(1)).or_insert(0) += value,
621                        MulInt(value, factor) => {
622                            *simplified_terms.entry((*factor).clone()).or_insert(0) += value;
623                        }
624                        n => *simplified_terms.entry(n).or_insert(0) += 1,
625                    };
626                }
627
628                pub fn evaluate_count(term: TDim, count: i64) -> Option<TDim> {
629                    match count {
630                        0 => None,
631                        _ if term == TDim::Val(1) => Some(TDim::Val(count)),
632                        1 => Some(term),
633                        _ => Some(TDim::MulInt(count, Box::new(term))),
634                    }
635                }
636
637                // Pull the integer GCD of all term coefficients out as a
638                // common factor: e.g. Add([Val(6), MulInt(14, S)]) becomes
639                // MulInt(2, Add([Val(3), MulInt(7, S)])).  The downstream
640                // Div(MulInt(p, a), q) arm then cancels (p, q) gcd, so
641                // (6 + 14·S) / 8 reduces to (3 + 7·S) / 4 with no special
642                // Div-over-Add rule needed.
643                //
644                // Only consider entries with non-zero counts — zero-count
645                // entries (canceled-out factors) get filtered later, but
646                // would otherwise drag the gcd to spurious values.  Only
647                // factor when at least one surviving entry has a
648                // non-constant key, otherwise the Add reduces to a single
649                // `Val` and wrapping it in `MulInt(g, Val(c/g))` is a
650                // strict regression in canonical form.
651                let has_non_const =
652                    simplified_terms.iter().any(|(k, &c)| c != 0 && !matches!(k, Val(_)));
653                let coef_gcd = if has_non_const {
654                    simplified_terms
655                        .values()
656                        .filter(|&&c| c != 0)
657                        .map(|c| c.unsigned_abs() as i64)
658                        .reduce(|a, b| a.gcd(&b))
659                        .unwrap_or(0)
660                } else {
661                    0
662                };
663                let outer_factor = if coef_gcd > 1 {
664                    for v in simplified_terms.values_mut() {
665                        *v /= coef_gcd;
666                    }
667                    Some(coef_gcd)
668                } else {
669                    None
670                };
671
672                let mut members: Vec<TDim> = simplified_terms
673                    .into_iter()
674                    .filter_map(|(term, count)| evaluate_count(term, count))
675                    .collect();
676                members.sort_by(tdim_lexi_order);
677
678                let inner = match members.len() {
679                    0 => TDim::Val(0),
680                    1 => members.into_iter().next().unwrap(),
681                    _ => TDim::Add(members),
682                };
683                match outer_factor {
684                    None => inner,
685                    Some(_) if matches!(inner, TDim::Val(0)) => TDim::Val(0),
686                    Some(g) => TDim::MulInt(g, Box::new(inner)),
687                }
688            }
689            Mul(terms) => {
690                // Distribute over Add: if exactly one factor is an Add,
691                // expand Mul([a, Add([b, c])]) => Add([Mul([a, b]), Mul([a, c])]).
692                // This lets (T+1)*P simplify to T*P + P, which is needed for
693                // cancellation in expressions like (T+1)*P - T*P.
694                //
695                // Multi-Add Muls are *not* eagerly expanded here — keeping them
696                // factored matters for `maybe_div`'s bag-of-factors path, which
697                // cancels common Add factors symbolically (e.g. dividing
698                // 16B·(1+Y)² by 8B·(1+Y) needs the (1+Y)s to be visible as
699                // factors, not melted into a polynomial sum).  Use
700                // `expand_polynomial` if you need a fully distributed canonical
701                // form for equality comparisons (see Reshape volume check).
702                {
703                    let add_indices: Vec<usize> = terms
704                        .iter()
705                        .enumerate()
706                        .filter(|(_, t)| matches!(t, Add(_)))
707                        .map(|(i, _)| i)
708                        .collect();
709                    if add_indices.len() == 1 {
710                        let add_idx = add_indices[0];
711                        let Add(add_terms) = &terms[add_idx] else { unreachable!() };
712                        let other_factors: Vec<TDim> = terms
713                            .iter()
714                            .enumerate()
715                            .filter(|(i, _)| *i != add_idx)
716                            .map(|(_, t)| t.clone())
717                            .collect();
718                        let distributed: Vec<TDim> = add_terms
719                            .iter()
720                            .map(|at| {
721                                let mut product = other_factors.clone();
722                                product.push(at.clone());
723                                Mul(product)
724                            })
725                            .collect();
726                        return Add(distributed).simplify_rec(scope, scenario, extra);
727                    }
728                }
729
730                // in case a term is a multiplication itself, flatten it
731                // e.g., (a*b)*c => a*b*c, and MulInt(k, x) => Val(k)*x
732                let mut flattened_terms = vec![];
733                for t in terms {
734                    match t.clone().reduce() {
735                        Mul(inner_terms) => flattened_terms.extend(inner_terms),
736                        MulInt(k, inner) => {
737                            flattened_terms.push(Val(k));
738                            flattened_terms.push(*inner);
739                        }
740                        other => flattened_terms.push(other),
741                    }
742                }
743                let mut terms = flattened_terms;
744
745                let mut gcd = Mul(terms.clone()).gcd() as i64;
746                if gcd == 0 {
747                    return Val(0);
748                }
749                terms = if gcd != 1 {
750                    terms
751                        .into_iter()
752                        .map(|t| {
753                            let gcd = t.gcd();
754                            (t / gcd).simplify_rec(scope, scenario, extra)
755                        })
756                        .collect()
757                } else {
758                    terms
759                };
760                if terms.iter().filter(|t| t == &&Val(-1)).count() % 2 == 1 {
761                    gcd = -gcd;
762                }
763                terms.retain(|t| !t.is_one() && t != &Val(-1));
764                terms.sort_by(tdim_lexi_order);
765
766                match (gcd, terms.len()) {
767                    (_, 0) => Val(gcd), // Case #1: If 0 variables, return product
768                    (0, _) => Val(0),   // Case #2: Result is 0 if coef is 0 (actually
769                    // unreachable as we check at the beginning)
770                    (1, 1) => terms.remove(0), // Case #3: Product is 1, so return the only term
771                    (1, _) => Mul(terms), // Case #4: Product is 1, so return the non-integer terms
772                    (_, 1) => MulInt(gcd, Box::new(terms.remove(0))), // Case #5: Single variable, convert to 1 MulInt
773                    _ => MulInt(gcd, Box::new(Mul(terms))), // Case #6: Multiple variables, convert to MulInt
774                }
775            }
776            MulInt(coef, expr) => {
777                match *expr {
778                    MulInt(c2, inner) => {
779                        if let Some(c) = coef.checked_mul(c2) {
780                            return MulInt(c, inner).simplify_rec(scope, scenario, extra);
781                        } else {
782                            return MulInt(coef, Box::new(MulInt(c2, inner)));
783                        }
784                    }
785                    Val(v) => {
786                        return coef
787                            .checked_mul(v)
788                            .map(Val)
789                            .unwrap_or_else(|| MulInt(coef, Box::new(Val(v))));
790                    }
791                    _ => {}
792                }
793
794                let simplified = expr.simplify_rec(scope, scenario, extra);
795                match (coef, simplified) {
796                    (0, _) => Val(0), // Case #1: If coef is 0, return 0
797                    (1, s) => s,      // Case #2: If coef is 1, return the simplified expression
798                    (_, Add(terms)) => Add(terms
799                        .into_iter()
800                        .map(|term| {
801                            MulInt(coef, Box::new(term)).simplify_rec(scope, scenario, extra)
802                        })
803                        .collect()), // Case #3: If expression is an addition, distribute the coef
804                    (c, Val(v)) => {
805                        c.checked_mul(v).map(Val).unwrap_or_else(|| MulInt(c, Box::new(Val(v))))
806                    } // Case #4: If expression is a value, combine coefs
807                    (c, MulInt(v, inner)) => {
808                        if let Some(cv) = c.checked_mul(v) {
809                            MulInt(cv, inner) // Case #5: If expression is a MulInt, combine coefs
810                        } else {
811                            MulInt(c, Box::new(MulInt(v, inner)))
812                        }
813                    }
814                    (_, s) => MulInt(coef, Box::new(s)), // Case #6: Otherwise, return the original
815                }
816            }
817            Div(a, q) => {
818                if q == 1 {
819                    return a.simplify_rec(scope, scenario, extra);
820                } else if let Div(a, q2) = *a {
821                    return Div(a, q * q2).simplify_rec(scope, scenario, extra);
822                }
823                let a = a.simplify_rec(scope, scenario, extra);
824                if let Val(a) = a {
825                    Val(a / q as i64)
826                } else if let MulInt(-1, a) = a {
827                    MulInt(-1, b!(Div(a, q)))
828                } else if let Add(mut terms) = a {
829                    if terms
830                        .iter()
831                        .any(|t| if let MulInt(-1, s) = t { matches!(&**s, Sym(_)) } else { false })
832                    {
833                        MulInt(
834                            -1,
835                            b!(Div(
836                                b!(Add(terms.into_iter().map(|t| MulInt(-1, b!(t))).collect())
837                                    .simplify_rec(scope, scenario, extra)),
838                                q
839                            )),
840                        )
841                    } else if let Some(val) = terms
842                        .iter()
843                        .find_map(|t| if let Val(v) = t { Some(*v) } else { None })
844                        .and_then(|v| {
845                            if v >= q as i64 {
846                                Some(v / q as i64)
847                            } else if v < 0 {
848                                Some(-Integer::div_ceil(&-v, &(q as i64)))
849                            } else {
850                                None
851                            }
852                        })
853                    {
854                        terms.push(Val(-val * q as i64));
855                        // simplify_rec the inner Div too so that follow-up rules
856                        // (e.g. divide-multiple-plus-remainder below) can collapse
857                        // it once the Val extraction has tidied up the residual.
858                        let inner = Div(b!(Add(terms).simplify_rec(scope, scenario, extra)), q)
859                            .simplify_rec(scope, scenario, extra);
860                        Add(vec![Val(val), inner])
861                    } else if let Some(simplified) =
862                        try_divide_multiple_plus_remainder(&terms, q, scope, extra)
863                    {
864                        // Match `Div(Add([k·X, …, c]), k)` where:
865                        //   - one or more terms have a coefficient that is a multiple of q,
866                        //   - the rest sum to a constant in [0, q),
867                        //   - every extracted X is provably non-negative.
868                        // Then `(k·X + c)/k = X + 0 = X` under tract's truncating
869                        // division.  This is sound for X ≥ 0 only — at X = -1
870                        // the truncation rounds toward zero, not floor, breaking
871                        // the identity.  We use prove_positive_or_zero to gate.
872                        simplified.simplify_rec(scope, scenario, extra)
873                    } else if let Some(found_idx) = terms.iter().position(|term| {
874                        // Rule: (Y − q·(Y/q)) / q = 0  [i.e. (Y mod q) / q = 0]
875                        // Always sound: |Y mod q| < q so (Y mod q)/q = 0 under
876                        // truncating division regardless of sign of Y.
877                        matches!(term, MulInt(p, inner)
878                            if *p == -(q as i64)
879                            && matches!(inner.as_ref(), Div(_, q2) if *q2 == q))
880                    }) {
881                        let MulInt(_, inner) = &terms[found_idx] else { unreachable!() };
882                        let Div(y, _) = inner.as_ref() else { unreachable!() };
883                        let remaining: Vec<TDim> = terms
884                            .iter()
885                            .enumerate()
886                            .filter(|&(i, _)| i != found_idx)
887                            .map(|(_, t)| t.clone())
888                            .collect();
889                        let remaining_sum = match remaining.len() {
890                            0 => Val(0),
891                            1 => remaining.into_iter().next().unwrap(),
892                            _ => Add(remaining),
893                        };
894                        if eq_structural(&remaining_sum, y) {
895                            Val(0)
896                        } else {
897                            Div(b!(Add(terms)), q)
898                        }
899                    } else {
900                        Div(b!(Add(terms)), q)
901                    }
902                } else if let MulInt(p, a) = a {
903                    if p == q as i64 {
904                        a.simplify()
905                    } else {
906                        let gcd = p.abs().gcd(&(q as i64));
907                        if gcd == p {
908                            Div(a, q / gcd as u64)
909                        } else if gcd == q as i64 {
910                            MulInt(p / gcd, a)
911                        } else if gcd > 1 {
912                            Div(b!(MulInt(p / gcd, a)), q / gcd as u64)
913                                .simplify_rec(scope, scenario, extra)
914                        } else {
915                            Div(b!(MulInt(p, a)), q)
916                        }
917                    }
918                } else {
919                    Div(b!(a), q)
920                }
921            }
922            Broadcast(terms) => {
923                let mut terms: Vec<TDim> = terms
924                    .iter()
925                    .map(|s| s.clone().simplify_rec(scope, scenario, extra))
926                    .flat_map(|t| if let Broadcast(t) = t { t } else { vec![t] })
927                    .filter(|t| !t.is_one())
928                    .sorted_by(tdim_lexi_order)
929                    .dedup()
930                    .collect_vec();
931                // a#min(a,b) if a>0 && b>0 => a
932                match &*terms {
933                    [] => Val(1),
934                    [_] => terms.remove(0),
935                    [a, Min(m)] | [Min(m), a]
936                        if m.contains(a)
937                            && m.iter()
938                                .all(|t| scope.prove_strict_positive_with_extra(t, extra)) =>
939                    {
940                        a.clone()
941                    }
942                    _ => Broadcast(terms),
943                }
944            }
945
946            Min(terms) => {
947                let mut flatten: Vec<TDim> = terms
948                    .into_iter()
949                    .map(|t| t.simplify_rec(scope, scenario, extra))
950                    .flat_map(|t| if let Min(t) = t { t } else { vec![t] })
951                    .filter(|t| t != &Val(i64::MAX))
952                    .sorted_by(tdim_lexi_order)
953                    .dedup()
954                    .collect();
955                #[allow(clippy::mutable_key_type)]
956                let mut redundant = HashSet::<TDim>::default();
957                for pair in flatten.iter().permutations(2) {
958                    let (a, b) = (pair[0], pair[1]);
959                    if redundant.contains(a) || redundant.contains(b) {
960                        continue;
961                    }
962                    let diff = a.clone() - b;
963                    if diff.as_i64().is_some_and(|i| i >= 0)
964                        || scope.prove_positive_or_zero_with_extra(&diff, extra)
965                    {
966                        redundant.insert(a.clone());
967                    }
968                }
969                flatten.retain(|t| !redundant.contains(t));
970                if flatten.len() == 0 {
971                    i64::MAX.to_dim()
972                } else if flatten.len() == 1 {
973                    flatten.into_iter().next().unwrap()
974                } else {
975                    Min(flatten)
976                }
977            }
978            Max(terms) => {
979                let mut flatten: Vec<TDim> = terms
980                    .into_iter()
981                    .map(|t| t.simplify_rec(scope, scenario, extra))
982                    .flat_map(|t| if let Max(t) = t { t } else { vec![t] })
983                    .filter(|t| t != &Val(i64::MIN))
984                    .sorted_by(tdim_lexi_order)
985                    .dedup()
986                    .collect();
987                #[allow(clippy::mutable_key_type)]
988                let mut redundant = HashSet::<TDim>::default();
989                for pair in flatten.iter().permutations(2) {
990                    let (a, b) = (pair[0], pair[1]);
991                    if redundant.contains(a) || redundant.contains(b) {
992                        continue;
993                    }
994                    let diff = a.clone() - b;
995                    if diff.as_i64().is_some_and(|i| i >= 0)
996                        || scope.prove_positive_or_zero_with_extra(&diff, extra)
997                    {
998                        redundant.insert(b.clone());
999                    }
1000                }
1001                flatten.retain(|t| !redundant.contains(t));
1002                if flatten.len() == 0 {
1003                    i64::MIN.to_dim()
1004                } else if flatten.len() == 1 {
1005                    flatten.into_iter().next().unwrap()
1006                } else {
1007                    Max(flatten)
1008                }
1009            }
1010            Sym(s) => scope
1011                .assertions(scenario)
1012                .find_map(|a| match a {
1013                    Assertion::Eq(Sym(sym), v) if sym == &s => Some(v.clone()),
1014                    _ => None,
1015                })
1016                .unwrap_or(Sym(s)),
1017            Val(_) => self,
1018            Ge(a, b) => {
1019                let a = a.simplify_rec(scope, scenario, extra);
1020                let b = b.simplify_rec(scope, scenario, extra);
1021                match (&a, &b) {
1022                    (Val(av), Val(bv)) => Val(if av >= bv { 1 } else { 0 }),
1023                    _ => {
1024                        let diff = a.clone() - b.clone();
1025                        if scope.prove_positive_or_zero_with_extra(&diff, extra) {
1026                            Val(1)
1027                        } else if scope
1028                            .prove_strict_positive_with_extra(&(b.clone() - a.clone()), extra)
1029                        {
1030                            Val(0)
1031                        } else {
1032                            Ge(b!(a), b!(b))
1033                        }
1034                    }
1035                }
1036            }
1037            Eq(a, b) => {
1038                let a = a.simplify_rec(scope, scenario, extra);
1039                let b = b.simplify_rec(scope, scenario, extra);
1040                match (&a, &b) {
1041                    (Val(av), Val(bv)) => Val(if av == bv { 1 } else { 0 }),
1042                    _ => {
1043                        let diff = a.clone() - b.clone();
1044                        if scope.prove_strict_positive_with_extra(&diff, extra)
1045                            || scope
1046                                .prove_strict_positive_with_extra(&(b.clone() - a.clone()), extra)
1047                        {
1048                            Val(0)
1049                        } else {
1050                            // When one side is 0 or 1 and the other is
1051                            // provably in [0,1], reduce to boolean algebra:
1052                            //   Eq(expr, 0) → 1 - expr
1053                            //   Eq(expr, 1) → expr
1054                            let boolean_case = match (&a, &b) {
1055                                (Val(0), e) | (e, Val(0)) => Some((e, false)),
1056                                (Val(1), e) | (e, Val(1)) => Some((e, true)),
1057                                _ => None,
1058                            };
1059                            if let Some((expr, equals_one)) = boolean_case
1060                                && scope.prove_positive_or_zero_with_extra(expr, extra)
1061                                && scope.prove_positive_or_zero_with_extra(
1062                                    &(Val(1) - expr.clone()),
1063                                    extra,
1064                                )
1065                            {
1066                                return if equals_one {
1067                                    expr.clone()
1068                                } else {
1069                                    (Val(1) - expr.clone()).simplify_rec(scope, scenario, extra)
1070                                };
1071                            }
1072                            Eq(b!(a), b!(b))
1073                        }
1074                    }
1075                }
1076            }
1077        }
1078    }
1079
1080    pub(super) fn inclusive_bound(&self, scope: &SymbolScopeData, upper: bool) -> Option<i64> {
1081        use self::TDim::*;
1082        match self {
1083            Val(n) => Some(*n),
1084            Sym(_) => {
1085                if upper {
1086                    scope
1087                        .all_assertions()
1088                        .iter()
1089                        .filter_map(|assert| match &assert {
1090                            Assertion::LT(left, right)
1091                                if left == self && right.as_i64().is_some() =>
1092                            {
1093                                Some(right.as_i64().unwrap() - 1)
1094                            }
1095                            Assertion::LTE(left, right)
1096                                if left == self && right.as_i64().is_some() =>
1097                            {
1098                                Some(right.as_i64().unwrap())
1099                            }
1100                            _ => None,
1101                        })
1102                        .min()
1103                } else {
1104                    scope
1105                        .all_assertions()
1106                        .iter()
1107                        .filter_map(|assert| match &assert {
1108                            Assertion::GT(left, right)
1109                                if left == self && right.as_i64().is_some() =>
1110                            {
1111                                Some(right.as_i64().unwrap() + 1)
1112                            }
1113                            Assertion::GTE(left, right)
1114                                if left == self && right.as_i64().is_some() =>
1115                            {
1116                                Some(right.as_i64().unwrap())
1117                            }
1118                            _ => None,
1119                        })
1120                        .max()
1121                }
1122            }
1123            Add(terms) => {
1124                let mut bound: i64 = 0;
1125                for t in terms {
1126                    {
1127                        let b = t.inclusive_bound(scope, upper)?;
1128                        bound = bound.checked_add(b)?;
1129                    }
1130                }
1131                Some(bound)
1132            }
1133            MulInt(p, a) => match p.cmp(&0) {
1134                Ordering::Equal => Some(0),
1135                Ordering::Greater => {
1136                    a.inclusive_bound(scope, upper).and_then(|x| x.checked_mul(*p))
1137                }
1138                Ordering::Less => a.inclusive_bound(scope, !upper).and_then(|x| x.checked_mul(*p)),
1139            },
1140            Mul(terms) => {
1141                // If all factors have known non-negative bounds, we can bound the product.
1142                let mut lo: i64 = 1;
1143                let mut hi: i64 = 1;
1144                for t in terms {
1145                    let t_lo = t.inclusive_bound(scope, false)?;
1146                    let t_hi = t.inclusive_bound(scope, true)?;
1147                    if t_lo < 0 {
1148                        return None;
1149                    }
1150                    lo = lo.checked_mul(t_lo)?;
1151                    hi = hi.checked_mul(t_hi)?;
1152                }
1153                Some(if upper { hi } else { lo })
1154            }
1155            Min(terms) if !upper => {
1156                // All terms must have known lower bounds; if any is unknown,
1157                // the Min lower bound is unknown.
1158                let bounds: Option<Vec<i64>> =
1159                    terms.iter().map(|t| t.inclusive_bound(scope, false)).collect();
1160                bounds.map(|b| b.into_iter().min().unwrap_or(i64::MAX))
1161            }
1162            Max(terms) if upper => {
1163                // All terms must have known upper bounds; if any is unknown,
1164                // the Max upper bound is unknown.
1165                let bounds: Option<Vec<i64>> =
1166                    terms.iter().map(|t| t.inclusive_bound(scope, true)).collect();
1167                bounds.map(|b| b.into_iter().max().unwrap_or(i64::MIN))
1168            }
1169            Div(a, q) => a.inclusive_bound(scope, upper).map(|x| x / (*q as i64)),
1170            Broadcast(terms) => {
1171                if upper {
1172                    Max(terms.clone()).inclusive_bound(scope, true)
1173                } else {
1174                    Min(terms.clone()).inclusive_bound(scope, false)
1175                }
1176            }
1177            Ge(_, _) | Eq(_, _) => {
1178                if upper {
1179                    Some(1)
1180                } else {
1181                    Some(0)
1182                }
1183            }
1184            _ => None,
1185        }
1186    }
1187
1188    pub fn low_inclusive_bound(&self) -> Option<i64> {
1189        if let TDim::Val(v) = self {
1190            return Some(*v);
1191        }
1192        let scope = self.find_scope()?;
1193        let data = scope.0.lock();
1194        let data = data.borrow();
1195        self.inclusive_bound(&data, false)
1196    }
1197
1198    pub fn high_inclusive_bound(&self) -> Option<i64> {
1199        if let TDim::Val(v) = self {
1200            return Some(*v);
1201        }
1202        let scope = self.find_scope()?;
1203        let data = scope.0.lock();
1204        let data = data.borrow();
1205        self.inclusive_bound(&data, true)
1206    }
1207
1208    pub fn prove_positive_or_zero(&self) -> bool {
1209        if let TDim::Val(v) = self {
1210            return *v >= 0;
1211        }
1212        let Some(scope) = self.find_scope() else { return false };
1213        let data = scope.0.lock();
1214        let data = data.borrow();
1215        data.prove_positive_or_zero(self)
1216    }
1217
1218    pub fn prove_strict_positive(&self) -> bool {
1219        if let TDim::Val(v) = self {
1220            return *v > 0;
1221        }
1222        (self.clone() - 1).prove_positive_or_zero()
1223    }
1224
1225    pub fn prove_negative_or_zero(&self) -> bool {
1226        if let TDim::Val(v) = self {
1227            return *v <= 0;
1228        }
1229        self.clone().neg().prove_positive_or_zero()
1230    }
1231
1232    pub fn prove_strict_negative(&self) -> bool {
1233        if let TDim::Val(v) = self {
1234            return *v < 0;
1235        }
1236        self.clone().neg().prove_strict_positive()
1237    }
1238
1239    /// Least common multiple of two `TDim`s when both reduce to positive
1240    /// integers.
1241    ///
1242    /// Returns `Val(0)` if either operand is `0`, and `None` if either is
1243    /// symbolic, negative, or if the LCM would overflow `i64`. Callers
1244    /// that need a safe answer for symbolic operands should fall back at
1245    /// the call site.
1246    pub fn lcm(&self, other: &TDim) -> Option<TDim> {
1247        match (self.as_i64(), other.as_i64()) {
1248            (Some(a), Some(b)) if a > 0 && b > 0 => {
1249                let g = (a as u64).gcd(&(b as u64));
1250                let l = (a as u64 / g).saturating_mul(b as u64);
1251                if l > i64::MAX as u64 { None } else { Some(TDim::Val(l as i64)) }
1252            }
1253            (Some(0), _) | (_, Some(0)) => Some(TDim::Val(0)),
1254            _ => None,
1255        }
1256    }
1257
1258    pub fn gcd(&self) -> u64 {
1259        use self::TDim::*;
1260        match self {
1261            Val(v) => v.unsigned_abs(),
1262            Sym(_) => 1,
1263            Add(terms) => {
1264                let (head, tail) = terms.split_first().unwrap();
1265                tail.iter().fold(head.gcd(), |a, b| a.gcd(&b.gcd()))
1266            }
1267            MulInt(p, a) => a.gcd().saturating_mul(p.unsigned_abs()),
1268            Mul(terms) => terms.iter().map(|t| t.gcd()).fold(1u64, |a, b| a.saturating_mul(b)),
1269            Min(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
1270            Max(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
1271            Div(a, q) => {
1272                if a.gcd() % *q == 0 {
1273                    a.gcd() / *q
1274                } else {
1275                    1
1276                }
1277            }
1278            Broadcast(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap_or(1),
1279            Ge(_, _) | Eq(_, _) => 1,
1280        }
1281    }
1282
1283    fn div(&self, d: u64) -> TDim {
1284        use self::TDim::*;
1285        if d == 1 {
1286            return self.clone();
1287        }
1288        match self {
1289            Val(v) => Val(v / d as i64),
1290            Sym(_) => panic!(),
1291            Add(terms) => Add(terms.iter().map(|t| t.div(d)).collect()),
1292            Min(terms) => Min(terms.iter().map(|t| t.div(d)).collect()),
1293            Max(terms) => Max(terms.iter().map(|t| t.div(d)).collect()),
1294            Broadcast(terms) => Broadcast(terms.iter().map(|t| t.div(d)).collect()),
1295            Mul(_) => Div(Box::new(self.clone()), d),
1296            MulInt(p, a) => {
1297                if *p == d as i64 {
1298                    (**a).clone()
1299                } else {
1300                    let gcd = p.unsigned_abs().gcd(&d);
1301                    MulInt(p / gcd as i64, b!(a.div(d / gcd)))
1302                }
1303            }
1304            Div(a, q) => Div(a.clone(), q * d),
1305            Ge(_, _) | Eq(_, _) => Div(Box::new(self.clone()), d),
1306        }
1307    }
1308
1309    pub fn div_ceil(self, rhs: u64) -> TDim {
1310        TDim::Div(Box::new(Add(vec![self, Val(rhs as i64 - 1)])), rhs).reduce()
1311    }
1312
1313    pub fn guess_slope(&self, sym: &Symbol) -> (i64, u64) {
1314        fn slope_rec(d: &TDim, sym: &Symbol) -> (i64, i64) {
1315            match d {
1316                Val(_) => (0, 1),
1317                Sym(s) => ((sym == s) as i64, 1),
1318                Add(terms) => terms
1319                    .iter()
1320                    .map(|d| slope_rec(d, sym))
1321                    .fold((0, 1), |a, b| ((a.0 * b.1 + a.1 * b.0), (b.1 * a.1))),
1322                Mul(terms) => terms
1323                    .iter()
1324                    .map(|d| slope_rec(d, sym))
1325                    .fold((1, 1), |a, b| ((a.0 * b.0), (b.1 * a.1))),
1326                MulInt(p, a) => {
1327                    let (n, d) = slope_rec(a, sym);
1328                    (p * n, d)
1329                }
1330                Div(a, q) => {
1331                    let (n, d) = slope_rec(a, sym);
1332                    (n, d * *q as i64)
1333                }
1334                Broadcast(terms) => slope_rec(&terms[0], sym),
1335                Min(terms) => slope_rec(&terms[0], sym),
1336                Max(terms) => slope_rec(&terms[0], sym),
1337                Ge(_, _) | Eq(_, _) => (0, 1),
1338            }
1339        }
1340        let (p, q) = slope_rec(self, sym);
1341        reduce_ratio(p, q)
1342    }
1343
1344    #[allow(clippy::mutable_key_type)]
1345    pub fn symbols(&self) -> std::collections::HashSet<Symbol> {
1346        match self {
1347            Val(_) => maplit::hashset!(),
1348            Sym(s) => maplit::hashset!(s.clone()),
1349            Add(terms) | Mul(terms) | Broadcast(terms) | Min(terms) | Max(terms) => {
1350                terms.iter().fold(maplit::hashset!(), |mut set, v| {
1351                    set.extend(v.symbols());
1352                    set
1353                })
1354            }
1355            MulInt(_, a) => a.symbols(),
1356            Div(a, _) => a.symbols(),
1357            Ge(a, b) | Eq(a, b) => {
1358                let mut set = a.symbols();
1359                set.extend(b.symbols());
1360                set
1361            }
1362        }
1363    }
1364
1365    pub fn compatible_with(&self, other: &TDim) -> bool {
1366        if let Ok(x) = (self.clone() - other).to_i64() {
1367            return x == 0;
1368        }
1369        true // maybe ? :)
1370    }
1371}
1372
1373pub(super) fn reduce_ratio(mut p: i64, mut q: i64) -> (i64, u64) {
1374    let gcd = p.abs().gcd(&q.abs());
1375    if gcd > 1 {
1376        p /= gcd;
1377        q /= gcd;
1378    }
1379    if q < 0 { (-p, (-q) as u64) } else { (p, q as u64) }
1380}
1381
1382impl Zero for TDim {
1383    fn zero() -> Self {
1384        Val(0)
1385    }
1386    fn is_zero(&self) -> bool {
1387        matches!(self, Val(0))
1388    }
1389}
1390
1391impl Default for TDim {
1392    fn default() -> TDim {
1393        Val(0)
1394    }
1395}
1396
1397impl num_traits::Bounded for TDim {
1398    fn min_value() -> Self {
1399        TDim::Val(i64::MIN)
1400    }
1401
1402    fn max_value() -> Self {
1403        TDim::Val(i64::MAX)
1404    }
1405}
1406
1407impl num_traits::One for TDim {
1408    fn one() -> Self {
1409        TDim::Val(1)
1410    }
1411}
1412
1413impl ::std::iter::Sum for TDim {
1414    fn sum<I: Iterator<Item = TDim>>(iter: I) -> TDim {
1415        iter.fold(0.into(), |a, b| a + b)
1416    }
1417}
1418
1419impl<'a> ::std::iter::Sum<&'a TDim> for TDim {
1420    fn sum<I: Iterator<Item = &'a TDim>>(iter: I) -> TDim {
1421        iter.fold(0.into(), |a, b| a + b)
1422    }
1423}
1424
1425impl std::iter::Product for TDim {
1426    fn product<I: Iterator<Item = TDim>>(iter: I) -> Self {
1427        iter.fold(TDim::Val(1), |a, b| a * b)
1428    }
1429}
1430
1431impl<'a> ::std::iter::Product<&'a TDim> for TDim {
1432    fn product<I: Iterator<Item = &'a TDim>>(iter: I) -> TDim {
1433        iter.fold(1.into(), |a, b| a * b)
1434    }
1435}
1436
1437macro_rules! from_i {
1438    ($i: ty) => {
1439        impl From<$i> for TDim {
1440            fn from(v: $i) -> TDim {
1441                TDim::Val(v as _)
1442            }
1443        }
1444        impl<'a> From<&'a $i> for TDim {
1445            fn from(v: &'a $i) -> TDim {
1446                TDim::Val(*v as _)
1447            }
1448        }
1449    };
1450}
1451
1452from_i!(i32);
1453from_i!(i64);
1454from_i!(u64);
1455from_i!(isize);
1456from_i!(usize);
1457
1458impl From<Symbol> for TDim {
1459    fn from(it: Symbol) -> Self {
1460        TDim::Sym(it)
1461    }
1462}
1463
1464impl<'a> From<&'a Symbol> for TDim {
1465    fn from(it: &'a Symbol) -> Self {
1466        TDim::Sym(it.clone())
1467    }
1468}
1469
1470impl ops::Neg for TDim {
1471    type Output = Self;
1472    fn neg(self) -> Self {
1473        if let Val(v) = self { Val(-v) } else { TDim::MulInt(-1, Box::new(self)).reduce() }
1474    }
1475}
1476
1477impl<'a> ops::AddAssign<&'a TDim> for TDim {
1478    fn add_assign(&mut self, rhs: &'a TDim) {
1479        if rhs.is_zero() {
1480        } else if self.is_zero() {
1481            *self = rhs.clone();
1482        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1483            *s += o;
1484        } else {
1485            *self = TDim::Add(vec![std::mem::take(self), rhs.clone()]).reduce()
1486        }
1487    }
1488}
1489
1490impl<I> ops::AddAssign<I> for TDim
1491where
1492    I: Into<TDim>,
1493{
1494    fn add_assign(&mut self, rhs: I) {
1495        let rhs = rhs.into();
1496        if rhs.is_zero() {
1497        } else if self.is_zero() {
1498            *self = rhs;
1499        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1500            *s += o;
1501        } else {
1502            *self = TDim::Add(vec![std::mem::take(self), rhs]).reduce()
1503        }
1504    }
1505}
1506
1507impl<I> ops::Add<I> for TDim
1508where
1509    I: Into<TDim>,
1510{
1511    type Output = Self;
1512    fn add(mut self, rhs: I) -> Self {
1513        self += rhs;
1514        self
1515    }
1516}
1517
1518impl<'a> ops::Add<&'a TDim> for TDim {
1519    type Output = Self;
1520    fn add(mut self, rhs: &'a TDim) -> Self {
1521        self += rhs;
1522        self
1523    }
1524}
1525
1526#[allow(clippy::suspicious_op_assign_impl)]
1527impl<'a> ops::SubAssign<&'a TDim> for TDim {
1528    fn sub_assign(&mut self, rhs: &'a TDim) {
1529        if rhs.is_zero() {
1530        } else if self.is_zero() {
1531            *self = rhs.clone().neg();
1532        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1533            *s -= o;
1534        } else {
1535            *self = TDim::Add(vec![std::mem::take(self), rhs.clone().neg()]).reduce()
1536        }
1537    }
1538}
1539
1540impl<I> ops::SubAssign<I> for TDim
1541where
1542    I: Into<TDim>,
1543{
1544    fn sub_assign(&mut self, rhs: I) {
1545        let rhs = rhs.into();
1546        if rhs.is_zero() {
1547        } else if self.is_zero() {
1548            *self = rhs.neg();
1549        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1550            *s -= o;
1551        } else {
1552            *self = TDim::Add(vec![std::mem::take(self), rhs.neg()]).reduce()
1553        }
1554    }
1555}
1556
1557impl<I> ops::Sub<I> for TDim
1558where
1559    I: Into<TDim>,
1560{
1561    type Output = Self;
1562    fn sub(mut self, rhs: I) -> Self {
1563        self -= rhs;
1564        self
1565    }
1566}
1567
1568impl<'a> ops::Sub<&'a TDim> for TDim {
1569    type Output = Self;
1570    fn sub(mut self, rhs: &'a TDim) -> Self {
1571        self -= rhs;
1572        self
1573    }
1574}
1575
1576impl<I: Into<TDim>> ops::MulAssign<I> for TDim {
1577    fn mul_assign(&mut self, rhs: I) {
1578        let rhs = rhs.into();
1579        if self.is_one() {
1580            *self = rhs
1581        } else if rhs.is_one() {
1582        } else {
1583            *self = TDim::Mul(vec![rhs, std::mem::take(self)]).reduce()
1584        }
1585    }
1586}
1587
1588impl<'a> ops::MulAssign<&'a TDim> for TDim {
1589    fn mul_assign(&mut self, rhs: &'a TDim) {
1590        if self.is_one() {
1591            *self = rhs.clone()
1592        } else if rhs.is_one() {
1593        } else {
1594            *self = TDim::Mul(vec![std::mem::take(self), rhs.clone()]).reduce()
1595        }
1596    }
1597}
1598
1599impl<I: Into<TDim>> ops::Mul<I> for TDim {
1600    type Output = Self;
1601    fn mul(mut self, rhs: I) -> Self {
1602        self *= rhs.into();
1603        self
1604    }
1605}
1606
1607impl<'a> ops::Mul<&'a TDim> for TDim {
1608    type Output = Self;
1609    fn mul(mut self, rhs: &'a TDim) -> Self {
1610        self *= rhs;
1611        self
1612    }
1613}
1614
1615impl<I: AsPrimitive<u64> + PrimInt> ops::DivAssign<I> for TDim {
1616    fn div_assign(&mut self, rhs: I) {
1617        *self = TDim::Div(Box::new(std::mem::take(self)), rhs.as_()).reduce()
1618    }
1619}
1620
1621impl<I: AsPrimitive<u64> + PrimInt> ops::Div<I> for TDim {
1622    type Output = Self;
1623    fn div(mut self, rhs: I) -> Self {
1624        self /= rhs.as_();
1625        self
1626    }
1627}
1628
1629impl<I: AsPrimitive<u64> + PrimInt> ops::RemAssign<I> for TDim {
1630    fn rem_assign(&mut self, rhs: I) {
1631        *self += -(self.clone() / rhs.as_() * rhs.as_());
1632    }
1633}
1634
1635impl<I: AsPrimitive<u64> + PrimInt> ops::Rem<I> for TDim {
1636    type Output = Self;
1637    fn rem(mut self, rhs: I) -> Self {
1638        self %= rhs;
1639        self
1640    }
1641}
1642
1643#[cfg(test)]
1644mod tests {
1645    use super::*;
1646
1647    macro_rules! b( ($e:expr) => { Box::new($e) } );
1648
1649    lazy_static::lazy_static! {
1650        static ref table: SymbolScope = SymbolScope::default();
1651        static ref A: Symbol = table.sym("a");
1652        static ref B: Symbol = table.sym("b");
1653        static ref C: Symbol = table.sym("c");
1654        static ref D: Symbol = table.sym("d");
1655        static ref E: Symbol = table.sym("e");
1656    }
1657
1658    fn neg(a: &TDim) -> TDim {
1659        mul(-1, a)
1660    }
1661
1662    fn add(a: &TDim, b: &TDim) -> TDim {
1663        TDim::Add(vec![a.clone(), b.clone()])
1664    }
1665
1666    fn mul(a: i64, b: &TDim) -> TDim {
1667        TDim::MulInt(a, b![b.clone()])
1668    }
1669
1670    fn div(a: &TDim, b: u64) -> TDim {
1671        TDim::Div(b!(a.clone()), b)
1672    }
1673
1674    #[test]
1675    fn reduce_add() {
1676        assert_eq!(add(&A.to_dim(), &neg(&A.to_dim())).reduce(), Val(0))
1677    }
1678
1679    #[test]
1680    fn lcm_basic() {
1681        assert_eq!(Val(16).lcm(&Val(32)), Some(Val(32)));
1682        assert_eq!(Val(32).lcm(&Val(16)), Some(Val(32)));
1683        assert_eq!(Val(6).lcm(&Val(8)), Some(Val(24)));
1684        assert_eq!(Val(7).lcm(&Val(7)), Some(Val(7)));
1685        // Symbolic: not computable; callers fall back.
1686        assert_eq!(Val(16).lcm(&A.to_dim()), None);
1687    }
1688
1689    #[test]
1690    fn reduce_neg_mul() {
1691        assert_eq!(neg(&mul(2, &A.to_dim())).reduce(), mul(-2, &A.to_dim()))
1692    }
1693
1694    #[test]
1695    fn reduce_cplx_ex_2() {
1696        assert_eq!(
1697            add(
1698                &add(&Val(-4), &mul(-2, &div(&A.to_dim(), 4))),
1699                &mul(-2, &mul(-1, &div(&A.to_dim(), 4)))
1700            )
1701            .reduce(),
1702            Val(-4)
1703        )
1704    }
1705
1706    #[test]
1707    fn reduce_cplx_ex_3() {
1708        assert_eq!(div(&MulInt(1, b!(MulInt(4, b!(A.to_dim())))), 4).reduce(), A.to_dim())
1709    }
1710
1711    #[test]
1712    fn reduce_cplx_ex_4() {
1713        // (S+1)/2 + (1-S)/2 == 1
1714        assert_eq!(
1715            add(&div(&add(&A.to_dim(), &Val(1)), 2), &div(&add(&neg(&A.to_dim()), &Val(1)), 2))
1716                .reduce(),
1717            1.into()
1718        );
1719    }
1720
1721    #[test]
1722    fn reduce_mul_mul_1() {
1723        assert_eq!(mul(3, &mul(2, &A.to_dim())).reduce(), mul(6, &A.to_dim()))
1724    }
1725
1726    #[test]
1727    fn reduce_mul_mul_2() {
1728        assert_eq!(mul(-2, &mul(-1, &A.to_dim())).reduce(), mul(2, &A.to_dim()))
1729    }
1730
1731    #[test]
1732    fn reduce_mul_div_1() {
1733        assert_eq!(mul(2, &div(&mul(-1, &A.to_dim()), 3)).reduce(), mul(-2, &div(&A.to_dim(), 3)))
1734    }
1735
1736    #[test]
1737    fn const_and_add() {
1738        let e: TDim = 2i64.into();
1739        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), 2);
1740        let e: TDim = TDim::from(2) + 3;
1741        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), 5);
1742        let e: TDim = TDim::from(2) - 3;
1743        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), -1);
1744        let e: TDim = -TDim::from(2);
1745        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), -2);
1746    }
1747
1748    #[test]
1749    fn substitution() {
1750        let a: TDim = A.to_dim();
1751        assert_eq!(a.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 2);
1752        let e = a + 3;
1753        assert_eq!(e.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 5);
1754    }
1755
1756    #[test]
1757    fn reduce_adds() {
1758        let e: TDim = TDim::from(2) + 1;
1759        assert_eq!(e, TDim::from(3));
1760        let e: TDim = TDim::from(3) + 2;
1761        assert_eq!(e, TDim::from(5));
1762        let e: TDim = TDim::from(3) + 0;
1763        assert_eq!(e, TDim::from(3));
1764        let e: TDim = TDim::from(3) + 2 + 1;
1765        assert_eq!(e, TDim::from(6));
1766    }
1767
1768    #[test]
1769    fn reduce_muls() {
1770        let e: TDim = Val(1) * A.to_dim();
1771        assert_eq!(e, A.to_dim());
1772        let e: TDim = A.to_dim() * &B.to_dim() * 1;
1773        assert_eq!(e, A.to_dim() * &B.to_dim());
1774    }
1775
1776    #[test]
1777    fn reduce_divs() {
1778        let e: TDim = TDim::from(2) / 1;
1779        assert_eq!(e, TDim::from(2));
1780        let e: TDim = TDim::from(3) / 2;
1781        assert_eq!(e, TDim::from(1));
1782        let e: TDim = TDim::from(3) % 2;
1783        assert_eq!(e, TDim::from(1));
1784        let e: TDim = TDim::from(5) / 2;
1785        assert_eq!(e, TDim::from(2));
1786        let e: TDim = TDim::from(5) % 2;
1787        assert_eq!(e, TDim::from(1));
1788    }
1789
1790    #[test]
1791    fn reduce_div_bug_0() {
1792        let e1: TDim = (A.to_dim() + 23) / 2 - 1;
1793        let e2: TDim = (A.to_dim() + 21) / 2;
1794        assert_eq!(e1, e2);
1795    }
1796
1797    #[test]
1798    fn reduce_div_bug_1() {
1799        let e1: TDim = (A.to_dim() + -1) / 2;
1800        let e2: TDim = (A.to_dim() + 1) / 2 - 1;
1801        assert_eq!(e1, e2);
1802    }
1803
1804    #[test]
1805    fn reduce_div_bug_2() {
1806        let e1: TDim = ((A.to_dim() + 1) / 2 + 1) / 2;
1807        let e2: TDim = (A.to_dim() + 3) / 4;
1808        assert_eq!(e1, e2);
1809    }
1810
1811    #[test]
1812    fn divide_multiple_plus_remainder() {
1813        // (k·X + r)/k → X under truncating division when 0 ≤ r < k AND X ≥ 0.
1814        let scope = SymbolScope::default().with_assertion("S>=0").unwrap();
1815        let s = scope.sym("S");
1816
1817        // (2S+1)/2 → S
1818        let e: TDim = (s.to_dim() * 2 + 1) / 2;
1819        assert_eq!(e.simplify(), s.to_dim());
1820
1821        // -1 + (2S+1)/2 → S - 1
1822        let e: TDim = (s.to_dim() * 2 + 1) / 2 - 1;
1823        assert_eq!(e.simplify(), s.to_dim() - 1);
1824
1825        // (2S-1)/2 → S - 1   (Val rule extracts -1 first, then our rule)
1826        let e: TDim = (s.to_dim() * 2 - 1) / 2;
1827        assert_eq!(e.simplify(), s.to_dim() - 1);
1828
1829        // (4S+3)/2 → 2S + 1   (Val rule extracts 1 = 3/2, then our rule on (4S+1)/2 → 2S)
1830        let e: TDim = (s.to_dim() * 4 + 3) / 2;
1831        assert_eq!(e.simplify(), s.to_dim() * 2 + 1);
1832    }
1833
1834    #[test]
1835    fn divide_multiple_plus_remainder_no_assertion() {
1836        // Without an X≥0 assertion the (k·X+c)/k → X identity does NOT hold
1837        // (X=-1, k=2, c=1 gives -1/2=0 ≠ X under truncating division). The
1838        // wiggle Div arm used to emit that variant unconditionally; reduce()
1839        // would then pick it on cost. Now wiggle skips the variant when the
1840        // remainder bucket contains a Val, leaving only the sound rule
1841        // gated on prove_positive_or_zero in simplify_rec.
1842        let scope = SymbolScope::default();
1843        let s = scope.sym("S");
1844        let e: TDim = (s.to_dim() * 2 + 1) / 2;
1845        assert_ne!(e.simplify(), s.to_dim());
1846    }
1847
1848    #[test]
1849    fn modulo_div_is_zero() {
1850        // (Y − q·(Y/q)) / q = 0 for any Y and any q — the modulo remainder
1851        // divided by the modulus is always zero under truncating division.
1852        let scope = SymbolScope::default();
1853        let s = scope.sym("S");
1854        // Simple case: (S - 2*(S/2)) / 2 = (S mod 2) / 2 = 0
1855        let e: TDim = (s.to_dim() - s.to_dim() / 2 * 2) / 2;
1856        assert_eq!(e.simplify(), TDim::Val(0));
1857        // Composite case: ((S+1) - 2*((S+1)/2)) / 2 = 0
1858        // This is the exact pattern from SameUpper conv padding.
1859        let a = s.to_dim() + 1;
1860        let e2: TDim = (a.clone() - a.clone() / 2 * 2) / 2;
1861        assert_eq!(e2.simplify(), TDim::Val(0));
1862    }
1863
1864    #[test]
1865    fn reduce_div_bug_3() {
1866        let e1: TDim = (A.to_dim() / 2) * -4;
1867        let e2: TDim = (A.to_dim() / 2) * -4 / 1;
1868        assert_eq!(e1, e2);
1869    }
1870
1871    #[test]
1872    fn reduce_mul_div() {
1873        let e: TDim = A.to_dim() * 2 / 2;
1874        assert_eq!(e, A.to_dim());
1875    }
1876
1877    #[test]
1878    fn expand_polynomial_two_add_factors() {
1879        // (a + 2*a*b) * (1 + b)  ==poly==  a * (1 + b) * (1 + 2*b)
1880        // Both fully expand to a + 3*a*b + 2*a*b*b.  We don't auto-expand in
1881        // simplify (it would block maybe_div on factored-form denominators),
1882        // but expand_polynomial does, and Reshape uses it for volume checks.
1883        let a = A.to_dim();
1884        let b = B.to_dim();
1885        let lhs = (a.clone() + a.clone() * &b * 2) * (TDim::from(1) + &b);
1886        let rhs = a.clone() * (TDim::from(1) + &b) * (TDim::from(1) + b.clone() * 2);
1887        assert_eq!(lhs.expand_polynomial(), rhs.expand_polynomial());
1888    }
1889
1890    #[test]
1891    fn reduce_div_mul() {
1892        let e: TDim = A.to_dim() / 2 * 2;
1893        assert_ne!(e, A.to_dim());
1894    }
1895
1896    #[test]
1897    fn reduce_add_div() {
1898        let e: TDim = A.to_dim() / 2 + 1;
1899        assert_eq!(e, ((A.to_dim() + 2) / 2));
1900    }
1901
1902    #[test]
1903    fn reduce_neg_mul_() {
1904        let e: TDim = TDim::from(1) - A.to_dim() * 2;
1905        assert_eq!(e, TDim::from(1) + A.to_dim() * -2);
1906    }
1907
1908    #[test]
1909    fn reduce_add_rem_1() {
1910        assert_eq!(((A.to_dim() + 4) % 2), (A.to_dim() % 2));
1911    }
1912
1913    #[test]
1914    fn reduce_add_rem_2() {
1915        assert_eq!(((A.to_dim() - 4) % 2), (A.to_dim() % 2));
1916    }
1917
1918    #[test]
1919    fn reduce_rem_div() {
1920        let e: TDim = A.to_dim() % 2 / 2;
1921        assert_eq!(e, TDim::from(0));
1922    }
1923
1924    #[test]
1925    fn conv2d_ex_1() {
1926        let e = (TDim::from(1) - 1 + 1).div_ceil(1);
1927        assert_eq!(e, TDim::from(1));
1928    }
1929
1930    #[test]
1931    fn conv2d_ex_2() {
1932        let e = (A.to_dim() - 3 + 1).div_ceil(1);
1933        assert_eq!(e, A.to_dim() + -2);
1934    }
1935
1936    #[test]
1937    fn extract_int_gcd_from_muls() {
1938        let term = (A.to_dim() + 1) / 4;
1939        let mul = (term.clone() * 24 - 24) * (term.clone() * 2 - 2);
1940        let target = (term.clone() - 1) * (term.clone() - 1) * 48;
1941        assert_eq!(mul, target);
1942    }
1943
1944    #[test]
1945    fn equality_of_muls() {
1946        let term = (A.to_dim() + 1) / 4;
1947        let mul1 = (term.clone() * 2 - 3) * (term.clone() - 1);
1948        let mul2 = (term.clone() - 1) * (term.clone() * 2 - 3);
1949        assert_eq!(mul1, mul2);
1950    }
1951
1952    #[test]
1953    fn factorize_complex_expr_times_int() {
1954        let term = (A.to_dim() + 1) / 4;
1955        let e = term.clone() * 2 - &term - 1;
1956        assert_eq!(e, term - 1);
1957    }
1958
1959    #[test]
1960    fn broadcast_over_min() {
1961        // assuming a>0, b>0 then a#min(a,b) can be replaced by a
1962        // proof:
1963        //    if b == 1 => min(a,b)=1 => a#1=a => ok
1964        //    if a <= b => min(a,b)=a => ok
1965        //    if 1 < B < A => expression was invalid, we're generalizing over the non-domain and ignoring the constraint
1966        for a in 1..5 {
1967            for b in 1..5 {
1968                if b > 1 && a > b {
1969                    assert!(a.broadcast(a.min(b)).is_err());
1970                } else {
1971                    assert_eq!(a.broadcast(a.min(b)).unwrap(), a);
1972                }
1973            }
1974        }
1975    }
1976
1977    #[test]
1978    fn min_ints_1() {
1979        assert_eq!(2.to_dim().mini(1.to_dim()), 1.to_dim());
1980    }
1981
1982    #[test]
1983    fn min_ints_2() {
1984        assert_eq!(1.to_dim().mini(2.to_dim()), 1.to_dim());
1985    }
1986
1987    #[test]
1988    fn min_same() {
1989        assert_eq!(A.to_dim().mini(A.to_dim()), A.to_dim());
1990    }
1991
1992    #[test]
1993    fn min_noop() {
1994        assert_eq!(A.to_dim().mini(1.to_dim()), A.to_dim().mini(1.to_dim()));
1995    }
1996
1997    #[test]
1998    fn min_diff_1() {
1999        assert_eq!((A.to_dim() + 1).mini(A.to_dim() + 2), A.to_dim() + 1);
2000    }
2001
2002    #[test]
2003    fn slope_0() {
2004        assert_eq!(12.to_dim().guess_slope(&A), (0, 1));
2005    }
2006
2007    #[test]
2008    fn slope_1() {
2009        assert_eq!(A.to_dim().guess_slope(&A), (1, 1));
2010    }
2011
2012    #[test]
2013    fn slope_2() {
2014        assert_eq!((A.to_dim() * 2).guess_slope(&A), (2, 1));
2015    }
2016
2017    #[test]
2018    fn slope_3() {
2019        assert_eq!((A.to_dim() * 2 + A.to_dim() / 2).guess_slope(&A), (5, 2));
2020    }
2021
2022    #[test]
2023    fn slope_4() {
2024        assert_eq!((A.to_dim()).guess_slope(&B), (0, 1));
2025    }
2026
2027    #[test]
2028    fn slope_5() {
2029        assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1));
2030        assert_eq!((A.to_dim() + 1).guess_slope(&B), (0, 1));
2031    }
2032
2033    #[test]
2034    fn slope_6() {
2035        assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1));
2036        assert_eq!((A.to_dim() + B.to_dim()).guess_slope(&B), (1, 1));
2037    }
2038
2039    #[test]
2040    fn min_0() -> TractResult<()> {
2041        let symbols = SymbolScope::default();
2042        assert_eq!(
2043            symbols.parse_tdim("min(S+3, S+2)").unwrap().simplify(),
2044            symbols.parse_tdim("S+2").unwrap(),
2045        );
2046        Ok(())
2047    }
2048
2049    #[test]
2050    fn commutative_mul_parens() -> TractResult<()> {
2051        let symbols = SymbolScope::default();
2052        assert_eq!(
2053            symbols.parse_tdim("A*(B*C)").unwrap().simplify(),
2054            symbols.parse_tdim("(B*A)*C").unwrap().simplify(),
2055        );
2056        Ok(())
2057    }
2058
2059    #[test]
2060    fn commutative_in_nemo_parakeet_model() -> TractResult<()> {
2061        let symbols = SymbolScope::default();
2062        assert_eq!(
2063            symbols
2064                .parse_tdim("8*(1+-1*max(0,5000+-1*(S+7)/8)+max(0,4999+(S+7)/8))*((B)*((S+7)/8))")
2065                .unwrap()
2066                .simplify(),
2067            symbols
2068                .parse_tdim("8*((B)*(1+-1*max(0,5000+-1*(S+7)/8)+max(0,4999+(S+7)/8)))*((S+7)/8)")
2069                .unwrap()
2070                .simplify(),
2071        );
2072        Ok(())
2073    }
2074
2075    #[test]
2076    fn commutative_mul_parens_deep() -> TractResult<()> {
2077        let symbols = SymbolScope::default();
2078        let deep_tdim = Mul(vec![
2079            Mul(vec![Mul(vec![Mul(vec![A.to_dim(), B.to_dim()]), C.to_dim()]), D.to_dim()]),
2080            E.to_dim(),
2081        ])
2082        .simplify();
2083        assert_eq!(deep_tdim, symbols.parse_tdim("a*b*c*d*e").unwrap().simplify());
2084        Ok(())
2085    }
2086
2087    // ---- Tests for new comparison/not TDim variants ----
2088
2089    #[test]
2090    fn ge_concrete_true() {
2091        assert_eq!(Ge(b!(Val(5)), b!(Val(3))).reduce(), Val(1));
2092    }
2093
2094    #[test]
2095    fn ge_concrete_false() {
2096        assert_eq!(Ge(b!(Val(2)), b!(Val(3))).reduce(), Val(0));
2097    }
2098
2099    #[test]
2100    fn lt_concrete_true() {
2101        // Lt(2,3) normalizes to Ge(3, 2+1) = Ge(3, 3)
2102        assert_eq!(Ge(b!(Val(3)), b!(Val(3))).reduce(), Val(1));
2103    }
2104
2105    #[test]
2106    fn lt_concrete_false() {
2107        // Lt(5,3) normalizes to Ge(3, 5+1) = Ge(3, 6)
2108        assert_eq!(Ge(b!(Val(3)), b!(Val(6))).reduce(), Val(0));
2109    }
2110
2111    #[test]
2112    fn eq_concrete_true() {
2113        assert_eq!(Eq(b!(Val(3)), b!(Val(3))).reduce(), Val(1));
2114    }
2115
2116    #[test]
2117    fn eq_concrete_false() {
2118        assert_eq!(Eq(b!(Val(3)), b!(Val(4))).reduce(), Val(0));
2119    }
2120
2121    #[test]
2122    fn not_val_0() {
2123        // not(0) = 1 - 0 = 1
2124        assert_eq!((Val(1) - Val(0)).reduce(), Val(1));
2125    }
2126
2127    #[test]
2128    fn not_val_1() {
2129        // not(1) = 1 - 1 = 0
2130        assert_eq!((Val(1) - Val(1)).reduce(), Val(0));
2131    }
2132
2133    #[test]
2134    fn not_lt_becomes_ge() {
2135        // not(Lt(x1, T)) = 1 - Ge(T, x1+1); check it evaluates correctly at boundary
2136        let s = SymbolScope::default();
2137        let t = s.sym("T");
2138        let x1 = s.sym("x1");
2139        // at x1 = T (boundary), Ge(T, T+1) = 0, so 1 - 0 = 1 (not-lt is true when x1 >= T)
2140        let expr = Val(1) - Ge(b!(Sym(t.clone())), b!(Sym(x1.clone()) + Val(1)));
2141        let at_boundary = expr.substitute(&x1, &Sym(t.clone())).unwrap().simplify();
2142        assert_eq!(at_boundary, Val(1));
2143    }
2144
2145    #[test]
2146    fn eq_with_assertion_proves_false() {
2147        // Eq(T, 0) should reduce to Val(0) when T >= 1
2148        let s = SymbolScope::default();
2149        s.add_assertion("T >= 1").unwrap();
2150        let t = s.sym("T");
2151        let expr = Eq(b!(Sym(t)), b!(Val(0)));
2152        assert_eq!(expr.simplify(), Val(0));
2153    }
2154
2155    #[test]
2156    fn ge_coord_at_extremes() {
2157        // Ge(x1, T) should not simplify without coordinate substitution
2158        let s = SymbolScope::default();
2159        s.add_assertion("T >= 1").unwrap();
2160        let t = s.sym("T");
2161        let x1 = s.sym("x1");
2162        let expr = Ge(b!(Sym(x1.clone())), b!(Sym(t.clone())));
2163        // simplify() alone can't prove this false (x1 could be > T)
2164        // but with coordinate substitution (x1 = T-1), Ge(T-1, T) = 0
2165        let at_max = expr.substitute(&x1, &(Sym(t.clone()) - Val(1))).unwrap().simplify();
2166        assert_eq!(at_max, Val(0));
2167    }
2168
2169    #[test]
2170    fn eval_to_i64_new_variants() {
2171        use super::super::sym::SymbolValues;
2172        let sv = SymbolValues::default();
2173        assert_eq!(Ge(b!(Val(5)), b!(Val(3))).eval_to_i64(&sv).unwrap(), 1);
2174        assert_eq!(Ge(b!(Val(3)), b!(Val(5))).eval_to_i64(&sv).unwrap(), 0);
2175        assert_eq!(Eq(b!(Val(3)), b!(Val(3))).eval_to_i64(&sv).unwrap(), 1);
2176        assert_eq!(Eq(b!(Val(3)), b!(Val(4))).eval_to_i64(&sv).unwrap(), 0);
2177    }
2178
2179    #[test]
2180    fn eq_boolean_simplifies() {
2181        let s = SymbolScope::default();
2182        s.add_assertion("cw >= 0").unwrap();
2183        s.add_assertion("cw <= 1").unwrap();
2184        let cw = s.sym("cw");
2185        // Eq(1 - cw, 0) → cw
2186        assert_eq!(Eq(b!(Val(1) - Sym(cw.clone())), b!(Val(0))).simplify(), Sym(cw.clone()));
2187        // Eq(cw, 0) → 1 - cw
2188        assert_eq!(Eq(b!(Sym(cw.clone())), b!(Val(0))).simplify(), Val(1) - Sym(cw.clone()));
2189        // Eq(cw, 1) → cw
2190        assert_eq!(Eq(b!(Sym(cw.clone())), b!(Val(1))).simplify(), Sym(cw.clone()));
2191        // Eq(1 - cw, 1) → 1 - cw
2192        assert_eq!(Eq(b!(Val(1) - Sym(cw.clone())), b!(Val(1))).simplify(), Val(1) - Sym(cw));
2193    }
2194
2195    #[test]
2196    fn eq_boolean_mul_of_ge() {
2197        // Product of Ge terms: Ge(a,b) * Ge(c,d) is in [0,1]
2198        // so Eq(product, 0) should simplify to 1 - product
2199        let s = SymbolScope::default();
2200        let x = s.sym("x");
2201        let product =
2202            Mul(vec![Ge(b!(Val(2)), b!(Sym(x.clone()))), Ge(b!(Sym(x.clone())), b!(Val(0)))]);
2203        let eq = Eq(b!(product.clone()), b!(Val(0)));
2204        assert_eq!(eq.simplify(), Val(1) - product);
2205    }
2206
2207    #[test]
2208    fn min_1_max_0_sym() {
2209        // Min(1, Max(0, X)) must not simplify away the Min when X is unconstrained.
2210        let s = SymbolScope::default();
2211        let x = s.sym("X");
2212        let expr = Min(vec![Val(1), Max(vec![Val(0), Sym(x)])]);
2213        let simplified = expr.simplify();
2214        eprintln!("simplified: {simplified}");
2215        assert!(format!("{simplified}").contains("min"), "Min dropped: {simplified}");
2216    }
2217
2218    #[test]
2219    fn min_preserved_in_subtraction_parts() {
2220        // Test that Min([1, X]) simplifies correctly in isolation
2221        let s = SymbolScope::default();
2222        let t = s.sym("T");
2223        let p = s.sym("P");
2224        let ss = s.sym("S");
2225
2226        let cum_after =
2227            Max(vec![Val(0), (Sym(t.clone()) + Val(1)) * Sym(p.clone()) - Sym(ss.clone())]);
2228        let min_after = Min(vec![Val(1), cum_after.clone()]);
2229        let simplified = min_after.simplify();
2230        eprintln!("min_after simplified: {simplified}");
2231        // Must contain "min" — the Min must not be dropped
2232        assert!(format!("{simplified}").contains("min"), "Min wrapper was dropped: {simplified}");
2233    }
2234
2235    #[test]
2236    fn min_preserved_in_subtraction() {
2237        // min(1, X) - min(1, Y) must preserve the min() wrappers.
2238        // This is the pattern used by PulseV2Pad's output_facts for after-padding.
2239        let s = SymbolScope::default();
2240        let t = s.sym("T");
2241        let p = s.sym("P");
2242        let ss = s.sym("S");
2243
2244        let cum_after =
2245            Max(vec![Val(0), (Sym(t.clone()) + Val(1)) * Sym(p.clone()) - Sym(ss.clone())]);
2246        let cum_before = Max(vec![Val(0), Sym(t.clone()) * Sym(p.clone()) - Sym(ss.clone())]);
2247
2248        let ap = Min(vec![Val(1), cum_after.clone()]) - Min(vec![Val(1), cum_before.clone()]);
2249        let simplified = ap.simplify();
2250
2251        // At T=1, P=4, S=3: min(1, max(0, 8-3)) - min(1, max(0, 4-3)) = 1 - 1 = 0
2252        use super::super::sym::SymbolValues;
2253        let sv = SymbolValues::default().with(&t, 1).with(&p, 4).with(&ss, 3);
2254        assert_eq!(simplified.eval_to_i64(&sv).unwrap(), 0, "simplified: {simplified}");
2255
2256        // At T=0, P=4, S=3: min(1, max(0, 4-3)) - min(1, max(0, 0-3)) = 1 - 0 = 1
2257        let sv = SymbolValues::default().with(&t, 0).with(&p, 4).with(&ss, 3);
2258        assert_eq!(simplified.eval_to_i64(&sv).unwrap(), 1, "simplified: {simplified}");
2259
2260        // At T=0, P=1, S=1: min(1, max(0, 1-1)) - min(1, max(0, 0-1)) = 0 - 0 = 0
2261        let sv = SymbolValues::default().with(&t, 0).with(&p, 1).with(&ss, 1);
2262        assert_eq!(simplified.eval_to_i64(&sv).unwrap(), 0, "simplified: {simplified}");
2263    }
2264
2265    #[test]
2266    fn mul_neg_b_by_8() {
2267        let s = SymbolScope::default();
2268        let b = Sym(s.sym("B"));
2269        // 8*(-1*B) should equal -8*B
2270        let a = Mul(vec![Val(8), MulInt(-1, Box::new(b.clone()))]);
2271        let c = MulInt(-8, Box::new(b.clone()));
2272        let a_s = a.simplify();
2273        let c_s = c.simplify();
2274        assert_eq!(a_s, c_s, "8*(-1*B) should simplify the same as -8*B");
2275    }
2276
2277    /// Encoder-pulse case: (6 + 14·S) / 8 == (3 + 7·S) / 4.
2278    /// Both Add terms (Val(6) and MulInt(14, S)) share factor 2 with
2279    /// divisor 8, so the simplifier should reduce both sides by 2.
2280    #[test]
2281    fn reduce_div_by_common_factor_with_divisor() {
2282        let lhs = (A.to_dim() * 14 + 6) / 8;
2283        let rhs = (A.to_dim() * 7 + 3) / 4;
2284        assert_eq!(lhs, rhs);
2285    }
2286
2287    /// Common factor that fully divides the divisor → drop the divisor.
2288    /// (4·a + 8) / 4  ==  a + 2.
2289    #[test]
2290    fn reduce_div_when_factor_equals_divisor() {
2291        let lhs = (A.to_dim() * 4 + 8) / 4;
2292        let rhs = A.to_dim() + 2;
2293        assert_eq!(lhs, rhs);
2294    }
2295
2296    /// No common factor → no reduction (identity check).
2297    /// (3 + 7·a) / 4 stays as-is (gcd(3, 7, 4) = 1).
2298    #[test]
2299    fn no_reduce_when_terms_coprime_with_divisor() {
2300        let e = (A.to_dim() * 7 + 3) / 4;
2301        // We just check it didn't reduce to something weird; the
2302        // canonical form is `Div(Add(...), 4)`.
2303        match &e {
2304            Div(_, q) => assert_eq!(*q, 4),
2305            other => panic!("expected Div(_, 4), got {other:?}"),
2306        }
2307    }
2308
2309    /// Sym without an explicit `MulInt` wrapper has implicit coefficient
2310    /// 1.  Any common factor gcd including 1 collapses to 1, so the
2311    /// reduction does nothing — the rule must not silently drop the Sym.
2312    #[test]
2313    fn no_reduce_when_sym_has_implicit_unit_coefficient() {
2314        // (a + 4) / 2 must stay non-trivial — gcd(1, 4, 2) = 1.
2315        let e = (A.to_dim() + 4) / 2;
2316        // It can simplify to other forms but it should still depend on `a`.
2317        // Eval at a=2 → (2+4)/2 = 3.  Eval at a=4 → (4+4)/2 = 4.
2318        let sv2 = SymbolValues::default().with(&A, 2);
2319        let sv4 = SymbolValues::default().with(&A, 4);
2320        assert_eq!(e.eval_to_i64(&sv2).unwrap(), 3);
2321        assert_eq!(e.eval_to_i64(&sv4).unwrap(), 4);
2322    }
2323}