tensorflow/
buffer.rs

1use super::TensorType;
2use libc::size_t;
3use std::alloc;
4use std::borrow::Borrow;
5use std::borrow::BorrowMut;
6use std::marker::PhantomData;
7use std::mem;
8use std::ops::Deref;
9use std::ops::DerefMut;
10use std::ops::Index;
11use std::ops::IndexMut;
12use std::ops::Range;
13use std::ops::RangeFrom;
14use std::ops::RangeFull;
15use std::ops::RangeTo;
16use std::os::raw::c_void as std_c_void;
17use std::process;
18use std::slice;
19#[cfg(feature = "default")]
20use tensorflow_sys as tf;
21#[cfg(feature = "tensorflow_runtime_linking")]
22use tensorflow_sys_runtime as tf;
23
24/// Fixed-length heap-allocated vector.
25/// This is basically a `Box<[T]>`, except that that type can't actually be constructed.
26/// Furthermore, `[T; N]` can't be constructed if N is not a compile-time constant.
27#[derive(Debug)]
28pub(crate) struct Buffer<T: TensorType> {
29    inner: *mut tf::TF_Buffer,
30    owned: bool,
31    phantom: PhantomData<T>,
32}
33
34impl<T: TensorType> Buffer<T> {
35    /// Creates a new buffer initialized to zeros.
36    ///
37    /// `len` is the number of elements.
38    pub fn new(len: usize) -> Self {
39        let mut b = unsafe { Buffer::new_uninitialized(len) };
40        // TODO: Use libc::memset for primitives once we have impl specialization and
41        // memset is included
42        for i in 0..len {
43            b[i] = T::default();
44        }
45        b
46    }
47
48    /// Creates a new uninitialized buffer.
49    ///
50    /// `len` is the number of elements.
51    /// The caller is responsible for initializing the data.
52    pub unsafe fn new_uninitialized(len: usize) -> Self {
53        let inner = tf::TF_NewBuffer();
54        let align = mem::align_of::<T>();
55        let size = mem::size_of::<T>();
56        let ptr = alloc::alloc(alloc::Layout::from_size_align(size * len, align).unwrap());
57        assert!(!ptr.is_null(), "allocation failure");
58
59        (*inner).data_deallocator = Some(deallocator::<T>);
60        (*inner).data = ptr as *mut std_c_void;
61        (*inner).length = len;
62        Buffer {
63            inner,
64            owned: true,
65            phantom: PhantomData,
66        }
67    }
68
69    /// Creates a new buffer with no memory allocated.
70    pub unsafe fn new_unallocated() -> Self {
71        Buffer {
72            inner: tf::TF_NewBuffer(),
73            owned: true,
74            phantom: PhantomData,
75        }
76    }
77
78    /// Creates a buffer from data owned by the C API.
79    ///
80    /// `len` is the number of elements.
81    /// The underlying data is *not* freed when the buffer is destroyed.
82    pub unsafe fn from_ptr(ptr: *mut T, len: usize) -> Self {
83        let inner = tf::TF_NewBuffer();
84        (*inner).data = ptr as *const std_c_void;
85        (*inner).length = len;
86        Buffer {
87            inner,
88            owned: true,
89            phantom: PhantomData,
90        }
91    }
92
93    #[inline]
94    fn data(&self) -> *const T {
95        unsafe { (*self.inner).data as *const T }
96    }
97
98    #[inline]
99    fn data_mut(&mut self) -> *mut T {
100        unsafe { (*self.inner).data as *mut T }
101    }
102
103    #[inline]
104    fn length(&self) -> usize {
105        unsafe { (*self.inner).length }
106    }
107
108    /// Creates a buffer from data owned by the C API.
109    ///
110    /// `len` is the number of elements.
111    /// The underlying data is freed when the buffer is destroyed if `owned`
112    /// is true and the `buf` has a data deallocator.
113    pub unsafe fn from_c(buf: *mut tf::TF_Buffer, owned: bool) -> Self {
114        Buffer {
115            inner: buf,
116            owned,
117            phantom: PhantomData,
118        }
119    }
120
121    pub fn inner(&self) -> *const tf::TF_Buffer {
122        self.inner
123    }
124
125    pub fn inner_mut(&mut self) -> *mut tf::TF_Buffer {
126        self.inner
127    }
128}
129
130unsafe extern "C" fn deallocator<T>(data: *mut std_c_void, length: size_t) {
131    let align = mem::align_of::<T>();
132    let size = mem::size_of::<T>();
133    let layout = alloc::Layout::from_size_align(size * length, align).unwrap_or_else(|_| {
134        eprintln!("internal error: failed to construct layout");
135        // make sure not to unwind
136        process::abort();
137    });
138    alloc::dealloc(data as *mut _, layout);
139}
140
141impl<T: TensorType> Drop for Buffer<T> {
142    fn drop(&mut self) {
143        if self.owned {
144            unsafe {
145                tf::TF_DeleteBuffer(self.inner);
146            }
147        }
148    }
149}
150
151impl<T: TensorType> AsRef<[T]> for Buffer<T> {
152    #[inline]
153    fn as_ref(&self) -> &[T] {
154        unsafe { slice::from_raw_parts(self.data(), (*self.inner).length) }
155    }
156}
157
158impl<T: TensorType> AsMut<[T]> for Buffer<T> {
159    #[inline]
160    fn as_mut(&mut self) -> &mut [T] {
161        unsafe { slice::from_raw_parts_mut(self.data_mut(), (*self.inner).length) }
162    }
163}
164
165impl<T: TensorType> Deref for Buffer<T> {
166    type Target = [T];
167
168    #[inline]
169    fn deref(&self) -> &[T] {
170        self.as_ref()
171    }
172}
173
174impl<T: TensorType> DerefMut for Buffer<T> {
175    #[inline]
176    fn deref_mut(&mut self) -> &mut [T] {
177        self.as_mut()
178    }
179}
180
181impl<T: TensorType> Borrow<[T]> for Buffer<T> {
182    #[inline]
183    fn borrow(&self) -> &[T] {
184        self.as_ref()
185    }
186}
187
188impl<T: TensorType> BorrowMut<[T]> for Buffer<T> {
189    #[inline]
190    fn borrow_mut(&mut self) -> &mut [T] {
191        self.as_mut()
192    }
193}
194
195impl<T: TensorType> Clone for Buffer<T>
196where
197    T: Clone,
198{
199    #[inline]
200    fn clone(&self) -> Buffer<T> {
201        let mut b = unsafe { Buffer::new_uninitialized((*self.inner).length) };
202        // TODO: Use std::ptr::copy for primitives once we have impl specialization
203        for i in 0..self.length() {
204            b[i] = self[i].clone();
205        }
206        b
207    }
208
209    #[inline]
210    fn clone_from(&mut self, other: &Buffer<T>) {
211        assert!(
212            self.length() == other.length(),
213            "self.length() = {}, other.length() = {}",
214            self.length(),
215            other.length()
216        );
217        // TODO: Use std::ptr::copy for primitives once we have impl specialization
218        for i in 0..self.length() {
219            self[i] = other[i].clone();
220        }
221    }
222}
223
224impl<T: TensorType> Index<usize> for Buffer<T> {
225    type Output = T;
226
227    #[inline]
228    fn index(&self, index: usize) -> &T {
229        assert!(
230            index < self.length(),
231            "index = {}, length = {}",
232            index,
233            self.length()
234        );
235        unsafe { &*self.data().add(index) }
236    }
237}
238
239impl<T: TensorType> IndexMut<usize> for Buffer<T> {
240    #[inline]
241    fn index_mut(&mut self, index: usize) -> &mut T {
242        assert!(
243            index < self.length(),
244            "index = {}, length = {}",
245            index,
246            self.length()
247        );
248        unsafe { &mut *self.data_mut().add(index) }
249    }
250}
251
252impl<T: TensorType> Index<Range<usize>> for Buffer<T> {
253    type Output = [T];
254
255    #[inline]
256    fn index(&self, index: Range<usize>) -> &[T] {
257        assert!(
258            index.start <= index.end,
259            "index.start = {}, index.end = {}",
260            index.start,
261            index.end
262        );
263        assert!(
264            index.end <= self.length(),
265            "index.end = {}, length = {}",
266            index.end,
267            self.length()
268        );
269        unsafe { slice::from_raw_parts(&*self.data().add(index.start), index.len()) }
270    }
271}
272
273impl<T: TensorType> IndexMut<Range<usize>> for Buffer<T> {
274    #[inline]
275    fn index_mut(&mut self, index: Range<usize>) -> &mut [T] {
276        assert!(
277            index.start <= index.end,
278            "index.start = {}, index.end = {}",
279            index.start,
280            index.end
281        );
282        assert!(
283            index.end <= self.length(),
284            "index.end = {}, length = {}",
285            index.end,
286            self.length()
287        );
288        unsafe { slice::from_raw_parts_mut(&mut *self.data_mut().add(index.start), index.len()) }
289    }
290}
291
292impl<T: TensorType> Index<RangeTo<usize>> for Buffer<T> {
293    type Output = [T];
294
295    #[inline]
296    fn index(&self, index: RangeTo<usize>) -> &[T] {
297        assert!(
298            index.end <= self.length(),
299            "index.end = {}, length = {}",
300            index.end,
301            self.length()
302        );
303        unsafe { slice::from_raw_parts(&*self.data(), index.end) }
304    }
305}
306
307impl<T: TensorType> IndexMut<RangeTo<usize>> for Buffer<T> {
308    #[inline]
309    fn index_mut(&mut self, index: RangeTo<usize>) -> &mut [T] {
310        assert!(
311            index.end <= self.length(),
312            "index.end = {}, length = {}",
313            index.end,
314            self.length()
315        );
316        unsafe { slice::from_raw_parts_mut(&mut *self.data_mut(), index.end) }
317    }
318}
319
320impl<T: TensorType> Index<RangeFrom<usize>> for Buffer<T> {
321    type Output = [T];
322
323    #[inline]
324    fn index(&self, index: RangeFrom<usize>) -> &[T] {
325        assert!(
326            index.start <= self.length(),
327            "index.start = {}, length = {}",
328            index.start,
329            self.length()
330        );
331        unsafe {
332            slice::from_raw_parts(&*self.data().add(index.start), self.length() - index.start)
333        }
334    }
335}
336
337impl<T: TensorType> IndexMut<RangeFrom<usize>> for Buffer<T> {
338    #[inline]
339    fn index_mut(&mut self, index: RangeFrom<usize>) -> &mut [T] {
340        assert!(
341            index.start <= self.length(),
342            "index.start = {}, length = {}",
343            index.start,
344            self.length()
345        );
346        unsafe {
347            slice::from_raw_parts_mut(
348                &mut *self.data_mut().add(index.start),
349                self.length() - index.start,
350            )
351        }
352    }
353}
354
355impl<T: TensorType> Index<RangeFull> for Buffer<T> {
356    type Output = [T];
357
358    #[inline]
359    fn index(&self, _: RangeFull) -> &[T] {
360        unsafe { slice::from_raw_parts(&*self.data(), self.length()) }
361    }
362}
363
364impl<T: TensorType> IndexMut<RangeFull> for Buffer<T> {
365    #[inline]
366    fn index_mut(&mut self, _: RangeFull) -> &mut [T] {
367        unsafe { slice::from_raw_parts_mut(&mut *self.data_mut(), self.length()) }
368    }
369}
370
371impl<'a, T: TensorType> From<&'a [T]> for Buffer<T> {
372    fn from(data: &'a [T]) -> Buffer<T> {
373        let mut buffer = Buffer::new(data.len());
374        buffer.clone_from_slice(data);
375        buffer
376    }
377}
378
379impl<'a, T: TensorType> From<&'a Vec<T>> for Buffer<T> {
380    #[allow(trivial_casts)]
381    fn from(data: &'a Vec<T>) -> Buffer<T> {
382        Buffer::from(data as &[T])
383    }
384}
385
386impl<T: TensorType> From<Buffer<T>> for Vec<T> {
387    fn from(buffer: Buffer<T>) -> Vec<T> {
388        let mut vec = Vec::with_capacity(buffer.len());
389        vec.extend_from_slice(&buffer);
390        vec
391    }
392}
393
394////////////////////////
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399
400    #[test]
401    fn basic() {
402        let mut buf = Buffer::new(10);
403        assert_eq!(buf.len(), 10);
404        buf[0] = 1;
405        assert_eq!(buf[0], 1);
406    }
407}