Skip to main content

tract_data/tensor/
storage.rs

1use std::alloc::Layout;
2use std::fmt;
3use std::hash::Hash;
4
5use crate::TractResult;
6use crate::blob::Blob;
7use crate::exotic::ExoticFact;
8use downcast_rs::{Downcast, impl_downcast};
9use dyn_eq::DynEq;
10
11/// Trait abstracting over tensor storage backends.
12///
13/// `PlainStorage` is the primary implementation backed by a contiguous `Blob`.
14/// Non-plain backends are held behind `StorageKind::Exotic(Box<dyn TensorStorage>)`.
15pub trait TensorStorage:
16    Send + Sync + fmt::Debug + fmt::Display + dyn_eq::DynEq + Downcast
17{
18    fn byte_len(&self) -> usize;
19    fn is_empty(&self) -> bool;
20    fn deep_clone(&self) -> Box<dyn TensorStorage>;
21    fn as_plain(&self) -> Option<&PlainStorage>;
22    fn as_plain_mut(&mut self) -> Option<&mut PlainStorage>;
23    fn into_plain(self: Box<Self>) -> Option<PlainStorage>;
24    fn dyn_hash(&self, state: &mut dyn std::hash::Hasher);
25    /// Build the `ExoticFact` that describes this storage for use in `TypedFact`.
26    ///
27    /// Plain storage returns `None`. Exotic storages should return the
28    /// appropriate fact so that `From<Arc<Tensor>> for TypedFact` preserves
29    /// exotic-ness.
30    fn exotic_fact(&self, shape: &[usize]) -> TractResult<Option<Box<dyn ExoticFact>>>;
31}
32impl_downcast!(TensorStorage);
33dyn_eq::eq_trait_object!(TensorStorage);
34
35/// Plain, contiguous storage backed by a `Blob`.
36#[derive(Eq)]
37pub struct PlainStorage(pub(crate) Blob);
38
39impl PlainStorage {
40    #[inline]
41    pub fn layout(&self) -> &Layout {
42        self.0.layout()
43    }
44
45    #[inline]
46    pub fn as_bytes(&self) -> &[u8] {
47        self.0.as_bytes()
48    }
49
50    #[inline]
51    pub fn as_bytes_mut(&mut self) -> &mut [u8] {
52        self.0.as_bytes_mut()
53    }
54
55    #[inline]
56    pub fn as_ptr(&self) -> *const u8 {
57        self.0.as_bytes().as_ptr()
58    }
59
60    #[inline]
61    pub fn as_mut_ptr(&mut self) -> *mut u8 {
62        self.0.as_bytes_mut().as_mut_ptr()
63    }
64
65    #[inline]
66    pub fn into_blob(self) -> Blob {
67        self.0
68    }
69}
70
71impl Default for PlainStorage {
72    #[inline]
73    fn default() -> Self {
74        PlainStorage(Blob::default())
75    }
76}
77
78impl Clone for PlainStorage {
79    #[inline]
80    fn clone(&self) -> Self {
81        PlainStorage(self.0.clone())
82    }
83}
84
85impl Hash for PlainStorage {
86    #[inline]
87    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
88        self.0.hash(state);
89    }
90}
91
92impl PartialEq for PlainStorage {
93    #[inline]
94    fn eq(&self, other: &Self) -> bool {
95        self.0 == other.0
96    }
97}
98
99impl From<Blob> for PlainStorage {
100    #[inline]
101    fn from(blob: Blob) -> Self {
102        PlainStorage(blob)
103    }
104}
105
106impl std::ops::Deref for PlainStorage {
107    type Target = [u8];
108    #[inline]
109    fn deref(&self) -> &[u8] {
110        self.0.as_bytes()
111    }
112}
113
114impl std::ops::DerefMut for PlainStorage {
115    #[inline]
116    fn deref_mut(&mut self) -> &mut [u8] {
117        self.0.as_bytes_mut()
118    }
119}
120
121impl fmt::Debug for PlainStorage {
122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123        fmt::Debug::fmt(&self.0, f)
124    }
125}
126
127impl fmt::Display for PlainStorage {
128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129        fmt::Display::fmt(&self.0, f)
130    }
131}
132
133impl TensorStorage for PlainStorage {
134    #[inline]
135    fn is_empty(&self) -> bool {
136        self.0.is_empty()
137    }
138
139    #[inline]
140    fn byte_len(&self) -> usize {
141        self.0.len()
142    }
143
144    fn deep_clone(&self) -> Box<dyn TensorStorage> {
145        Box::new(PlainStorage(self.0.clone()))
146    }
147
148    fn as_plain(&self) -> Option<&PlainStorage> {
149        Some(self)
150    }
151
152    fn as_plain_mut(&mut self) -> Option<&mut PlainStorage> {
153        Some(self)
154    }
155
156    fn into_plain(self: Box<Self>) -> Option<PlainStorage> {
157        Some(*self)
158    }
159
160    fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
161        state.write_u8(0);
162        state.write(self.0.as_bytes());
163    }
164
165    fn exotic_fact(&self, _shape: &[usize]) -> TractResult<Option<Box<dyn ExoticFact>>> {
166        Ok(None)
167    }
168}
169
170/// Inline enum replacing `Box<dyn TensorStorage>`.
171///
172/// The common `Plain` case stays inline (no heap alloc, no vtable indirection).
173/// `Exotic` covers non-plain backends behind a single Box indirection.
174#[derive(Debug, PartialEq, Eq)]
175#[allow(dead_code)]
176pub(crate) enum StorageKind {
177    Plain(PlainStorage),
178    Exotic(Box<dyn TensorStorage>),
179}
180
181impl StorageKind {
182    #[inline]
183    pub fn as_plain(&self) -> Option<&PlainStorage> {
184        match self {
185            StorageKind::Plain(d) => Some(d),
186            StorageKind::Exotic(o) => o.as_plain(),
187        }
188    }
189
190    #[inline]
191    pub fn as_plain_mut(&mut self) -> Option<&mut PlainStorage> {
192        match self {
193            StorageKind::Plain(d) => Some(d),
194            StorageKind::Exotic(o) => o.as_plain_mut(),
195        }
196    }
197
198    #[inline]
199    pub fn into_plain(self) -> Option<PlainStorage> {
200        match self {
201            StorageKind::Plain(d) => Some(d),
202            StorageKind::Exotic(o) => o.into_plain(),
203        }
204    }
205
206    #[inline]
207    pub fn byte_len(&self) -> usize {
208        match self {
209            StorageKind::Plain(d) => d.0.len(),
210            StorageKind::Exotic(o) => o.byte_len(),
211        }
212    }
213
214    #[inline]
215    pub fn is_empty(&self) -> bool {
216        match self {
217            StorageKind::Plain(d) => d.0.is_empty(),
218            StorageKind::Exotic(o) => o.is_empty(),
219        }
220    }
221
222    #[inline]
223    #[allow(dead_code)]
224    pub fn deep_clone(&self) -> StorageKind {
225        match self {
226            StorageKind::Plain(d) => StorageKind::Plain(d.clone()),
227            StorageKind::Exotic(o) => StorageKind::Exotic(o.deep_clone()),
228        }
229    }
230
231    #[inline]
232    pub fn as_storage(&self) -> &dyn TensorStorage {
233        match self {
234            StorageKind::Plain(d) => d,
235            StorageKind::Exotic(o) => o.as_ref(),
236        }
237    }
238
239    #[inline]
240    #[allow(dead_code)]
241    pub fn as_storage_mut(&mut self) -> &mut dyn TensorStorage {
242        match self {
243            StorageKind::Plain(d) => d,
244            StorageKind::Exotic(o) => o.as_mut(),
245        }
246    }
247
248    pub fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
249        match self {
250            StorageKind::Plain(d) => {
251                state.write_u8(0);
252                state.write(d.as_bytes())
253            }
254            StorageKind::Exotic(o) => o.dyn_hash(state),
255        }
256    }
257}