Skip to main content

tract_data/dim/
mod.rs

1//! Extended dimension support
2use crate::internal::*;
3use num_traits::{One, Zero};
4use std::fmt;
5use std::ops;
6
7mod assertion;
8mod parse;
9mod resolve;
10mod sym;
11mod tree;
12
13pub use self::assertion::Assertion;
14pub use self::parse::parse_tdim;
15pub use self::resolve::solve_for;
16pub use self::sym::{Symbol, SymbolScope, SymbolValues};
17pub use self::tree::{TDim, TooEarly};
18
19use crate::{TractError, TractResult};
20
21/// A super-trait for value acting as tensor dimensions in tract.
22///
23/// Implemented by:
24///
25/// * `usize` for regular dimensions
26/// * `TDim` supporting regular and streaming dimensions
27pub trait DimLike:
28    Clone
29    + Default
30    + PartialEq
31    + From<usize>
32    + for<'a> std::convert::TryFrom<&'a TDim, Error = TractError>
33    + ::num_traits::Zero
34    + fmt::Debug
35    + fmt::Display
36    + std::hash::Hash
37    + ops::Add<Self, Output = Self>
38    + ops::Add<usize, Output = Self>
39    + for<'a> ops::Add<&'a Self, Output = Self>
40    + ops::Sub<Self, Output = Self>
41    + ops::Sub<usize, Output = Self>
42    + for<'a> ops::Sub<&'a Self, Output = Self>
43    + ops::Mul<Self, Output = Self>
44    + ops::Mul<usize, Output = Self>
45    + for<'a> ops::Mul<&'a Self, Output = Self>
46    + ops::Div<usize, Output = Self>
47    + ops::Rem<usize, Output = Self>
48    + Send
49    + Sync
50    + 'static
51    + std::iter::Sum
52    + std::iter::Product
53    + ToDim
54    + One
55{
56    fn maybe_div(&self, other: &Self) -> TractResult<(Self, u64)>;
57
58    /// Integer divise, rounding up to next integer.
59    fn divceil(&self, other: usize) -> Self {
60        (self.clone() + other - 1) / other
61    }
62
63    /// Convert to regular integer.
64    fn to_i64(&self) -> TractResult<i64>;
65
66    fn to_usize(&self) -> TractResult<usize> {
67        self.to_i64().map(|d| d as usize)
68    }
69
70    fn to_isize(&self) -> TractResult<isize> {
71        self.to_i64().map(|d| d as isize)
72    }
73
74    fn to_i32(&self) -> TractResult<i32> {
75        self.to_i64().map(|d| d as i32)
76    }
77
78    /// Substitute as many symbols as possible in the dim value.
79    fn eval(&self, values: &SymbolValues) -> Self;
80
81    /// Full evaluation of the symbol, failing if a symbol is missing
82    fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64>;
83
84    fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self>;
85    fn substitute_all(&self, map: &std::collections::HashMap<Symbol, Self>) -> TractResult<Self>;
86
87    fn broadcast(self, other: Self) -> TractResult<Self>;
88    fn mini(self, other: Self) -> Self;
89    fn maxi(self, other: Self) -> Self;
90
91    fn compatible_with(&self, other: &Self) -> bool;
92}
93
94impl DimLike for TDim {
95    fn maybe_div(&self, other: &Self) -> TractResult<(Self, u64)> {
96        if self.is_zero() {
97            return Ok((TDim::zero(), 1));
98        } else if other.is_zero() {
99            bail!("Division by zero")
100        }
101        // If self is a sum and every term divides by `other` with the same
102        // remaining divisor, distribute: (a + b) / d = a/d + b/d. This
103        // catches shared symbolic factors the bag-based path below misses,
104        // e.g. `8*slice*((T+7)/8) + 8*((T+7)/8)` divided by `8*((T+7)/8)`.
105        if let TDim::Add(terms) = self
106            && terms.len() >= 2
107            && let Some(parts) =
108                terms.iter().map(|t| t.maybe_div(other).ok()).collect::<Option<Vec<_>>>()
109            && let Some((_, q0)) = parts.first()
110            && parts.iter().all(|(_, q)| q == q0)
111        {
112            let q = *q0;
113            let sum = parts.into_iter().map(|(d, _)| d).fold(TDim::zero(), |acc, d| acc + d);
114            return Ok((sum.reduce(), q));
115        }
116        fn expand(dim: &TDim) -> (i64, Vec<TDim>) {
117            match dim {
118                TDim::Mul(terms) => terms.iter().map(expand).fold((1i64, vec![]), |acc, t| {
119                    (acc.0 * t.0, acc.1.into_iter().chain(t.1).collect())
120                }),
121                TDim::MulInt(a, terms) => {
122                    let (b, v) = expand(terms);
123                    (a * b, v)
124                }
125                TDim::Val(x) => (*x, vec![]),
126                TDim::Add(terms) => {
127                    let gcd =
128                        terms.iter().map(expand).map(|(n, _)| n).reduce(|a, b| a.gcd(&b)).unwrap();
129                    (
130                        gcd,
131                        vec![TDim::Add(terms.iter().map(|t| t.clone() / gcd).collect()).simplify()],
132                    )
133                }
134                it => (1, vec![it.clone()]),
135            }
136        }
137        let (mut num_int, mut num) = expand(self);
138        let (mut denum_int, mut denum) = expand(other);
139        if num == denum {
140            num = vec![];
141            denum = vec![];
142        }
143        for it in denum {
144            if let Some(pos) = num.iter().position(|n| n == &it) {
145                num.remove(pos);
146            } else {
147                bail!("Can't divide {} by {}", self, other)
148            }
149        }
150        use num_integer::Integer;
151        if denum_int < 0 {
152            num_int *= -1;
153            denum_int *= -1;
154        }
155        let gcd = num_int.gcd(&denum_int);
156        num_int /= gcd;
157        denum_int /= gcd;
158        Ok(((TDim::Mul(num) * num_int).reduce(), denum_int as u64))
159    }
160
161    fn to_i64(&self) -> TractResult<i64> {
162        TDim::to_i64(self)
163    }
164
165    fn eval(&self, values: &SymbolValues) -> Self {
166        self.eval(values)
167    }
168
169    fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self> {
170        self.substitute(from, to)
171    }
172
173    fn substitute_all(&self, map: &std::collections::HashMap<Symbol, Self>) -> TractResult<Self> {
174        TDim::substitute_all(self, map)
175    }
176
177    fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64> {
178        TDim::eval_to_i64(self, values)
179    }
180
181    fn broadcast(self, other: Self) -> TractResult<Self> {
182        if self.is_one() {
183            Ok(other)
184        } else if other.is_one() {
185            Ok(self)
186        } else {
187            Ok(TDim::Broadcast(vec![self, other]).simplify())
188        }
189    }
190
191    fn compatible_with(&self, other: &Self) -> bool {
192        self.compatible_with(other)
193    }
194
195    fn mini(self, other: Self) -> Self {
196        TDim::Min(vec![self, other]).simplify()
197    }
198
199    fn maxi(self, other: Self) -> Self {
200        TDim::Max(vec![self, other]).simplify()
201    }
202}
203
204impl<'a> std::convert::TryFrom<&'a TDim> for TDim {
205    type Error = TractError;
206    fn try_from(d: &'a TDim) -> TractResult<TDim> {
207        Ok(d.clone())
208    }
209}
210
211impl DimLike for usize {
212    fn maybe_div(&self, other: &Self) -> TractResult<(Self, u64)> {
213        use num_integer::Integer;
214        let gcd = self.gcd(other);
215        Ok((self / gcd, (other / gcd) as u64))
216    }
217
218    fn to_i64(&self) -> TractResult<i64> {
219        Ok(*self as i64)
220    }
221
222    fn eval(&self, _values: &SymbolValues) -> Self {
223        *self
224    }
225
226    fn substitute(&self, _from: &Symbol, _to: &Self) -> TractResult<Self> {
227        Ok(*self)
228    }
229
230    fn substitute_all(&self, _map: &std::collections::HashMap<Symbol, Self>) -> TractResult<Self> {
231        Ok(*self)
232    }
233
234    fn eval_to_i64(&self, _: &SymbolValues) -> TractResult<i64> {
235        Ok(*self as i64)
236    }
237
238    fn broadcast(self, other: Self) -> TractResult<Self> {
239        if self == 1 || self == other {
240            Ok(other)
241        } else if other == 1 {
242            Ok(self)
243        } else {
244            bail!("Can not broadcast {self} against {other}")
245        }
246    }
247
248    fn compatible_with(&self, other: &Self) -> bool {
249        self == other
250    }
251
252    fn mini(self, other: Self) -> Self {
253        if self < other { self } else { other }
254    }
255
256    fn maxi(self, other: Self) -> Self {
257        if self > other { self } else { other }
258    }
259}
260
261impl<'a> std::convert::TryFrom<&'a TDim> for usize {
262    type Error = TractError;
263    fn try_from(d: &'a TDim) -> TractResult<usize> {
264        d.to_usize()
265    }
266}
267
268/// Convenience trait to convert values to TDim.
269pub trait ToDim {
270    /// Convert self to a TDim.
271    fn to_dim(&self) -> TDim;
272}
273
274impl<I: Into<TDim> + Clone> ToDim for I {
275    fn to_dim(&self) -> TDim {
276        self.clone().into()
277    }
278}
279
280impl ToDim for &TDim {
281    fn to_dim(&self) -> TDim {
282        (*self).clone()
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    lazy_static::lazy_static! {
291        static ref S: (SymbolScope, Symbol) = {
292            let table = SymbolScope::default();
293            let s = table.new_with_prefix("S");
294            (table, s)
295        };
296    }
297
298    pub fn s() -> TDim {
299        S.1.clone().into()
300    }
301
302    #[test]
303    fn div() {
304        assert_eq!(TDim::from(12).maybe_div(&TDim::from(4)).unwrap(), (3.into(), 1));
305    }
306
307    #[test]
308    fn div_sym_int() {
309        assert_eq!((s() * 12).maybe_div(&TDim::from(4)).unwrap(), (s() * 3, 1));
310    }
311
312    #[test]
313    fn div_sym_sym() {
314        assert_eq!((s() * 12).maybe_div(&(s() * 4)).unwrap(), (3.into(), 1));
315    }
316
317    #[test]
318    fn div_sym_sym_ratio() {
319        assert_eq!((s() * 13).maybe_div(&(s() * 4)).unwrap(), (13.into(), 4));
320    }
321
322    #[test]
323    fn div_sym_sym_rem() {
324        assert!((s() + 1).maybe_div(&(s() * 4)).is_err());
325    }
326
327    #[test]
328    fn div_sym_sym_simply_1() {
329        assert_eq!((s()).maybe_div(&(s())).unwrap(), (TDim::Val(1), 1));
330    }
331
332    #[test]
333    fn div_sym_sym_complex() {
334        let s = s();
335        let b = S.0.sym("b");
336        assert_eq!(
337            (256.to_dim() * &s * &b).maybe_div(&(1.to_dim() * &s * &b)).unwrap(),
338            (256.into(), 1)
339        );
340    }
341
342    #[test]
343    fn div_sym_sym_with_add() {
344        assert_eq!((s() * 80 - 160).maybe_div(&(s() - 2)).unwrap(), (80.into(), 1));
345    }
346
347    #[test]
348    fn div_with_shared_div_ceil_factor() {
349        // Repro for:
350        //   Can't divide 8*(slice)*((T+7)/8)+8*(T+7)/8 by 8*(T+7)/8
351        //
352        // Numerator factors as (slice + 1) * 8 * ((T+7)/8); denominator is the
353        // shared 8 * ((T+7)/8), so the division should yield slice + 1.
354        let t: TDim = S.0.sym("T").into();
355        let slice = S.0.sym("slice");
356        let c = t.div_ceil(8); // (T+7)/8
357        let num = 8.to_dim() * &slice * &c + 8.to_dim() * &c;
358        let denom = 8.to_dim() * &c;
359        assert_eq!(num.maybe_div(&denom).unwrap(), (slice.to_dim() + 1, 1));
360    }
361}