use std::fmt;
use std::iter::FromIterator;
use std::ops::{Add, Div, Mul, Neg, Rem, Sub};
use tract_num_traits::Zero;
use crate::internal::*;
pub trait Factoid: fmt::Debug + Clone + PartialEq + Default {
type Concrete: fmt::Debug;
fn concretize(&self) -> Option<Self::Concrete>;
fn is_concrete(&self) -> bool {
self.concretize().is_some()
}
fn unify(&self, other: &Self) -> TractResult<Self>;
fn unify_with(&mut self, other: &Self) -> TractResult<bool> {
let new = self.unify(&other)?;
let mut changed = false;
if &new != self {
changed = true;
*self = new.clone();
}
Ok(changed)
}
fn unify_with_mut(&mut self, other: &mut Self) -> TractResult<bool> {
let new = self.unify(&other)?;
let mut changed = false;
if &new != self {
changed = true;
*self = new.clone();
}
if &new != other {
changed = true;
*other = new;
}
Ok(changed)
}
fn unify_all(facts: &mut [&mut Self]) -> TractResult<bool> {
let mut overall_changed = false;
loop {
let mut changed = false;
for i in 0..facts.len() - 1 {
for j in i + 1..facts.len() {
let (left, right) = facts.split_at_mut(j);
let c = left[i].unify_with(right[0])?;
changed = changed || c;
overall_changed = changed || c;
}
}
if !changed {
return Ok(overall_changed);
}
}
}
}
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[derive(Clone, PartialEq)]
pub enum GenericFactoid<T: fmt::Debug + Clone + PartialEq> {
Only(T),
Any,
}
impl<T: Copy + Clone + fmt::Debug + PartialEq> Copy for GenericFactoid<T> {}
impl<T: fmt::Debug + Clone + PartialEq> Factoid for GenericFactoid<T> {
type Concrete = T;
fn concretize(&self) -> Option<T> {
match self {
GenericFactoid::Any => None,
GenericFactoid::Only(m) => Some(m.clone()),
}
}
fn unify(&self, other: &Self) -> TractResult<Self> {
let fact = match (self, other) {
(_, GenericFactoid::Any) => self.clone(),
(GenericFactoid::Any, _) => other.clone(),
_ if self == other => self.clone(),
_ => bail!("Impossible to unify {:?} with {:?}.", self, other),
};
Ok(fact)
}
}
impl<T: fmt::Debug + Clone + PartialEq> Default for GenericFactoid<T> {
fn default() -> Self {
GenericFactoid::Any
}
}
impl<T: fmt::Debug + Clone + PartialEq> From<T> for GenericFactoid<T> {
fn from(t: T) -> Self {
GenericFactoid::Only(t)
}
}
impl<T: fmt::Debug + Clone + PartialEq> fmt::Debug for GenericFactoid<T> {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
match self {
GenericFactoid::Any => write!(formatter, "?"),
GenericFactoid::Only(u) => write!(formatter, "{:?}", u),
}
}
}
pub type TypeFactoid = GenericFactoid<DatumType>;
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[derive(Clone, PartialEq)]
pub struct ShapeFactoid {
pub(super) open: bool,
pub(super) dims: TVec<GenericFactoid<i32>>,
pub(super) stream: Option<StreamFact>,
}
impl ShapeFactoid {
pub fn open(dims: TVec<DimFact>) -> ShapeFactoid {
if let Some((ix, d)) = dims
.iter()
.enumerate()
.find(|(_ix, d)| d.concretize().map(|d| d.is_stream()).unwrap_or(false))
{
let stream = Some(StreamFact { axis: ix, len: d.concretize().unwrap() });
ShapeFactoid {
open: true,
dims: dims
.iter()
.map(|d| match d {
GenericFactoid::Only(d) if d.is_stream() => GenericFactoid::Only(-1),
GenericFactoid::Only(d) => GenericFactoid::Only(d.to_integer().unwrap()),
GenericFactoid::Any => GenericFactoid::Any,
})
.collect(),
stream,
}
} else {
ShapeFactoid {
open: true,
dims: dims
.iter()
.map(|d| match d {
GenericFactoid::Only(d) => GenericFactoid::Only(d.to_integer().unwrap()),
GenericFactoid::Any => GenericFactoid::Any,
})
.collect(),
stream: None,
}
}
}
pub fn is_open(&self) -> bool {
self.open
}
pub fn closed(dims: TVec<DimFact>) -> ShapeFactoid {
ShapeFactoid { open: false, ..Self::open(dims) }
}
pub fn rank(&self) -> IntFactoid {
if self.open { GenericFactoid::Any } else { GenericFactoid::Only(self.dims.len() as i32) }
.into()
}
pub fn ensure_rank_at_least(&mut self, n: usize) -> bool {
let mut changed = false;
while self.dims.len() <= n {
self.dims.push(GenericFactoid::Any);
changed = true;
}
changed
}
pub fn dim(&self, i: usize) -> Option<DimFact> {
self.dims().nth(i)
}
pub fn set_dim(&mut self, i: usize, d: TDim) -> bool {
let fact = GenericFactoid::Only(d.clone());
if self.dim(i).as_ref() == Some(&fact) {
return false;
}
match d.to_integer() {
Ok(n) => self.dims[i] = GenericFactoid::Only(n),
Err(_) => {
self.dims[i] = GenericFactoid::Only(-1);
self.stream = Some(StreamFact { axis: i, len: d })
}
}
return true;
}
pub fn dims(&self) -> impl Iterator<Item = DimFact> {
let stream = self.stream.clone();
self.dims.clone().into_iter().map(move |d| match d {
GenericFactoid::Only(-1) => {
assert!(stream.is_some(), "-1 dim found with no stream. This is a tract bug.");
GenericFactoid::Only(stream.as_ref().unwrap().len.clone())
}
GenericFactoid::Only(d) => GenericFactoid::Only(d.to_dim()),
GenericFactoid::Any => GenericFactoid::Any,
})
}
pub fn stream_info(&self) -> TractResult<Option<StreamFact>> {
let concrete = self
.concretize()
.ok_or("Shape has unknown dims, can not find streaming dim for sure.")?;
let count = concrete.iter().filter(|&d| d.is_stream()).count();
if count > 1 {
bail!("Shape has more than one streaming dim. This is terribly wrong.")
}
Ok(concrete
.into_iter()
.enumerate()
.find(|(_, d)| d.is_stream())
.map(|(axis, len)| StreamFact { axis, len }))
}
pub fn as_concrete_finite(&self) -> TractResult<Option<TVec<usize>>> {
if !self.is_concrete() || self.stream_info()?.is_some() {
return Ok(None);
}
Ok(Some(self.dims.iter().map(|i| i.concretize().unwrap() as usize).collect()))
}
}
impl Factoid for ShapeFactoid {
type Concrete = TVec<TDim>;
fn concretize(self: &ShapeFactoid) -> Option<TVec<TDim>> {
if self.open {
return None;
}
let dims: TVec<_> = self.dims().filter_map(|d| d.concretize()).collect();
if dims.len() < self.dims.len() {
None
} else {
Some(dims)
}
}
fn unify(&self, other: &Self) -> TractResult<Self> {
let (x, y) = (self, other);
use tract_itertools::EitherOrBoth::{Both, Left, Right};
use tract_itertools::Itertools;
let xi = x.dims();
let yi = y.dims();
let dimensions: TVec<_> = xi
.zip_longest(yi)
.map(|r| match r {
Both(a, b) => a.unify(&b),
Left(ref d) if y.open => Ok(d.clone()),
Right(ref d) if x.open => Ok(d.clone()),
Left(_) | Right(_) => bail!(
"Impossible to unify closed shapes of different rank (found {:?} and {:?}).",
x,
y
),
})
.collect::<TractResult<_>>()
.map_err(|e| format!("Unifying shapes {:?} and {:?}, {}", x, y, e))?;
if x.open && y.open {
Ok(ShapeFactoid::open(dimensions))
} else {
Ok(ShapeFactoid::closed(dimensions))
}
}
}
impl Default for ShapeFactoid {
fn default() -> ShapeFactoid {
ShapeFactoid::open(tvec![])
}
}
impl FromIterator<TDim> for ShapeFactoid {
fn from_iter<I: IntoIterator<Item = TDim>>(iter: I) -> ShapeFactoid {
ShapeFactoid::closed(iter.into_iter().map(|d| GenericFactoid::Only(d)).collect())
}
}
impl FromIterator<usize> for ShapeFactoid {
fn from_iter<I: IntoIterator<Item = usize>>(iter: I) -> ShapeFactoid {
ShapeFactoid::closed(iter.into_iter().map(|d| GenericFactoid::Only(d.to_dim())).collect())
}
}
impl<D: ToDim, I: IntoIterator<Item = D>> From<I> for ShapeFactoid {
fn from(it: I) -> ShapeFactoid {
ShapeFactoid::closed(it.into_iter().map(|d| GenericFactoid::Only(d.to_dim())).collect())
}
}
impl fmt::Debug for ShapeFactoid {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
for (ix, d) in self.dims.iter().enumerate() {
if ix != 0 {
write!(formatter, "x")?
}
if let Some(ref stream) = self.stream {
if stream.axis == ix {
write!(formatter, "{:?}", stream.len)?;
} else {
write!(formatter, "{:?}", d)?;
}
} else {
write!(formatter, "{:?}", d)?;
}
}
if self.open {
if self.dims.len() == 0 {
write!(formatter, "..")?;
} else {
write!(formatter, "x..")?;
}
}
Ok(())
}
}
pub type DimFact = GenericFactoid<TDim>;
pub type ValueFact = GenericFactoid<Arc<Tensor>>;
pub type IntFactoid = GenericFactoid<i32>;
impl<T> Zero for GenericFactoid<T>
where
T: Add<T, Output = T> + Zero + PartialEq + Clone + ::std::fmt::Debug,
{
fn zero() -> GenericFactoid<T> {
GenericFactoid::Only(T::zero())
}
fn is_zero(&self) -> bool {
match self {
GenericFactoid::Only(t) => t.is_zero(),
_ => false,
}
}
}
impl<T> Neg for GenericFactoid<T>
where
T: Neg<Output = T> + PartialEq + Clone + ::std::fmt::Debug,
{
type Output = GenericFactoid<T>;
fn neg(self) -> GenericFactoid<T> {
match self {
GenericFactoid::Only(t) => GenericFactoid::Only(t.neg()),
any => any,
}
}
}
impl<T, I> Add<I> for GenericFactoid<T>
where
T: Add<T, Output = T> + PartialEq + Clone + ::std::fmt::Debug,
I: Into<GenericFactoid<T>>,
{
type Output = GenericFactoid<T>;
fn add(self, rhs: I) -> Self::Output {
match (self.concretize(), rhs.into().concretize()) {
(Some(a), Some(b)) => GenericFactoid::Only(a + b),
_ => GenericFactoid::Any,
}
}
}
impl<T> Sub<GenericFactoid<T>> for GenericFactoid<T>
where
T: Sub<T, Output = T> + PartialEq + Clone + ::std::fmt::Debug,
{
type Output = GenericFactoid<T>;
fn sub(self, rhs: GenericFactoid<T>) -> Self::Output {
match (self.concretize(), rhs.concretize()) {
(Some(a), Some(b)) => GenericFactoid::Only(a - b),
_ => GenericFactoid::Any,
}
}
}
impl<T, R> Mul<R> for GenericFactoid<T>
where
T: Mul<R, Output = T> + PartialEq + Clone + ::std::fmt::Debug,
{
type Output = GenericFactoid<T>;
fn mul(self, rhs: R) -> Self::Output {
if let Some(a) = self.concretize() {
GenericFactoid::Only(a * rhs)
} else {
GenericFactoid::Any
}
}
}
impl<T, R> Div<R> for GenericFactoid<T>
where
T: Div<R, Output = T> + PartialEq + Clone + ::std::fmt::Debug,
{
type Output = GenericFactoid<T>;
fn div(self, rhs: R) -> Self::Output {
if let Some(a) = self.concretize() {
GenericFactoid::Only(a / rhs)
} else {
GenericFactoid::Any
}
}
}
impl<T, R> Rem<R> for GenericFactoid<T>
where
T: Rem<R, Output = T> + PartialEq + Clone + ::std::fmt::Debug,
{
type Output = GenericFactoid<T>;
fn rem(self, rhs: R) -> Self::Output {
if let Some(a) = self.concretize() {
GenericFactoid::Only(a % rhs)
} else {
GenericFactoid::Any
}
}
}
#[cfg(test)]
mod tests {
use super::GenericFactoid::*;
use super::*;
use tract_core::datum::DatumType;
#[test]
fn unify_same_datum_type() {
let dt = TypeFactoid::Only(DatumType::F32);
assert_eq!(dt.unify(&dt).unwrap(), dt);
}
#[test]
fn unify_different_datum_types_only() {
let dt1 = TypeFactoid::Only(DatumType::F32);
let dt2 = TypeFactoid::Only(DatumType::F64);
assert!(dt1.unify(&dt2).is_err());
}
#[test]
fn unify_different_datum_types_any_left() {
let dt = TypeFactoid::Only(DatumType::F32);
assert_eq!(TypeFactoid::Any.unify(&dt).unwrap(), dt);
}
#[test]
fn unify_different_datum_types_any_right() {
let dt = TypeFactoid::Only(DatumType::F32);
assert_eq!(dt.unify(&TypeFactoid::Any).unwrap(), dt);
}
#[test]
fn unify_same_shape_1() {
let s = ShapeFactoid::closed(tvec![]);
assert_eq!(s.unify(&s).unwrap(), s);
}
#[test]
fn unify_same_shape_2() {
let s = ShapeFactoid::closed(tvec![Any]);
assert_eq!(s.unify(&s).unwrap(), s);
}
#[test]
fn unify_same_shape_3() {
let s = ShapeFactoid::closed(tvec![Only(1.into()), Only(2.into())]);
assert_eq!(s.unify(&s).unwrap(), s);
}
#[test]
fn unify_different_shapes_1() {
let s1 = ShapeFactoid::closed(tvec![Only(1.into()), Only(2.into())]);
let s2 = ShapeFactoid::closed(tvec![Only(1.into())]);
assert!(s1.unify(&s2).is_err());
}
#[test]
fn unify_different_shapes_2() {
let s1 = ShapeFactoid::closed(tvec![Only(1.into()), Only(2.into())]);
let s2 = ShapeFactoid::closed(tvec![Any]);
assert!(s1.unify(&s2).is_err());
}
#[test]
fn unify_different_shapes_3() {
let s1 = ShapeFactoid::open(tvec![Only(1.into()), Only(2.into())]);
let s2 = ShapeFactoid::closed(tvec![Any]);
assert!(s1.unify(&s2).is_err());
}
#[test]
fn unify_different_shapes_4() {
let s1 = ShapeFactoid::closed(tvec![Any]);
let s2 = ShapeFactoid::closed(tvec![Any]);
let sr = ShapeFactoid::closed(tvec![Any]);
assert_eq!(s1.unify(&s2).unwrap(), sr);
}
#[test]
fn unify_different_shapes_5() {
let s1 = ShapeFactoid::closed(tvec![Any]);
let s2 = ShapeFactoid::closed(tvec![Only(1.into())]);
let sr = ShapeFactoid::closed(tvec![Only(1.into())]);
assert_eq!(s1.unify(&s2).unwrap(), sr);
}
#[test]
fn unify_different_shapes_6() {
let s1 = ShapeFactoid::open(tvec![]);
let s2 = ShapeFactoid::closed(tvec![Only(1.into())]);
let sr = ShapeFactoid::closed(tvec![Only(1.into())]);
assert_eq!(s1.unify(&s2).unwrap(), sr);
}
#[test]
fn unify_different_shapes_7() {
let s1 = ShapeFactoid::open(tvec![Any, Only(2.into())]);
let s2 = ShapeFactoid::closed(tvec![Only(1.into()), Any, Any]);
let sr = ShapeFactoid::closed(tvec![Only(1.into()), Only(2.into()), Any]);
assert_eq!(s1.unify(&s2).unwrap(), sr);
}
#[test]
fn unify_same_value() {
let t = ValueFact::Only(rctensor0(12f32));
assert_eq!(t.unify(&t).unwrap(), t);
}
#[test]
fn unify_different_values_only() {
let t1 = ValueFact::Only(rctensor1(&[12f32]));
let t2 = ValueFact::Only(rctensor1(&[12f32, 42.0]));
assert!(t1.unify(&t2).is_err());
}
#[test]
fn unify_different_values_any_left() {
let t1 = ValueFact::Only(rctensor1(&[12f32]));
assert_eq!(ValueFact::Any.unify(&t1).unwrap(), t1);
}
#[test]
fn unify_different_values_any_right() {
let t1 = ValueFact::Only(rctensor1(&[12f32]));
assert_eq!(t1.unify(&ValueFact::Any).unwrap(), t1);
}
}