use std::fmt;
use std::ops::Index;
use crate::tract_num_traits::ToPrimitive;
use crate::infer::factoid::*;
use self::super::cache::Cache;
use self::super::expr::Output;
use self::super::path::Path;
pub trait Proxy {
fn get_path(&self) -> &Path;
}
pub trait ComparableProxy: Proxy {
type Output: Output;
}
macro_rules! impl_proxy {
($struct:ident) => {
impl Proxy for $struct {
fn get_path(&self) -> &Path {
&self.path
}
}
impl fmt::Debug for $struct {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "{:?}", self.get_path())
}
}
impl<'a> Proxy for &'a $struct {
fn get_path(&self) -> &Path {
&self.path
}
}
};
}
macro_rules! impl_comparable_proxy {
($struct:ident, $output:ident) => {
impl ComparableProxy for $struct {
type Output = $output;
}
impl<'a> ComparableProxy for &'a $struct {
type Output = $output;
}
};
}
#[derive(new)]
pub struct IntProxy {
path: Path,
}
impl_proxy!(IntProxy);
impl_comparable_proxy!(IntProxy, IntFactoid);
pub struct TensorProxy {
pub datum_type: TypeProxy,
pub rank: IntProxy,
pub shape: ShapeProxy,
pub value: ValueProxy,
path: Path,
}
impl TensorProxy {
pub fn new(path: Path) -> TensorProxy {
TensorProxy {
datum_type: TypeProxy::new([&path[..], &[0]].concat().into()),
rank: IntProxy::new([&path[..], &[1]].concat().into()),
shape: ShapeProxy::new([&path[..], &[2]].concat().into()),
value: ValueProxy::new([&path[..], &[3]].concat().into()),
path,
}
}
}
impl_proxy!(TensorProxy);
#[derive(new)]
pub struct TypeProxy {
path: Path,
}
impl_proxy!(TypeProxy);
impl_comparable_proxy!(TypeProxy, TypeFactoid);
pub struct ShapeProxy {
dims: Cache<usize, DimProxy>,
path: Path,
}
impl ShapeProxy {
pub fn new(path: Path) -> ShapeProxy {
ShapeProxy { dims: Cache::new(), path }
}
}
impl_proxy!(ShapeProxy);
impl_comparable_proxy!(ShapeProxy, ShapeFactoid);
impl Index<usize> for ShapeProxy {
type Output = DimProxy;
fn index(&self, index: usize) -> &DimProxy {
let path = [&self.path[..], &[index.to_isize().unwrap()]].concat();
self.dims.get(index, || DimProxy::new(path.into()))
}
}
#[derive(new)]
pub struct DimProxy {
path: Path,
}
impl_proxy!(DimProxy);
impl_comparable_proxy!(DimProxy, DimFact);
pub struct ValueProxy {
sub: Cache<usize, ElementProxy>,
root: IntProxy,
path: Path,
}
impl ValueProxy {
pub fn new(path: Path) -> ValueProxy {
let root = IntProxy::new([&path[..], &[-1]].concat().into());
ValueProxy { sub: Cache::new(), root, path }
}
}
impl Index<()> for ValueProxy {
type Output = IntProxy;
fn index(&self, _: ()) -> &IntProxy {
&self.root
}
}
impl Index<usize> for ValueProxy {
type Output = ElementProxy;
fn index(&self, index: usize) -> &ElementProxy {
let path = [&self.path[..], &[index.to_isize().unwrap()]].concat();
self.sub.get(index, || ElementProxy::new(path.into()))
}
}
impl_proxy!(ValueProxy);
impl_comparable_proxy!(ValueProxy, ValueFact);
pub struct ElementProxy {
sub: Cache<usize, ElementProxy>,
path: Path,
}
impl ElementProxy {
pub fn new(path: Path) -> ElementProxy {
ElementProxy { sub: Cache::new(), path }
}
}
impl Index<usize> for ElementProxy {
type Output = ElementProxy;
fn index(&self, index: usize) -> &ElementProxy {
let path = [&self.path[..], &[index.to_isize().unwrap()]].concat();
self.sub.get(index, || ElementProxy::new(path.into()))
}
}
impl_proxy!(ElementProxy);
impl_comparable_proxy!(ElementProxy, IntFactoid);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_proxy_datum_type() {
let input = TensorProxy::new(vec![0, 0].into());
assert_eq!(input.datum_type.get_path(), &vec![0, 0, 0].into());
}
#[test]
fn test_tensor_proxy_rank() {
let input = TensorProxy::new(vec![0, 0].into());
assert_eq!(input.rank.get_path(), &vec![0, 0, 1].into());
}
#[test]
fn test_tensor_proxy_shape() {
let input = TensorProxy::new(vec![0, 0].into());
assert_eq!(input.shape[0].get_path(), &vec![0, 0, 2, 0].into());
assert_eq!(input.shape[2].get_path(), &vec![0, 0, 2, 2].into());
}
#[test]
fn test_tensor_proxy_value() {
let input = TensorProxy::new(vec![0, 0].into());
assert_eq!(input.value.get_path(), &vec![0, 0, 3].into());
assert_eq!(input.value[()].get_path(), &vec![0, 0, 3, -1].into());
assert_eq!(input.value[0].get_path(), &vec![0, 0, 3, 0].into());
assert_eq!(input.value[0][1].get_path(), &vec![0, 0, 3, 0, 1].into());
assert_eq!(input.value[1][2][3].get_path(), &vec![0, 0, 3, 1, 2, 3].into());
}
}