1use super::TensorType;
2use libc::size_t;
3use std::alloc;
4use std::borrow::Borrow;
5use std::borrow::BorrowMut;
6use std::marker::PhantomData;
7use std::mem;
8use std::ops::Deref;
9use std::ops::DerefMut;
10use std::ops::Index;
11use std::ops::IndexMut;
12use std::ops::Range;
13use std::ops::RangeFrom;
14use std::ops::RangeFull;
15use std::ops::RangeTo;
16use std::os::raw::c_void as std_c_void;
17use std::process;
18use std::slice;
19#[cfg(feature = "default")]
20use tensorflow_sys as tf;
21#[cfg(feature = "tensorflow_runtime_linking")]
22use tensorflow_sys_runtime as tf;
23
24#[derive(Debug)]
28pub(crate) struct Buffer<T: TensorType> {
29 inner: *mut tf::TF_Buffer,
30 owned: bool,
31 phantom: PhantomData<T>,
32}
33
34impl<T: TensorType> Buffer<T> {
35 pub fn new(len: usize) -> Self {
39 let mut b = unsafe { Buffer::new_uninitialized(len) };
40 for i in 0..len {
43 b[i] = T::default();
44 }
45 b
46 }
47
48 pub unsafe fn new_uninitialized(len: usize) -> Self {
53 let inner = tf::TF_NewBuffer();
54 let align = mem::align_of::<T>();
55 let size = mem::size_of::<T>();
56 let ptr = alloc::alloc(alloc::Layout::from_size_align(size * len, align).unwrap());
57 assert!(!ptr.is_null(), "allocation failure");
58
59 (*inner).data_deallocator = Some(deallocator::<T>);
60 (*inner).data = ptr as *mut std_c_void;
61 (*inner).length = len;
62 Buffer {
63 inner,
64 owned: true,
65 phantom: PhantomData,
66 }
67 }
68
69 pub unsafe fn new_unallocated() -> Self {
71 Buffer {
72 inner: tf::TF_NewBuffer(),
73 owned: true,
74 phantom: PhantomData,
75 }
76 }
77
78 pub unsafe fn from_ptr(ptr: *mut T, len: usize) -> Self {
83 let inner = tf::TF_NewBuffer();
84 (*inner).data = ptr as *const std_c_void;
85 (*inner).length = len;
86 Buffer {
87 inner,
88 owned: true,
89 phantom: PhantomData,
90 }
91 }
92
93 #[inline]
94 fn data(&self) -> *const T {
95 unsafe { (*self.inner).data as *const T }
96 }
97
98 #[inline]
99 fn data_mut(&mut self) -> *mut T {
100 unsafe { (*self.inner).data as *mut T }
101 }
102
103 #[inline]
104 fn length(&self) -> usize {
105 unsafe { (*self.inner).length }
106 }
107
108 pub unsafe fn from_c(buf: *mut tf::TF_Buffer, owned: bool) -> Self {
114 Buffer {
115 inner: buf,
116 owned,
117 phantom: PhantomData,
118 }
119 }
120
121 pub fn inner(&self) -> *const tf::TF_Buffer {
122 self.inner
123 }
124
125 pub fn inner_mut(&mut self) -> *mut tf::TF_Buffer {
126 self.inner
127 }
128}
129
130unsafe extern "C" fn deallocator<T>(data: *mut std_c_void, length: size_t) {
131 let align = mem::align_of::<T>();
132 let size = mem::size_of::<T>();
133 let layout = alloc::Layout::from_size_align(size * length, align).unwrap_or_else(|_| {
134 eprintln!("internal error: failed to construct layout");
135 process::abort();
137 });
138 alloc::dealloc(data as *mut _, layout);
139}
140
141impl<T: TensorType> Drop for Buffer<T> {
142 fn drop(&mut self) {
143 if self.owned {
144 unsafe {
145 tf::TF_DeleteBuffer(self.inner);
146 }
147 }
148 }
149}
150
151impl<T: TensorType> AsRef<[T]> for Buffer<T> {
152 #[inline]
153 fn as_ref(&self) -> &[T] {
154 unsafe { slice::from_raw_parts(self.data(), (*self.inner).length) }
155 }
156}
157
158impl<T: TensorType> AsMut<[T]> for Buffer<T> {
159 #[inline]
160 fn as_mut(&mut self) -> &mut [T] {
161 unsafe { slice::from_raw_parts_mut(self.data_mut(), (*self.inner).length) }
162 }
163}
164
165impl<T: TensorType> Deref for Buffer<T> {
166 type Target = [T];
167
168 #[inline]
169 fn deref(&self) -> &[T] {
170 self.as_ref()
171 }
172}
173
174impl<T: TensorType> DerefMut for Buffer<T> {
175 #[inline]
176 fn deref_mut(&mut self) -> &mut [T] {
177 self.as_mut()
178 }
179}
180
181impl<T: TensorType> Borrow<[T]> for Buffer<T> {
182 #[inline]
183 fn borrow(&self) -> &[T] {
184 self.as_ref()
185 }
186}
187
188impl<T: TensorType> BorrowMut<[T]> for Buffer<T> {
189 #[inline]
190 fn borrow_mut(&mut self) -> &mut [T] {
191 self.as_mut()
192 }
193}
194
195impl<T: TensorType> Clone for Buffer<T>
196where
197 T: Clone,
198{
199 #[inline]
200 fn clone(&self) -> Buffer<T> {
201 let mut b = unsafe { Buffer::new_uninitialized((*self.inner).length) };
202 for i in 0..self.length() {
204 b[i] = self[i].clone();
205 }
206 b
207 }
208
209 #[inline]
210 fn clone_from(&mut self, other: &Buffer<T>) {
211 assert!(
212 self.length() == other.length(),
213 "self.length() = {}, other.length() = {}",
214 self.length(),
215 other.length()
216 );
217 for i in 0..self.length() {
219 self[i] = other[i].clone();
220 }
221 }
222}
223
224impl<T: TensorType> Index<usize> for Buffer<T> {
225 type Output = T;
226
227 #[inline]
228 fn index(&self, index: usize) -> &T {
229 assert!(
230 index < self.length(),
231 "index = {}, length = {}",
232 index,
233 self.length()
234 );
235 unsafe { &*self.data().add(index) }
236 }
237}
238
239impl<T: TensorType> IndexMut<usize> for Buffer<T> {
240 #[inline]
241 fn index_mut(&mut self, index: usize) -> &mut T {
242 assert!(
243 index < self.length(),
244 "index = {}, length = {}",
245 index,
246 self.length()
247 );
248 unsafe { &mut *self.data_mut().add(index) }
249 }
250}
251
252impl<T: TensorType> Index<Range<usize>> for Buffer<T> {
253 type Output = [T];
254
255 #[inline]
256 fn index(&self, index: Range<usize>) -> &[T] {
257 assert!(
258 index.start <= index.end,
259 "index.start = {}, index.end = {}",
260 index.start,
261 index.end
262 );
263 assert!(
264 index.end <= self.length(),
265 "index.end = {}, length = {}",
266 index.end,
267 self.length()
268 );
269 unsafe { slice::from_raw_parts(&*self.data().add(index.start), index.len()) }
270 }
271}
272
273impl<T: TensorType> IndexMut<Range<usize>> for Buffer<T> {
274 #[inline]
275 fn index_mut(&mut self, index: Range<usize>) -> &mut [T] {
276 assert!(
277 index.start <= index.end,
278 "index.start = {}, index.end = {}",
279 index.start,
280 index.end
281 );
282 assert!(
283 index.end <= self.length(),
284 "index.end = {}, length = {}",
285 index.end,
286 self.length()
287 );
288 unsafe { slice::from_raw_parts_mut(&mut *self.data_mut().add(index.start), index.len()) }
289 }
290}
291
292impl<T: TensorType> Index<RangeTo<usize>> for Buffer<T> {
293 type Output = [T];
294
295 #[inline]
296 fn index(&self, index: RangeTo<usize>) -> &[T] {
297 assert!(
298 index.end <= self.length(),
299 "index.end = {}, length = {}",
300 index.end,
301 self.length()
302 );
303 unsafe { slice::from_raw_parts(&*self.data(), index.end) }
304 }
305}
306
307impl<T: TensorType> IndexMut<RangeTo<usize>> for Buffer<T> {
308 #[inline]
309 fn index_mut(&mut self, index: RangeTo<usize>) -> &mut [T] {
310 assert!(
311 index.end <= self.length(),
312 "index.end = {}, length = {}",
313 index.end,
314 self.length()
315 );
316 unsafe { slice::from_raw_parts_mut(&mut *self.data_mut(), index.end) }
317 }
318}
319
320impl<T: TensorType> Index<RangeFrom<usize>> for Buffer<T> {
321 type Output = [T];
322
323 #[inline]
324 fn index(&self, index: RangeFrom<usize>) -> &[T] {
325 assert!(
326 index.start <= self.length(),
327 "index.start = {}, length = {}",
328 index.start,
329 self.length()
330 );
331 unsafe {
332 slice::from_raw_parts(&*self.data().add(index.start), self.length() - index.start)
333 }
334 }
335}
336
337impl<T: TensorType> IndexMut<RangeFrom<usize>> for Buffer<T> {
338 #[inline]
339 fn index_mut(&mut self, index: RangeFrom<usize>) -> &mut [T] {
340 assert!(
341 index.start <= self.length(),
342 "index.start = {}, length = {}",
343 index.start,
344 self.length()
345 );
346 unsafe {
347 slice::from_raw_parts_mut(
348 &mut *self.data_mut().add(index.start),
349 self.length() - index.start,
350 )
351 }
352 }
353}
354
355impl<T: TensorType> Index<RangeFull> for Buffer<T> {
356 type Output = [T];
357
358 #[inline]
359 fn index(&self, _: RangeFull) -> &[T] {
360 unsafe { slice::from_raw_parts(&*self.data(), self.length()) }
361 }
362}
363
364impl<T: TensorType> IndexMut<RangeFull> for Buffer<T> {
365 #[inline]
366 fn index_mut(&mut self, _: RangeFull) -> &mut [T] {
367 unsafe { slice::from_raw_parts_mut(&mut *self.data_mut(), self.length()) }
368 }
369}
370
371impl<'a, T: TensorType> From<&'a [T]> for Buffer<T> {
372 fn from(data: &'a [T]) -> Buffer<T> {
373 let mut buffer = Buffer::new(data.len());
374 buffer.clone_from_slice(data);
375 buffer
376 }
377}
378
379impl<'a, T: TensorType> From<&'a Vec<T>> for Buffer<T> {
380 #[allow(trivial_casts)]
381 fn from(data: &'a Vec<T>) -> Buffer<T> {
382 Buffer::from(data as &[T])
383 }
384}
385
386impl<T: TensorType> From<Buffer<T>> for Vec<T> {
387 fn from(buffer: Buffer<T>) -> Vec<T> {
388 let mut vec = Vec::with_capacity(buffer.len());
389 vec.extend_from_slice(&buffer);
390 vec
391 }
392}
393
394#[cfg(test)]
397mod tests {
398 use super::*;
399
400 #[test]
401 fn basic() {
402 let mut buf = Buffer::new(10);
403 assert_eq!(buf.len(), 10);
404 buf[0] = 1;
405 assert_eq!(buf[0], 1);
406 }
407}