1use crate::internal::*;
3use num_traits::Zero;
4use std::fmt;
5use std::ops;
6
7mod assertion;
8mod parse;
9mod resolve;
10mod sym;
11mod tree;
12
13pub use self::assertion::Assertion;
14pub use self::parse::parse_tdim;
15pub use self::resolve::solve_for;
16pub use self::sym::{Symbol, SymbolScope, SymbolValues};
17pub use self::tree::{TDim, TooEarly};
18
19use crate::{TractError, TractResult};
20
21pub trait DimLike:
28 Clone
29 + Default
30 + PartialEq
31 + From<usize>
32 + for<'a> std::convert::TryFrom<&'a TDim, Error = TractError>
33 + ::num_traits::Zero
34 + fmt::Debug
35 + fmt::Display
36 + std::hash::Hash
37 + ops::Add<Self, Output = Self>
38 + ops::Add<usize, Output = Self>
39 + for<'a> ops::Add<&'a Self, Output = Self>
40 + ops::Sub<Self, Output = Self>
41 + ops::Sub<usize, Output = Self>
42 + for<'a> ops::Sub<&'a Self, Output = Self>
43 + ops::Mul<Self, Output = Self>
44 + ops::Mul<usize, Output = Self>
45 + for<'a> ops::Mul<&'a Self, Output = Self>
46 + ops::Div<usize, Output = Self>
47 + ops::Rem<usize, Output = Self>
48 + Send
49 + Sync
50 + 'static
51 + std::iter::Sum
52 + std::iter::Product
53 + ToDim
54{
55 fn maybe_div(&self, other: &Self) -> TractResult<(Self, u64)>;
56
57 fn divceil(&self, other: usize) -> Self {
59 (self.clone() + other - 1) / other
60 }
61
62 fn to_i64(&self) -> TractResult<i64>;
64
65 fn to_usize(&self) -> TractResult<usize> {
66 self.to_i64().map(|d| d as usize)
67 }
68
69 fn to_isize(&self) -> TractResult<isize> {
70 self.to_i64().map(|d| d as isize)
71 }
72
73 fn to_i32(&self) -> TractResult<i32> {
74 self.to_i64().map(|d| d as i32)
75 }
76
77 fn one() -> Self;
79
80 fn eval(&self, values: &SymbolValues) -> Self;
82
83 fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64>;
85
86 fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self>;
87
88 fn broadcast(self, other: Self) -> TractResult<Self>;
89 fn mini(self, other: Self) -> Self;
90 fn maxi(self, other: Self) -> Self;
91
92 fn compatible_with(&self, other: &Self) -> bool;
93}
94
95impl DimLike for TDim {
96 fn maybe_div(&self, other: &Self) -> TractResult<(Self, u64)> {
97 if self.is_zero() {
98 return Ok((TDim::zero(), 1));
99 } else if other.is_zero() {
100 bail!("Division by zero")
101 }
102 fn expand(dim: &TDim) -> (i64, Vec<TDim>) {
103 match dim {
104 TDim::Mul(terms) => terms.iter().map(expand).fold((1i64, vec![]), |acc, t| {
105 (acc.0 * t.0, acc.1.into_iter().chain(t.1).collect())
106 }),
107 TDim::MulInt(a, terms) => {
108 let (b, v) = expand(terms);
109 (a * b, v)
110 }
111 TDim::Val(x) => (*x, vec![]),
112 TDim::Add(terms) => {
113 let gcd =
114 terms.iter().map(expand).map(|(n, _)| n).reduce(|a, b| a.gcd(&b)).unwrap();
115 (
116 gcd,
117 vec![TDim::Add(terms.iter().map(|t| t.clone() / gcd).collect()).simplify()],
118 )
119 }
120 it => (1, vec![it.clone()]),
121 }
122 }
123 let (mut num_int, mut num) = expand(self);
124 let (mut denum_int, mut denum) = expand(other);
125 if num == denum {
126 num = vec![];
127 denum = vec![];
128 }
129 for it in denum {
130 if let Some(pos) = num.iter().position(|n| n == &it) {
131 num.remove(pos);
132 } else {
133 bail!("Can't divide {} by {}", self, other)
134 }
135 }
136 use num_integer::Integer;
137 if denum_int < 0 {
138 num_int *= -1;
139 denum_int *= -1;
140 }
141 let gcd = num_int.gcd(&denum_int);
142 num_int /= gcd;
143 denum_int /= gcd;
144 Ok(((TDim::Mul(num) * num_int).reduce(), denum_int as u64))
145 }
146
147 fn to_i64(&self) -> TractResult<i64> {
148 TDim::to_i64(self)
149 }
150
151 fn one() -> Self {
152 Self::from(1)
153 }
154
155 fn eval(&self, values: &SymbolValues) -> Self {
156 self.eval(values)
157 }
158
159 fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self> {
160 self.substitute(from, to)
161 }
162
163 fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64> {
164 TDim::eval_to_i64(self, values)
165 }
166
167 fn broadcast(self, other: Self) -> TractResult<Self> {
168 if self.is_one() {
169 Ok(other)
170 } else if other.is_one() {
171 Ok(self)
172 } else {
173 Ok(TDim::Broadcast(vec![self, other]).simplify())
174 }
175 }
176
177 fn compatible_with(&self, other: &Self) -> bool {
178 self.compatible_with(other)
179 }
180
181 fn mini(self, other: Self) -> Self {
182 TDim::Min(vec![self, other]).simplify()
183 }
184
185 fn maxi(self, other: Self) -> Self {
186 TDim::Max(vec![self, other]).simplify()
187 }
188}
189
190impl<'a> std::convert::TryFrom<&'a TDim> for TDim {
191 type Error = TractError;
192 fn try_from(d: &'a TDim) -> TractResult<TDim> {
193 Ok(d.clone())
194 }
195}
196
197impl DimLike for usize {
198 fn maybe_div(&self, other: &Self) -> TractResult<(Self, u64)> {
199 use num_integer::Integer;
200 let gcd = self.gcd(other);
201 Ok((self / gcd, (other / gcd) as u64))
202 }
203
204 fn to_i64(&self) -> TractResult<i64> {
205 Ok(*self as i64)
206 }
207
208 fn one() -> usize {
209 1
210 }
211
212 fn eval(&self, _values: &SymbolValues) -> Self {
213 *self
214 }
215
216 fn substitute(&self, _from: &Symbol, _to: &Self) -> TractResult<Self> {
217 Ok(*self)
218 }
219
220 fn eval_to_i64(&self, _: &SymbolValues) -> TractResult<i64> {
221 Ok(*self as i64)
222 }
223
224 fn broadcast(self, other: Self) -> TractResult<Self> {
225 if self == 1 || self == other {
226 Ok(other)
227 } else if other == 1 {
228 Ok(self)
229 } else {
230 bail!("Can not broadcast {self} against {other}")
231 }
232 }
233
234 fn compatible_with(&self, other: &Self) -> bool {
235 self == other
236 }
237
238 fn mini(self, other: Self) -> Self {
239 if self < other {
240 self
241 } else {
242 other
243 }
244 }
245
246 fn maxi(self, other: Self) -> Self {
247 if self > other {
248 self
249 } else {
250 other
251 }
252 }
253}
254
255impl<'a> std::convert::TryFrom<&'a TDim> for usize {
256 type Error = TractError;
257 fn try_from(d: &'a TDim) -> TractResult<usize> {
258 d.to_usize()
259 }
260}
261
262pub trait ToDim {
264 fn to_dim(&self) -> TDim;
266}
267
268impl<I: Into<TDim> + Clone> ToDim for I {
269 fn to_dim(&self) -> TDim {
270 self.clone().into()
271 }
272}
273
274impl ToDim for &TDim {
275 fn to_dim(&self) -> TDim {
276 (*self).clone()
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 lazy_static::lazy_static! {
285 static ref S: (SymbolScope, Symbol) = {
286 let table = SymbolScope::default();
287 let s = table.new_with_prefix("S");
288 (table, s)
289 };
290 }
291
292 pub fn s() -> TDim {
293 S.1.clone().into()
294 }
295
296 #[test]
297 fn div() {
298 assert_eq!(TDim::from(12).maybe_div(&TDim::from(4)).unwrap(), (3.into(), 1));
299 }
300
301 #[test]
302 fn div_sym_int() {
303 assert_eq!((s() * 12).maybe_div(&TDim::from(4)).unwrap(), (s() * 3, 1));
304 }
305
306 #[test]
307 fn div_sym_sym() {
308 assert_eq!((s() * 12).maybe_div(&(s() * 4)).unwrap(), (3.into(), 1));
309 }
310
311 #[test]
312 fn div_sym_sym_ratio() {
313 assert_eq!((s() * 13).maybe_div(&(s() * 4)).unwrap(), (13.into(), 4));
314 }
315
316 #[test]
317 fn div_sym_sym_rem() {
318 assert!((s() + 1).maybe_div(&(s() * 4)).is_err());
319 }
320
321 #[test]
322 fn div_sym_sym_simply_1() {
323 assert_eq!((s()).maybe_div(&(s())).unwrap(), (TDim::Val(1), 1));
324 }
325
326 #[test]
327 fn div_sym_sym_complex() {
328 let s = s();
329 let b = S.0.sym("b");
330 assert_eq!(
331 (256.to_dim() * &s * &b).maybe_div(&(1.to_dim() * &s * &b)).unwrap(),
332 (256.into(), 1)
333 );
334 }
335
336 #[test]
337 fn div_sym_sym_with_add() {
338 assert_eq!((s() * 80 - 160).maybe_div(&(s() - 2)).unwrap(), (80.into(), 1));
339 }
340}