1use crate::internal::*;
3use num_traits::{One, Zero};
4use std::fmt;
5use std::ops;
6
7mod assertion;
8mod parse;
9mod resolve;
10mod sym;
11mod tree;
12
13pub use self::assertion::Assertion;
14pub use self::parse::parse_tdim;
15pub use self::resolve::solve_for;
16pub use self::sym::{Symbol, SymbolScope, SymbolValues};
17pub use self::tree::{TDim, TooEarly};
18
19use crate::{TractError, TractResult};
20
21pub trait DimLike:
28 Clone
29 + Default
30 + PartialEq
31 + From<usize>
32 + for<'a> std::convert::TryFrom<&'a TDim, Error = TractError>
33 + ::num_traits::Zero
34 + fmt::Debug
35 + fmt::Display
36 + std::hash::Hash
37 + ops::Add<Self, Output = Self>
38 + ops::Add<usize, Output = Self>
39 + for<'a> ops::Add<&'a Self, Output = Self>
40 + ops::Sub<Self, Output = Self>
41 + ops::Sub<usize, Output = Self>
42 + for<'a> ops::Sub<&'a Self, Output = Self>
43 + ops::Mul<Self, Output = Self>
44 + ops::Mul<usize, Output = Self>
45 + for<'a> ops::Mul<&'a Self, Output = Self>
46 + ops::Div<usize, Output = Self>
47 + ops::Rem<usize, Output = Self>
48 + Send
49 + Sync
50 + 'static
51 + std::iter::Sum
52 + std::iter::Product
53 + ToDim
54 + One
55{
56 fn maybe_div(&self, other: &Self) -> TractResult<(Self, u64)>;
57
58 fn divceil(&self, other: usize) -> Self {
60 (self.clone() + other - 1) / other
61 }
62
63 fn to_i64(&self) -> TractResult<i64>;
65
66 fn to_usize(&self) -> TractResult<usize> {
67 self.to_i64().map(|d| d as usize)
68 }
69
70 fn to_isize(&self) -> TractResult<isize> {
71 self.to_i64().map(|d| d as isize)
72 }
73
74 fn to_i32(&self) -> TractResult<i32> {
75 self.to_i64().map(|d| d as i32)
76 }
77
78 fn eval(&self, values: &SymbolValues) -> Self;
80
81 fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64>;
83
84 fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self>;
85 fn substitute_all(&self, map: &std::collections::HashMap<Symbol, Self>) -> TractResult<Self>;
86
87 fn broadcast(self, other: Self) -> TractResult<Self>;
88 fn mini(self, other: Self) -> Self;
89 fn maxi(self, other: Self) -> Self;
90
91 fn compatible_with(&self, other: &Self) -> bool;
92}
93
94impl DimLike for TDim {
95 fn maybe_div(&self, other: &Self) -> TractResult<(Self, u64)> {
96 if self.is_zero() {
97 return Ok((TDim::zero(), 1));
98 } else if other.is_zero() {
99 bail!("Division by zero")
100 }
101 if let TDim::Add(terms) = self
106 && terms.len() >= 2
107 && let Some(parts) =
108 terms.iter().map(|t| t.maybe_div(other).ok()).collect::<Option<Vec<_>>>()
109 && let Some((_, q0)) = parts.first()
110 && parts.iter().all(|(_, q)| q == q0)
111 {
112 let q = *q0;
113 let sum = parts.into_iter().map(|(d, _)| d).fold(TDim::zero(), |acc, d| acc + d);
114 return Ok((sum.reduce(), q));
115 }
116 fn expand(dim: &TDim) -> (i64, Vec<TDim>) {
117 match dim {
118 TDim::Mul(terms) => terms.iter().map(expand).fold((1i64, vec![]), |acc, t| {
119 (acc.0 * t.0, acc.1.into_iter().chain(t.1).collect())
120 }),
121 TDim::MulInt(a, terms) => {
122 let (b, v) = expand(terms);
123 (a * b, v)
124 }
125 TDim::Val(x) => (*x, vec![]),
126 TDim::Add(terms) => {
127 let gcd =
128 terms.iter().map(expand).map(|(n, _)| n).reduce(|a, b| a.gcd(&b)).unwrap();
129 (
130 gcd,
131 vec![TDim::Add(terms.iter().map(|t| t.clone() / gcd).collect()).simplify()],
132 )
133 }
134 it => (1, vec![it.clone()]),
135 }
136 }
137 let (mut num_int, mut num) = expand(self);
138 let (mut denum_int, mut denum) = expand(other);
139 if num == denum {
140 num = vec![];
141 denum = vec![];
142 }
143 for it in denum {
144 if let Some(pos) = num.iter().position(|n| n == &it) {
145 num.remove(pos);
146 } else {
147 bail!("Can't divide {} by {}", self, other)
148 }
149 }
150 use num_integer::Integer;
151 if denum_int < 0 {
152 num_int *= -1;
153 denum_int *= -1;
154 }
155 let gcd = num_int.gcd(&denum_int);
156 num_int /= gcd;
157 denum_int /= gcd;
158 Ok(((TDim::Mul(num) * num_int).reduce(), denum_int as u64))
159 }
160
161 fn to_i64(&self) -> TractResult<i64> {
162 TDim::to_i64(self)
163 }
164
165 fn eval(&self, values: &SymbolValues) -> Self {
166 self.eval(values)
167 }
168
169 fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self> {
170 self.substitute(from, to)
171 }
172
173 fn substitute_all(&self, map: &std::collections::HashMap<Symbol, Self>) -> TractResult<Self> {
174 TDim::substitute_all(self, map)
175 }
176
177 fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64> {
178 TDim::eval_to_i64(self, values)
179 }
180
181 fn broadcast(self, other: Self) -> TractResult<Self> {
182 if self.is_one() {
183 Ok(other)
184 } else if other.is_one() {
185 Ok(self)
186 } else {
187 Ok(TDim::Broadcast(vec![self, other]).simplify())
188 }
189 }
190
191 fn compatible_with(&self, other: &Self) -> bool {
192 self.compatible_with(other)
193 }
194
195 fn mini(self, other: Self) -> Self {
196 TDim::Min(vec![self, other]).simplify()
197 }
198
199 fn maxi(self, other: Self) -> Self {
200 TDim::Max(vec![self, other]).simplify()
201 }
202}
203
204impl<'a> std::convert::TryFrom<&'a TDim> for TDim {
205 type Error = TractError;
206 fn try_from(d: &'a TDim) -> TractResult<TDim> {
207 Ok(d.clone())
208 }
209}
210
211impl DimLike for usize {
212 fn maybe_div(&self, other: &Self) -> TractResult<(Self, u64)> {
213 use num_integer::Integer;
214 let gcd = self.gcd(other);
215 Ok((self / gcd, (other / gcd) as u64))
216 }
217
218 fn to_i64(&self) -> TractResult<i64> {
219 Ok(*self as i64)
220 }
221
222 fn eval(&self, _values: &SymbolValues) -> Self {
223 *self
224 }
225
226 fn substitute(&self, _from: &Symbol, _to: &Self) -> TractResult<Self> {
227 Ok(*self)
228 }
229
230 fn substitute_all(&self, _map: &std::collections::HashMap<Symbol, Self>) -> TractResult<Self> {
231 Ok(*self)
232 }
233
234 fn eval_to_i64(&self, _: &SymbolValues) -> TractResult<i64> {
235 Ok(*self as i64)
236 }
237
238 fn broadcast(self, other: Self) -> TractResult<Self> {
239 if self == 1 || self == other {
240 Ok(other)
241 } else if other == 1 {
242 Ok(self)
243 } else {
244 bail!("Can not broadcast {self} against {other}")
245 }
246 }
247
248 fn compatible_with(&self, other: &Self) -> bool {
249 self == other
250 }
251
252 fn mini(self, other: Self) -> Self {
253 if self < other { self } else { other }
254 }
255
256 fn maxi(self, other: Self) -> Self {
257 if self > other { self } else { other }
258 }
259}
260
261impl<'a> std::convert::TryFrom<&'a TDim> for usize {
262 type Error = TractError;
263 fn try_from(d: &'a TDim) -> TractResult<usize> {
264 d.to_usize()
265 }
266}
267
268pub trait ToDim {
270 fn to_dim(&self) -> TDim;
272}
273
274impl<I: Into<TDim> + Clone> ToDim for I {
275 fn to_dim(&self) -> TDim {
276 self.clone().into()
277 }
278}
279
280impl ToDim for &TDim {
281 fn to_dim(&self) -> TDim {
282 (*self).clone()
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 lazy_static::lazy_static! {
291 static ref S: (SymbolScope, Symbol) = {
292 let table = SymbolScope::default();
293 let s = table.new_with_prefix("S");
294 (table, s)
295 };
296 }
297
298 pub fn s() -> TDim {
299 S.1.clone().into()
300 }
301
302 #[test]
303 fn div() {
304 assert_eq!(TDim::from(12).maybe_div(&TDim::from(4)).unwrap(), (3.into(), 1));
305 }
306
307 #[test]
308 fn div_sym_int() {
309 assert_eq!((s() * 12).maybe_div(&TDim::from(4)).unwrap(), (s() * 3, 1));
310 }
311
312 #[test]
313 fn div_sym_sym() {
314 assert_eq!((s() * 12).maybe_div(&(s() * 4)).unwrap(), (3.into(), 1));
315 }
316
317 #[test]
318 fn div_sym_sym_ratio() {
319 assert_eq!((s() * 13).maybe_div(&(s() * 4)).unwrap(), (13.into(), 4));
320 }
321
322 #[test]
323 fn div_sym_sym_rem() {
324 assert!((s() + 1).maybe_div(&(s() * 4)).is_err());
325 }
326
327 #[test]
328 fn div_sym_sym_simply_1() {
329 assert_eq!((s()).maybe_div(&(s())).unwrap(), (TDim::Val(1), 1));
330 }
331
332 #[test]
333 fn div_sym_sym_complex() {
334 let s = s();
335 let b = S.0.sym("b");
336 assert_eq!(
337 (256.to_dim() * &s * &b).maybe_div(&(1.to_dim() * &s * &b)).unwrap(),
338 (256.into(), 1)
339 );
340 }
341
342 #[test]
343 fn div_sym_sym_with_add() {
344 assert_eq!((s() * 80 - 160).maybe_div(&(s() - 2)).unwrap(), (80.into(), 1));
345 }
346
347 #[test]
348 fn div_with_shared_div_ceil_factor() {
349 let t: TDim = S.0.sym("T").into();
355 let slice = S.0.sym("slice");
356 let c = t.div_ceil(8); let num = 8.to_dim() * &slice * &c + 8.to_dim() * &c;
358 let denom = 8.to_dim() * &c;
359 assert_eq!(num.maybe_div(&denom).unwrap(), (slice.to_dim() + 1, 1));
360 }
361}