tract_data/tensor/
storage.rs1use std::alloc::Layout;
2use std::fmt;
3use std::hash::Hash;
4
5use crate::TractResult;
6use crate::blob::Blob;
7use crate::exotic::ExoticFact;
8use downcast_rs::{Downcast, impl_downcast};
9use dyn_eq::DynEq;
10
11pub trait TensorStorage:
16 Send + Sync + fmt::Debug + fmt::Display + dyn_eq::DynEq + Downcast
17{
18 fn byte_len(&self) -> usize;
19 fn is_empty(&self) -> bool;
20 fn deep_clone(&self) -> Box<dyn TensorStorage>;
21 fn as_plain(&self) -> Option<&PlainStorage>;
22 fn as_plain_mut(&mut self) -> Option<&mut PlainStorage>;
23 fn into_plain(self: Box<Self>) -> Option<PlainStorage>;
24 fn dyn_hash(&self, state: &mut dyn std::hash::Hasher);
25 fn exotic_fact(&self, shape: &[usize]) -> TractResult<Option<Box<dyn ExoticFact>>>;
31}
32impl_downcast!(TensorStorage);
33dyn_eq::eq_trait_object!(TensorStorage);
34
35#[derive(Eq)]
37pub struct PlainStorage(pub(crate) Blob);
38
39impl PlainStorage {
40 #[inline]
41 pub fn layout(&self) -> &Layout {
42 self.0.layout()
43 }
44
45 #[inline]
46 pub fn as_bytes(&self) -> &[u8] {
47 self.0.as_bytes()
48 }
49
50 #[inline]
51 pub fn as_bytes_mut(&mut self) -> &mut [u8] {
52 self.0.as_bytes_mut()
53 }
54
55 #[inline]
56 pub fn as_ptr(&self) -> *const u8 {
57 self.0.as_bytes().as_ptr()
58 }
59
60 #[inline]
61 pub fn as_mut_ptr(&mut self) -> *mut u8 {
62 self.0.as_bytes_mut().as_mut_ptr()
63 }
64
65 #[inline]
66 pub fn into_blob(self) -> Blob {
67 self.0
68 }
69}
70
71impl Default for PlainStorage {
72 #[inline]
73 fn default() -> Self {
74 PlainStorage(Blob::default())
75 }
76}
77
78impl Clone for PlainStorage {
79 #[inline]
80 fn clone(&self) -> Self {
81 PlainStorage(self.0.clone())
82 }
83}
84
85impl Hash for PlainStorage {
86 #[inline]
87 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
88 self.0.hash(state);
89 }
90}
91
92impl PartialEq for PlainStorage {
93 #[inline]
94 fn eq(&self, other: &Self) -> bool {
95 self.0 == other.0
96 }
97}
98
99impl From<Blob> for PlainStorage {
100 #[inline]
101 fn from(blob: Blob) -> Self {
102 PlainStorage(blob)
103 }
104}
105
106impl std::ops::Deref for PlainStorage {
107 type Target = [u8];
108 #[inline]
109 fn deref(&self) -> &[u8] {
110 self.0.as_bytes()
111 }
112}
113
114impl std::ops::DerefMut for PlainStorage {
115 #[inline]
116 fn deref_mut(&mut self) -> &mut [u8] {
117 self.0.as_bytes_mut()
118 }
119}
120
121impl fmt::Debug for PlainStorage {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 fmt::Debug::fmt(&self.0, f)
124 }
125}
126
127impl fmt::Display for PlainStorage {
128 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129 fmt::Display::fmt(&self.0, f)
130 }
131}
132
133impl TensorStorage for PlainStorage {
134 #[inline]
135 fn is_empty(&self) -> bool {
136 self.0.is_empty()
137 }
138
139 #[inline]
140 fn byte_len(&self) -> usize {
141 self.0.len()
142 }
143
144 fn deep_clone(&self) -> Box<dyn TensorStorage> {
145 Box::new(PlainStorage(self.0.clone()))
146 }
147
148 fn as_plain(&self) -> Option<&PlainStorage> {
149 Some(self)
150 }
151
152 fn as_plain_mut(&mut self) -> Option<&mut PlainStorage> {
153 Some(self)
154 }
155
156 fn into_plain(self: Box<Self>) -> Option<PlainStorage> {
157 Some(*self)
158 }
159
160 fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
161 state.write_u8(0);
162 state.write(self.0.as_bytes());
163 }
164
165 fn exotic_fact(&self, _shape: &[usize]) -> TractResult<Option<Box<dyn ExoticFact>>> {
166 Ok(None)
167 }
168}
169
170#[derive(Debug, PartialEq, Eq)]
175#[allow(dead_code)]
176pub(crate) enum StorageKind {
177 Plain(PlainStorage),
178 Exotic(Box<dyn TensorStorage>),
179}
180
181impl StorageKind {
182 #[inline]
183 pub fn as_plain(&self) -> Option<&PlainStorage> {
184 match self {
185 StorageKind::Plain(d) => Some(d),
186 StorageKind::Exotic(o) => o.as_plain(),
187 }
188 }
189
190 #[inline]
191 pub fn as_plain_mut(&mut self) -> Option<&mut PlainStorage> {
192 match self {
193 StorageKind::Plain(d) => Some(d),
194 StorageKind::Exotic(o) => o.as_plain_mut(),
195 }
196 }
197
198 #[inline]
199 pub fn into_plain(self) -> Option<PlainStorage> {
200 match self {
201 StorageKind::Plain(d) => Some(d),
202 StorageKind::Exotic(o) => o.into_plain(),
203 }
204 }
205
206 #[inline]
207 pub fn byte_len(&self) -> usize {
208 match self {
209 StorageKind::Plain(d) => d.0.len(),
210 StorageKind::Exotic(o) => o.byte_len(),
211 }
212 }
213
214 #[inline]
215 pub fn is_empty(&self) -> bool {
216 match self {
217 StorageKind::Plain(d) => d.0.is_empty(),
218 StorageKind::Exotic(o) => o.is_empty(),
219 }
220 }
221
222 #[inline]
223 #[allow(dead_code)]
224 pub fn deep_clone(&self) -> StorageKind {
225 match self {
226 StorageKind::Plain(d) => StorageKind::Plain(d.clone()),
227 StorageKind::Exotic(o) => StorageKind::Exotic(o.deep_clone()),
228 }
229 }
230
231 #[inline]
232 pub fn as_storage(&self) -> &dyn TensorStorage {
233 match self {
234 StorageKind::Plain(d) => d,
235 StorageKind::Exotic(o) => o.as_ref(),
236 }
237 }
238
239 #[inline]
240 #[allow(dead_code)]
241 pub fn as_storage_mut(&mut self) -> &mut dyn TensorStorage {
242 match self {
243 StorageKind::Plain(d) => d,
244 StorageKind::Exotic(o) => o.as_mut(),
245 }
246 }
247
248 pub fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
249 match self {
250 StorageKind::Plain(d) => {
251 state.write_u8(0);
252 state.write(d.as_bytes())
253 }
254 StorageKind::Exotic(o) => o.dyn_hash(state),
255 }
256 }
257}