Skip to main content

tract_data/dim/
tree.rs

1use crate::dim::Assertion;
2use crate::internal::*;
3
4use super::{DimLike, sym::*};
5use itertools::Itertools;
6use num_integer::Integer;
7use num_traits::{AsPrimitive, PrimInt, Zero};
8use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10use std::fmt::Debug;
11use std::ops::Neg;
12use std::{fmt, ops};
13
14#[derive(Debug)]
15pub enum TooEarly {
16    UndeterminedSymbol(String),
17    Other(String),
18}
19
20impl std::fmt::Display for TooEarly {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        match self {
23            TooEarly::UndeterminedSymbol(s) => write!(f, "Undetermined symbol in expression: {s}"),
24            TooEarly::Other(s) => write!(f, "{s}"),
25        }
26    }
27}
28
29impl std::error::Error for TooEarly {}
30
31macro_rules! b( ($e:expr) => { Box::new($e) } );
32
33// `Hash` stays structural while `PartialEq` accepts an algebraic second chance:
34// see the `PartialEq` impl below for the rationale (the simplifier's internal
35// `HashMap<TDim, _>` only ever compares within same-canonical-form buckets, so
36// the standard `a == b => hash(a) == hash(b)` contract being violated outside
37// that path is acceptable here).
38#[allow(clippy::derived_hash_with_manual_eq)]
39#[derive(Clone, Eq, Hash, Debug)]
40pub enum TDim {
41    Val(i64),
42    Sym(Symbol),
43    Add(Vec<TDim>),
44    Mul(Vec<TDim>),
45    MulInt(i64, Box<TDim>),
46    Div(Box<TDim>, u64),
47    Broadcast(Vec<TDim>),
48    Min(Vec<TDim>),
49    Max(Vec<TDim>),
50    /// Comparison: evaluates to 1 (true) or 0 (false). lhs >= rhs
51    Ge(Box<TDim>, Box<TDim>),
52    /// Comparison: evaluates to 1 (true) or 0 (false). lhs == rhs
53    Eq(Box<TDim>, Box<TDim>),
54}
55
56use TDim::*;
57
58/// Structural equality on the TDim tree — what `#[derive(PartialEq)]` would
59/// produce.  Used as the fast-path inside `PartialEq` (and by the simplifier's
60/// internal `HashMap<TDim, _>`, which compares within same-hash buckets where
61/// structural equality is the only thing that matters).
62fn eq_structural(a: &TDim, b: &TDim) -> bool {
63    match (a, b) {
64        (Val(x), Val(y)) => x == y,
65        (Sym(x), Sym(y)) => x == y,
66        (Add(x), Add(y))
67        | (Mul(x), Mul(y))
68        | (Broadcast(x), Broadcast(y))
69        | (Min(x), Min(y))
70        | (Max(x), Max(y)) => {
71            x.len() == y.len() && x.iter().zip(y).all(|(a, b)| eq_structural(a, b))
72        }
73        (MulInt(p, x), MulInt(q, y)) => p == q && eq_structural(x, y),
74        (Div(x, p), Div(y, q)) => p == q && eq_structural(x, y),
75        (Ge(a, b), Ge(c, d)) | (Eq(a, b), Eq(c, d)) => eq_structural(a, c) && eq_structural(b, d),
76        _ => false,
77    }
78}
79
80// Thread-local guard: while simplifying the difference inside `eq`, fall back
81// to the structural-only path for any nested `==` calls.  Without this guard
82// the simplifier's internal `HashMap<TDim, i64>` would re-enter `eq` from
83// inside `(self - other).simplify()`, recursing without bound on
84// non-structurally-equal inputs.
85std::thread_local! {
86    static EQ_GUARD: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
87}
88
89impl PartialEq for TDim {
90    fn eq(&self, other: &Self) -> bool {
91        // Fast path: structural tree equality.
92        if eq_structural(self, other) {
93            return true;
94        }
95        // Inside an enclosing simplification triggered by a previous
96        // second-chance call, fall back to structural equality only.
97        if EQ_GUARD.with(|g| g.get()) {
98            return false;
99        }
100        // Skip second-chance when either side is a leaf (`Val` or `Sym`).
101        // For `Val(c)` vs anything non-`Val`: if they were semantically
102        // equal, the simplifier should already have folded the other side
103        // to `Val(c)`; running a diff-and-simplify here just risks
104        // arithmetic overflow on extreme constants (e.g. the simplifier
105        // filters against `Val(i64::MAX)`/`Val(i64::MIN)` sentinels).
106        // For `Sym(x)` leaves, assertion-driven equality belongs in
107        // `simplify`, not in `eq`.
108        if matches!(self, Val(_) | Sym(_)) || matches!(other, Val(_) | Sym(_)) {
109            return false;
110        }
111        // Second chance: prove the difference simplifies to zero.  Two
112        // algebraically equal TDims often arrive at different canonical
113        // forms via different construction paths (e.g. `1 + (7S+3)/4` and
114        // `((S+1)*7)/4` after blockify substitutes T → k·S in encoder
115        // shapes).  Subtracting and simplifying lets the existing
116        // simplifier rules cancel them out.
117        EQ_GUARD.with(|g| g.set(true));
118        let diff = (self.clone() - other.clone()).simplify();
119        EQ_GUARD.with(|g| g.set(false));
120        matches!(diff, Val(0))
121    }
122}
123
124fn tdim_lexi_order(a: &TDim, b: &TDim) -> Ordering {
125    match (a, b) {
126        (Sym(a), Sym(b)) => a.cmp(b),
127        (Val(a), Val(b)) => a.cmp(b),
128        (Add(a), Add(b))
129        | (Mul(a), Mul(b))
130        | (Broadcast(a), Broadcast(b))
131        | (Min(a), Min(b))
132        | (Max(a), Max(b)) => a.len().cmp(&b.len()).then(
133            a.iter()
134                .zip(b.iter())
135                .fold(Ordering::Equal, |acc, (a, b)| acc.then_with(|| tdim_lexi_order(a, b))),
136        ),
137        (MulInt(p, d), MulInt(q, e)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)),
138        (Div(d, p), Div(e, q)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)),
139        (Sym(_), _) => Ordering::Less,
140        (_, Sym(_)) => Ordering::Greater,
141        (Val(_), _) => Ordering::Less,
142        (_, Val(_)) => Ordering::Greater,
143        (Add(_), _) => Ordering::Less,
144        (_, Add(_)) => Ordering::Greater,
145        (Mul(_), _) => Ordering::Less,
146        (_, Mul(_)) => Ordering::Greater,
147        (MulInt(_, _), _) => Ordering::Less,
148        (_, MulInt(_, _)) => Ordering::Greater,
149        (Broadcast(_), _) => Ordering::Less,
150        (_, Broadcast(_)) => Ordering::Greater,
151        (Min(_), _) => Ordering::Less,
152        (_, Min(_)) => Ordering::Greater,
153        (Max(_), _) => Ordering::Less,
154        (_, Max(_)) => Ordering::Greater,
155        (Ge(a1, b1), Ge(a2, b2)) | (Eq(a1, b1), Eq(a2, b2)) => {
156            tdim_lexi_order(a1, a2).then_with(|| tdim_lexi_order(b1, b2))
157        }
158        (Ge(_, _) | Eq(_, _), _) => Ordering::Less,
159        (_, Ge(_, _) | Eq(_, _)) => Ordering::Greater,
160    }
161}
162
163/// `Div(Add(terms), q)` — try to extract every `MulInt(c, X)` where `c % q == 0`
164/// out of the Div, leaving only a constant remainder in `[0, q)`.
165///
166/// Returns `Some(simplified)` when the residual constant is in `[0, q)` and
167/// every extracted symbolic factor `X` is provably non-negative — both
168/// conditions are required for soundness under tract's truncating
169/// division (`Rust i64 /`):
170///
171/// * the constant being in `[0, q)` makes `c/q_trunc = 0`;
172/// * `X ≥ 0` makes the identity `(k·X + c)/k_trunc = X` hold (it fails
173///   at e.g. `X = -1, k = 2, c = 0` because truncation rounds toward zero).
174///
175/// The `Val` arm above already handles constants outside `[0, q)`, so by
176/// the time we get here `terms` contains at most one `Val` and any number
177/// of `MulInt(c, X)` / other shapes.
178fn try_divide_multiple_plus_remainder(
179    terms: &[TDim],
180    q: u64,
181    scope: &SymbolScopeData,
182    extra: &[Assertion],
183) -> Option<TDim> {
184    let mut quotients: Vec<TDim> = vec![];
185    let mut const_rem: i64 = 0;
186    let mut any_extracted = false;
187    for term in terms {
188        match term {
189            MulInt(c, x) if *c != 0 && c.rem_euclid(q as i64) == 0 => {
190                if !scope.prove_positive_or_zero_with_extra(x, extra) {
191                    return None;
192                }
193                let new_coeff = c / (q as i64);
194                quotients.push(if new_coeff == 1 {
195                    (**x).clone()
196                } else if new_coeff == -1 {
197                    MulInt(-1, x.clone())
198                } else {
199                    MulInt(new_coeff, x.clone())
200                });
201                any_extracted = true;
202            }
203            Val(v) => const_rem += v,
204            _ => return None,
205        }
206    }
207    if !any_extracted {
208        return None;
209    }
210    if !(0..q as i64).contains(&const_rem) {
211        return None;
212    }
213    Some(if quotients.len() == 1 { quotients.remove(0) } else { Add(quotients) })
214}
215
216impl fmt::Display for TDim {
217    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
218        match &self {
219            Sym(sym) => write!(fmt, "{sym}"),
220            Val(it) => write!(fmt, "{it}"),
221            Add(it) => write!(fmt, "{}", it.iter().map(|x| format!("{x}")).join("+")),
222            Mul(it) => write!(fmt, "{}", it.iter().map(|x| format!("({x})")).join("*")),
223            Broadcast(it) => {
224                write!(fmt, "broadcast({})", it.iter().map(|x| format!("({x})")).join(", "))
225            }
226            Min(it) => write!(fmt, "min({})", it.iter().map(|x| format!("{x}")).join(",")),
227            Max(it) => write!(fmt, "max({})", it.iter().map(|x| format!("{x}")).join(",")),
228            MulInt(a, b) => write!(fmt, "{a}*{b}"),
229            Div(a, b) => write!(fmt, "({a})/{b}"),
230            Ge(a, b) => write!(fmt, "({a}>={b})"),
231            Eq(a, b) => write!(fmt, "({a}=={b})"),
232        }
233    }
234}
235
236impl TDim {
237    #[inline]
238    pub fn is_one(&self) -> bool {
239        matches!(self, Val(1))
240    }
241
242    #[inline]
243    pub fn to_i64(&self) -> TractResult<i64> {
244        if let Val(v) = self {
245            Ok(*v)
246        } else {
247            Err(TooEarly::UndeterminedSymbol(self.to_string()))?
248        }
249    }
250
251    #[inline]
252    pub fn as_i64(&self) -> Option<i64> {
253        if let Val(v) = self { Some(*v) } else { None }
254    }
255
256    pub fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64> {
257        match self {
258            Sym(sym) => {
259                let Some(v) = values.get(sym) else {
260                    Err(TooEarly::UndeterminedSymbol(self.to_string()))?
261                };
262                Ok(v)
263            }
264            Val(v) => Ok(*v),
265            Add(terms) => terms.iter().try_fold(0i64, |acc, it| {
266                let x = it.eval_to_i64(values)?;
267                acc.checked_add(x)
268                    .with_context(|| format!("Overflow in TDim addition ({acc} + {x})"))
269            }),
270            Mul(terms) => terms.iter().try_fold(1i64, |acc, it| {
271                let x = it.eval_to_i64(values)?;
272                acc.checked_mul(x)
273                    .with_context(|| format!("Overflow in TDim multiplication ({acc} * {x})"))
274            }),
275            Min(terms) => terms
276                .iter()
277                .try_fold(i64::MAX, |acc, it| it.eval_to_i64(values).map(|x| acc.min(x))),
278            Max(terms) => terms
279                .iter()
280                .try_fold(i64::MIN, |acc, it| it.eval_to_i64(values).map(|x| acc.max(x))),
281            Broadcast(terms) => terms.iter().try_fold(1i64, |acc, it| {
282                it.eval_to_i64(values)
283                    .and_then(|x| ((acc as usize).broadcast(x as usize)).map(|x| x as i64))
284            }),
285            Div(a, q) => Ok(a.eval_to_i64(values)? / *q as i64),
286            MulInt(p, a) => {
287                let x = a.eval_to_i64(values)?;
288                x.checked_mul(*p)
289                    .with_context(|| format!("Overflow in TDim multiplication ({x} * {p})"))
290            }
291            Ge(a, b) => Ok(if a.eval_to_i64(values)? >= b.eval_to_i64(values)? { 1 } else { 0 }),
292            Eq(a, b) => Ok(if a.eval_to_i64(values)? == b.eval_to_i64(values)? { 1 } else { 0 }),
293        }
294    }
295
296    pub fn eval(&self, values: &SymbolValues) -> TDim {
297        match self {
298            Sym(sym) => values.get(sym).map(Val).unwrap_or_else(|| Sym(sym.clone())),
299            Val(v) => Val(*v),
300            Add(terms) => terms.iter().fold(Val(0), |acc, it| -> TDim { acc + it.eval(values) }),
301            Mul(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim { acc * it.eval(values) }),
302            Min(terms) => {
303                terms.iter().fold(Val(i64::MAX), |acc, it| -> TDim { acc.mini(it.eval(values)) })
304            }
305            Max(terms) => {
306                terms.iter().fold(Val(i64::MIN), |acc, it| -> TDim { acc.maxi(it.eval(values)) })
307            }
308            Broadcast(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim {
309                acc.broadcast(it.eval(values)).unwrap_or_else(|_| self.clone())
310            }),
311            Div(a, q) => a.eval(values) / *q as i64,
312            MulInt(p, a) => a.eval(values) * *p,
313            Ge(a, b) => {
314                let a2 = a.eval(values);
315                let b2 = b.eval(values);
316                if let (Val(av), Val(bv)) = (&a2, &b2) {
317                    Val(if av >= bv { 1 } else { 0 })
318                } else {
319                    Ge(b!(a2), b!(b2))
320                }
321            }
322            Eq(a, b) => {
323                let a2 = a.eval(values);
324                let b2 = b.eval(values);
325                if let (Val(av), Val(bv)) = (&a2, &b2) {
326                    Val(if av == bv { 1 } else { 0 })
327                } else {
328                    Eq(b!(a2), b!(b2))
329                }
330            }
331        }
332    }
333
334    pub fn eval_with_scenario(&self, scenario: &str) -> TDim {
335        if let Val(v) = self {
336            return Val(*v);
337        }
338        let scope = self.find_scope().unwrap();
339        let scope = scope.0;
340        let locked = scope.lock();
341        let scope = locked.borrow();
342        self.clone().simplify_rec(&scope, Some(scenario), &[])
343    }
344
345    pub fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self> {
346        self.substitute_all(&std::collections::HashMap::from([(from.clone(), to.clone())]))
347    }
348
349    pub fn substitute_all(
350        &self,
351        map: &std::collections::HashMap<Symbol, Self>,
352    ) -> TractResult<Self> {
353        match self {
354            Sym(sym) => Ok(map.get(sym).cloned().unwrap_or_else(|| self.clone())),
355            Val(v) => Ok(Val(*v)),
356            Add(terms) => terms.iter().try_fold(Val(0), |acc, it| -> TractResult<TDim> {
357                Ok(acc + it.substitute_all(map)?)
358            }),
359            Mul(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult<TDim> {
360                Ok(acc * it.substitute_all(map)?)
361            }),
362            Broadcast(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult<TDim> {
363                acc.broadcast(it.substitute_all(map)?)
364            }),
365            Min(terms) => terms.iter().try_fold(Val(i64::MAX), |acc, it| -> TractResult<TDim> {
366                Ok(acc.mini(it.substitute_all(map)?))
367            }),
368            Max(terms) => terms.iter().try_fold(Val(i64::MIN), |acc, it| -> TractResult<TDim> {
369                Ok(acc.maxi(it.substitute_all(map)?))
370            }),
371            Div(a, q) => Ok(a.substitute_all(map)? / *q as i64),
372            MulInt(p, a) => Ok(a.substitute_all(map)? * *p),
373            Ge(a, b) => Ok(Ge(b!(a.substitute_all(map)?), b!(b.substitute_all(map)?))),
374            Eq(a, b) => Ok(Eq(b!(a.substitute_all(map)?), b!(b.substitute_all(map)?))),
375        }
376    }
377
378    pub fn reduce(self) -> TDim {
379        self.simplify()
380            .wiggle()
381            .into_iter()
382            .sorted_by(tdim_lexi_order)
383            .unique()
384            .map(|e| e.simplify())
385            .min_by_key(|e| e.cost())
386            .unwrap()
387    }
388
389    fn cost(&self) -> usize {
390        use self::TDim::*;
391        match self {
392            Sym(_) | Val(_) => 1,
393            Add(terms) => 2 * terms.iter().map(TDim::cost).sum::<usize>(),
394            Mul(terms) => 3 * terms.iter().map(TDim::cost).sum::<usize>(),
395            Broadcast(terms) => 4 * terms.iter().map(TDim::cost).sum::<usize>(),
396            Min(terms) | Max(terms) => 5 * terms.iter().map(TDim::cost).sum::<usize>(),
397            Div(a, _) => 3 * a.cost(),
398            MulInt(_, a) => 2 * a.cost(),
399            Ge(a, b) | Eq(a, b) => 5 * (a.cost() + b.cost()),
400        }
401    }
402
403    fn wiggle(&self) -> Vec<TDim> {
404        use self::TDim::*;
405        match self {
406            Sym(_) | Val(_) | Mul(_) | Broadcast(_) | Min(_) | Max(_) | Ge(_, _) | Eq(_, _) => {
407                vec![self.clone()]
408            }
409            Add(terms) => {
410                let mut forms = vec![];
411                let sub_exprs = terms.iter().map(|e| e.wiggle()).multi_cartesian_product();
412
413                fn first_div_term(terms: &[TDim]) -> Option<(usize, &TDim, u64)> {
414                    terms.iter().enumerate().find_map(|(index, t)| match t {
415                        Div(numerator, quotient) => Some((index, &**numerator, *quotient)),
416                        _ => None,
417                    })
418                }
419
420                fn generate_new_numerator(
421                    div_index: usize,
422                    numerator: &TDim,
423                    quotient: u64,
424                    expr: &[TDim],
425                ) -> Vec<TDim> {
426                    expr.iter()
427                        .enumerate()
428                        .map(|(index, term)| {
429                            if index == div_index {
430                                numerator.clone()
431                            } else {
432                                MulInt(quotient as i64, Box::new(term.clone()))
433                            }
434                        })
435                        .collect()
436                }
437
438                for expr in sub_exprs {
439                    if let Some((div_index, numerator, quotient)) = first_div_term(&expr) {
440                        let new_numerator =
441                            generate_new_numerator(div_index, numerator, quotient, &expr);
442                        forms.push(Div(Box::new(Add(new_numerator)), quotient))
443                    }
444
445                    forms.push(Add(expr));
446                }
447                forms
448            }
449            MulInt(p, a) => a.wiggle().into_iter().map(|a| MulInt(*p, b!(a))).collect(),
450            Div(a, q) => {
451                let mut forms = vec![];
452                for num in a.wiggle() {
453                    if let Add(terms) = &num {
454                        let (integer, non_integer): (Vec<_>, Vec<_>) =
455                            terms.iter().cloned().partition(|a| a.gcd() % q == 0);
456                        // Skip when the non-integer bucket holds a constant:
457                        // under tract's truncating `/`, splitting (k·X+c)/k →
458                        // X + c/k is unsound for negative X (X=-1, k=2, c=1:
459                        // (-1)/2 = 0 ≠ X). The sound version, gated on
460                        // prove_positive_or_zero, lives in simplify_rec::Div
461                        // via try_divide_multiple_plus_remainder. Cases where
462                        // the remainder is purely symbolic (e.g. A%2 → /2
463                        // lowers to (A − 2·(A/2))/2, non_integer=[A]) stay
464                        // here: the emitted Div(non_integer, q) cancels with
465                        // the extracted quotient and reduces to zero.
466                        if !non_integer.iter().any(|t| matches!(t, Val(_))) {
467                            let mut new_terms =
468                                integer.iter().map(|i| i.div(*q)).collect::<Vec<_>>();
469                            if non_integer.len() > 0 {
470                                new_terms.push(Div(b!(Add(non_integer)), *q));
471                            }
472                            forms.push(Add(new_terms))
473                        }
474                    }
475                    forms.push(Div(b!(num), *q))
476                }
477                forms
478            }
479        }
480    }
481
482    fn find_any_sym(tdim: &TDim) -> Option<&Symbol> {
483        match tdim {
484            Val(_) => None,
485            Sym(s) => Some(s),
486            Add(terms) | Mul(terms) | Min(terms) | Max(terms) | Broadcast(terms) => {
487                terms.iter().find_map(Self::find_any_sym)
488            }
489            MulInt(_, t) | Div(t, _) => Self::find_any_sym(t),
490            Ge(a, b) | Eq(a, b) => Self::find_any_sym(a).or_else(|| Self::find_any_sym(b)),
491        }
492    }
493
494    pub fn find_scope(&self) -> Option<SymbolScope> {
495        Self::find_any_sym(self).and_then(|s| s.scope().clone())
496    }
497
498    /// Fully distribute every `Mul` of `Add`s in `self` into a flat sum of
499    /// products, then `simplify`.  Used to compare two algebraically equal
500    /// but differently-factored TDims for equality (e.g. Reshape volume
501    /// checks on graphs where the same dimension is built two ways).
502    ///
503    /// Cost can blow up combinatorially on very deeply factored expressions
504    /// — call this only at boundaries where structural equality is needed,
505    /// not as a general-purpose simplifier.
506    pub fn expand_polynomial(self) -> TDim {
507        use self::TDim::*;
508        match self {
509            Mul(terms) => {
510                let terms: Vec<TDim> = terms.into_iter().map(Self::expand_polynomial).collect();
511                if let Some(add_idx) = terms.iter().position(|t| matches!(t, Add(_))) {
512                    let Add(add_terms) = terms[add_idx].clone() else { unreachable!() };
513                    let others: Vec<TDim> = terms
514                        .iter()
515                        .enumerate()
516                        .filter(|(i, _)| *i != add_idx)
517                        .map(|(_, t)| t.clone())
518                        .collect();
519                    Add(add_terms
520                        .into_iter()
521                        .map(|t| {
522                            let mut product = others.clone();
523                            product.push(t);
524                            Mul(product).expand_polynomial()
525                        })
526                        .collect())
527                    .simplify()
528                } else {
529                    Mul(terms).simplify()
530                }
531            }
532            MulInt(c, inner) => MulInt(c, Box::new(inner.expand_polynomial())).simplify(),
533            Add(terms) => Add(terms.into_iter().map(Self::expand_polynomial).collect()).simplify(),
534            Div(a, q) => Div(Box::new(a.expand_polynomial()), q).simplify(),
535            Min(terms) => Min(terms.into_iter().map(Self::expand_polynomial).collect()).simplify(),
536            Max(terms) => Max(terms.into_iter().map(Self::expand_polynomial).collect()).simplify(),
537            Broadcast(terms) => {
538                Broadcast(terms.into_iter().map(Self::expand_polynomial).collect()).simplify()
539            }
540            Ge(a, b) => {
541                Ge(Box::new(a.expand_polynomial()), Box::new(b.expand_polynomial())).simplify()
542            }
543            Eq(a, b) => {
544                Eq(Box::new(a.expand_polynomial()), Box::new(b.expand_polynomial())).simplify()
545            }
546            it @ (Sym(_) | Val(_)) => it,
547        }
548    }
549
550    pub fn simplify(self) -> TDim {
551        use self::TDim::*;
552        if let Ok(v) = self.eval_to_i64(&SymbolValues::default()) {
553            return Val(v);
554        }
555        let Some(scope) = self.find_scope() else {
556            return self;
557        };
558        let scope = scope.0;
559        let locked = scope.lock();
560        let scope = locked.borrow();
561        let it = self.simplify_rec(&scope, None, &[]);
562        let mut current: Option<TDim> = None;
563        for scenario in scope.scenarios() {
564            let v = it.clone().simplify_rec(&scope, Some(scenario), &[]);
565            if current.is_some_and(|c| c != v) {
566                return it;
567            } else {
568                current = Some(v);
569            }
570        }
571        current.unwrap_or(it)
572    }
573
574    pub fn simplify_with_extra_assertions(self, extra: &[Assertion]) -> TDim {
575        use self::TDim::*;
576        if extra.is_empty() {
577            return self.simplify();
578        }
579        if let Ok(v) = self.eval_to_i64(&SymbolValues::default()) {
580            return Val(v);
581        }
582        let Some(scope) = self.find_scope() else {
583            return self;
584        };
585        let scope = scope.0;
586        let locked = scope.lock();
587        let scope = locked.borrow();
588        let it = self.simplify_rec(&scope, None, extra);
589        let mut current: Option<TDim> = None;
590        for scenario in scope.scenarios() {
591            let v = it.clone().simplify_rec(&scope, Some(scenario), extra);
592            if current.is_some_and(|c| c != v) {
593                return it;
594            } else {
595                current = Some(v);
596            }
597        }
598        current.unwrap_or(it)
599    }
600
601    fn simplify_rec(
602        self,
603        scope: &SymbolScopeData,
604        scenario: Option<&str>,
605        extra: &[Assertion],
606    ) -> TDim {
607        match self {
608            Add(mut terms) => {
609                #[allow(clippy::mutable_key_type)]
610                let mut simplified_terms: HashMap<TDim, i64> = HashMap::new();
611                // factorize common sub-expr
612                while let Some(term) = terms.pop() {
613                    let simplified = term.simplify_rec(scope, scenario, extra);
614                    match simplified {
615                        Val(0) => {} // ignore
616                        Add(members) => {
617                            terms.extend(members);
618                            continue;
619                        }
620                        Val(value) => *simplified_terms.entry(Val(1)).or_insert(0) += value,
621                        MulInt(value, factor) => {
622                            *simplified_terms.entry((*factor).clone()).or_insert(0) += value;
623                        }
624                        n => *simplified_terms.entry(n).or_insert(0) += 1,
625                    };
626                }
627
628                pub fn evaluate_count(term: TDim, count: i64) -> Option<TDim> {
629                    match count {
630                        0 => None,
631                        _ if term == TDim::Val(1) => Some(TDim::Val(count)),
632                        1 => Some(term),
633                        _ => Some(TDim::MulInt(count, Box::new(term))),
634                    }
635                }
636
637                // Pull the integer GCD of all term coefficients out as a
638                // common factor: e.g. Add([Val(6), MulInt(14, S)]) becomes
639                // MulInt(2, Add([Val(3), MulInt(7, S)])).  The downstream
640                // Div(MulInt(p, a), q) arm then cancels (p, q) gcd, so
641                // (6 + 14·S) / 8 reduces to (3 + 7·S) / 4 with no special
642                // Div-over-Add rule needed.
643                //
644                // Only consider entries with non-zero counts — zero-count
645                // entries (canceled-out factors) get filtered later, but
646                // would otherwise drag the gcd to spurious values.  Only
647                // factor when at least one surviving entry has a
648                // non-constant key, otherwise the Add reduces to a single
649                // `Val` and wrapping it in `MulInt(g, Val(c/g))` is a
650                // strict regression in canonical form.
651                let has_non_const =
652                    simplified_terms.iter().any(|(k, &c)| c != 0 && !matches!(k, Val(_)));
653                let coef_gcd = if has_non_const {
654                    simplified_terms
655                        .values()
656                        .filter(|&&c| c != 0)
657                        .map(|c| c.unsigned_abs() as i64)
658                        .reduce(|a, b| a.gcd(&b))
659                        .unwrap_or(0)
660                } else {
661                    0
662                };
663                let outer_factor = if coef_gcd > 1 {
664                    for v in simplified_terms.values_mut() {
665                        *v /= coef_gcd;
666                    }
667                    Some(coef_gcd)
668                } else {
669                    None
670                };
671
672                let mut members: Vec<TDim> = simplified_terms
673                    .into_iter()
674                    .filter_map(|(term, count)| evaluate_count(term, count))
675                    .collect();
676                members.sort_by(tdim_lexi_order);
677
678                let inner = match members.len() {
679                    0 => TDim::Val(0),
680                    1 => members.into_iter().next().unwrap(),
681                    _ => TDim::Add(members),
682                };
683                match outer_factor {
684                    None => inner,
685                    Some(_) if matches!(inner, TDim::Val(0)) => TDim::Val(0),
686                    Some(g) => TDim::MulInt(g, Box::new(inner)),
687                }
688            }
689            Mul(terms) => {
690                // Distribute over Add: if exactly one factor is an Add,
691                // expand Mul([a, Add([b, c])]) => Add([Mul([a, b]), Mul([a, c])]).
692                // This lets (T+1)*P simplify to T*P + P, which is needed for
693                // cancellation in expressions like (T+1)*P - T*P.
694                //
695                // Multi-Add Muls are *not* eagerly expanded here — keeping them
696                // factored matters for `maybe_div`'s bag-of-factors path, which
697                // cancels common Add factors symbolically (e.g. dividing
698                // 16B·(1+Y)² by 8B·(1+Y) needs the (1+Y)s to be visible as
699                // factors, not melted into a polynomial sum).  Use
700                // `expand_polynomial` if you need a fully distributed canonical
701                // form for equality comparisons (see Reshape volume check).
702                {
703                    let add_indices: Vec<usize> = terms
704                        .iter()
705                        .enumerate()
706                        .filter(|(_, t)| matches!(t, Add(_)))
707                        .map(|(i, _)| i)
708                        .collect();
709                    if add_indices.len() == 1 {
710                        let add_idx = add_indices[0];
711                        let Add(add_terms) = &terms[add_idx] else { unreachable!() };
712                        let other_factors: Vec<TDim> = terms
713                            .iter()
714                            .enumerate()
715                            .filter(|(i, _)| *i != add_idx)
716                            .map(|(_, t)| t.clone())
717                            .collect();
718                        let distributed: Vec<TDim> = add_terms
719                            .iter()
720                            .map(|at| {
721                                let mut product = other_factors.clone();
722                                product.push(at.clone());
723                                Mul(product)
724                            })
725                            .collect();
726                        return Add(distributed).simplify_rec(scope, scenario, extra);
727                    }
728                }
729
730                // in case a term is a multiplication itself, flatten it
731                // e.g., (a*b)*c => a*b*c, and MulInt(k, x) => Val(k)*x
732                let mut flattened_terms = vec![];
733                for t in terms {
734                    match t.clone().reduce() {
735                        Mul(inner_terms) => flattened_terms.extend(inner_terms),
736                        MulInt(k, inner) => {
737                            flattened_terms.push(Val(k));
738                            flattened_terms.push(*inner);
739                        }
740                        other => flattened_terms.push(other),
741                    }
742                }
743                let mut terms = flattened_terms;
744
745                let mut gcd = Mul(terms.clone()).gcd() as i64;
746                if gcd == 0 {
747                    return Val(0);
748                }
749                terms = if gcd != 1 {
750                    terms
751                        .into_iter()
752                        .map(|t| {
753                            let gcd = t.gcd();
754                            (t / gcd).simplify_rec(scope, scenario, extra)
755                        })
756                        .collect()
757                } else {
758                    terms
759                };
760                if terms.iter().filter(|t| t == &&Val(-1)).count() % 2 == 1 {
761                    gcd = -gcd;
762                }
763                terms.retain(|t| !t.is_one() && t != &Val(-1));
764                terms.sort_by(tdim_lexi_order);
765
766                match (gcd, terms.len()) {
767                    (_, 0) => Val(gcd), // Case #1: If 0 variables, return product
768                    (0, _) => Val(0),   // Case #2: Result is 0 if coef is 0 (actually
769                    // unreachable as we check at the beginning)
770                    (1, 1) => terms.remove(0), // Case #3: Product is 1, so return the only term
771                    (1, _) => Mul(terms), // Case #4: Product is 1, so return the non-integer terms
772                    (_, 1) => MulInt(gcd, Box::new(terms.remove(0))), // Case #5: Single variable, convert to 1 MulInt
773                    _ => MulInt(gcd, Box::new(Mul(terms))), // Case #6: Multiple variables, convert to MulInt
774                }
775            }
776            MulInt(coef, expr) => {
777                match *expr {
778                    MulInt(c2, inner) => {
779                        if let Some(c) = coef.checked_mul(c2) {
780                            return MulInt(c, inner).simplify_rec(scope, scenario, extra);
781                        } else {
782                            return MulInt(coef, Box::new(MulInt(c2, inner)));
783                        }
784                    }
785                    Val(v) => {
786                        return coef
787                            .checked_mul(v)
788                            .map(Val)
789                            .unwrap_or_else(|| MulInt(coef, Box::new(Val(v))));
790                    }
791                    _ => {}
792                }
793
794                let simplified = expr.simplify_rec(scope, scenario, extra);
795                match (coef, simplified) {
796                    (0, _) => Val(0), // Case #1: If coef is 0, return 0
797                    (1, s) => s,      // Case #2: If coef is 1, return the simplified expression
798                    (_, Add(terms)) => Add(terms
799                        .into_iter()
800                        .map(|term| {
801                            MulInt(coef, Box::new(term)).simplify_rec(scope, scenario, extra)
802                        })
803                        .collect()), // Case #3: If expression is an addition, distribute the coef
804                    (c, Val(v)) => {
805                        c.checked_mul(v).map(Val).unwrap_or_else(|| MulInt(c, Box::new(Val(v))))
806                    } // Case #4: If expression is a value, combine coefs
807                    (c, MulInt(v, inner)) => {
808                        if let Some(cv) = c.checked_mul(v) {
809                            MulInt(cv, inner) // Case #5: If expression is a MulInt, combine coefs
810                        } else {
811                            MulInt(c, Box::new(MulInt(v, inner)))
812                        }
813                    }
814                    (_, s) => MulInt(coef, Box::new(s)), // Case #6: Otherwise, return the original
815                }
816            }
817            Div(a, q) => {
818                if q == 1 {
819                    return a.simplify_rec(scope, scenario, extra);
820                } else if let Div(a, q2) = *a {
821                    return Div(a, q * q2).simplify_rec(scope, scenario, extra);
822                }
823                let a = a.simplify_rec(scope, scenario, extra);
824                if let Val(a) = a {
825                    Val(a / q as i64)
826                } else if let MulInt(-1, a) = a {
827                    MulInt(-1, b!(Div(a, q)))
828                } else if let Add(mut terms) = a {
829                    if terms
830                        .iter()
831                        .any(|t| if let MulInt(-1, s) = t { matches!(&**s, Sym(_)) } else { false })
832                    {
833                        MulInt(
834                            -1,
835                            b!(Div(
836                                b!(Add(terms.into_iter().map(|t| MulInt(-1, b!(t))).collect())
837                                    .simplify_rec(scope, scenario, extra)),
838                                q
839                            )),
840                        )
841                    } else if let Some(val) = terms
842                        .iter()
843                        .find_map(|t| if let Val(v) = t { Some(*v) } else { None })
844                        .and_then(|v| {
845                            if v >= q as i64 {
846                                Some(v / q as i64)
847                            } else if v < 0 {
848                                Some(-Integer::div_ceil(&-v, &(q as i64)))
849                            } else {
850                                None
851                            }
852                        })
853                    {
854                        terms.push(Val(-val * q as i64));
855                        // simplify_rec the inner Div too so that follow-up rules
856                        // (e.g. divide-multiple-plus-remainder below) can collapse
857                        // it once the Val extraction has tidied up the residual.
858                        let inner = Div(b!(Add(terms).simplify_rec(scope, scenario, extra)), q)
859                            .simplify_rec(scope, scenario, extra);
860                        Add(vec![Val(val), inner])
861                    } else if let Some(simplified) =
862                        try_divide_multiple_plus_remainder(&terms, q, scope, extra)
863                    {
864                        // Match `Div(Add([k·X, …, c]), k)` where:
865                        //   - one or more terms have a coefficient that is a multiple of q,
866                        //   - the rest sum to a constant in [0, q),
867                        //   - every extracted X is provably non-negative.
868                        // Then `(k·X + c)/k = X + 0 = X` under tract's truncating
869                        // division.  This is sound for X ≥ 0 only — at X = -1
870                        // the truncation rounds toward zero, not floor, breaking
871                        // the identity.  We use prove_positive_or_zero to gate.
872                        simplified.simplify_rec(scope, scenario, extra)
873                    } else if let Some(found_idx) = terms.iter().position(|term| {
874                        // Rule: (Y − q·(Y/q)) / q = 0  [i.e. (Y mod q) / q = 0]
875                        // Always sound: |Y mod q| < q so (Y mod q)/q = 0 under
876                        // truncating division regardless of sign of Y.
877                        matches!(term, MulInt(p, inner)
878                            if *p == -(q as i64)
879                            && matches!(inner.as_ref(), Div(_, q2) if *q2 == q))
880                    }) {
881                        let MulInt(_, inner) = &terms[found_idx] else { unreachable!() };
882                        let Div(y, _) = inner.as_ref() else { unreachable!() };
883                        let remaining: Vec<TDim> = terms
884                            .iter()
885                            .enumerate()
886                            .filter(|&(i, _)| i != found_idx)
887                            .map(|(_, t)| t.clone())
888                            .collect();
889                        let remaining_sum = match remaining.len() {
890                            0 => Val(0),
891                            1 => remaining.into_iter().next().unwrap(),
892                            _ => Add(remaining),
893                        };
894                        if eq_structural(&remaining_sum, y) {
895                            Val(0)
896                        } else {
897                            Div(b!(Add(terms)), q)
898                        }
899                    } else {
900                        Div(b!(Add(terms)), q)
901                    }
902                } else if let MulInt(p, a) = a {
903                    if p == q as i64 {
904                        a.simplify()
905                    } else {
906                        let gcd = p.abs().gcd(&(q as i64));
907                        if gcd == p {
908                            Div(a, q / gcd as u64)
909                        } else if gcd == q as i64 {
910                            MulInt(p / gcd, a)
911                        } else if gcd > 1 {
912                            Div(b!(MulInt(p / gcd, a)), q / gcd as u64)
913                                .simplify_rec(scope, scenario, extra)
914                        } else {
915                            Div(b!(MulInt(p, a)), q)
916                        }
917                    }
918                } else {
919                    Div(b!(a), q)
920                }
921            }
922            Broadcast(terms) => {
923                let mut terms: Vec<TDim> = terms
924                    .iter()
925                    .map(|s| s.clone().simplify_rec(scope, scenario, extra))
926                    .flat_map(|t| if let Broadcast(t) = t { t } else { vec![t] })
927                    .filter(|t| !t.is_one())
928                    .sorted_by(tdim_lexi_order)
929                    .dedup()
930                    .collect_vec();
931                // a#min(a,b) if a>0 && b>0 => a
932                match &*terms {
933                    [] => Val(1),
934                    [_] => terms.remove(0),
935                    [a, Min(m)] | [Min(m), a]
936                        if m.contains(a)
937                            && m.iter()
938                                .all(|t| scope.prove_strict_positive_with_extra(t, extra)) =>
939                    {
940                        a.clone()
941                    }
942                    _ => Broadcast(terms),
943                }
944            }
945
946            Min(terms) => {
947                let mut flatten: Vec<TDim> = terms
948                    .into_iter()
949                    .map(|t| t.simplify_rec(scope, scenario, extra))
950                    .flat_map(|t| if let Min(t) = t { t } else { vec![t] })
951                    .filter(|t| t != &Val(i64::MAX))
952                    .sorted_by(tdim_lexi_order)
953                    .dedup()
954                    .collect();
955                #[allow(clippy::mutable_key_type)]
956                let mut redundant = HashSet::<TDim>::default();
957                for pair in flatten.iter().permutations(2) {
958                    let (a, b) = (pair[0], pair[1]);
959                    if redundant.contains(a) || redundant.contains(b) {
960                        continue;
961                    }
962                    let diff = a.clone() - b;
963                    if diff.as_i64().is_some_and(|i| i >= 0)
964                        || scope.prove_positive_or_zero_with_extra(&diff, extra)
965                    {
966                        redundant.insert(a.clone());
967                    }
968                }
969                flatten.retain(|t| !redundant.contains(t));
970                if flatten.len() == 0 {
971                    i64::MAX.to_dim()
972                } else if flatten.len() == 1 {
973                    flatten.into_iter().next().unwrap()
974                } else {
975                    Min(flatten)
976                }
977            }
978            Max(terms) => {
979                let mut flatten: Vec<TDim> = terms
980                    .into_iter()
981                    .map(|t| t.simplify_rec(scope, scenario, extra))
982                    .flat_map(|t| if let Max(t) = t { t } else { vec![t] })
983                    .filter(|t| t != &Val(i64::MIN))
984                    .sorted_by(tdim_lexi_order)
985                    .dedup()
986                    .collect();
987                #[allow(clippy::mutable_key_type)]
988                let mut redundant = HashSet::<TDim>::default();
989                for pair in flatten.iter().permutations(2) {
990                    let (a, b) = (pair[0], pair[1]);
991                    if redundant.contains(a) || redundant.contains(b) {
992                        continue;
993                    }
994                    let diff = a.clone() - b;
995                    if diff.as_i64().is_some_and(|i| i >= 0)
996                        || scope.prove_positive_or_zero_with_extra(&diff, extra)
997                    {
998                        redundant.insert(b.clone());
999                    }
1000                }
1001                flatten.retain(|t| !redundant.contains(t));
1002                if flatten.len() == 0 {
1003                    i64::MIN.to_dim()
1004                } else if flatten.len() == 1 {
1005                    flatten.into_iter().next().unwrap()
1006                } else {
1007                    Max(flatten)
1008                }
1009            }
1010            Sym(s) => scope
1011                .assertions(scenario)
1012                .find_map(|a| match a {
1013                    Assertion::Eq(Sym(sym), v) if sym == &s => Some(v.clone()),
1014                    _ => None,
1015                })
1016                .unwrap_or(Sym(s)),
1017            Val(_) => self,
1018            Ge(a, b) => {
1019                let a = a.simplify_rec(scope, scenario, extra);
1020                let b = b.simplify_rec(scope, scenario, extra);
1021                match (&a, &b) {
1022                    (Val(av), Val(bv)) => Val(if av >= bv { 1 } else { 0 }),
1023                    _ => {
1024                        let diff = a.clone() - b.clone();
1025                        if scope.prove_positive_or_zero_with_extra(&diff, extra) {
1026                            Val(1)
1027                        } else if scope
1028                            .prove_strict_positive_with_extra(&(b.clone() - a.clone()), extra)
1029                        {
1030                            Val(0)
1031                        } else {
1032                            Ge(b!(a), b!(b))
1033                        }
1034                    }
1035                }
1036            }
1037            Eq(a, b) => {
1038                let a = a.simplify_rec(scope, scenario, extra);
1039                let b = b.simplify_rec(scope, scenario, extra);
1040                match (&a, &b) {
1041                    (Val(av), Val(bv)) => Val(if av == bv { 1 } else { 0 }),
1042                    _ => {
1043                        let diff = a.clone() - b.clone();
1044                        if scope.prove_strict_positive_with_extra(&diff, extra)
1045                            || scope
1046                                .prove_strict_positive_with_extra(&(b.clone() - a.clone()), extra)
1047                        {
1048                            Val(0)
1049                        } else {
1050                            // When one side is 0 or 1 and the other is
1051                            // provably in [0,1], reduce to boolean algebra:
1052                            //   Eq(expr, 0) → 1 - expr
1053                            //   Eq(expr, 1) → expr
1054                            let boolean_case = match (&a, &b) {
1055                                (Val(0), e) | (e, Val(0)) => Some((e, false)),
1056                                (Val(1), e) | (e, Val(1)) => Some((e, true)),
1057                                _ => None,
1058                            };
1059                            if let Some((expr, equals_one)) = boolean_case {
1060                                if scope.prove_positive_or_zero_with_extra(expr, extra)
1061                                    && scope.prove_positive_or_zero_with_extra(
1062                                        &(Val(1) - expr.clone()),
1063                                        extra,
1064                                    )
1065                                {
1066                                    return if equals_one {
1067                                        expr.clone()
1068                                    } else {
1069                                        (Val(1) - expr.clone()).simplify_rec(scope, scenario, extra)
1070                                    };
1071                                }
1072                            }
1073                            Eq(b!(a), b!(b))
1074                        }
1075                    }
1076                }
1077            }
1078        }
1079    }
1080
1081    pub(super) fn inclusive_bound(&self, scope: &SymbolScopeData, upper: bool) -> Option<i64> {
1082        use self::TDim::*;
1083        match self {
1084            Val(n) => Some(*n),
1085            Sym(_) => {
1086                if upper {
1087                    scope
1088                        .all_assertions()
1089                        .iter()
1090                        .filter_map(|assert| match &assert {
1091                            Assertion::LT(left, right)
1092                                if left == self && right.as_i64().is_some() =>
1093                            {
1094                                Some(right.as_i64().unwrap() - 1)
1095                            }
1096                            Assertion::LTE(left, right)
1097                                if left == self && right.as_i64().is_some() =>
1098                            {
1099                                Some(right.as_i64().unwrap())
1100                            }
1101                            _ => None,
1102                        })
1103                        .min()
1104                } else {
1105                    scope
1106                        .all_assertions()
1107                        .iter()
1108                        .filter_map(|assert| match &assert {
1109                            Assertion::GT(left, right)
1110                                if left == self && right.as_i64().is_some() =>
1111                            {
1112                                Some(right.as_i64().unwrap() + 1)
1113                            }
1114                            Assertion::GTE(left, right)
1115                                if left == self && right.as_i64().is_some() =>
1116                            {
1117                                Some(right.as_i64().unwrap())
1118                            }
1119                            _ => None,
1120                        })
1121                        .max()
1122                }
1123            }
1124            Add(terms) => {
1125                let mut bound: i64 = 0;
1126                for t in terms {
1127                    if let Some(b) = t.inclusive_bound(scope, upper) {
1128                        bound = bound.checked_add(b)?;
1129                    } else {
1130                        return None;
1131                    }
1132                }
1133                Some(bound)
1134            }
1135            MulInt(p, a) => match p.cmp(&0) {
1136                Ordering::Equal => Some(0),
1137                Ordering::Greater => {
1138                    a.inclusive_bound(scope, upper).and_then(|x| x.checked_mul(*p))
1139                }
1140                Ordering::Less => a.inclusive_bound(scope, !upper).and_then(|x| x.checked_mul(*p)),
1141            },
1142            Mul(terms) => {
1143                // If all factors have known non-negative bounds, we can bound the product.
1144                let mut lo: i64 = 1;
1145                let mut hi: i64 = 1;
1146                for t in terms {
1147                    let t_lo = t.inclusive_bound(scope, false)?;
1148                    let t_hi = t.inclusive_bound(scope, true)?;
1149                    if t_lo < 0 {
1150                        return None;
1151                    }
1152                    lo = lo.checked_mul(t_lo)?;
1153                    hi = hi.checked_mul(t_hi)?;
1154                }
1155                Some(if upper { hi } else { lo })
1156            }
1157            Min(terms) if !upper => {
1158                // All terms must have known lower bounds; if any is unknown,
1159                // the Min lower bound is unknown.
1160                let bounds: Option<Vec<i64>> =
1161                    terms.iter().map(|t| t.inclusive_bound(scope, false)).collect();
1162                bounds.map(|b| b.into_iter().min().unwrap_or(i64::MAX))
1163            }
1164            Max(terms) if upper => {
1165                // All terms must have known upper bounds; if any is unknown,
1166                // the Max upper bound is unknown.
1167                let bounds: Option<Vec<i64>> =
1168                    terms.iter().map(|t| t.inclusive_bound(scope, true)).collect();
1169                bounds.map(|b| b.into_iter().max().unwrap_or(i64::MIN))
1170            }
1171            Div(a, q) => a.inclusive_bound(scope, upper).map(|x| x / (*q as i64)),
1172            Broadcast(terms) => {
1173                if upper {
1174                    Max(terms.clone()).inclusive_bound(scope, true)
1175                } else {
1176                    Min(terms.clone()).inclusive_bound(scope, false)
1177                }
1178            }
1179            Ge(_, _) | Eq(_, _) => {
1180                if upper {
1181                    Some(1)
1182                } else {
1183                    Some(0)
1184                }
1185            }
1186            _ => None,
1187        }
1188    }
1189
1190    pub fn low_inclusive_bound(&self) -> Option<i64> {
1191        if let TDim::Val(v) = self {
1192            return Some(*v);
1193        }
1194        let scope = self.find_scope()?;
1195        let data = scope.0.lock();
1196        let data = data.borrow();
1197        self.inclusive_bound(&data, false)
1198    }
1199
1200    pub fn high_inclusive_bound(&self) -> Option<i64> {
1201        if let TDim::Val(v) = self {
1202            return Some(*v);
1203        }
1204        let scope = self.find_scope()?;
1205        let data = scope.0.lock();
1206        let data = data.borrow();
1207        self.inclusive_bound(&data, true)
1208    }
1209
1210    pub fn prove_positive_or_zero(&self) -> bool {
1211        if let TDim::Val(v) = self {
1212            return *v >= 0;
1213        }
1214        let Some(scope) = self.find_scope() else { return false };
1215        let data = scope.0.lock();
1216        let data = data.borrow();
1217        data.prove_positive_or_zero(self)
1218    }
1219
1220    pub fn prove_strict_positive(&self) -> bool {
1221        if let TDim::Val(v) = self {
1222            return *v > 0;
1223        }
1224        (self.clone() - 1).prove_positive_or_zero()
1225    }
1226
1227    pub fn prove_negative_or_zero(&self) -> bool {
1228        if let TDim::Val(v) = self {
1229            return *v <= 0;
1230        }
1231        self.clone().neg().prove_positive_or_zero()
1232    }
1233
1234    pub fn prove_strict_negative(&self) -> bool {
1235        if let TDim::Val(v) = self {
1236            return *v < 0;
1237        }
1238        self.clone().neg().prove_strict_positive()
1239    }
1240
1241    /// Least common multiple of two `TDim`s when both reduce to positive
1242    /// integers.
1243    ///
1244    /// Returns `Val(0)` if either operand is `0`, and `None` if either is
1245    /// symbolic, negative, or if the LCM would overflow `i64`. Callers
1246    /// that need a safe answer for symbolic operands should fall back at
1247    /// the call site.
1248    pub fn lcm(&self, other: &TDim) -> Option<TDim> {
1249        match (self.as_i64(), other.as_i64()) {
1250            (Some(a), Some(b)) if a > 0 && b > 0 => {
1251                let g = (a as u64).gcd(&(b as u64));
1252                let l = (a as u64 / g).saturating_mul(b as u64);
1253                if l > i64::MAX as u64 { None } else { Some(TDim::Val(l as i64)) }
1254            }
1255            (Some(0), _) | (_, Some(0)) => Some(TDim::Val(0)),
1256            _ => None,
1257        }
1258    }
1259
1260    pub fn gcd(&self) -> u64 {
1261        use self::TDim::*;
1262        match self {
1263            Val(v) => v.unsigned_abs(),
1264            Sym(_) => 1,
1265            Add(terms) => {
1266                let (head, tail) = terms.split_first().unwrap();
1267                tail.iter().fold(head.gcd(), |a, b| a.gcd(&b.gcd()))
1268            }
1269            MulInt(p, a) => a.gcd().saturating_mul(p.unsigned_abs()),
1270            Mul(terms) => terms.iter().map(|t| t.gcd()).fold(1u64, |a, b| a.saturating_mul(b)),
1271            Min(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
1272            Max(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
1273            Div(a, q) => {
1274                if a.gcd() % *q == 0 {
1275                    a.gcd() / *q
1276                } else {
1277                    1
1278                }
1279            }
1280            Broadcast(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap_or(1),
1281            Ge(_, _) | Eq(_, _) => 1,
1282        }
1283    }
1284
1285    fn div(&self, d: u64) -> TDim {
1286        use self::TDim::*;
1287        if d == 1 {
1288            return self.clone();
1289        }
1290        match self {
1291            Val(v) => Val(v / d as i64),
1292            Sym(_) => panic!(),
1293            Add(terms) => Add(terms.iter().map(|t| t.div(d)).collect()),
1294            Min(terms) => Min(terms.iter().map(|t| t.div(d)).collect()),
1295            Max(terms) => Max(terms.iter().map(|t| t.div(d)).collect()),
1296            Broadcast(terms) => Broadcast(terms.iter().map(|t| t.div(d)).collect()),
1297            Mul(_) => Div(Box::new(self.clone()), d),
1298            MulInt(p, a) => {
1299                if *p == d as i64 {
1300                    (**a).clone()
1301                } else {
1302                    let gcd = p.unsigned_abs().gcd(&d);
1303                    MulInt(p / gcd as i64, b!(a.div(d / gcd)))
1304                }
1305            }
1306            Div(a, q) => Div(a.clone(), q * d),
1307            Ge(_, _) | Eq(_, _) => Div(Box::new(self.clone()), d),
1308        }
1309    }
1310
1311    pub fn div_ceil(self, rhs: u64) -> TDim {
1312        TDim::Div(Box::new(Add(vec![self, Val(rhs as i64 - 1)])), rhs).reduce()
1313    }
1314
1315    pub fn guess_slope(&self, sym: &Symbol) -> (i64, u64) {
1316        fn slope_rec(d: &TDim, sym: &Symbol) -> (i64, i64) {
1317            match d {
1318                Val(_) => (0, 1),
1319                Sym(s) => ((sym == s) as i64, 1),
1320                Add(terms) => terms
1321                    .iter()
1322                    .map(|d| slope_rec(d, sym))
1323                    .fold((0, 1), |a, b| ((a.0 * b.1 + a.1 * b.0), (b.1 * a.1))),
1324                Mul(terms) => terms
1325                    .iter()
1326                    .map(|d| slope_rec(d, sym))
1327                    .fold((1, 1), |a, b| ((a.0 * b.0), (b.1 * a.1))),
1328                MulInt(p, a) => {
1329                    let (n, d) = slope_rec(a, sym);
1330                    (p * n, d)
1331                }
1332                Div(a, q) => {
1333                    let (n, d) = slope_rec(a, sym);
1334                    (n, d * *q as i64)
1335                }
1336                Broadcast(terms) => slope_rec(&terms[0], sym),
1337                Min(terms) => slope_rec(&terms[0], sym),
1338                Max(terms) => slope_rec(&terms[0], sym),
1339                Ge(_, _) | Eq(_, _) => (0, 1),
1340            }
1341        }
1342        let (p, q) = slope_rec(self, sym);
1343        reduce_ratio(p, q)
1344    }
1345
1346    #[allow(clippy::mutable_key_type)]
1347    pub fn symbols(&self) -> std::collections::HashSet<Symbol> {
1348        match self {
1349            Val(_) => maplit::hashset!(),
1350            Sym(s) => maplit::hashset!(s.clone()),
1351            Add(terms) | Mul(terms) | Broadcast(terms) | Min(terms) | Max(terms) => {
1352                terms.iter().fold(maplit::hashset!(), |mut set, v| {
1353                    set.extend(v.symbols());
1354                    set
1355                })
1356            }
1357            MulInt(_, a) => a.symbols(),
1358            Div(a, _) => a.symbols(),
1359            Ge(a, b) | Eq(a, b) => {
1360                let mut set = a.symbols();
1361                set.extend(b.symbols());
1362                set
1363            }
1364        }
1365    }
1366
1367    pub fn compatible_with(&self, other: &TDim) -> bool {
1368        if let Ok(x) = (self.clone() - other).to_i64() {
1369            return x == 0;
1370        }
1371        true // maybe ? :)
1372    }
1373}
1374
1375pub(super) fn reduce_ratio(mut p: i64, mut q: i64) -> (i64, u64) {
1376    let gcd = p.abs().gcd(&q.abs());
1377    if gcd > 1 {
1378        p /= gcd;
1379        q /= gcd;
1380    }
1381    if q < 0 { (-p, (-q) as u64) } else { (p, q as u64) }
1382}
1383
1384impl Zero for TDim {
1385    fn zero() -> Self {
1386        Val(0)
1387    }
1388    fn is_zero(&self) -> bool {
1389        matches!(self, Val(0))
1390    }
1391}
1392
1393impl Default for TDim {
1394    fn default() -> TDim {
1395        Val(0)
1396    }
1397}
1398
1399impl num_traits::Bounded for TDim {
1400    fn min_value() -> Self {
1401        TDim::Val(i64::MIN)
1402    }
1403
1404    fn max_value() -> Self {
1405        TDim::Val(i64::MAX)
1406    }
1407}
1408
1409impl num_traits::One for TDim {
1410    fn one() -> Self {
1411        TDim::Val(1)
1412    }
1413}
1414
1415impl ::std::iter::Sum for TDim {
1416    fn sum<I: Iterator<Item = TDim>>(iter: I) -> TDim {
1417        iter.fold(0.into(), |a, b| a + b)
1418    }
1419}
1420
1421impl<'a> ::std::iter::Sum<&'a TDim> for TDim {
1422    fn sum<I: Iterator<Item = &'a TDim>>(iter: I) -> TDim {
1423        iter.fold(0.into(), |a, b| a + b)
1424    }
1425}
1426
1427impl std::iter::Product for TDim {
1428    fn product<I: Iterator<Item = TDim>>(iter: I) -> Self {
1429        iter.fold(TDim::Val(1), |a, b| a * b)
1430    }
1431}
1432
1433impl<'a> ::std::iter::Product<&'a TDim> for TDim {
1434    fn product<I: Iterator<Item = &'a TDim>>(iter: I) -> TDim {
1435        iter.fold(1.into(), |a, b| a * b)
1436    }
1437}
1438
1439macro_rules! from_i {
1440    ($i: ty) => {
1441        impl From<$i> for TDim {
1442            fn from(v: $i) -> TDim {
1443                TDim::Val(v as _)
1444            }
1445        }
1446        impl<'a> From<&'a $i> for TDim {
1447            fn from(v: &'a $i) -> TDim {
1448                TDim::Val(*v as _)
1449            }
1450        }
1451    };
1452}
1453
1454from_i!(i32);
1455from_i!(i64);
1456from_i!(u64);
1457from_i!(isize);
1458from_i!(usize);
1459
1460impl From<Symbol> for TDim {
1461    fn from(it: Symbol) -> Self {
1462        TDim::Sym(it)
1463    }
1464}
1465
1466impl<'a> From<&'a Symbol> for TDim {
1467    fn from(it: &'a Symbol) -> Self {
1468        TDim::Sym(it.clone())
1469    }
1470}
1471
1472impl ops::Neg for TDim {
1473    type Output = Self;
1474    fn neg(self) -> Self {
1475        if let Val(v) = self { Val(-v) } else { TDim::MulInt(-1, Box::new(self)).reduce() }
1476    }
1477}
1478
1479impl<'a> ops::AddAssign<&'a TDim> for TDim {
1480    fn add_assign(&mut self, rhs: &'a TDim) {
1481        if rhs.is_zero() {
1482        } else if self.is_zero() {
1483            *self = rhs.clone();
1484        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1485            *s += o;
1486        } else {
1487            *self = TDim::Add(vec![std::mem::take(self), rhs.clone()]).reduce()
1488        }
1489    }
1490}
1491
1492impl<I> ops::AddAssign<I> for TDim
1493where
1494    I: Into<TDim>,
1495{
1496    fn add_assign(&mut self, rhs: I) {
1497        let rhs = rhs.into();
1498        if rhs.is_zero() {
1499        } else if self.is_zero() {
1500            *self = rhs;
1501        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1502            *s += o;
1503        } else {
1504            *self = TDim::Add(vec![std::mem::take(self), rhs]).reduce()
1505        }
1506    }
1507}
1508
1509impl<I> ops::Add<I> for TDim
1510where
1511    I: Into<TDim>,
1512{
1513    type Output = Self;
1514    fn add(mut self, rhs: I) -> Self {
1515        self += rhs;
1516        self
1517    }
1518}
1519
1520impl<'a> ops::Add<&'a TDim> for TDim {
1521    type Output = Self;
1522    fn add(mut self, rhs: &'a TDim) -> Self {
1523        self += rhs;
1524        self
1525    }
1526}
1527
1528#[allow(clippy::suspicious_op_assign_impl)]
1529impl<'a> ops::SubAssign<&'a TDim> for TDim {
1530    fn sub_assign(&mut self, rhs: &'a TDim) {
1531        if rhs.is_zero() {
1532        } else if self.is_zero() {
1533            *self = rhs.clone().neg();
1534        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1535            *s -= o;
1536        } else {
1537            *self = TDim::Add(vec![std::mem::take(self), rhs.clone().neg()]).reduce()
1538        }
1539    }
1540}
1541
1542impl<I> ops::SubAssign<I> for TDim
1543where
1544    I: Into<TDim>,
1545{
1546    fn sub_assign(&mut self, rhs: I) {
1547        let rhs = rhs.into();
1548        if rhs.is_zero() {
1549        } else if self.is_zero() {
1550            *self = rhs.neg();
1551        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1552            *s -= o;
1553        } else {
1554            *self = TDim::Add(vec![std::mem::take(self), rhs.neg()]).reduce()
1555        }
1556    }
1557}
1558
1559impl<I> ops::Sub<I> for TDim
1560where
1561    I: Into<TDim>,
1562{
1563    type Output = Self;
1564    fn sub(mut self, rhs: I) -> Self {
1565        self -= rhs;
1566        self
1567    }
1568}
1569
1570impl<'a> ops::Sub<&'a TDim> for TDim {
1571    type Output = Self;
1572    fn sub(mut self, rhs: &'a TDim) -> Self {
1573        self -= rhs;
1574        self
1575    }
1576}
1577
1578impl<I: Into<TDim>> ops::MulAssign<I> for TDim {
1579    fn mul_assign(&mut self, rhs: I) {
1580        let rhs = rhs.into();
1581        if self.is_one() {
1582            *self = rhs
1583        } else if rhs.is_one() {
1584        } else {
1585            *self = TDim::Mul(vec![rhs, std::mem::take(self)]).reduce()
1586        }
1587    }
1588}
1589
1590impl<'a> ops::MulAssign<&'a TDim> for TDim {
1591    fn mul_assign(&mut self, rhs: &'a TDim) {
1592        if self.is_one() {
1593            *self = rhs.clone()
1594        } else if rhs.is_one() {
1595        } else {
1596            *self = TDim::Mul(vec![std::mem::take(self), rhs.clone()]).reduce()
1597        }
1598    }
1599}
1600
1601impl<I: Into<TDim>> ops::Mul<I> for TDim {
1602    type Output = Self;
1603    fn mul(mut self, rhs: I) -> Self {
1604        self *= rhs.into();
1605        self
1606    }
1607}
1608
1609impl<'a> ops::Mul<&'a TDim> for TDim {
1610    type Output = Self;
1611    fn mul(mut self, rhs: &'a TDim) -> Self {
1612        self *= rhs;
1613        self
1614    }
1615}
1616
1617impl<I: AsPrimitive<u64> + PrimInt> ops::DivAssign<I> for TDim {
1618    fn div_assign(&mut self, rhs: I) {
1619        *self = TDim::Div(Box::new(std::mem::take(self)), rhs.as_()).reduce()
1620    }
1621}
1622
1623impl<I: AsPrimitive<u64> + PrimInt> ops::Div<I> for TDim {
1624    type Output = Self;
1625    fn div(mut self, rhs: I) -> Self {
1626        self /= rhs.as_();
1627        self
1628    }
1629}
1630
1631impl<I: AsPrimitive<u64> + PrimInt> ops::RemAssign<I> for TDim {
1632    fn rem_assign(&mut self, rhs: I) {
1633        *self += -(self.clone() / rhs.as_() * rhs.as_());
1634    }
1635}
1636
1637impl<I: AsPrimitive<u64> + PrimInt> ops::Rem<I> for TDim {
1638    type Output = Self;
1639    fn rem(mut self, rhs: I) -> Self {
1640        self %= rhs;
1641        self
1642    }
1643}
1644
1645#[cfg(test)]
1646mod tests {
1647    use super::*;
1648
1649    macro_rules! b( ($e:expr) => { Box::new($e) } );
1650
1651    lazy_static::lazy_static! {
1652        static ref table: SymbolScope = SymbolScope::default();
1653        static ref A: Symbol = table.sym("a");
1654        static ref B: Symbol = table.sym("b");
1655        static ref C: Symbol = table.sym("c");
1656        static ref D: Symbol = table.sym("d");
1657        static ref E: Symbol = table.sym("e");
1658    }
1659
1660    fn neg(a: &TDim) -> TDim {
1661        mul(-1, a)
1662    }
1663
1664    fn add(a: &TDim, b: &TDim) -> TDim {
1665        TDim::Add(vec![a.clone(), b.clone()])
1666    }
1667
1668    fn mul(a: i64, b: &TDim) -> TDim {
1669        TDim::MulInt(a, b![b.clone()])
1670    }
1671
1672    fn div(a: &TDim, b: u64) -> TDim {
1673        TDim::Div(b!(a.clone()), b)
1674    }
1675
1676    #[test]
1677    fn reduce_add() {
1678        assert_eq!(add(&A.to_dim(), &neg(&A.to_dim())).reduce(), Val(0))
1679    }
1680
1681    #[test]
1682    fn lcm_basic() {
1683        assert_eq!(Val(16).lcm(&Val(32)), Some(Val(32)));
1684        assert_eq!(Val(32).lcm(&Val(16)), Some(Val(32)));
1685        assert_eq!(Val(6).lcm(&Val(8)), Some(Val(24)));
1686        assert_eq!(Val(7).lcm(&Val(7)), Some(Val(7)));
1687        // Symbolic: not computable; callers fall back.
1688        assert_eq!(Val(16).lcm(&A.to_dim()), None);
1689    }
1690
1691    #[test]
1692    fn reduce_neg_mul() {
1693        assert_eq!(neg(&mul(2, &A.to_dim())).reduce(), mul(-2, &A.to_dim()))
1694    }
1695
1696    #[test]
1697    fn reduce_cplx_ex_2() {
1698        assert_eq!(
1699            add(
1700                &add(&Val(-4), &mul(-2, &div(&A.to_dim(), 4))),
1701                &mul(-2, &mul(-1, &div(&A.to_dim(), 4)))
1702            )
1703            .reduce(),
1704            Val(-4)
1705        )
1706    }
1707
1708    #[test]
1709    fn reduce_cplx_ex_3() {
1710        assert_eq!(div(&MulInt(1, b!(MulInt(4, b!(A.to_dim())))), 4).reduce(), A.to_dim())
1711    }
1712
1713    #[test]
1714    fn reduce_cplx_ex_4() {
1715        // (S+1)/2 + (1-S)/2 == 1
1716        assert_eq!(
1717            add(&div(&add(&A.to_dim(), &Val(1)), 2), &div(&add(&neg(&A.to_dim()), &Val(1)), 2))
1718                .reduce(),
1719            1.into()
1720        );
1721    }
1722
1723    #[test]
1724    fn reduce_mul_mul_1() {
1725        assert_eq!(mul(3, &mul(2, &A.to_dim())).reduce(), mul(6, &A.to_dim()))
1726    }
1727
1728    #[test]
1729    fn reduce_mul_mul_2() {
1730        assert_eq!(mul(-2, &mul(-1, &A.to_dim())).reduce(), mul(2, &A.to_dim()))
1731    }
1732
1733    #[test]
1734    fn reduce_mul_div_1() {
1735        assert_eq!(mul(2, &div(&mul(-1, &A.to_dim()), 3)).reduce(), mul(-2, &div(&A.to_dim(), 3)))
1736    }
1737
1738    #[test]
1739    fn const_and_add() {
1740        let e: TDim = 2i64.into();
1741        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), 2);
1742        let e: TDim = TDim::from(2) + 3;
1743        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), 5);
1744        let e: TDim = TDim::from(2) - 3;
1745        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), -1);
1746        let e: TDim = -TDim::from(2);
1747        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), -2);
1748    }
1749
1750    #[test]
1751    fn substitution() {
1752        let a: TDim = A.to_dim();
1753        assert_eq!(a.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 2);
1754        let e = a + 3;
1755        assert_eq!(e.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 5);
1756    }
1757
1758    #[test]
1759    fn reduce_adds() {
1760        let e: TDim = TDim::from(2) + 1;
1761        assert_eq!(e, TDim::from(3));
1762        let e: TDim = TDim::from(3) + 2;
1763        assert_eq!(e, TDim::from(5));
1764        let e: TDim = TDim::from(3) + 0;
1765        assert_eq!(e, TDim::from(3));
1766        let e: TDim = TDim::from(3) + 2 + 1;
1767        assert_eq!(e, TDim::from(6));
1768    }
1769
1770    #[test]
1771    fn reduce_muls() {
1772        let e: TDim = Val(1) * A.to_dim();
1773        assert_eq!(e, A.to_dim());
1774        let e: TDim = A.to_dim() * &B.to_dim() * 1;
1775        assert_eq!(e, A.to_dim() * &B.to_dim());
1776    }
1777
1778    #[test]
1779    fn reduce_divs() {
1780        let e: TDim = TDim::from(2) / 1;
1781        assert_eq!(e, TDim::from(2));
1782        let e: TDim = TDim::from(3) / 2;
1783        assert_eq!(e, TDim::from(1));
1784        let e: TDim = TDim::from(3) % 2;
1785        assert_eq!(e, TDim::from(1));
1786        let e: TDim = TDim::from(5) / 2;
1787        assert_eq!(e, TDim::from(2));
1788        let e: TDim = TDim::from(5) % 2;
1789        assert_eq!(e, TDim::from(1));
1790    }
1791
1792    #[test]
1793    fn reduce_div_bug_0() {
1794        let e1: TDim = (A.to_dim() + 23) / 2 - 1;
1795        let e2: TDim = (A.to_dim() + 21) / 2;
1796        assert_eq!(e1, e2);
1797    }
1798
1799    #[test]
1800    fn reduce_div_bug_1() {
1801        let e1: TDim = (A.to_dim() + -1) / 2;
1802        let e2: TDim = (A.to_dim() + 1) / 2 - 1;
1803        assert_eq!(e1, e2);
1804    }
1805
1806    #[test]
1807    fn reduce_div_bug_2() {
1808        let e1: TDim = ((A.to_dim() + 1) / 2 + 1) / 2;
1809        let e2: TDim = (A.to_dim() + 3) / 4;
1810        assert_eq!(e1, e2);
1811    }
1812
1813    #[test]
1814    fn divide_multiple_plus_remainder() {
1815        // (k·X + r)/k → X under truncating division when 0 ≤ r < k AND X ≥ 0.
1816        let scope = SymbolScope::default().with_assertion("S>=0").unwrap();
1817        let s = scope.sym("S");
1818
1819        // (2S+1)/2 → S
1820        let e: TDim = (s.to_dim() * 2 + 1) / 2;
1821        assert_eq!(e.simplify(), s.to_dim());
1822
1823        // -1 + (2S+1)/2 → S - 1
1824        let e: TDim = (s.to_dim() * 2 + 1) / 2 - 1;
1825        assert_eq!(e.simplify(), s.to_dim() - 1);
1826
1827        // (2S-1)/2 → S - 1   (Val rule extracts -1 first, then our rule)
1828        let e: TDim = (s.to_dim() * 2 - 1) / 2;
1829        assert_eq!(e.simplify(), s.to_dim() - 1);
1830
1831        // (4S+3)/2 → 2S + 1   (Val rule extracts 1 = 3/2, then our rule on (4S+1)/2 → 2S)
1832        let e: TDim = (s.to_dim() * 4 + 3) / 2;
1833        assert_eq!(e.simplify(), s.to_dim() * 2 + 1);
1834    }
1835
1836    #[test]
1837    fn divide_multiple_plus_remainder_no_assertion() {
1838        // Without an X≥0 assertion the (k·X+c)/k → X identity does NOT hold
1839        // (X=-1, k=2, c=1 gives -1/2=0 ≠ X under truncating division). The
1840        // wiggle Div arm used to emit that variant unconditionally; reduce()
1841        // would then pick it on cost. Now wiggle skips the variant when the
1842        // remainder bucket contains a Val, leaving only the sound rule
1843        // gated on prove_positive_or_zero in simplify_rec.
1844        let scope = SymbolScope::default();
1845        let s = scope.sym("S");
1846        let e: TDim = (s.to_dim() * 2 + 1) / 2;
1847        assert_ne!(e.simplify(), s.to_dim());
1848    }
1849
1850    #[test]
1851    fn modulo_div_is_zero() {
1852        // (Y − q·(Y/q)) / q = 0 for any Y and any q — the modulo remainder
1853        // divided by the modulus is always zero under truncating division.
1854        let scope = SymbolScope::default();
1855        let s = scope.sym("S");
1856        // Simple case: (S - 2*(S/2)) / 2 = (S mod 2) / 2 = 0
1857        let e: TDim = (s.to_dim() - s.to_dim() / 2 * 2) / 2;
1858        assert_eq!(e.simplify(), TDim::Val(0));
1859        // Composite case: ((S+1) - 2*((S+1)/2)) / 2 = 0
1860        // This is the exact pattern from SameUpper conv padding.
1861        let a = s.to_dim() + 1;
1862        let e2: TDim = (a.clone() - a.clone() / 2 * 2) / 2;
1863        assert_eq!(e2.simplify(), TDim::Val(0));
1864    }
1865
1866    #[test]
1867    fn reduce_div_bug_3() {
1868        let e1: TDim = (A.to_dim() / 2) * -4;
1869        let e2: TDim = (A.to_dim() / 2) * -4 / 1;
1870        assert_eq!(e1, e2);
1871    }
1872
1873    #[test]
1874    fn reduce_mul_div() {
1875        let e: TDim = A.to_dim() * 2 / 2;
1876        assert_eq!(e, A.to_dim());
1877    }
1878
1879    #[test]
1880    fn expand_polynomial_two_add_factors() {
1881        // (a + 2*a*b) * (1 + b)  ==poly==  a * (1 + b) * (1 + 2*b)
1882        // Both fully expand to a + 3*a*b + 2*a*b*b.  We don't auto-expand in
1883        // simplify (it would block maybe_div on factored-form denominators),
1884        // but expand_polynomial does, and Reshape uses it for volume checks.
1885        let a = A.to_dim();
1886        let b = B.to_dim();
1887        let lhs = (a.clone() + a.clone() * &b * 2) * (TDim::from(1) + &b);
1888        let rhs = a.clone() * (TDim::from(1) + &b) * (TDim::from(1) + b.clone() * 2);
1889        assert_eq!(lhs.expand_polynomial(), rhs.expand_polynomial());
1890    }
1891
1892    #[test]
1893    fn reduce_div_mul() {
1894        let e: TDim = A.to_dim() / 2 * 2;
1895        assert_ne!(e, A.to_dim());
1896    }
1897
1898    #[test]
1899    fn reduce_add_div() {
1900        let e: TDim = A.to_dim() / 2 + 1;
1901        assert_eq!(e, ((A.to_dim() + 2) / 2));
1902    }
1903
1904    #[test]
1905    fn reduce_neg_mul_() {
1906        let e: TDim = TDim::from(1) - A.to_dim() * 2;
1907        assert_eq!(e, TDim::from(1) + A.to_dim() * -2);
1908    }
1909
1910    #[test]
1911    fn reduce_add_rem_1() {
1912        assert_eq!(((A.to_dim() + 4) % 2), (A.to_dim() % 2));
1913    }
1914
1915    #[test]
1916    fn reduce_add_rem_2() {
1917        assert_eq!(((A.to_dim() - 4) % 2), (A.to_dim() % 2));
1918    }
1919
1920    #[test]
1921    fn reduce_rem_div() {
1922        let e: TDim = A.to_dim() % 2 / 2;
1923        assert_eq!(e, TDim::from(0));
1924    }
1925
1926    #[test]
1927    fn conv2d_ex_1() {
1928        let e = (TDim::from(1) - 1 + 1).div_ceil(1);
1929        assert_eq!(e, TDim::from(1));
1930    }
1931
1932    #[test]
1933    fn conv2d_ex_2() {
1934        let e = (A.to_dim() - 3 + 1).div_ceil(1);
1935        assert_eq!(e, A.to_dim() + -2);
1936    }
1937
1938    #[test]
1939    fn extract_int_gcd_from_muls() {
1940        let term = (A.to_dim() + 1) / 4;
1941        let mul = (term.clone() * 24 - 24) * (term.clone() * 2 - 2);
1942        let target = (term.clone() - 1) * (term.clone() - 1) * 48;
1943        assert_eq!(mul, target);
1944    }
1945
1946    #[test]
1947    fn equality_of_muls() {
1948        let term = (A.to_dim() + 1) / 4;
1949        let mul1 = (term.clone() * 2 - 3) * (term.clone() - 1);
1950        let mul2 = (term.clone() - 1) * (term.clone() * 2 - 3);
1951        assert_eq!(mul1, mul2);
1952    }
1953
1954    #[test]
1955    fn factorize_complex_expr_times_int() {
1956        let term = (A.to_dim() + 1) / 4;
1957        let e = term.clone() * 2 - &term - 1;
1958        assert_eq!(e, term - 1);
1959    }
1960
1961    #[test]
1962    fn broadcast_over_min() {
1963        // assuming a>0, b>0 then a#min(a,b) can be replaced by a
1964        // proof:
1965        //    if b == 1 => min(a,b)=1 => a#1=a => ok
1966        //    if a <= b => min(a,b)=a => ok
1967        //    if 1 < B < A => expression was invalid, we're generalizing over the non-domain and ignoring the constraint
1968        for a in 1..5 {
1969            for b in 1..5 {
1970                if b > 1 && a > b {
1971                    assert!(a.broadcast(a.min(b)).is_err());
1972                } else {
1973                    assert_eq!(a.broadcast(a.min(b)).unwrap(), a);
1974                }
1975            }
1976        }
1977    }
1978
1979    #[test]
1980    fn min_ints_1() {
1981        assert_eq!(2.to_dim().mini(1.to_dim()), 1.to_dim());
1982    }
1983
1984    #[test]
1985    fn min_ints_2() {
1986        assert_eq!(1.to_dim().mini(2.to_dim()), 1.to_dim());
1987    }
1988
1989    #[test]
1990    fn min_same() {
1991        assert_eq!(A.to_dim().mini(A.to_dim()), A.to_dim());
1992    }
1993
1994    #[test]
1995    fn min_noop() {
1996        assert_eq!(A.to_dim().mini(1.to_dim()), A.to_dim().mini(1.to_dim()));
1997    }
1998
1999    #[test]
2000    fn min_diff_1() {
2001        assert_eq!((A.to_dim() + 1).mini(A.to_dim() + 2), A.to_dim() + 1);
2002    }
2003
2004    #[test]
2005    fn slope_0() {
2006        assert_eq!(12.to_dim().guess_slope(&A), (0, 1));
2007    }
2008
2009    #[test]
2010    fn slope_1() {
2011        assert_eq!(A.to_dim().guess_slope(&A), (1, 1));
2012    }
2013
2014    #[test]
2015    fn slope_2() {
2016        assert_eq!((A.to_dim() * 2).guess_slope(&A), (2, 1));
2017    }
2018
2019    #[test]
2020    fn slope_3() {
2021        assert_eq!((A.to_dim() * 2 + A.to_dim() / 2).guess_slope(&A), (5, 2));
2022    }
2023
2024    #[test]
2025    fn slope_4() {
2026        assert_eq!((A.to_dim()).guess_slope(&B), (0, 1));
2027    }
2028
2029    #[test]
2030    fn slope_5() {
2031        assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1));
2032        assert_eq!((A.to_dim() + 1).guess_slope(&B), (0, 1));
2033    }
2034
2035    #[test]
2036    fn slope_6() {
2037        assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1));
2038        assert_eq!((A.to_dim() + B.to_dim()).guess_slope(&B), (1, 1));
2039    }
2040
2041    #[test]
2042    fn min_0() -> TractResult<()> {
2043        let symbols = SymbolScope::default();
2044        assert_eq!(
2045            symbols.parse_tdim("min(S+3, S+2)").unwrap().simplify(),
2046            symbols.parse_tdim("S+2").unwrap(),
2047        );
2048        Ok(())
2049    }
2050
2051    #[test]
2052    fn commutative_mul_parens() -> TractResult<()> {
2053        let symbols = SymbolScope::default();
2054        assert_eq!(
2055            symbols.parse_tdim("A*(B*C)").unwrap().simplify(),
2056            symbols.parse_tdim("(B*A)*C").unwrap().simplify(),
2057        );
2058        Ok(())
2059    }
2060
2061    #[test]
2062    fn commutative_in_nemo_parakeet_model() -> TractResult<()> {
2063        let symbols = SymbolScope::default();
2064        assert_eq!(
2065            symbols
2066                .parse_tdim("8*(1+-1*max(0,5000+-1*(S+7)/8)+max(0,4999+(S+7)/8))*((B)*((S+7)/8))")
2067                .unwrap()
2068                .simplify(),
2069            symbols
2070                .parse_tdim("8*((B)*(1+-1*max(0,5000+-1*(S+7)/8)+max(0,4999+(S+7)/8)))*((S+7)/8)")
2071                .unwrap()
2072                .simplify(),
2073        );
2074        Ok(())
2075    }
2076
2077    #[test]
2078    fn commutative_mul_parens_deep() -> TractResult<()> {
2079        let symbols = SymbolScope::default();
2080        let deep_tdim = Mul(vec![
2081            Mul(vec![Mul(vec![Mul(vec![A.to_dim(), B.to_dim()]), C.to_dim()]), D.to_dim()]),
2082            E.to_dim(),
2083        ])
2084        .simplify();
2085        assert_eq!(deep_tdim, symbols.parse_tdim("a*b*c*d*e").unwrap().simplify());
2086        Ok(())
2087    }
2088
2089    // ---- Tests for new comparison/not TDim variants ----
2090
2091    #[test]
2092    fn ge_concrete_true() {
2093        assert_eq!(Ge(b!(Val(5)), b!(Val(3))).reduce(), Val(1));
2094    }
2095
2096    #[test]
2097    fn ge_concrete_false() {
2098        assert_eq!(Ge(b!(Val(2)), b!(Val(3))).reduce(), Val(0));
2099    }
2100
2101    #[test]
2102    fn lt_concrete_true() {
2103        // Lt(2,3) normalizes to Ge(3, 2+1) = Ge(3, 3)
2104        assert_eq!(Ge(b!(Val(3)), b!(Val(3))).reduce(), Val(1));
2105    }
2106
2107    #[test]
2108    fn lt_concrete_false() {
2109        // Lt(5,3) normalizes to Ge(3, 5+1) = Ge(3, 6)
2110        assert_eq!(Ge(b!(Val(3)), b!(Val(6))).reduce(), Val(0));
2111    }
2112
2113    #[test]
2114    fn eq_concrete_true() {
2115        assert_eq!(Eq(b!(Val(3)), b!(Val(3))).reduce(), Val(1));
2116    }
2117
2118    #[test]
2119    fn eq_concrete_false() {
2120        assert_eq!(Eq(b!(Val(3)), b!(Val(4))).reduce(), Val(0));
2121    }
2122
2123    #[test]
2124    fn not_val_0() {
2125        // not(0) = 1 - 0 = 1
2126        assert_eq!((Val(1) - Val(0)).reduce(), Val(1));
2127    }
2128
2129    #[test]
2130    fn not_val_1() {
2131        // not(1) = 1 - 1 = 0
2132        assert_eq!((Val(1) - Val(1)).reduce(), Val(0));
2133    }
2134
2135    #[test]
2136    fn not_lt_becomes_ge() {
2137        // not(Lt(x1, T)) = 1 - Ge(T, x1+1); check it evaluates correctly at boundary
2138        let s = SymbolScope::default();
2139        let t = s.sym("T");
2140        let x1 = s.sym("x1");
2141        // at x1 = T (boundary), Ge(T, T+1) = 0, so 1 - 0 = 1 (not-lt is true when x1 >= T)
2142        let expr = Val(1) - Ge(b!(Sym(t.clone())), b!(Sym(x1.clone()) + Val(1)));
2143        let at_boundary = expr.substitute(&x1, &Sym(t.clone())).unwrap().simplify();
2144        assert_eq!(at_boundary, Val(1));
2145    }
2146
2147    #[test]
2148    fn eq_with_assertion_proves_false() {
2149        // Eq(T, 0) should reduce to Val(0) when T >= 1
2150        let s = SymbolScope::default();
2151        s.add_assertion("T >= 1").unwrap();
2152        let t = s.sym("T");
2153        let expr = Eq(b!(Sym(t)), b!(Val(0)));
2154        assert_eq!(expr.simplify(), Val(0));
2155    }
2156
2157    #[test]
2158    fn ge_coord_at_extremes() {
2159        // Ge(x1, T) should not simplify without coordinate substitution
2160        let s = SymbolScope::default();
2161        s.add_assertion("T >= 1").unwrap();
2162        let t = s.sym("T");
2163        let x1 = s.sym("x1");
2164        let expr = Ge(b!(Sym(x1.clone())), b!(Sym(t.clone())));
2165        // simplify() alone can't prove this false (x1 could be > T)
2166        // but with coordinate substitution (x1 = T-1), Ge(T-1, T) = 0
2167        let at_max = expr.substitute(&x1, &(Sym(t.clone()) - Val(1))).unwrap().simplify();
2168        assert_eq!(at_max, Val(0));
2169    }
2170
2171    #[test]
2172    fn eval_to_i64_new_variants() {
2173        use super::super::sym::SymbolValues;
2174        let sv = SymbolValues::default();
2175        assert_eq!(Ge(b!(Val(5)), b!(Val(3))).eval_to_i64(&sv).unwrap(), 1);
2176        assert_eq!(Ge(b!(Val(3)), b!(Val(5))).eval_to_i64(&sv).unwrap(), 0);
2177        assert_eq!(Eq(b!(Val(3)), b!(Val(3))).eval_to_i64(&sv).unwrap(), 1);
2178        assert_eq!(Eq(b!(Val(3)), b!(Val(4))).eval_to_i64(&sv).unwrap(), 0);
2179    }
2180
2181    #[test]
2182    fn eq_boolean_simplifies() {
2183        let s = SymbolScope::default();
2184        s.add_assertion("cw >= 0").unwrap();
2185        s.add_assertion("cw <= 1").unwrap();
2186        let cw = s.sym("cw");
2187        // Eq(1 - cw, 0) → cw
2188        assert_eq!(Eq(b!(Val(1) - Sym(cw.clone())), b!(Val(0))).simplify(), Sym(cw.clone()));
2189        // Eq(cw, 0) → 1 - cw
2190        assert_eq!(Eq(b!(Sym(cw.clone())), b!(Val(0))).simplify(), Val(1) - Sym(cw.clone()));
2191        // Eq(cw, 1) → cw
2192        assert_eq!(Eq(b!(Sym(cw.clone())), b!(Val(1))).simplify(), Sym(cw.clone()));
2193        // Eq(1 - cw, 1) → 1 - cw
2194        assert_eq!(Eq(b!(Val(1) - Sym(cw.clone())), b!(Val(1))).simplify(), Val(1) - Sym(cw));
2195    }
2196
2197    #[test]
2198    fn eq_boolean_mul_of_ge() {
2199        // Product of Ge terms: Ge(a,b) * Ge(c,d) is in [0,1]
2200        // so Eq(product, 0) should simplify to 1 - product
2201        let s = SymbolScope::default();
2202        let x = s.sym("x");
2203        let product =
2204            Mul(vec![Ge(b!(Val(2)), b!(Sym(x.clone()))), Ge(b!(Sym(x.clone())), b!(Val(0)))]);
2205        let eq = Eq(b!(product.clone()), b!(Val(0)));
2206        assert_eq!(eq.simplify(), Val(1) - product);
2207    }
2208
2209    #[test]
2210    fn min_1_max_0_sym() {
2211        // Min(1, Max(0, X)) must not simplify away the Min when X is unconstrained.
2212        let s = SymbolScope::default();
2213        let x = s.sym("X");
2214        let expr = Min(vec![Val(1), Max(vec![Val(0), Sym(x)])]);
2215        let simplified = expr.simplify();
2216        eprintln!("simplified: {simplified}");
2217        assert!(format!("{simplified}").contains("min"), "Min dropped: {simplified}");
2218    }
2219
2220    #[test]
2221    fn min_preserved_in_subtraction_parts() {
2222        // Test that Min([1, X]) simplifies correctly in isolation
2223        let s = SymbolScope::default();
2224        let t = s.sym("T");
2225        let p = s.sym("P");
2226        let ss = s.sym("S");
2227
2228        let cum_after =
2229            Max(vec![Val(0), (Sym(t.clone()) + Val(1)) * Sym(p.clone()) - Sym(ss.clone())]);
2230        let min_after = Min(vec![Val(1), cum_after.clone()]);
2231        let simplified = min_after.simplify();
2232        eprintln!("min_after simplified: {simplified}");
2233        // Must contain "min" — the Min must not be dropped
2234        assert!(format!("{simplified}").contains("min"), "Min wrapper was dropped: {simplified}");
2235    }
2236
2237    #[test]
2238    fn min_preserved_in_subtraction() {
2239        // min(1, X) - min(1, Y) must preserve the min() wrappers.
2240        // This is the pattern used by PulseV2Pad's output_facts for after-padding.
2241        let s = SymbolScope::default();
2242        let t = s.sym("T");
2243        let p = s.sym("P");
2244        let ss = s.sym("S");
2245
2246        let cum_after =
2247            Max(vec![Val(0), (Sym(t.clone()) + Val(1)) * Sym(p.clone()) - Sym(ss.clone())]);
2248        let cum_before = Max(vec![Val(0), Sym(t.clone()) * Sym(p.clone()) - Sym(ss.clone())]);
2249
2250        let ap = Min(vec![Val(1), cum_after.clone()]) - Min(vec![Val(1), cum_before.clone()]);
2251        let simplified = ap.simplify();
2252
2253        // At T=1, P=4, S=3: min(1, max(0, 8-3)) - min(1, max(0, 4-3)) = 1 - 1 = 0
2254        use super::super::sym::SymbolValues;
2255        let sv = SymbolValues::default().with(&t, 1).with(&p, 4).with(&ss, 3);
2256        assert_eq!(simplified.eval_to_i64(&sv).unwrap(), 0, "simplified: {simplified}");
2257
2258        // At T=0, P=4, S=3: min(1, max(0, 4-3)) - min(1, max(0, 0-3)) = 1 - 0 = 1
2259        let sv = SymbolValues::default().with(&t, 0).with(&p, 4).with(&ss, 3);
2260        assert_eq!(simplified.eval_to_i64(&sv).unwrap(), 1, "simplified: {simplified}");
2261
2262        // At T=0, P=1, S=1: min(1, max(0, 1-1)) - min(1, max(0, 0-1)) = 0 - 0 = 0
2263        let sv = SymbolValues::default().with(&t, 0).with(&p, 1).with(&ss, 1);
2264        assert_eq!(simplified.eval_to_i64(&sv).unwrap(), 0, "simplified: {simplified}");
2265    }
2266
2267    #[test]
2268    fn mul_neg_b_by_8() {
2269        let s = SymbolScope::default();
2270        let b = Sym(s.sym("B"));
2271        // 8*(-1*B) should equal -8*B
2272        let a = Mul(vec![Val(8), MulInt(-1, Box::new(b.clone()))]);
2273        let c = MulInt(-8, Box::new(b.clone()));
2274        let a_s = a.simplify();
2275        let c_s = c.simplify();
2276        assert_eq!(a_s, c_s, "8*(-1*B) should simplify the same as -8*B");
2277    }
2278
2279    /// Encoder-pulse case: (6 + 14·S) / 8 == (3 + 7·S) / 4.
2280    /// Both Add terms (Val(6) and MulInt(14, S)) share factor 2 with
2281    /// divisor 8, so the simplifier should reduce both sides by 2.
2282    #[test]
2283    fn reduce_div_by_common_factor_with_divisor() {
2284        let lhs = (A.to_dim() * 14 + 6) / 8;
2285        let rhs = (A.to_dim() * 7 + 3) / 4;
2286        assert_eq!(lhs, rhs);
2287    }
2288
2289    /// Common factor that fully divides the divisor → drop the divisor.
2290    /// (4·a + 8) / 4  ==  a + 2.
2291    #[test]
2292    fn reduce_div_when_factor_equals_divisor() {
2293        let lhs = (A.to_dim() * 4 + 8) / 4;
2294        let rhs = A.to_dim() + 2;
2295        assert_eq!(lhs, rhs);
2296    }
2297
2298    /// No common factor → no reduction (identity check).
2299    /// (3 + 7·a) / 4 stays as-is (gcd(3, 7, 4) = 1).
2300    #[test]
2301    fn no_reduce_when_terms_coprime_with_divisor() {
2302        let e = (A.to_dim() * 7 + 3) / 4;
2303        // We just check it didn't reduce to something weird; the
2304        // canonical form is `Div(Add(...), 4)`.
2305        match &e {
2306            Div(_, q) => assert_eq!(*q, 4),
2307            other => panic!("expected Div(_, 4), got {other:?}"),
2308        }
2309    }
2310
2311    /// Sym without an explicit `MulInt` wrapper has implicit coefficient
2312    /// 1.  Any common factor gcd including 1 collapses to 1, so the
2313    /// reduction does nothing — the rule must not silently drop the Sym.
2314    #[test]
2315    fn no_reduce_when_sym_has_implicit_unit_coefficient() {
2316        // (a + 4) / 2 must stay non-trivial — gcd(1, 4, 2) = 1.
2317        let e = (A.to_dim() + 4) / 2;
2318        // It can simplify to other forms but it should still depend on `a`.
2319        // Eval at a=2 → (2+4)/2 = 3.  Eval at a=4 → (4+4)/2 = 4.
2320        let sv2 = SymbolValues::default().with(&A, 2);
2321        let sv4 = SymbolValues::default().with(&A, 4);
2322        assert_eq!(e.eval_to_i64(&sv2).unwrap(), 3);
2323        assert_eq!(e.eval_to_i64(&sv4).unwrap(), 4);
2324    }
2325}