use num_traits::Zero;
use std::fmt;
use std::ops;
mod parse;
mod resolve;
mod sym;
mod tree;
pub use self::parse::parse_tdim;
pub use self::resolve::solve_for;
pub use self::sym::{Symbol, SymbolTable, SymbolValues};
pub use self::tree::{TDim, UndeterminedSymbol};
use crate::{TractError, TractResult};
pub trait DimLike:
Clone
+ Default
+ PartialEq
+ From<usize>
+ for<'a> std::convert::TryFrom<&'a TDim, Error = TractError>
+ ::num_traits::Zero
+ fmt::Debug
+ fmt::Display
+ std::hash::Hash
+ ops::Add<Self, Output = Self>
+ ops::Add<usize, Output = Self>
+ for<'a> ops::Add<&'a Self, Output = Self>
+ ops::Sub<Self, Output = Self>
+ ops::Sub<usize, Output = Self>
+ for<'a> ops::Sub<&'a Self, Output = Self>
+ ops::Mul<Self, Output = Self>
+ ops::Mul<usize, Output = Self>
+ for<'a> ops::Mul<&'a Self, Output = Self>
+ ops::Div<usize, Output = Self>
+ ops::Rem<usize, Output = Self>
+ Send
+ Sync
+ 'static
+ std::iter::Sum
+ std::iter::Product
+ ToDim
{
fn maybe_div(&self, other: &Self) -> TractResult<(Self, u64)>;
fn divceil(&self, other: usize) -> Self {
(self.clone() + other - 1) / other
}
fn to_i64(&self) -> TractResult<i64>;
fn to_usize(&self) -> TractResult<usize> {
self.to_i64().map(|d| d as usize)
}
fn to_isize(&self) -> TractResult<isize> {
self.to_i64().map(|d| d as isize)
}
fn to_i32(&self) -> TractResult<i32> {
self.to_i64().map(|d| d as i32)
}
fn one() -> Self;
fn eval(&self, values: &SymbolValues) -> Self;
fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64>;
fn substitute(&self, from: &Symbol, to: &Self) -> Self;
}
impl DimLike for TDim {
fn maybe_div(&self, other: &Self) -> TractResult<(Self, u64)> {
if self.is_zero() {
return Ok((TDim::zero(), 1));
} else if other.is_zero() {
anyhow::bail!("Division by zero")
}
fn expand(dim: &TDim) -> (i64, Vec<TDim>) {
match dim {
TDim::Mul(terms) => terms.iter().map(expand).fold((1i64, vec![]), |acc, t| {
(acc.0 * t.0, acc.1.into_iter().chain(t.1).collect())
}),
TDim::MulInt(a, terms) => {
let (b, v) = expand(terms);
(a * b, v)
}
TDim::Val(x) => (*x, vec![]),
TDim::Add(terms) => {
let gcd =
terms.iter().map(expand).map(|(n, _)| n).reduce(|a, b| a.gcd(&b)).unwrap();
(
gcd,
vec![TDim::Add(terms.iter().map(|t| t.clone() / gcd).collect()).simplify()],
)
}
it => (1, vec![it.clone()]),
}
}
let (mut num_int, mut num) = expand(self);
let (mut denum_int, mut denum) = expand(other);
if num == denum {
num = vec![];
denum = vec![];
}
for it in denum {
if let Some(pos) = num.iter().position(|n| n == &it) {
num.remove(pos);
} else {
anyhow::bail!("Can't divide {} by {}", self, other)
}
}
use num_integer::Integer;
if denum_int < 0 {
num_int *= -1;
denum_int *= -1;
}
let gcd = num_int.gcd(&denum_int);
num_int /= gcd;
denum_int /= gcd;
Ok(((TDim::Mul(num) * num_int).reduce(), denum_int as u64))
}
fn to_i64(&self) -> TractResult<i64> {
TDim::to_i64(self)
}
fn one() -> Self {
Self::from(1)
}
fn eval(&self, values: &SymbolValues) -> Self {
self.eval(values)
}
fn substitute(&self, from: &Symbol, to: &Self) -> Self {
self.substitute(from, to)
}
fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64> {
TDim::eval_to_i64(self, values)
}
}
impl<'a> std::convert::TryFrom<&'a TDim> for TDim {
type Error = anyhow::Error;
fn try_from(d: &'a TDim) -> TractResult<TDim> {
Ok(d.clone())
}
}
impl DimLike for usize {
fn maybe_div(&self, other: &Self) -> TractResult<(Self, u64)> {
use num_integer::Integer;
let gcd = self.gcd(other);
Ok((self / gcd, (other / gcd) as u64))
}
fn to_i64(&self) -> TractResult<i64> {
Ok(*self as i64)
}
fn one() -> usize {
1
}
fn eval(&self, _values: &SymbolValues) -> Self {
*self
}
fn substitute(&self, _from: &Symbol, _to: &Self) -> Self {
*self
}
fn eval_to_i64(&self, _: &SymbolValues) -> TractResult<i64> {
Ok(*self as i64)
}
}
impl<'a> std::convert::TryFrom<&'a TDim> for usize {
type Error = anyhow::Error;
fn try_from(d: &'a TDim) -> anyhow::Result<usize> {
d.to_usize()
}
}
pub trait ToDim {
fn to_dim(&self) -> TDim;
}
impl<I: Into<TDim> + Clone> ToDim for I {
fn to_dim(&self) -> TDim {
self.clone().into()
}
}
impl<'a> ToDim for &'a TDim {
fn to_dim(&self) -> TDim {
(*self).clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
lazy_static::lazy_static! {
static ref S: (SymbolTable, Symbol) = {
let table = SymbolTable::default();
let s = table.new_with_prefix("S");
(table, s)
};
}
pub fn s() -> TDim {
S.1.clone().into()
}
#[test]
fn div() {
assert_eq!(TDim::from(12).maybe_div(&TDim::from(4)).unwrap(), (3.into(), 1));
}
#[test]
fn div_sym_int() {
assert_eq!((s() * 12).maybe_div(&TDim::from(4)).unwrap(), (s() * 3, 1));
}
#[test]
fn div_sym_sym() {
assert_eq!((s() * 12).maybe_div(&(s() * 4)).unwrap(), (3.into(), 1));
}
#[test]
fn div_sym_sym_ratio() {
assert_eq!((s() * 13).maybe_div(&(s() * 4)).unwrap(), (13.into(), 4));
}
#[test]
fn div_sym_sym_rem() {
assert!((s() + 1).maybe_div(&(s() * 4)).is_err());
}
#[test]
fn div_sym_sym_simply_1() {
assert_eq!((s()).maybe_div(&(s())).unwrap(), (TDim::Val(1), 1));
}
#[test]
fn div_sym_sym_complex() {
let s = s();
let b = S.0.sym("b");
assert_eq!(
(256.to_dim() * &s * &b).maybe_div(&(1.to_dim() * &s * &b)).unwrap(),
(256.into(), 1)
);
}
#[test]
fn div_sym_sym_with_add() {
assert_eq!((s() * 80 - 160).maybe_div(&(s() - 2)).unwrap(), (80.into(), 1));
}
}