1use std::fmt;
2use std::ops::Index;
3
4use crate::tract_num_traits::ToPrimitive;
5
6use crate::infer::factoid::*;
7
8use self::super::cache::Cache;
9use self::super::expr::Output;
10use self::super::path::Path;
11
12pub trait Proxy {
14 fn get_path(&self) -> &Path;
21}
22
23pub trait ComparableProxy: Proxy {
25 type Output: Output;
26}
27
28macro_rules! impl_proxy {
30 ($struct:ident) => {
31 impl Proxy for $struct {
32 fn get_path(&self) -> &Path {
34 &self.path
35 }
36 }
37
38 impl fmt::Debug for $struct {
39 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
40 write!(formatter, "{:?}", self.get_path())
41 }
42 }
43
44 impl<'a> Proxy for &'a $struct {
45 fn get_path(&self) -> &Path {
47 &self.path
48 }
49 }
50 };
51}
52
53macro_rules! impl_comparable_proxy {
55 ($struct:ident, $output:ident) => {
56 impl ComparableProxy for $struct {
57 type Output = $output;
58 }
59 impl<'a> ComparableProxy for &'a $struct {
60 type Output = $output;
61 }
62 };
63}
64
65#[derive(new)]
67pub struct IntProxy {
68 path: Path,
69}
70
71impl_proxy!(IntProxy);
72impl_comparable_proxy!(IntProxy, IntFactoid);
73
74pub struct TensorProxy {
84 pub datum_type: TypeProxy,
85 pub rank: IntProxy,
86 pub shape: ShapeProxy,
87 pub value: ValueProxy,
88 path: Path,
89}
90
91impl TensorProxy {
92 pub fn new(path: Path) -> TensorProxy {
94 TensorProxy {
95 datum_type: TypeProxy::new([&path[..], &[0]].concat().into()),
96 rank: IntProxy::new([&path[..], &[1]].concat().into()),
97 shape: ShapeProxy::new([&path[..], &[2]].concat().into()),
98 value: ValueProxy::new([&path[..], &[3]].concat().into()),
99 path,
100 }
101 }
102}
103
104impl_proxy!(TensorProxy);
105
106#[derive(new)]
108pub struct TypeProxy {
109 path: Path,
110}
111
112impl_proxy!(TypeProxy);
113impl_comparable_proxy!(TypeProxy, TypeFactoid);
114
115pub struct ShapeProxy {
117 dims: Cache<usize, DimProxy>,
118 path: Path,
119}
120
121impl ShapeProxy {
122 pub fn new(path: Path) -> ShapeProxy {
124 ShapeProxy { dims: Cache::new(), path }
125 }
126}
127
128impl_proxy!(ShapeProxy);
129impl_comparable_proxy!(ShapeProxy, ShapeFactoid);
130
131impl Index<usize> for ShapeProxy {
132 type Output = DimProxy;
133
134 fn index(&self, index: usize) -> &DimProxy {
136 let path = [&self.path[..], &[index.to_isize().unwrap()]].concat();
137 self.dims.get(index, || DimProxy::new(path.into()))
138 }
139}
140
141#[derive(new)]
143pub struct DimProxy {
144 path: Path,
145}
146
147impl_proxy!(DimProxy);
148impl_comparable_proxy!(DimProxy, DimFact);
149
150pub struct ValueProxy {
157 sub: Cache<usize, ElementProxy>,
158 root: IntProxy,
159 path: Path,
160}
161
162impl ValueProxy {
163 pub fn new(path: Path) -> ValueProxy {
165 let root = IntProxy::new([&path[..], &[-1]].concat().into());
166 ValueProxy { sub: Cache::new(), root, path }
167 }
168}
169
170impl Index<()> for ValueProxy {
171 type Output = IntProxy;
172
173 fn index(&self, _: ()) -> &IntProxy {
175 &self.root
176 }
177}
178
179impl Index<usize> for ValueProxy {
180 type Output = ElementProxy;
181
182 fn index(&self, index: usize) -> &ElementProxy {
184 let path = [&self.path[..], &[index.to_isize().unwrap()]].concat();
185 self.sub.get(index, || ElementProxy::new(path.into()))
186 }
187}
188
189impl_proxy!(ValueProxy);
190impl_comparable_proxy!(ValueProxy, ValueFact);
191
192pub struct ElementProxy {
194 sub: Cache<usize, ElementProxy>,
195 path: Path,
196}
197
198impl ElementProxy {
199 pub fn new(path: Path) -> ElementProxy {
201 ElementProxy { sub: Cache::new(), path }
202 }
203}
204
205impl Index<usize> for ElementProxy {
206 type Output = ElementProxy;
207
208 fn index(&self, index: usize) -> &ElementProxy {
210 let path = [&self.path[..], &[index.to_isize().unwrap()]].concat();
211 self.sub.get(index, || ElementProxy::new(path.into()))
212 }
213}
214
215impl_proxy!(ElementProxy);
216impl_comparable_proxy!(ElementProxy, IntFactoid);
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn test_tensor_proxy_datum_type() {
224 let input = TensorProxy::new(vec![0, 0].into());
225 assert_eq!(input.datum_type.get_path(), &vec![0, 0, 0].into());
226 }
227
228 #[test]
229 fn test_tensor_proxy_rank() {
230 let input = TensorProxy::new(vec![0, 0].into());
231 assert_eq!(input.rank.get_path(), &vec![0, 0, 1].into());
232 }
233
234 #[test]
235 fn test_tensor_proxy_shape() {
236 let input = TensorProxy::new(vec![0, 0].into());
237 assert_eq!(input.shape[0].get_path(), &vec![0, 0, 2, 0].into());
238 assert_eq!(input.shape[2].get_path(), &vec![0, 0, 2, 2].into());
239 }
240
241 #[test]
242 fn test_tensor_proxy_value() {
243 let input = TensorProxy::new(vec![0, 0].into());
244 assert_eq!(input.value.get_path(), &vec![0, 0, 3].into());
245 assert_eq!(input.value[()].get_path(), &vec![0, 0, 3, -1].into());
246 assert_eq!(input.value[0].get_path(), &vec![0, 0, 3, 0].into());
247 assert_eq!(input.value[0][1].get_path(), &vec![0, 0, 3, 0, 1].into());
248 assert_eq!(input.value[1][2][3].get_path(), &vec![0, 0, 3, 1, 2, 3].into());
249 }
250}