tract_hir/infer/rules/
proxies.rs

1use std::fmt;
2use std::ops::Index;
3
4use crate::tract_num_traits::ToPrimitive;
5
6use crate::infer::factoid::*;
7
8use self::super::cache::Cache;
9use self::super::expr::Output;
10use self::super::path::Path;
11
12/// A proxy for any value.
13pub trait Proxy {
14    /// Returns the symbolic path to the value.
15    ///
16    /// Take the `inputs[0].shape[1]` proxy for instance: it represents the
17    /// second dimension of the shape of the first input. Because we encode
18    /// the "inputs" vectors as `0`, and the `shape` field as `2`, the path
19    /// for this proxy will be `vec![0, 0, 2, 1]`.
20    fn get_path(&self) -> &Path;
21}
22
23/// A proxy which can be used in a solver rule.
24pub trait ComparableProxy: Proxy {
25    type Output: Output;
26}
27
28/// Generates the get_path method for structs which have a `path` field.
29macro_rules! impl_proxy {
30    ($struct:ident) => {
31        impl Proxy for $struct {
32            /// Returns the symbolic path to the value.
33            fn get_path(&self) -> &Path {
34                &self.path
35            }
36        }
37
38        impl fmt::Debug for $struct {
39            fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
40                write!(formatter, "{:?}", self.get_path())
41            }
42        }
43
44        impl<'a> Proxy for &'a $struct {
45            /// Returns the symbolic path to the value.
46            fn get_path(&self) -> &Path {
47                &self.path
48            }
49        }
50    };
51}
52
53/// Implements the ComparableProxy trait for the proxy and references to it.
54macro_rules! impl_comparable_proxy {
55    ($struct:ident, $output:ident) => {
56        impl ComparableProxy for $struct {
57            type Output = $output;
58        }
59        impl<'a> ComparableProxy for &'a $struct {
60            type Output = $output;
61        }
62    };
63}
64
65/// A proxy for any integer-like value.
66#[derive(new)]
67pub struct IntProxy {
68    path: Path,
69}
70
71impl_proxy!(IntProxy);
72impl_comparable_proxy!(IntProxy, IntFactoid);
73
74/// A proxy for a tensor.
75///
76/// This is used for rules involving the datum_type, rank, shape or value of a
77/// tensor. Here are a few examples of constraints that can be expressed:
78/// ```text
79/// solver.equals(input.datum_type, DTYPE_I32)
80/// solver.equals(input.rank, 2)
81/// solver.equals(input.shape[1], output.value[0][1])
82/// ```
83pub struct TensorProxy {
84    pub datum_type: TypeProxy,
85    pub rank: IntProxy,
86    pub shape: ShapeProxy,
87    pub value: ValueProxy,
88    path: Path,
89}
90
91impl TensorProxy {
92    /// Creates a new TensorProxy instance.
93    pub fn new(path: Path) -> TensorProxy {
94        TensorProxy {
95            datum_type: TypeProxy::new([&path[..], &[0]].concat().into()),
96            rank: IntProxy::new([&path[..], &[1]].concat().into()),
97            shape: ShapeProxy::new([&path[..], &[2]].concat().into()),
98            value: ValueProxy::new([&path[..], &[3]].concat().into()),
99            path,
100        }
101    }
102}
103
104impl_proxy!(TensorProxy);
105
106/// A proxy for a tensor datum_type.
107#[derive(new)]
108pub struct TypeProxy {
109    path: Path,
110}
111
112impl_proxy!(TypeProxy);
113impl_comparable_proxy!(TypeProxy, TypeFactoid);
114
115/// A proxy for a tensor shape.
116pub struct ShapeProxy {
117    dims: Cache<usize, DimProxy>,
118    path: Path,
119}
120
121impl ShapeProxy {
122    /// Creates a new ShapeProxy instance.
123    pub fn new(path: Path) -> ShapeProxy {
124        ShapeProxy { dims: Cache::new(), path }
125    }
126}
127
128impl_proxy!(ShapeProxy);
129impl_comparable_proxy!(ShapeProxy, ShapeFactoid);
130
131impl Index<usize> for ShapeProxy {
132    type Output = DimProxy;
133
134    /// Returns the DimProxy corresponding to the given index.
135    fn index(&self, index: usize) -> &DimProxy {
136        let path = [&self.path[..], &[index.to_isize().unwrap()]].concat();
137        self.dims.get(index, || DimProxy::new(path.into()))
138    }
139}
140
141/// A proxy for a dimension of a shape.
142#[derive(new)]
143pub struct DimProxy {
144    path: Path,
145}
146
147impl_proxy!(DimProxy);
148impl_comparable_proxy!(DimProxy, DimFact);
149
150/// A proxy for the whole tensor value.
151///
152/// This proxy is a bit special as it allows arbitrarily nested indexing, so
153/// that writing something like ```input.value[1][6][2]``` will always work.
154/// To make this work, each ValueProxy holds a cache which will generate new
155/// ValueProxys for nested items on the fly and store them.
156pub struct ValueProxy {
157    sub: Cache<usize, ElementProxy>,
158    root: IntProxy,
159    path: Path,
160}
161
162impl ValueProxy {
163    /// Creates a new RootValueProxy instance.
164    pub fn new(path: Path) -> ValueProxy {
165        let root = IntProxy::new([&path[..], &[-1]].concat().into());
166        ValueProxy { sub: Cache::new(), root, path }
167    }
168}
169
170impl Index<()> for ValueProxy {
171    type Output = IntProxy;
172
173    /// Returns the RootValueProxy corresponding to the given index.
174    fn index(&self, _: ()) -> &IntProxy {
175        &self.root
176    }
177}
178
179impl Index<usize> for ValueProxy {
180    type Output = ElementProxy;
181
182    /// Returns the ElementProxy corresponding to the given index.
183    fn index(&self, index: usize) -> &ElementProxy {
184        let path = [&self.path[..], &[index.to_isize().unwrap()]].concat();
185        self.sub.get(index, || ElementProxy::new(path.into()))
186    }
187}
188
189impl_proxy!(ValueProxy);
190impl_comparable_proxy!(ValueProxy, ValueFact);
191
192/// A proxy for a tensor element.
193pub struct ElementProxy {
194    sub: Cache<usize, ElementProxy>,
195    path: Path,
196}
197
198impl ElementProxy {
199    /// Creates a new ElementProxy instance.
200    pub fn new(path: Path) -> ElementProxy {
201        ElementProxy { sub: Cache::new(), path }
202    }
203}
204
205impl Index<usize> for ElementProxy {
206    type Output = ElementProxy;
207
208    /// Returns the ElementProxy corresponding to the given index.
209    fn index(&self, index: usize) -> &ElementProxy {
210        let path = [&self.path[..], &[index.to_isize().unwrap()]].concat();
211        self.sub.get(index, || ElementProxy::new(path.into()))
212    }
213}
214
215impl_proxy!(ElementProxy);
216impl_comparable_proxy!(ElementProxy, IntFactoid);
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn test_tensor_proxy_datum_type() {
224        let input = TensorProxy::new(vec![0, 0].into());
225        assert_eq!(input.datum_type.get_path(), &vec![0, 0, 0].into());
226    }
227
228    #[test]
229    fn test_tensor_proxy_rank() {
230        let input = TensorProxy::new(vec![0, 0].into());
231        assert_eq!(input.rank.get_path(), &vec![0, 0, 1].into());
232    }
233
234    #[test]
235    fn test_tensor_proxy_shape() {
236        let input = TensorProxy::new(vec![0, 0].into());
237        assert_eq!(input.shape[0].get_path(), &vec![0, 0, 2, 0].into());
238        assert_eq!(input.shape[2].get_path(), &vec![0, 0, 2, 2].into());
239    }
240
241    #[test]
242    fn test_tensor_proxy_value() {
243        let input = TensorProxy::new(vec![0, 0].into());
244        assert_eq!(input.value.get_path(), &vec![0, 0, 3].into());
245        assert_eq!(input.value[()].get_path(), &vec![0, 0, 3, -1].into());
246        assert_eq!(input.value[0].get_path(), &vec![0, 0, 3, 0].into());
247        assert_eq!(input.value[0][1].get_path(), &vec![0, 0, 3, 0, 1].into());
248        assert_eq!(input.value[1][2][3].get_path(), &vec![0, 0, 3, 1, 2, 3].into());
249    }
250}