1use crate::internal::*;
3use num_traits::{One, Zero};
4use std::fmt;
5use std::ops;
6
7mod assertion;
8mod parse;
9mod resolve;
10mod sym;
11mod tree;
12
13pub use self::assertion::Assertion;
14pub use self::parse::parse_tdim;
15pub use self::resolve::solve_for;
16pub use self::sym::{Symbol, SymbolScope, SymbolValues};
17pub use self::tree::{TDim, TooEarly};
18
19use crate::{TractError, TractResult};
20
21pub trait DimLike:
28 Clone
29 + Default
30 + PartialEq
31 + From<usize>
32 + for<'a> std::convert::TryFrom<&'a TDim, Error = TractError>
33 + ::num_traits::Zero
34 + fmt::Debug
35 + fmt::Display
36 + std::hash::Hash
37 + ops::Add<Self, Output = Self>
38 + ops::Add<usize, Output = Self>
39 + for<'a> ops::Add<&'a Self, Output = Self>
40 + ops::Sub<Self, Output = Self>
41 + ops::Sub<usize, Output = Self>
42 + for<'a> ops::Sub<&'a Self, Output = Self>
43 + ops::Mul<Self, Output = Self>
44 + ops::Mul<usize, Output = Self>
45 + for<'a> ops::Mul<&'a Self, Output = Self>
46 + ops::Div<usize, Output = Self>
47 + ops::Rem<usize, Output = Self>
48 + Send
49 + Sync
50 + 'static
51 + std::iter::Sum
52 + std::iter::Product
53 + ToDim
54 + One
55{
56 fn maybe_div(&self, other: &Self) -> TractResult<(Self, u64)>;
57
58 fn divceil(&self, other: usize) -> Self {
60 (self.clone() + other - 1) / other
61 }
62
63 fn to_i64(&self) -> TractResult<i64>;
65
66 fn to_usize(&self) -> TractResult<usize> {
67 self.to_i64().map(|d| d as usize)
68 }
69
70 fn to_isize(&self) -> TractResult<isize> {
71 self.to_i64().map(|d| d as isize)
72 }
73
74 fn to_i32(&self) -> TractResult<i32> {
75 self.to_i64().map(|d| d as i32)
76 }
77
78 fn eval(&self, values: &SymbolValues) -> Self;
80
81 fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64>;
83
84 fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self>;
85
86 fn broadcast(self, other: Self) -> TractResult<Self>;
87 fn mini(self, other: Self) -> Self;
88 fn maxi(self, other: Self) -> Self;
89
90 fn compatible_with(&self, other: &Self) -> bool;
91}
92
93impl DimLike for TDim {
94 fn maybe_div(&self, other: &Self) -> TractResult<(Self, u64)> {
95 if self.is_zero() {
96 return Ok((TDim::zero(), 1));
97 } else if other.is_zero() {
98 bail!("Division by zero")
99 }
100 fn expand(dim: &TDim) -> (i64, Vec<TDim>) {
101 match dim {
102 TDim::Mul(terms) => terms.iter().map(expand).fold((1i64, vec![]), |acc, t| {
103 (acc.0 * t.0, acc.1.into_iter().chain(t.1).collect())
104 }),
105 TDim::MulInt(a, terms) => {
106 let (b, v) = expand(terms);
107 (a * b, v)
108 }
109 TDim::Val(x) => (*x, vec![]),
110 TDim::Add(terms) => {
111 let gcd =
112 terms.iter().map(expand).map(|(n, _)| n).reduce(|a, b| a.gcd(&b)).unwrap();
113 (
114 gcd,
115 vec![TDim::Add(terms.iter().map(|t| t.clone() / gcd).collect()).simplify()],
116 )
117 }
118 it => (1, vec![it.clone()]),
119 }
120 }
121 let (mut num_int, mut num) = expand(self);
122 let (mut denum_int, mut denum) = expand(other);
123 if num == denum {
124 num = vec![];
125 denum = vec![];
126 }
127 for it in denum {
128 if let Some(pos) = num.iter().position(|n| n == &it) {
129 num.remove(pos);
130 } else {
131 bail!("Can't divide {} by {}", self, other)
132 }
133 }
134 use num_integer::Integer;
135 if denum_int < 0 {
136 num_int *= -1;
137 denum_int *= -1;
138 }
139 let gcd = num_int.gcd(&denum_int);
140 num_int /= gcd;
141 denum_int /= gcd;
142 Ok(((TDim::Mul(num) * num_int).reduce(), denum_int as u64))
143 }
144
145 fn to_i64(&self) -> TractResult<i64> {
146 TDim::to_i64(self)
147 }
148
149 fn eval(&self, values: &SymbolValues) -> Self {
150 self.eval(values)
151 }
152
153 fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self> {
154 self.substitute(from, to)
155 }
156
157 fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64> {
158 TDim::eval_to_i64(self, values)
159 }
160
161 fn broadcast(self, other: Self) -> TractResult<Self> {
162 if self.is_one() {
163 Ok(other)
164 } else if other.is_one() {
165 Ok(self)
166 } else {
167 Ok(TDim::Broadcast(vec![self, other]).simplify())
168 }
169 }
170
171 fn compatible_with(&self, other: &Self) -> bool {
172 self.compatible_with(other)
173 }
174
175 fn mini(self, other: Self) -> Self {
176 TDim::Min(vec![self, other]).simplify()
177 }
178
179 fn maxi(self, other: Self) -> Self {
180 TDim::Max(vec![self, other]).simplify()
181 }
182}
183
184impl<'a> std::convert::TryFrom<&'a TDim> for TDim {
185 type Error = TractError;
186 fn try_from(d: &'a TDim) -> TractResult<TDim> {
187 Ok(d.clone())
188 }
189}
190
191impl DimLike for usize {
192 fn maybe_div(&self, other: &Self) -> TractResult<(Self, u64)> {
193 use num_integer::Integer;
194 let gcd = self.gcd(other);
195 Ok((self / gcd, (other / gcd) as u64))
196 }
197
198 fn to_i64(&self) -> TractResult<i64> {
199 Ok(*self as i64)
200 }
201
202 fn eval(&self, _values: &SymbolValues) -> Self {
203 *self
204 }
205
206 fn substitute(&self, _from: &Symbol, _to: &Self) -> TractResult<Self> {
207 Ok(*self)
208 }
209
210 fn eval_to_i64(&self, _: &SymbolValues) -> TractResult<i64> {
211 Ok(*self as i64)
212 }
213
214 fn broadcast(self, other: Self) -> TractResult<Self> {
215 if self == 1 || self == other {
216 Ok(other)
217 } else if other == 1 {
218 Ok(self)
219 } else {
220 bail!("Can not broadcast {self} against {other}")
221 }
222 }
223
224 fn compatible_with(&self, other: &Self) -> bool {
225 self == other
226 }
227
228 fn mini(self, other: Self) -> Self {
229 if self < other {
230 self
231 } else {
232 other
233 }
234 }
235
236 fn maxi(self, other: Self) -> Self {
237 if self > other {
238 self
239 } else {
240 other
241 }
242 }
243}
244
245impl<'a> std::convert::TryFrom<&'a TDim> for usize {
246 type Error = TractError;
247 fn try_from(d: &'a TDim) -> TractResult<usize> {
248 d.to_usize()
249 }
250}
251
252pub trait ToDim {
254 fn to_dim(&self) -> TDim;
256}
257
258impl<I: Into<TDim> + Clone> ToDim for I {
259 fn to_dim(&self) -> TDim {
260 self.clone().into()
261 }
262}
263
264impl ToDim for &TDim {
265 fn to_dim(&self) -> TDim {
266 (*self).clone()
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 lazy_static::lazy_static! {
275 static ref S: (SymbolScope, Symbol) = {
276 let table = SymbolScope::default();
277 let s = table.new_with_prefix("S");
278 (table, s)
279 };
280 }
281
282 pub fn s() -> TDim {
283 S.1.clone().into()
284 }
285
286 #[test]
287 fn div() {
288 assert_eq!(TDim::from(12).maybe_div(&TDim::from(4)).unwrap(), (3.into(), 1));
289 }
290
291 #[test]
292 fn div_sym_int() {
293 assert_eq!((s() * 12).maybe_div(&TDim::from(4)).unwrap(), (s() * 3, 1));
294 }
295
296 #[test]
297 fn div_sym_sym() {
298 assert_eq!((s() * 12).maybe_div(&(s() * 4)).unwrap(), (3.into(), 1));
299 }
300
301 #[test]
302 fn div_sym_sym_ratio() {
303 assert_eq!((s() * 13).maybe_div(&(s() * 4)).unwrap(), (13.into(), 4));
304 }
305
306 #[test]
307 fn div_sym_sym_rem() {
308 assert!((s() + 1).maybe_div(&(s() * 4)).is_err());
309 }
310
311 #[test]
312 fn div_sym_sym_simply_1() {
313 assert_eq!((s()).maybe_div(&(s())).unwrap(), (TDim::Val(1), 1));
314 }
315
316 #[test]
317 fn div_sym_sym_complex() {
318 let s = s();
319 let b = S.0.sym("b");
320 assert_eq!(
321 (256.to_dim() * &s * &b).maybe_div(&(1.to_dim() * &s * &b)).unwrap(),
322 (256.into(), 1)
323 );
324 }
325
326 #[test]
327 fn div_sym_sym_with_add() {
328 assert_eq!((s() * 80 - 160).maybe_div(&(s() - 2)).unwrap(), (80.into(), 1));
329 }
330}